package io.undertow.server.protocol.ajp;
import io.undertow.UndertowMessages;
import io.undertow.UndertowOptions;
import io.undertow.conduits.AbstractFramedStreamSinkConduit;
import io.undertow.conduits.ConduitListener;
import io.undertow.server.Connectors;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.HeaderMap;
import io.undertow.util.Headers;
import io.undertow.util.HttpString;
import io.undertow.util.StatusCodes;
import org.jboss.logging.Logger;
import org.xnio.Buffers;
import org.xnio.IoUtils;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.StreamSinkConduit;
import org.xnio.conduits.WriteReadyHandler;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.nio.channels.FileChannel;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import static org.xnio.Bits.allAreClear;
import static org.xnio.Bits.anyAreSet;
final class AjpServerResponseConduit extends AbstractFramedStreamSinkConduit {
private static final Logger log = Logger.getLogger("io.undertow.server.channel.ajp.response");
private static final int DEFAULT_MAX_DATA_SIZE = 8192;
private static final Map<HttpString, Integer> ;
private static final ByteBuffer FLUSH_PACKET = ByteBuffer.allocateDirect(8);
static {
final Map<HttpString, Integer> headers = new HashMap<>();
headers.put(Headers.CONTENT_TYPE, 0xA001);
headers.put(Headers.CONTENT_LANGUAGE, 0xA002);
headers.put(Headers.CONTENT_LENGTH, 0xA003);
headers.put(Headers.DATE, 0xA004);
headers.put(Headers.LAST_MODIFIED, 0xA005);
headers.put(Headers.LOCATION, 0xA006);
headers.put(Headers.SET_COOKIE, 0xA007);
headers.put(Headers.SET_COOKIE2, 0xA008);
headers.put(Headers.SERVLET_ENGINE, 0xA009);
headers.put(Headers.STATUS, 0xA00A);
headers.put(Headers.WWW_AUTHENTICATE, 0xA00B);
HEADER_MAP = Collections.unmodifiableMap(headers);
FLUSH_PACKET.put((byte) 'A');
FLUSH_PACKET.put((byte) 'B');
FLUSH_PACKET.put((byte) 0);
FLUSH_PACKET.put((byte) 4);
FLUSH_PACKET.put((byte) 3);
FLUSH_PACKET.put((byte) 0);
FLUSH_PACKET.put((byte) 0);
FLUSH_PACKET.put((byte) 0);
FLUSH_PACKET.flip();
}
private static final int FLAG_START = 1;
private static final int FLAG_WRITE_RESUMED = 1 << 2;
private static final int FLAG_WRITE_READ_BODY_CHUNK_FROM_LISTENER = 1 << 3;
private static final int FLAG_WRITE_SHUTDOWN = 1 << 4;
private static final int FLAG_READS_DONE = 1 << 5;
private static final int FLAG_FLUSH_QUEUED = 1 << 6;
private static final ByteBuffer CLOSE_FRAME_PERSISTENT;
private static final ByteBuffer CLOSE_FRAME_NON_PERSISTENT;
static {
ByteBuffer buffer = ByteBuffer.wrap(new byte[6]);
buffer.put((byte) 'A');
buffer.put((byte) 'B');
buffer.put((byte) 0);
buffer.put((byte) 2);
buffer.put((byte) 5);
buffer.put((byte) 1);
buffer.flip();
CLOSE_FRAME_PERSISTENT = buffer;
buffer = ByteBuffer.wrap(new byte[6]);
buffer.put(CLOSE_FRAME_PERSISTENT.duplicate());
buffer.put(5, (byte) 0);
buffer.flip();
CLOSE_FRAME_NON_PERSISTENT = buffer;
}
private final ByteBufferPool pool;
private int state = FLAG_START;
private final HttpServerExchange exchange;
private final ConduitListener<? super AjpServerResponseConduit> finishListener;
private final boolean headRequest;
AjpServerResponseConduit(final StreamSinkConduit next, final ByteBufferPool pool, final HttpServerExchange exchange, ConduitListener<? super AjpServerResponseConduit> finishListener, boolean headRequest) {
super(next);
this.pool = pool;
this.exchange = exchange;
this.finishListener = finishListener;
this.headRequest = headRequest;
state = FLAG_START;
}
private static void putInt(final ByteBuffer buf, int value) {
buf.put((byte) ((value >> 8) & 0xFF));
buf.put((byte) (value & 0xFF));
}
private static void putString(final ByteBuffer buf, String value) {
final int length = value.length();
putInt(buf, length);
for (int i = 0; i < length; ++i) {
char c = value.charAt(i);
if(c != '\r' && c != '\n'){
buf.put((byte) c);
} else {
buf.put((byte)' ');
}
}
buf.put((byte) 0);
}
private void putHttpString(final ByteBuffer buf, HttpString value) {
final int length = value.length();
putInt(buf, length);
value.appendTo(buf);
buf.put((byte) 0);
}
private void () {
int oldState = this.state;
if (anyAreSet(oldState, FLAG_START)) {
PooledByteBuffer[] byteBuffers = null;
Connectors.flattenCookies(exchange);
PooledByteBuffer pooled = pool.allocate();
ByteBuffer buffer = pooled.getBuffer();
buffer.put((byte) 'A');
buffer.put((byte) 'B');
buffer.put((byte) 0);
buffer.put((byte) 0);
buffer.put((byte) 4);
putInt(buffer, exchange.getStatusCode());
String reason = exchange.getReasonPhrase();
if(reason == null) {
reason = StatusCodes.getReason(exchange.getStatusCode());
}
if(reason.length() + 4 > buffer.remaining()) {
pooled.close();
throw UndertowMessages.MESSAGES.reasonPhraseToLargeForBuffer(reason);
}
putString(buffer, reason);
int headers = 0;
final HeaderMap responseHeaders = exchange.getResponseHeaders();
for (HttpString name : responseHeaders.getHeaderNames()) {
headers += responseHeaders.get(name).size();
}
putInt(buffer, headers);
for (final HttpString header : responseHeaders.getHeaderNames()) {
for (String headerValue : responseHeaders.get(header)) {
if(buffer.remaining() < header.length() + headerValue.length() + 6) {
buffer.flip();
if(byteBuffers == null) {
byteBuffers = new PooledByteBuffer[2];
byteBuffers[0] = pooled;
} else {
PooledByteBuffer[] old = byteBuffers;
byteBuffers = new PooledByteBuffer[old.length + 1];
System.arraycopy(old, 0, byteBuffers, 0, old.length);
}
pooled = pool.allocate();
byteBuffers[byteBuffers.length - 1] = pooled;
buffer = pooled.getBuffer();
}
Integer headerCode = HEADER_MAP.get(header);
if (headerCode != null) {
putInt(buffer, headerCode);
} else {
putHttpString(buffer, header);
}
putString(buffer, headerValue);
}
}
if(byteBuffers == null) {
int dataLength = buffer.position() - 4;
buffer.put(2, (byte) ((dataLength >> 8) & 0xFF));
buffer.put(3, (byte) (dataLength & 0xFF));
buffer.flip();
queueFrame(new PooledBufferFrameCallback(pooled), buffer);
} else {
ByteBuffer[] bufs = new ByteBuffer[byteBuffers.length];
for(int i = 0; i < bufs.length; ++i) {
bufs[i] = byteBuffers[i].getBuffer();
}
int dataLength = (int) (Buffers.remaining(bufs) - 4);
bufs[0].put(2, (byte) ((dataLength >> 8) & 0xFF));
bufs[0].put(3, (byte) (dataLength & 0xFF));
buffer.flip();
queueFrame(new PooledBuffersFrameCallback(byteBuffers), bufs);
}
state &= ~FLAG_START;
}
}
@Override
protected void queueCloseFrames() {
processAJPHeader();
final ByteBuffer buffer = exchange.isPersistent() ? CLOSE_FRAME_PERSISTENT.duplicate() : CLOSE_FRAME_NON_PERSISTENT.duplicate();
queueFrame(null, buffer);
}
public int write(final ByteBuffer src) throws IOException {
if(queuedDataLength() > 0) {
if(!flushQueuedData()) {
return 0;
}
}
processAJPHeader();
if (headRequest) {
int remaining = src.remaining();
src.position(src.position() + remaining);
return remaining;
}
int limit = src.limit();
try {
int maxData = exchange.getConnection().getUndertowOptions().get(UndertowOptions.MAX_AJP_PACKET_SIZE, DEFAULT_MAX_DATA_SIZE) - 8;
if (src.remaining() > maxData) {
src.limit(src.position() + maxData);
}
final int writeSize = src.remaining();
final ByteBuffer[] buffers = createHeader(src);
int toWrite = 0;
for (ByteBuffer buffer : buffers) {
toWrite += buffer.remaining();
}
final int originalPayloadSize = writeSize;
long r = 0;
do {
r = super.write(buffers, 0, buffers.length);
toWrite -= r;
if (r == -1) {
throw new ClosedChannelException();
} else if (r == 0) {
PooledByteBuffer newPooledBuffer = pool.allocate();
while (src.hasRemaining()) {
newPooledBuffer.getBuffer().put(src);
}
newPooledBuffer.getBuffer().flip();
ByteBuffer[] savedBuffers = new ByteBuffer[3];
savedBuffers[0] = buffers[0];
savedBuffers[1] = newPooledBuffer.getBuffer();
savedBuffers[2] = buffers[2];
queueFrame(new PooledBufferFrameCallback(newPooledBuffer), savedBuffers);
return originalPayloadSize;
}
} while (toWrite > 0);
return originalPayloadSize;
} catch (IOException | RuntimeException e) {
IoUtils.safeClose(exchange.getConnection());
throw e;
} finally {
src.limit(limit);
}
}
private ByteBuffer[] (final ByteBuffer src) {
int remaining = src.remaining();
int chunkSize = remaining + 4;
byte[] header = new byte[7];
header[0] = (byte) 'A';
header[1] = (byte) 'B';
header[2] = (byte) ((chunkSize >> 8) & 0xFF);
header[3] = (byte) (chunkSize & 0xFF);
header[4] = (byte) (3 & 0xFF);
header[5] = (byte) ((remaining >> 8) & 0xFF);
header[6] = (byte) (remaining & 0xFF);
byte[] footer = new byte[1];
footer[0] = 0;
final ByteBuffer[] buffers = new ByteBuffer[3];
buffers[0] = ByteBuffer.wrap(header);
buffers[1] = src;
buffers[2] = ByteBuffer.wrap(footer);
return buffers;
}
public long write(final ByteBuffer[] srcs) throws IOException {
return write(srcs, 0, srcs.length);
}
public long write(final ByteBuffer[] srcs, final int offset, final int length) throws IOException {
long total = 0;
for (int i = offset; i < offset + length; ++i) {
while (srcs[i].hasRemaining()) {
int written = write(srcs[i]);
if (written == 0) {
return total;
}
total += written;
}
}
return total;
}
public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
return src.transferTo(position, count, new ConduitWritableByteChannel(this));
}
public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
return IoUtils.transfer(source, count, throughBuffer, new ConduitWritableByteChannel(this));
}
@Override
protected void finished() {
if (finishListener != null) {
finishListener.handleEvent(this);
}
}
@Override
public void setWriteReadyHandler(WriteReadyHandler handler) {
next.setWriteReadyHandler(new AjpServerWriteReadyHandler(handler));
}
public void suspendWrites() {
log.trace("suspend");
state &= ~FLAG_WRITE_RESUMED;
if (allAreClear(state, FLAG_WRITE_READ_BODY_CHUNK_FROM_LISTENER)) {
next.suspendWrites();
}
}
public void resumeWrites() {
log.trace("resume");
state |= FLAG_WRITE_RESUMED;
next.resumeWrites();
}
public boolean flush() throws IOException {
processAJPHeader();
if(allAreClear(state, FLAG_FLUSH_QUEUED) && !isWritesTerminated()) {
queueFrame(new FrameCallBack() {
@Override
public void done() {
state &= ~FLAG_FLUSH_QUEUED;
}
@Override
public void failed(IOException e) {
}
}, FLUSH_PACKET.duplicate());
state |= FLAG_FLUSH_QUEUED;
}
return flushQueuedData();
}
public boolean isWriteResumed() {
return anyAreSet(state, FLAG_WRITE_RESUMED);
}
public void wakeupWrites() {
log.trace("wakeup");
state |= FLAG_WRITE_RESUMED;
next.wakeupWrites();
}
@Override
protected void doTerminateWrites() throws IOException {
try {
if (!exchange.isPersistent()) {
next.terminateWrites();
}
state |= FLAG_WRITE_SHUTDOWN;
} catch (IOException | RuntimeException e) {
IoUtils.safeClose(exchange.getConnection());
throw e;
}
}
@Override
public boolean isWriteShutdown() {
return super.isWriteShutdown() || anyAreSet(state, FLAG_WRITE_SHUTDOWN);
}
boolean doGetRequestBodyChunk(ByteBuffer buffer, final AjpServerRequestConduit requestChannel) throws IOException {
if(isWriteShutdown()) {
throw UndertowMessages.MESSAGES.channelIsClosed();
}
super.write(buffer);
if (buffer.hasRemaining()) {
this.state |= FLAG_WRITE_READ_BODY_CHUNK_FROM_LISTENER;
queueFrame(new FrameCallBack() {
@Override
public void done() {
state &= ~FLAG_WRITE_READ_BODY_CHUNK_FROM_LISTENER;
if (allAreClear(state, FLAG_WRITE_RESUMED)) {
next.suspendWrites();
}
}
@Override
public void failed(IOException e) {
requestChannel.setReadBodyChunkError(e);
}
}, buffer);
next.resumeWrites();
return false;
}
return true;
}
private final class AjpServerWriteReadyHandler implements WriteReadyHandler {
private final WriteReadyHandler delegate;
private AjpServerWriteReadyHandler(WriteReadyHandler delegate) {
this.delegate = delegate;
}
@Override
public void writeReady() {
if (anyAreSet(state, FLAG_WRITE_READ_BODY_CHUNK_FROM_LISTENER)) {
try {
flushQueuedData();
} catch (IOException e) {
log.debug("Error flushing when doing async READ_BODY_CHUNK flush", e);
}
}
if (anyAreSet(state, FLAG_WRITE_RESUMED)) {
delegate.writeReady();
}
}
@Override
public void forceTermination() {
delegate.forceTermination();
}
@Override
public void terminated() {
delegate.terminated();
}
}
}