package org.springframework.http.codec.multipart;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Subscription;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import org.springframework.core.codec.DecodingException;
import org.springframework.core.io.buffer.DataBuffer;
import org.springframework.core.io.buffer.DataBufferLimitException;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.Nullable;
final class MultipartParser extends BaseSubscriber<DataBuffer> {
private static final byte CR = '\r';
private static final byte LF = '\n';
private static final byte[] CR_LF = {CR, LF};
private static final byte HYPHEN = '-';
private static final byte[] TWO_HYPHENS = {HYPHEN, HYPHEN};
private static final String = "\\r\\n";
private static final Log logger = LogFactory.getLog(MultipartParser.class);
private final AtomicReference<State> state;
private final FluxSink<Token> sink;
private final byte[] boundary;
private final int ;
private final AtomicBoolean requestOutstanding = new AtomicBoolean();
private MultipartParser(FluxSink<Token> sink, byte[] boundary, int maxHeadersSize) {
this.sink = sink;
this.boundary = boundary;
this.maxHeadersSize = maxHeadersSize;
this.state = new AtomicReference<>(new PreambleState());
}
public static Flux<Token> parse(Flux<DataBuffer> buffers, byte[] boundary, int maxHeadersSize) {
return Flux.create(sink -> {
MultipartParser parser = new MultipartParser(sink, boundary, maxHeadersSize);
sink.onCancel(parser::onSinkCancel);
sink.onRequest(n -> parser.requestBuffer());
buffers.subscribe(parser);
});
}
@Override
protected void hookOnSubscribe(Subscription subscription) {
requestBuffer();
}
@Override
protected void hookOnNext(DataBuffer value) {
this.requestOutstanding.set(false);
this.state.get().onNext(value);
}
@Override
protected void hookOnComplete() {
this.state.get().onComplete();
}
@Override
protected void hookOnError(Throwable throwable) {
State oldState = this.state.getAndSet(DisposedState.INSTANCE);
oldState.dispose();
this.sink.error(throwable);
}
private void onSinkCancel() {
State oldState = this.state.getAndSet(DisposedState.INSTANCE);
oldState.dispose();
cancel();
}
boolean changeState(State oldState, State newState, @Nullable DataBuffer remainder) {
if (this.state.compareAndSet(oldState, newState)) {
if (logger.isTraceEnabled()) {
logger.trace("Changed state: " + oldState + " -> " + newState);
}
oldState.dispose();
if (remainder != null) {
if (remainder.readableByteCount() > 0) {
newState.onNext(remainder);
}
else {
DataBufferUtils.release(remainder);
requestBuffer();
}
}
return true;
}
else {
DataBufferUtils.release(remainder);
return false;
}
}
void (HttpHeaders headers) {
if (logger.isTraceEnabled()) {
logger.trace("Emitting headers: " + headers);
}
this.sink.next(new HeadersToken(headers));
}
void emitBody(DataBuffer buffer) {
if (logger.isTraceEnabled()) {
logger.trace("Emitting body: " + buffer);
}
this.sink.next(new BodyToken(buffer));
}
void emitError(Throwable t) {
cancel();
this.sink.error(t);
}
void emitComplete() {
cancel();
this.sink.complete();
}
private void requestBuffer() {
if (upstream() != null &&
!this.sink.isCancelled() &&
this.sink.requestedFromDownstream() > 0 &&
this.requestOutstanding.compareAndSet(false, true)) {
request(1);
}
}
public abstract static class Token {
public abstract HttpHeaders ();
public abstract DataBuffer buffer();
}
public final static class extends Token {
private final HttpHeaders ;
public (HttpHeaders headers) {
this.headers = headers;
}
@Override
public HttpHeaders () {
return this.headers;
}
@Override
public DataBuffer () {
throw new IllegalStateException();
}
}
public final static class BodyToken extends Token {
private final DataBuffer buffer;
public BodyToken(DataBuffer buffer) {
this.buffer = buffer;
}
@Override
public HttpHeaders headers() {
throw new IllegalStateException();
}
@Override
public DataBuffer buffer() {
return this.buffer;
}
}
private interface State {
void onNext(DataBuffer buf);
void onComplete();
default void dispose() {
}
}
private final class PreambleState implements State {
private final DataBufferUtils.Matcher firstBoundary;
public PreambleState() {
this.firstBoundary = DataBufferUtils.matcher(
MultipartUtils.concat(TWO_HYPHENS, MultipartParser.this.boundary));
}
@Override
public void onNext(DataBuffer buf) {
int endIdx = this.firstBoundary.match(buf);
if (endIdx != -1) {
if (logger.isTraceEnabled()) {
logger.trace("First boundary found @" + endIdx + " in " + buf);
}
DataBuffer headersBuf = MultipartUtils.sliceFrom(buf, endIdx);
DataBufferUtils.release(buf);
changeState(this, new HeadersState(), headersBuf);
}
else {
DataBufferUtils.release(buf);
requestBuffer();
}
}
@Override
public void onComplete() {
if (changeState(this, DisposedState.INSTANCE, null)) {
emitError(new DecodingException("Could not find first boundary"));
}
}
@Override
public String toString() {
return "PREAMBLE";
}
}
private final class implements State {
private final DataBufferUtils.Matcher = DataBufferUtils.matcher(MultipartUtils.concat(CR_LF, CR_LF));
private final AtomicInteger = new AtomicInteger();
private final List<DataBuffer> = new ArrayList<>();
@Override
public void (DataBuffer buf) {
long prevCount = this.byteCount.get();
long count = this.byteCount.addAndGet(buf.readableByteCount());
if (prevCount < 2 && count >= 2) {
if (isLastBoundary(buf)) {
if (logger.isTraceEnabled()) {
logger.trace("Last boundary found in " + buf);
}
if (changeState(this, DisposedState.INSTANCE, buf)) {
emitComplete();
}
return;
}
}
else if (count > MultipartParser.this.maxHeadersSize) {
if (changeState(this, DisposedState.INSTANCE, buf)) {
emitError(new DataBufferLimitException("Part headers exceeded the memory usage limit of " +
MultipartParser.this.maxHeadersSize + " bytes"));
}
return;
}
int endIdx = this.endHeaders.match(buf);
if (endIdx != -1) {
if (logger.isTraceEnabled()) {
logger.trace("End of headers found @" + endIdx + " in " + buf);
}
DataBuffer headerBuf = MultipartUtils.sliceTo(buf, endIdx);
this.buffers.add(headerBuf);
DataBuffer bodyBuf = MultipartUtils.sliceFrom(buf, endIdx);
DataBufferUtils.release(buf);
emitHeaders(parseHeaders());
changeState(this, new BodyState(), bodyBuf);
}
else {
this.buffers.add(buf);
requestBuffer();
}
}
private boolean (DataBuffer buf) {
return (this.buffers.isEmpty() &&
buf.readableByteCount() >= 2 &&
buf.getByte(0) == HYPHEN && buf.getByte(1) == HYPHEN)
||
(this.buffers.size() == 1 &&
this.buffers.get(0).readableByteCount() == 1 &&
this.buffers.get(0).getByte(0) == HYPHEN &&
buf.readableByteCount() >= 1 &&
buf.getByte(0) == HYPHEN);
}
private HttpHeaders () {
if (this.buffers.isEmpty()) {
return HttpHeaders.EMPTY;
}
DataBuffer joined = this.buffers.get(0).factory().join(this.buffers);
this.buffers.clear();
String string = joined.toString(StandardCharsets.ISO_8859_1);
DataBufferUtils.release(joined);
String[] lines = string.split(HEADER_ENTRY_SEPARATOR);
HttpHeaders result = new HttpHeaders();
for (String line : lines) {
int idx = line.indexOf(':');
if (idx != -1) {
String name = line.substring(0, idx);
String value = line.substring(idx + 1);
while (value.startsWith(" ")) {
value = value.substring(1);
}
result.add(name, value);
}
}
return result;
}
@Override
public void () {
if (changeState(this, DisposedState.INSTANCE, null)) {
emitError(new DecodingException("Could not find end of headers"));
}
}
@Override
public void () {
this.buffers.forEach(DataBufferUtils::release);
}
@Override
public String () {
return "HEADERS";
}
}
private final class BodyState implements State {
private final DataBufferUtils.Matcher boundary;
private final AtomicReference<DataBuffer> previous = new AtomicReference<>();
public BodyState() {
this.boundary = DataBufferUtils.matcher(
MultipartUtils.concat(CR_LF, TWO_HYPHENS, MultipartParser.this.boundary));
}
@Override
public void onNext(DataBuffer buffer) {
int endIdx = this.boundary.match(buffer);
if (endIdx != -1) {
if (logger.isTraceEnabled()) {
logger.trace("Boundary found @" + endIdx + " in " + buffer);
}
int len = endIdx - buffer.readPosition() - this.boundary.delimiter().length + 1;
if (len > 0) {
DataBuffer body = buffer.retainedSlice(buffer.readPosition(), len);
enqueue(body);
enqueue(null);
}
else if (len < 0) {
DataBuffer previous = this.previous.get();
int prevLen = previous.readableByteCount() + len;
if (prevLen > 0) {
DataBuffer body = previous.retainedSlice(previous.readPosition(), prevLen);
DataBufferUtils.release(previous);
this.previous.set(body);
enqueue(null);
}
else {
DataBufferUtils.release(previous);
this.previous.set(null);
}
}
else {
enqueue(null);
}
DataBuffer remainder = MultipartUtils.sliceFrom(buffer, endIdx);
DataBufferUtils.release(buffer);
changeState(this, new HeadersState(), remainder);
}
else {
enqueue(buffer);
requestBuffer();
}
}
private void enqueue(@Nullable DataBuffer buf) {
DataBuffer previous = this.previous.getAndSet(buf);
if (previous != null) {
emitBody(previous);
}
}
@Override
public void onComplete() {
if (changeState(this, DisposedState.INSTANCE, null)) {
emitError(new DecodingException("Could not find end of body"));
}
}
@Override
public void dispose() {
DataBuffer previous = this.previous.getAndSet(null);
if (previous != null) {
DataBufferUtils.release(previous);
}
}
@Override
public String toString() {
return "BODY";
}
}
private static final class DisposedState implements State {
public static final DisposedState INSTANCE = new DisposedState();
private DisposedState() {
}
@Override
public void onNext(DataBuffer buf) {
DataBufferUtils.release(buf);
}
@Override
public void onComplete() {
}
@Override
public String toString() {
return "DISPOSED";
}
}
}