package io.dropwizard.jetty;
import com.google.common.collect.ImmutableSortedSet;
import org.eclipse.jetty.http.HttpHeader;
import org.eclipse.jetty.server.Request;
import org.eclipse.jetty.server.handler.gzip.GzipHandler;
import javax.annotation.Nullable;
import javax.servlet.ReadListener;
import javax.servlet.ServletException;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.Enumeration;
import java.util.zip.GZIPInputStream;
import java.util.zip.Inflater;
import java.util.zip.InflaterInputStream;
public class BiDiGzipHandler extends GzipHandler {
private static final ThreadLocal<Inflater> localInflater = new ThreadLocal<>();
private int inputBufferSize = 8192;
private boolean inflateNoWrap = true;
public boolean isInflateNoWrap() {
return inflateNoWrap;
}
public void setInflateNoWrap(boolean inflateNoWrap) {
this.inflateNoWrap = inflateNoWrap;
}
public void setInputBufferSize(int inputBufferSize) {
this.inputBufferSize = inputBufferSize;
}
@Override
public void handle(String target, Request baseRequest, HttpServletRequest request, HttpServletResponse response)
throws IOException, ServletException {
final String encoding = request.getHeader(HttpHeader.CONTENT_ENCODING.asString());
if (GZIP.equalsIgnoreCase(encoding)) {
super.handle(target, baseRequest, wrapGzippedRequest(removeContentHeaders(request)), response);
} else if (DEFLATE.equalsIgnoreCase(encoding)) {
super.handle(target, baseRequest, wrapDeflatedRequest(removeContentHeaders(request)), response);
} else {
super.handle(target, baseRequest, request, response);
}
}
private Inflater buildInflater() {
final Inflater inflater = localInflater.get();
if (inflater != null) {
localInflater.set(null);
inflater.reset();
return inflater;
} else {
return new Inflater(inflateNoWrap);
}
}
private WrappedServletRequest wrapDeflatedRequest(HttpServletRequest request) throws IOException {
final Inflater inflater = buildInflater();
try {
final InflaterInputStream input = new InflaterInputStream(request.getInputStream(), inflater, inputBufferSize) {
@Override
public void close() throws IOException {
super.close();
localInflater.set(inflater);
}
};
return new WrappedServletRequest(request, new ZipExceptionHandlingInputStream(input, DEFLATE));
} catch (IOException e) {
throw ZipExceptionHandlingInputStream.handleException(DEFLATE, e);
}
}
private WrappedServletRequest wrapGzippedRequest(HttpServletRequest request) throws IOException {
try {
final GZIPInputStream input = new GZIPInputStream(request.getInputStream(), inputBufferSize);
return new WrappedServletRequest(request, new ZipExceptionHandlingInputStream(input, GZIP));
} catch (IOException e) {
throw ZipExceptionHandlingInputStream.handleException(GZIP, e);
}
}
private HttpServletRequest removeContentHeaders(final HttpServletRequest request) {
return new RemoveHttpHeadersWrapper(request, ImmutableSortedSet.orderedBy(String::compareToIgnoreCase)
.add(HttpHeader.CONTENT_ENCODING.asString())
.add(HttpHeader.CONTENT_LENGTH.asString())
.build());
}
private static class WrappedServletRequest extends HttpServletRequestWrapper {
private final ServletInputStream input;
private final BufferedReader reader;
private WrappedServletRequest(HttpServletRequest request,
InputStream inputStream) throws IOException {
super(request);
this.input = new WrappedServletInputStream(inputStream);
this.reader = new BufferedReader(new InputStreamReader(input, getCharset()));
}
private Charset getCharset() {
final String encoding = getCharacterEncoding();
if (encoding == null || !Charset.isSupported(encoding)) {
return StandardCharsets.ISO_8859_1;
}
return Charset.forName(encoding);
}
@Override
public ServletInputStream getInputStream() throws IOException {
return input;
}
@Override
public BufferedReader getReader() throws IOException {
return reader;
}
@Override
public int getContentLength() {
return -1;
}
@Override
public long getContentLengthLong() {
return -1L;
}
}
private static class WrappedServletInputStream extends ServletInputStream {
private final InputStream input;
private WrappedServletInputStream(InputStream input) {
this.input = input;
}
@Override
public void close() throws IOException {
input.close();
}
@Override
public int read(byte[] b, int off, int len) throws IOException {
return input.read(b, off, len);
}
@Override
public int available() throws IOException {
return input.available();
}
@Override
public void mark(int readlimit) {
input.mark(readlimit);
}
@Override
public boolean markSupported() {
return input.markSupported();
}
@Override
public int read() throws IOException {
return input.read();
}
@Override
public void reset() throws IOException {
input.reset();
}
@Override
public long skip(long n) throws IOException {
return input.skip(n);
}
@Override
public int read(byte[] b) throws IOException {
return input.read(b);
}
@Override
public boolean isFinished() {
try {
return input.available() == 0;
} catch (IOException ignored) {
}
return true;
}
@Override
public boolean isReady() {
try {
return input.available() > 0;
} catch (IOException ignored) {
}
return false;
}
@Override
public void setReadListener(ReadListener readListener) {
throw new UnsupportedOperationException();
}
}
private static class RemoveHttpHeadersWrapper extends HttpServletRequestWrapper {
private final ImmutableSortedSet<String> headerNames;
RemoveHttpHeadersWrapper(final HttpServletRequest request, final ImmutableSortedSet<String> headerNames) {
super(request);
this.headerNames = headerNames;
}
@Override
public int getIntHeader(final String name) {
if (headerNames.contains(name)) {
return -1;
} else {
return super.getIntHeader(name);
}
}
@Override
public Enumeration<String> getHeaders(final String name) {
if (headerNames.contains(name)) {
return Collections.emptyEnumeration();
} else {
return super.getHeaders(name);
}
}
@Override
@Nullable
public String getHeader(final String name) {
if (headerNames.contains(name)) {
return null;
} else {
return super.getHeader(name);
}
}
@Override
public long getDateHeader(final String name) {
if (headerNames.contains(name)) {
return -1L;
} else {
return super.getDateHeader(name);
}
}
}
}