package io.undertow.client.ajp;
import io.undertow.UndertowOptions;
import io.undertow.client.ClientCallback;
import io.undertow.client.ClientConnection;
import io.undertow.client.ClientProvider;
import io.undertow.client.ClientStatistics;
import io.undertow.conduits.ByteActivityCallback;
import io.undertow.conduits.BytesReceivedStreamSourceConduit;
import io.undertow.conduits.BytesSentStreamSinkConduit;
import io.undertow.protocols.ajp.AjpClientChannel;
import org.xnio.ChannelListener;
import org.xnio.IoFuture;
import org.xnio.OptionMap;
import io.undertow.connector.ByteBufferPool;
import org.xnio.StreamConnection;
import org.xnio.XnioIoThread;
import org.xnio.XnioWorker;
import org.xnio.ssl.XnioSsl;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
public class AjpClientProvider implements ClientProvider {
@Override
public Set<String> handlesSchemes() {
return new HashSet<>(Arrays.asList(new String[]{"ajp"}));
}
@Override
public void connect(final ClientCallback<ClientConnection> listener, final URI uri, final XnioWorker worker, final XnioSsl ssl, final ByteBufferPool bufferPool, final OptionMap options) {
connect(listener, null, uri, worker, ssl, bufferPool, options);
}
@Override
public void connect(final ClientCallback<ClientConnection> listener, final URI uri, final XnioIoThread ioThread, final XnioSsl ssl, final ByteBufferPool bufferPool, final OptionMap options) {
connect(listener, null, uri, ioThread, ssl, bufferPool, options);
}
@Override
public void connect(final ClientCallback<ClientConnection> listener, InetSocketAddress bindAddress, final URI uri, final XnioWorker worker, final XnioSsl ssl, final ByteBufferPool bufferPool, final OptionMap options) {
ChannelListener<StreamConnection> openListener = new ChannelListener<StreamConnection>() {
@Override
public void handleEvent(StreamConnection connection) {
handleConnected(connection, listener, uri, ssl, bufferPool, options);
}
};
IoFuture.Notifier<StreamConnection, Object> notifier = new IoFuture.Notifier<StreamConnection, Object>() {
@Override
public void notify(IoFuture<? extends StreamConnection> ioFuture, Object o) {
if (ioFuture.getStatus() == IoFuture.Status.FAILED) {
listener.failed(ioFuture.getException());
}
}
};
if(bindAddress == null) {
worker.openStreamConnection(new InetSocketAddress(uri.getHost(), uri.getPort() == -1 ? 8009 : uri.getPort()), openListener, options).addNotifier(notifier, null);
} else {
worker.openStreamConnection(bindAddress, new InetSocketAddress(uri.getHost(), uri.getPort() == -1 ? 8009 : uri.getPort()), openListener, null, options).addNotifier(notifier, null);
}
}
@Override
public void connect(final ClientCallback<ClientConnection> listener, InetSocketAddress bindAddress,final URI uri, final XnioIoThread ioThread, final XnioSsl ssl, final ByteBufferPool bufferPool, final OptionMap options) {
ChannelListener<StreamConnection> openListener = new ChannelListener<StreamConnection>() {
@Override
public void handleEvent(StreamConnection connection) {
handleConnected(connection, listener, uri, ssl, bufferPool, options);
}
};
IoFuture.Notifier<StreamConnection, Object> notifier = new IoFuture.Notifier<StreamConnection, Object>() {
@Override
public void notify(IoFuture<? extends StreamConnection> ioFuture, Object o) {
if (ioFuture.getStatus() == IoFuture.Status.FAILED) {
listener.failed(ioFuture.getException());
}
}
};
if(bindAddress == null) {
ioThread.openStreamConnection(new InetSocketAddress(uri.getHost(), uri.getPort() == -1 ? 8009 : uri.getPort()), openListener, options).addNotifier(notifier, null);
} else {
ioThread.openStreamConnection(bindAddress, new InetSocketAddress(uri.getHost(), uri.getPort() == -1 ? 8009 : uri.getPort()), openListener, null, options).addNotifier(notifier, null);
}
}
private void handleConnected(StreamConnection connection, ClientCallback<ClientConnection> listener, URI uri, XnioSsl ssl, ByteBufferPool bufferPool, OptionMap options) {
final ClientStatisticsImpl clientStatistics;
if (options.get(UndertowOptions.ENABLE_STATISTICS, false)) {
clientStatistics = new ClientStatisticsImpl();
connection.getSinkChannel().setConduit(new BytesSentStreamSinkConduit(connection.getSinkChannel().getConduit(), new ByteActivityCallback() {
@Override
public void activity(long bytes) {
clientStatistics.written += bytes;
}
}));
connection.getSourceChannel().setConduit(new BytesReceivedStreamSourceConduit(connection.getSourceChannel().getConduit(), new ByteActivityCallback() {
@Override
public void activity(long bytes) {
clientStatistics.read += bytes;
}
}));
} else {
clientStatistics = null;
}
listener.completed(new AjpClientConnection(new AjpClientChannel(connection, bufferPool, options), options, bufferPool, clientStatistics));
}
private static class ClientStatisticsImpl implements ClientStatistics {
private long requestCount, read, written;
@Override
public long getRequests() {
return requestCount;
}
@Override
public long getRead() {
return read;
}
@Override
public long getWritten() {
return written;
}
@Override
public void reset() {
read = 0;
written = 0;
requestCount = 0;
}
}
}