package com.codahale.metrics.servlet;
import com.codahale.metrics.Counter;
import com.codahale.metrics.Meter;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import javax.servlet.AsyncEvent;
import javax.servlet.AsyncListener;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;
import java.io.IOException;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import static com.codahale.metrics.MetricRegistry.name;
public abstract class AbstractInstrumentedFilter implements Filter {
static final String METRIC_PREFIX = "name-prefix";
private final String otherMetricName;
private final Map<Integer, String> meterNamesByStatusCode;
private final String registryAttribute;
private ConcurrentMap<Integer, Meter> metersByStatusCode;
private Meter otherMeter;
private Meter timeoutsMeter;
private Meter errorsMeter;
private Counter activeRequests;
private Timer requestTimer;
protected AbstractInstrumentedFilter(String registryAttribute,
Map<Integer, String> meterNamesByStatusCode,
String otherMetricName) {
this.registryAttribute = registryAttribute;
this.otherMetricName = otherMetricName;
this.meterNamesByStatusCode = meterNamesByStatusCode;
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
final MetricRegistry metricsRegistry = getMetricsFactory(filterConfig);
String metricName = filterConfig.getInitParameter(METRIC_PREFIX);
if (metricName == null || metricName.isEmpty()) {
metricName = getClass().getName();
}
this.metersByStatusCode = new ConcurrentHashMap<>(meterNamesByStatusCode.size());
for (Entry<Integer, String> entry : meterNamesByStatusCode.entrySet()) {
metersByStatusCode.put(entry.getKey(),
metricsRegistry.meter(name(metricName, entry.getValue())));
}
this.otherMeter = metricsRegistry.meter(name(metricName, otherMetricName));
this.timeoutsMeter = metricsRegistry.meter(name(metricName, "timeouts"));
this.errorsMeter = metricsRegistry.meter(name(metricName, "errors"));
this.activeRequests = metricsRegistry.counter(name(metricName, "activeRequests"));
this.requestTimer = metricsRegistry.timer(name(metricName, "requests"));
}
private MetricRegistry getMetricsFactory(FilterConfig filterConfig) {
final MetricRegistry metricsRegistry;
final Object o = filterConfig.getServletContext().getAttribute(this.registryAttribute);
if (o instanceof MetricRegistry) {
metricsRegistry = (MetricRegistry) o;
} else {
metricsRegistry = new MetricRegistry();
}
return metricsRegistry;
}
@Override
public void destroy() {
}
@Override
public void doFilter(ServletRequest request,
ServletResponse response,
FilterChain chain) throws IOException, ServletException {
final StatusExposingServletResponse wrappedResponse =
new StatusExposingServletResponse((HttpServletResponse) response);
activeRequests.inc();
final Timer.Context context = requestTimer.time();
boolean error = false;
try {
chain.doFilter(request, wrappedResponse);
} catch (IOException | RuntimeException | ServletException e) {
error = true;
throw e;
} finally {
if (!error && request.isAsyncStarted()) {
request.getAsyncContext().addListener(new AsyncResultListener(context));
} else {
context.stop();
activeRequests.dec();
if (error) {
errorsMeter.mark();
} else {
markMeterForStatusCode(wrappedResponse.getStatus());
}
}
}
}
private void markMeterForStatusCode(int status) {
final Meter metric = metersByStatusCode.get(status);
if (metric != null) {
metric.mark();
} else {
otherMeter.mark();
}
}
private static class StatusExposingServletResponse extends HttpServletResponseWrapper {
private int httpStatus = 200;
public StatusExposingServletResponse(HttpServletResponse response) {
super(response);
}
@Override
public void sendError(int sc) throws IOException {
httpStatus = sc;
super.sendError(sc);
}
@Override
public void sendError(int sc, String msg) throws IOException {
httpStatus = sc;
super.sendError(sc, msg);
}
@Override
public void setStatus(int sc) {
httpStatus = sc;
super.setStatus(sc);
}
@Override
@SuppressWarnings("deprecation")
public void setStatus(int sc, String sm) {
httpStatus = sc;
super.setStatus(sc, sm);
}
@Override
public int getStatus() {
return httpStatus;
}
}
private class AsyncResultListener implements AsyncListener {
private Timer.Context context;
private boolean done = false;
public AsyncResultListener(Timer.Context context) {
this.context = context;
}
@Override
public void onComplete(AsyncEvent event) throws IOException {
if (!done) {
HttpServletResponse suppliedResponse = (HttpServletResponse) event.getSuppliedResponse();
context.stop();
activeRequests.dec();
markMeterForStatusCode(suppliedResponse.getStatus());
}
}
@Override
public void onTimeout(AsyncEvent event) throws IOException {
context.stop();
activeRequests.dec();
timeoutsMeter.mark();
done = true;
}
@Override
public void onError(AsyncEvent event) throws IOException {
context.stop();
activeRequests.dec();
errorsMeter.mark();
done = true;
}
@Override
public void onStartAsync(AsyncEvent event) throws IOException {
}
}
}