package org.xnio.ssl;
import static org.xnio._private.Messages.msg;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.AbstractStreamSinkConduit;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.StreamSinkConduit;
final class JsseSslStreamSinkConduit extends AbstractStreamSinkConduit<StreamSinkConduit> {
private final JsseSslConduitEngine sslEngine;
private volatile boolean tls;
protected JsseSslStreamSinkConduit(StreamSinkConduit next, JsseSslConduitEngine sslEngine, boolean tls) {
super(next);
if (sslEngine == null) {
throw msg.nullParameter("sslEngine");
}
this.sslEngine = sslEngine;
this.tls = tls;
}
public void enableTls() {
tls = true;
if (isWriteResumed()) {
wakeupWrites();
}
}
@Override
public long transferFrom(final FileChannel src, final long position, final long count) throws IOException {
return src.transferTo(position, count, new ConduitWritableByteChannel(this));
}
@Override
public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException {
return Conduits.transfer(source, count, throughBuffer, this);
}
@Override
public int write(ByteBuffer src) throws IOException {
return write(src, false);
}
@Override
public long write(ByteBuffer[] srcs, int offs, int len) throws IOException {
return write(srcs, offs, len, false);
}
@Override
public int writeFinal(ByteBuffer src) throws IOException {
return write(src, true);
}
@Override
public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException {
return write(srcs, offset, length, true);
}
private int write(ByteBuffer src, boolean writeFinal) throws IOException {
if (!tls) {
if(writeFinal) {
return next.writeFinal(src);
} else {
return next.write(src);
}
}
final int wrappedBytes = sslEngine.wrap(src);
if (wrappedBytes > 0) {
writeWrappedBuffer(writeFinal);
}
return wrappedBytes;
}
private long write(ByteBuffer[] srcs, int offs, int len, boolean writeFinal) throws IOException {
if (!tls) {
if(writeFinal) {
return super.writeFinal(srcs, offs, len);
} else {
return super.write(srcs, offs, len);
}
}
final long wrappedBytes = sslEngine.wrap(srcs, offs, len);
if (wrappedBytes > 0) {
writeWrappedBuffer(writeFinal);
}
return wrappedBytes;
}
@Override
public void resumeWrites() {
if (tls && sslEngine.isFirstHandshake()) {
super.wakeupWrites();
} else {
super.resumeWrites();
}
}
@Override
public void terminateWrites() throws IOException {
if (!tls) {
super.terminateWrites();
return;
}
try {
sslEngine.closeOutbound();
flush();
} catch (IOException e) {
try {
super.truncateWrites();
} catch (IOException e2) {
e2.addSuppressed(e);
throw e2;
}
throw e;
}
}
@Override
public void awaitWritable() throws IOException {
if (tls) {
sslEngine.awaitCanWrap();
}
super.awaitWritable();
}
@Override
public void awaitWritable(long time, TimeUnit timeUnit) throws IOException {
if (!tls) {
super.awaitWritable(time, timeUnit);
return;
}
long duration = timeUnit.toNanos(time);
long awaited = System.nanoTime();
sslEngine.awaitCanWrap(time, timeUnit);
awaited = System.nanoTime() - awaited;
if (awaited > duration) {
return;
}
super.awaitWritable(duration - awaited, TimeUnit.NANOSECONDS);
}
@Override
public void truncateWrites() throws IOException {
if (tls) try {
sslEngine.closeOutbound();
} finally {
try {
super.truncateWrites();
} catch (IOException ignored) {
}
}
super.truncateWrites();
}
@Override
public boolean flush() throws IOException {
if (!tls) {
return super.flush();
}
if (sslEngine.isOutboundClosed()) {
if (sslEngine.flush() && writeWrappedBuffer(false) && super.flush()) {
super.terminateWrites();
return true;
} else {
return false;
}
}
return sslEngine.flush() && writeWrappedBuffer(false) && super.flush();
}
private boolean writeWrappedBuffer(boolean writeFinal) throws IOException {
synchronized (sslEngine.getWrapLock()) {
final ByteBuffer wrapBuffer = sslEngine.getWrappedBuffer();
for (;;) {
try {
if (!wrapBuffer.flip().hasRemaining()) {
if (writeFinal) {
terminateWrites();
}
return true;
}
if(writeFinal) {
if (super.writeFinal(wrapBuffer) == 0) {
return false;
}
} else {
if (super.write(wrapBuffer) == 0) {
return false;
}
}
} finally {
wrapBuffer.compact();
}
}
}
}
}