package org.eclipse.jetty.io;
import java.io.Closeable;
import java.io.IOException;
import java.net.SocketAddress;
import java.net.SocketException;
import java.nio.channels.SelectableChannel;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.time.Duration;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Executor;
import org.eclipse.jetty.util.Promise;
import org.eclipse.jetty.util.component.ContainerLifeCycle;
import org.eclipse.jetty.util.log.Log;
import org.eclipse.jetty.util.log.Logger;
import org.eclipse.jetty.util.ssl.SslContextFactory;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.eclipse.jetty.util.thread.ScheduledExecutorScheduler;
import org.eclipse.jetty.util.thread.Scheduler;
public class ClientConnector extends ContainerLifeCycle
{
public static final String CLIENT_CONNECTOR_CONTEXT_KEY = "org.eclipse.jetty.client.connector";
public static final String REMOTE_SOCKET_ADDRESS_CONTEXT_KEY = CLIENT_CONNECTOR_CONTEXT_KEY + ".remoteSocketAddress";
public static final String CLIENT_CONNECTION_FACTORY_CONTEXT_KEY = CLIENT_CONNECTOR_CONTEXT_KEY + ".clientConnectionFactory";
public static final String CONNECTION_PROMISE_CONTEXT_KEY = CLIENT_CONNECTOR_CONTEXT_KEY + ".connectionPromise";
private static final Logger LOG = Log.getLogger(ClientConnector.class);
private Executor executor;
private Scheduler scheduler;
private ByteBufferPool byteBufferPool;
private SslContextFactory.Client sslContextFactory;
private SelectorManager selectorManager;
private int selectors = 1;
private boolean connectBlocking;
private Duration connectTimeout = Duration.ofSeconds(5);
private Duration idleTimeout = Duration.ofSeconds(30);
private SocketAddress bindAddress;
public Executor getExecutor()
{
return executor;
}
public void setExecutor(Executor executor)
{
if (isStarted())
throw new IllegalStateException();
updateBean(this.executor, executor);
this.executor = executor;
}
public Scheduler getScheduler()
{
return scheduler;
}
public void setScheduler(Scheduler scheduler)
{
if (isStarted())
throw new IllegalStateException();
updateBean(this.scheduler, scheduler);
this.scheduler = scheduler;
}
public ByteBufferPool getByteBufferPool()
{
return byteBufferPool;
}
public void setByteBufferPool(ByteBufferPool byteBufferPool)
{
if (isStarted())
throw new IllegalStateException();
updateBean(this.byteBufferPool, byteBufferPool);
this.byteBufferPool = byteBufferPool;
}
public SslContextFactory.Client getSslContextFactory()
{
return sslContextFactory;
}
public void setSslContextFactory(SslContextFactory.Client sslContextFactory)
{
if (isStarted())
throw new IllegalStateException();
updateBean(this.sslContextFactory, sslContextFactory);
this.sslContextFactory = sslContextFactory;
}
public int getSelectors()
{
return selectors;
}
public void setSelectors(int selectors)
{
if (isStarted())
throw new IllegalStateException();
this.selectors = selectors;
}
public boolean isConnectBlocking()
{
return connectBlocking;
}
public void setConnectBlocking(boolean connectBlocking)
{
this.connectBlocking = connectBlocking;
}
public Duration getConnectTimeout()
{
return connectTimeout;
}
public void setConnectTimeout(Duration connectTimeout)
{
this.connectTimeout = connectTimeout;
if (selectorManager != null)
selectorManager.setConnectTimeout(connectTimeout.toMillis());
}
public Duration getIdleTimeout()
{
return idleTimeout;
}
public void setIdleTimeout(Duration idleTimeout)
{
this.idleTimeout = idleTimeout;
}
public SocketAddress getBindAddress()
{
return bindAddress;
}
public void setBindAddress(SocketAddress bindAddress)
{
this.bindAddress = bindAddress;
}
@Override
protected void doStart() throws Exception
{
if (executor == null)
{
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName(String.format("client-pool@%x", hashCode()));
setExecutor(clientThreads);
}
if (scheduler == null)
setScheduler(new ScheduledExecutorScheduler(String.format("client-scheduler@%x", hashCode()), false));
if (byteBufferPool == null)
setByteBufferPool(new MappedByteBufferPool());
if (sslContextFactory == null)
setSslContextFactory(newSslContextFactory());
selectorManager = newSelectorManager();
selectorManager.setConnectTimeout(getConnectTimeout().toMillis());
addBean(selectorManager);
super.doStart();
}
@Override
protected void doStop() throws Exception
{
super.doStop();
removeBean(selectorManager);
}
protected SslContextFactory.Client newSslContextFactory()
{
SslContextFactory.Client sslContextFactory = new SslContextFactory.Client(false);
sslContextFactory.setEndpointIdentificationAlgorithm("HTTPS");
return sslContextFactory;
}
protected SelectorManager newSelectorManager()
{
return new ClientSelectorManager(getExecutor(), getScheduler(), getSelectors());
}
public void connect(SocketAddress address, Map<String, Object> context)
{
SocketChannel channel = null;
try
{
if (context == null)
context = new HashMap<>();
context.put(ClientConnector.CLIENT_CONNECTOR_CONTEXT_KEY, this);
context.putIfAbsent(REMOTE_SOCKET_ADDRESS_CONTEXT_KEY, address);
channel = SocketChannel.open();
SocketAddress bindAddress = getBindAddress();
if (bindAddress != null)
{
if (LOG.isDebugEnabled())
LOG.debug("Binding to {} to connect to {}", bindAddress, address);
channel.bind(bindAddress);
}
configure(channel);
boolean connected = true;
boolean blocking = isConnectBlocking();
if (LOG.isDebugEnabled())
LOG.debug("Connecting {} to {}", blocking ? "blocking" : "non-blocking", address);
if (blocking)
{
channel.socket().connect(address, (int)getConnectTimeout().toMillis());
channel.configureBlocking(false);
}
else
{
channel.configureBlocking(false);
connected = channel.connect(address);
}
if (connected)
selectorManager.accept(channel, context);
else
selectorManager.connect(channel, context);
}
catch (Throwable x)
{
if (x.getClass() == SocketException.class)
x = new SocketException("Could not connect to " + address).initCause(x);
safeClose(channel);
connectFailed(x, context);
}
}
public void accept(SocketChannel channel, Map<String, Object> context)
{
try
{
context.put(ClientConnector.CLIENT_CONNECTOR_CONTEXT_KEY, this);
if (!channel.isConnected())
throw new IllegalStateException("SocketChannel must be connected");
configure(channel);
channel.configureBlocking(false);
selectorManager.accept(channel, context);
}
catch (Throwable failure)
{
if (LOG.isDebugEnabled())
LOG.debug("Could not accept {}", channel);
safeClose(channel);
Promise<?> promise = (Promise<?>)context.get(CONNECTION_PROMISE_CONTEXT_KEY);
promise.failed(failure);
}
}
protected void safeClose(Closeable closeable)
{
try
{
if (closeable != null)
closeable.close();
}
catch (Throwable x)
{
LOG.ignore(x);
}
}
protected void configure(SocketChannel channel) throws IOException
{
channel.socket().setTcpNoDelay(true);
}
protected void connectFailed(Throwable failure, Map<String, Object> context)
{
if (LOG.isDebugEnabled())
LOG.debug("Could not connect to {}", context.get(REMOTE_SOCKET_ADDRESS_CONTEXT_KEY));
Promise<?> promise = (Promise<?>)context.get(CONNECTION_PROMISE_CONTEXT_KEY);
promise.failed(failure);
}
protected class ClientSelectorManager extends SelectorManager
{
protected ClientSelectorManager(Executor executor, Scheduler scheduler, int selectors)
{
super(executor, scheduler, selectors);
}
@Override
protected EndPoint newEndPoint(SelectableChannel channel, ManagedSelector selector, SelectionKey selectionKey)
{
SocketChannelEndPoint endPoint = new SocketChannelEndPoint(channel, selector, selectionKey, getScheduler());
endPoint.setIdleTimeout(getIdleTimeout().toMillis());
return endPoint;
}
@Override
public Connection newConnection(SelectableChannel channel, EndPoint endPoint, Object attachment) throws IOException
{
@SuppressWarnings("unchecked")
Map<String, Object> context = (Map<String, Object>)attachment;
ClientConnectionFactory factory = (ClientConnectionFactory)context.get(CLIENT_CONNECTION_FACTORY_CONTEXT_KEY);
return factory.newConnection(endPoint, context);
}
@Override
protected void connectionFailed(SelectableChannel channel, Throwable failure, Object attachment)
{
@SuppressWarnings("unchecked")
Map<String, Object> context = (Map<String, Object>)attachment;
connectFailed(failure, context);
}
}
}