package io.dropwizard.jersey.filter;
import io.dropwizard.util.Sets;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Collections;
import java.util.Optional;
import java.util.Set;
public class AllowedMethodsFilter implements Filter {
public static final String ALLOWED_METHODS_PARAM = "allowedMethods";
public static final Set<String> DEFAULT_ALLOWED_METHODS = Sets.of(
"GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH"
);
private static final Logger LOG = LoggerFactory.getLogger(AllowedMethodsFilter.class);
private Set<String> allowedMethods = Collections.emptySet();
@Override
public void init(FilterConfig config) {
allowedMethods = Optional.ofNullable(config.getInitParameter(ALLOWED_METHODS_PARAM))
.map(p -> Sets.of(p.split(",")))
.orElse(DEFAULT_ALLOWED_METHODS);
}
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
handle((HttpServletRequest) request, (HttpServletResponse) response, chain);
}
private void handle(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
throws IOException, ServletException {
if (allowedMethods.contains(request.getMethod())) {
chain.doFilter(request, response);
} else {
LOG.debug("Request with disallowed method {} blocked", request.getMethod());
response.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
}
}
@Override
public void destroy() {
}
}