package com.mongodb.internal.connection.tlschannel;
import com.mongodb.internal.connection.tlschannel.impl.BufferHolder;
import com.mongodb.internal.connection.tlschannel.impl.ByteBufferSet;
import com.mongodb.internal.connection.tlschannel.impl.TlsChannelImpl;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLSession;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.Channel;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Supplier;
public final class ClientTlsChannel implements TlsChannel {
public static final class Builder extends TlsChannelBuilder<Builder> {
private Supplier<SSLEngine> sslEngineFactory;
private Builder(final ByteChannel underlying, final SSLEngine sslEngine) {
super(underlying);
this.sslEngineFactory = new Supplier<SSLEngine>() {
@Override
public SSLEngine get() {
return sslEngine;
}
};
}
private Builder(final ByteChannel underlying, final SSLContext sslContext) {
super(underlying);
this.sslEngineFactory = new Supplier<SSLEngine>() {
@Override
public SSLEngine get() {
return defaultSSLEngineFactory(sslContext);
}
};
}
@Override
Builder getThis() {
return this;
}
public ClientTlsChannel build() {
return new ClientTlsChannel(underlying, sslEngineFactory.get(), sessionInitCallback, runTasks,
plainBufferAllocator, encryptedBufferAllocator, releaseBuffers, waitForCloseConfirmation);
}
}
private static SSLEngine defaultSSLEngineFactory(final SSLContext sslContext) {
SSLEngine engine = sslContext.createSSLEngine();
engine.setUseClientMode(true);
return engine;
}
public static Builder newBuilder(final ByteChannel underlying, final SSLEngine sslEngine) {
return new Builder(underlying, sslEngine);
}
public static Builder newBuilder(final ByteChannel underlying, final SSLContext sslContext) {
return new Builder(underlying, sslContext);
}
private final ByteChannel underlying;
private final TlsChannelImpl impl;
private ClientTlsChannel(
final ByteChannel underlying,
final SSLEngine engine,
final Consumer<SSLSession> sessionInitCallback,
final boolean runTasks,
final BufferAllocator plainBufAllocator,
final BufferAllocator encryptedBufAllocator,
final boolean releaseBuffers,
final boolean waitForCloseNotifyOnClose) {
if (!engine.getUseClientMode()) {
throw new IllegalArgumentException("SSLEngine must be in client mode");
}
this.underlying = underlying;
TrackingAllocator trackingPlainBufAllocator = new TrackingAllocator(plainBufAllocator);
TrackingAllocator trackingEncryptedAllocator = new TrackingAllocator(encryptedBufAllocator);
impl = new TlsChannelImpl(underlying, underlying, engine, Optional.<BufferHolder>empty(), sessionInitCallback, runTasks,
trackingPlainBufAllocator, trackingEncryptedAllocator, releaseBuffers, waitForCloseNotifyOnClose);
}
@Override
public ByteChannel getUnderlying() {
return underlying;
}
@Override
public SSLEngine getSslEngine() {
return impl.engine();
}
@Override
public Consumer<SSLSession> getSessionInitCallback() {
return impl.getSessionInitCallback();
}
@Override
public TrackingAllocator getPlainBufferAllocator() {
return impl.getPlainBufferAllocator();
}
@Override
public TrackingAllocator getEncryptedBufferAllocator() {
return impl.getEncryptedBufferAllocator();
}
@Override
public boolean getRunTasks() {
return impl.getRunTasks();
}
@Override
public long read(final ByteBuffer[] dstBuffers, final int offset, final int length) throws IOException {
ByteBufferSet dest = new ByteBufferSet(dstBuffers, offset, length);
TlsChannelImpl.checkReadBuffer(dest);
return impl.read(dest);
}
@Override
public long read(final ByteBuffer[] dstBuffers) throws IOException {
return read(dstBuffers, 0, dstBuffers.length);
}
@Override
public int read(final ByteBuffer dstBuffer) throws IOException {
return (int) read(new ByteBuffer[]{dstBuffer});
}
@Override
public long write(final ByteBuffer[] srcBuffers, final int offset, final int length) throws IOException {
ByteBufferSet source = new ByteBufferSet(srcBuffers, offset, length);
return impl.write(source);
}
@Override
public long write(final ByteBuffer[] outs) throws IOException {
return write(outs, 0, outs.length);
}
@Override
public int write(final ByteBuffer srcBuffer) throws IOException {
return (int) write(new ByteBuffer[]{srcBuffer});
}
@Override
public void renegotiate() throws IOException {
impl.renegotiate();
}
@Override
public void handshake() throws IOException {
impl.handshake();
}
@Override
public void close() throws IOException {
impl.close();
}
@Override
public boolean isOpen() {
return impl.isOpen();
}
@Override
public boolean shutdown() throws IOException {
return impl.shutdown();
}
@Override
public boolean shutdownReceived() {
return impl.shutdownReceived();
}
@Override
public boolean shutdownSent() {
return impl.shutdownSent();
}
}