package org.springframework.web.util;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.URLEncoder;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import org.springframework.http.HttpMethod;
import org.springframework.lang.Nullable;
public class ContentCachingRequestWrapper extends HttpServletRequestWrapper {
private static final String FORM_CONTENT_TYPE = "application/x-www-form-urlencoded";
private final ByteArrayOutputStream cachedContent;
@Nullable
private final Integer contentCacheLimit;
@Nullable
private ServletInputStream inputStream;
@Nullable
private BufferedReader reader;
public ContentCachingRequestWrapper(HttpServletRequest request) {
super(request);
int contentLength = request.getContentLength();
this.cachedContent = new ByteArrayOutputStream(contentLength >= 0 ? contentLength : 1024);
this.contentCacheLimit = null;
}
public ContentCachingRequestWrapper(HttpServletRequest request, int contentCacheLimit) {
super(request);
this.cachedContent = new ByteArrayOutputStream(contentCacheLimit);
this.contentCacheLimit = contentCacheLimit;
}
@Override
public ServletInputStream getInputStream() throws IOException {
if (this.inputStream == null) {
this.inputStream = new ContentCachingInputStream(getRequest().getInputStream());
}
return this.inputStream;
}
@Override
public String getCharacterEncoding() {
String enc = super.getCharacterEncoding();
return (enc != null ? enc : WebUtils.DEFAULT_CHARACTER_ENCODING);
}
@Override
public BufferedReader getReader() throws IOException {
if (this.reader == null) {
this.reader = new BufferedReader(new InputStreamReader(getInputStream(), getCharacterEncoding()));
}
return this.reader;
}
@Override
public String getParameter(String name) {
if (this.cachedContent.size() == 0 && isFormPost()) {
writeRequestParametersToCachedContent();
}
return super.getParameter(name);
}
@Override
public Map<String, String[]> getParameterMap() {
if (this.cachedContent.size() == 0 && isFormPost()) {
writeRequestParametersToCachedContent();
}
return super.getParameterMap();
}
@Override
public Enumeration<String> getParameterNames() {
if (this.cachedContent.size() == 0 && isFormPost()) {
writeRequestParametersToCachedContent();
}
return super.getParameterNames();
}
@Override
public String[] getParameterValues(String name) {
if (this.cachedContent.size() == 0 && isFormPost()) {
writeRequestParametersToCachedContent();
}
return super.getParameterValues(name);
}
private boolean isFormPost() {
String contentType = getContentType();
return (contentType != null && contentType.contains(FORM_CONTENT_TYPE) &&
HttpMethod.POST.matches(getMethod()));
}
private void writeRequestParametersToCachedContent() {
try {
if (this.cachedContent.size() == 0) {
String requestEncoding = getCharacterEncoding();
Map<String, String[]> form = super.getParameterMap();
for (Iterator<String> nameIterator = form.keySet().iterator(); nameIterator.hasNext(); ) {
String name = nameIterator.next();
List<String> values = Arrays.asList(form.get(name));
for (Iterator<String> valueIterator = values.iterator(); valueIterator.hasNext(); ) {
String value = valueIterator.next();
this.cachedContent.write(URLEncoder.encode(name, requestEncoding).getBytes());
if (value != null) {
this.cachedContent.write('=');
this.cachedContent.write(URLEncoder.encode(value, requestEncoding).getBytes());
if (valueIterator.hasNext()) {
this.cachedContent.write('&');
}
}
}
if (nameIterator.hasNext()) {
this.cachedContent.write('&');
}
}
}
}
catch (IOException ex) {
throw new IllegalStateException("Failed to write request parameters to cached content", ex);
}
}
public byte[] getContentAsByteArray() {
return this.cachedContent.toByteArray();
}
protected void handleContentOverflow(int contentCacheLimit) {
}
private class ContentCachingInputStream extends ServletInputStream {
private final ServletInputStream is;
private boolean overflow = false;
public ContentCachingInputStream(ServletInputStream is) {
this.is = is;
}
@Override
public int read() throws IOException {
int ch = this.is.read();
if (ch != -1 && !this.overflow) {
if (contentCacheLimit != null && cachedContent.size() == contentCacheLimit) {
this.overflow = true;
handleContentOverflow(contentCacheLimit);
}
else {
cachedContent.write(ch);
}
}
return ch;
}
@Override
public int read(byte[] b) throws IOException {
int count = this.is.read(b);
writeToCache(b, 0, count);
return count;
}
private void writeToCache(final byte[] b, final int off, int count) {
if (!this.overflow && count > 0) {
if (contentCacheLimit != null &&
count + cachedContent.size() > contentCacheLimit) {
this.overflow = true;
cachedContent.write(b, off, contentCacheLimit - cachedContent.size());
handleContentOverflow(contentCacheLimit);
return;
}
cachedContent.write(b, off, count);
}
}
@Override
public int read(final byte[] b, final int off, final int len) throws IOException {
int count = this.is.read(b, off, len);
writeToCache(b, off, count);
return count;
}
@Override
public int readLine(final byte[] b, final int off, final int len) throws IOException {
int count = this.is.readLine(b, off, len);
writeToCache(b, off, count);
return count;
}
@Override
public boolean isFinished() {
return this.is.isFinished();
}
@Override
public boolean isReady() {
return this.is.isReady();
}
@Override
public void setReadListener(ReadListener readListener) {
this.is.setReadListener(readListener);
}
}
}