package io.undertow.server.protocol.proxy;
import io.undertow.UndertowLogger;
import io.undertow.UndertowMessages;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.protocols.ssl.UndertowXnioSsl;
import io.undertow.server.DelegateOpenListener;
import io.undertow.server.OpenListener;
import io.undertow.util.NetworkUtils;
import io.undertow.util.PooledAdaptor;
import org.xnio.ChannelListener;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import org.xnio.StreamConnection;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.PushBackStreamSourceConduit;
import org.xnio.ssl.SslConnection;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
Implementation of version 1 of the proxy protocol (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)
Even though it is not required by the spec this implementation provides a stateful parser, that can handle
fragmentation of
Author: Stuart Douglas
/**
* Implementation of version 1 of the proxy protocol (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)
* <p>
* Even though it is not required by the spec this implementation provides a stateful parser, that can handle
* fragmentation of
*
* @author Stuart Douglas
*/
class ProxyProtocolReadListener implements ChannelListener<StreamSourceChannel> {
private static final int MAX_HEADER_LENGTH = 107;
private static final byte[] NAME = "PROXY ".getBytes(StandardCharsets.US_ASCII);
private static final String UNKNOWN = "UNKNOWN";
private static final String TCP4 = "TCP4";
private static final String TCP_6 = "TCP6";
private final StreamConnection streamConnection;
private final OpenListener openListener;
private final UndertowXnioSsl ssl;
private final ByteBufferPool bufferPool;
private final OptionMap sslOptionMap;
private int byteCount;
private String protocol;
private InetAddress sourceAddress;
private InetAddress destAddress;
private int sourcePort = -1;
private int destPort = -1;
private StringBuilder stringBuilder = new StringBuilder();
private boolean carriageReturnSeen = false;
private boolean parsingUnknown = false;
ProxyProtocolReadListener(StreamConnection streamConnection, OpenListener openListener, UndertowXnioSsl ssl, ByteBufferPool bufferPool, OptionMap sslOptionMap) {
this.streamConnection = streamConnection;
this.openListener = openListener;
this.ssl = ssl;
this.bufferPool = bufferPool;
this.sslOptionMap = sslOptionMap;
if (bufferPool.getBufferSize() < MAX_HEADER_LENGTH) {
throw UndertowMessages.MESSAGES.bufferPoolTooSmall(MAX_HEADER_LENGTH);
}
}
@Override
public void handleEvent(StreamSourceChannel streamSourceChannel) {
PooledByteBuffer buffer = bufferPool.allocate();
boolean freeBuffer = true;
try {
for (; ; ) {
int res = streamSourceChannel.read(buffer.getBuffer());
if (res == -1) {
IoUtils.safeClose(streamConnection);
return;
} else if (res == 0) {
return;
} else {
buffer.getBuffer().flip();
while (buffer.getBuffer().hasRemaining()) {
char c = (char) buffer.getBuffer().get();
if (byteCount < NAME.length) {
//first we verify that we have the correct protocol
if (c != NAME[byteCount]) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else {
if (parsingUnknown) {
//we are parsing the UNKNOWN protocol
//we just ignore everything till \r\n
if (c == '\r') {
carriageReturnSeen = true;
} else if (c == '\n') {
if (!carriageReturnSeen) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
//we are done
if (buffer.getBuffer().hasRemaining()) {
freeBuffer = false;
proxyAccept(null, null, buffer);
} else {
proxyAccept(null, null, null);
}
return;
} else if (carriageReturnSeen) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (carriageReturnSeen) {
if (c == '\n') {
//we are done
SocketAddress s = new InetSocketAddress(sourceAddress, sourcePort);
SocketAddress d = new InetSocketAddress(destAddress, destPort);
if (buffer.getBuffer().hasRemaining()) {
freeBuffer = false;
proxyAccept(s, d, buffer);
} else {
proxyAccept(s, d, null);
}
return;
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else switch (c) {
case ' ':
//we have a space
if (sourcePort != -1 || stringBuilder.length() == 0) {
//header was invalid, either we are expecting a \r or a \n, or the previous character was a space
throw UndertowMessages.MESSAGES.invalidProxyHeader();
} else if (protocol == null) {
protocol = stringBuilder.toString();
stringBuilder.setLength(0);
if (protocol.equals(UNKNOWN)) {
parsingUnknown = true;
} else if (!protocol.equals(TCP4) && !protocol.equals(TCP_6)) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (sourceAddress == null) {
sourceAddress = parseAddress(stringBuilder.toString(), protocol);
stringBuilder.setLength(0);
} else if (destAddress == null) {
destAddress = parseAddress(stringBuilder.toString(), protocol);
stringBuilder.setLength(0);
} else {
sourcePort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
}
break;
case '\r':
if (destPort == -1 && sourcePort != -1 && !carriageReturnSeen && stringBuilder.length() > 0) {
destPort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
carriageReturnSeen = true;
} else if (protocol == null) {
if (UNKNOWN.equals(stringBuilder.toString())) {
parsingUnknown = true;
carriageReturnSeen = true;
}
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
break;
case '\n':
throw UndertowMessages.MESSAGES.invalidProxyHeader();
default:
stringBuilder.append(c);
}
}
byteCount++;
if (byteCount == MAX_HEADER_LENGTH) {
throw UndertowMessages.MESSAGES.headerSizeToLarge();
}
}
}
}
} catch (IOException e) {
UndertowLogger.REQUEST_IO_LOGGER.ioException(e);
IoUtils.safeClose(streamConnection);
} catch (Exception e) {
UndertowLogger.REQUEST_IO_LOGGER.ioException(new IOException(e));
IoUtils.safeClose(streamConnection);
} finally {
if (freeBuffer) {
buffer.close();
}
}
}
private void proxyAccept(SocketAddress source, SocketAddress dest, PooledByteBuffer additionalData) {
StreamConnection streamConnection = this.streamConnection;
if (source != null) {
streamConnection = new AddressWrappedConnection(streamConnection, source, dest);
}
if (ssl != null) {
//we need to apply the additional data before the SSL wrapping
if (additionalData != null) {
PushBackStreamSourceConduit conduit = new PushBackStreamSourceConduit(streamConnection.getSourceChannel().getConduit());
conduit.pushBack(new PooledAdaptor(additionalData));
streamConnection.getSourceChannel().setConduit(conduit);
}
SslConnection sslConnection = ssl.wrapExistingConnection(streamConnection, sslOptionMap == null ? OptionMap.EMPTY : sslOptionMap);
UndertowXnioSsl.getSslEngine(sslConnection).setUseClientMode(false);
streamConnection = sslConnection;
callOpenListener(streamConnection, null);
} else {
callOpenListener(streamConnection, additionalData);
}
}
private void callOpenListener(StreamConnection streamConnection, final PooledByteBuffer buffer) {
if (openListener instanceof DelegateOpenListener) {
((DelegateOpenListener) openListener).handleEvent(streamConnection, buffer);
} else {
if (buffer != null) {
PushBackStreamSourceConduit conduit = new PushBackStreamSourceConduit(streamConnection.getSourceChannel().getConduit());
conduit.pushBack(new PooledAdaptor(buffer));
streamConnection.getSourceChannel().setConduit(conduit);
}
openListener.handleEvent(streamConnection);
}
}
static InetAddress parseAddress(String addressString, String protocol) throws IOException {
if (protocol.equals(TCP4)) {
return NetworkUtils.parseIpv4Address(addressString);
} else {
return NetworkUtils.parseIpv6Address(addressString);
}
}
private static final class AddressWrappedConnection extends StreamConnection {
private final StreamConnection delegate;
private final SocketAddress source;
private final SocketAddress dest;
AddressWrappedConnection(StreamConnection delegate, SocketAddress source, SocketAddress dest) {
super(delegate.getIoThread());
this.delegate = delegate;
this.source = source;
this.dest = dest;
setSinkConduit(delegate.getSinkChannel().getConduit());
setSourceConduit(delegate.getSourceChannel().getConduit());
}
@Override
protected void notifyWriteClosed() {
IoUtils.safeClose(delegate.getSinkChannel());
}
@Override
protected void notifyReadClosed() {
IoUtils.safeClose(delegate.getSourceChannel());
}
@Override
public SocketAddress getPeerAddress() {
return source;
}
@Override
public SocketAddress getLocalAddress() {
return dest;
}
}
}