package org.xnio.ssl;
import static org.xnio._private.Messages.msg;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;
import org.xnio.channels.StreamSinkChannel;
import org.xnio.conduits.AbstractStreamSourceConduit;
import org.xnio.conduits.ConduitReadableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.StreamSourceConduit;
final class JsseSslStreamSourceConduit extends AbstractStreamSourceConduit<StreamSourceConduit> {
private final JsseSslConduitEngine sslEngine;
private volatile boolean tls;
protected JsseSslStreamSourceConduit(StreamSourceConduit next, JsseSslConduitEngine sslEngine, boolean tls) {
super(next);
if (sslEngine == null) {
throw msg.nullParameter("sslEngine");
}
this.sslEngine = sslEngine;
this.tls = tls;
}
void enableTls() {
tls = true;
if (isReadResumed()) {
wakeupReads();
}
}
@Override
public long transferTo(final long position, final long count, final FileChannel target) throws IOException {
return target.transferFrom(new ConduitReadableByteChannel(this), position, count);
}
@Override
public long transferTo(final long count, final ByteBuffer throughBuffer, final StreamSinkChannel target) throws IOException {
return Conduits.transfer(this, count, throughBuffer, target);
}
@Override
public int read(ByteBuffer dst) throws IOException {
if (!tls) {
final int res = super.read(dst);
if (res == -1) {
terminateReads();
}
return res;
}
if ((!sslEngine.isDataAvailable() && sslEngine.isInboundClosed()) || sslEngine.isClosed()) {
return -1;
}
final int readResult;
final int unwrapResult;
synchronized(sslEngine.getUnwrapLock()) {
final ByteBuffer unwrapBuffer = sslEngine.getUnwrapBuffer().compact();
try {
readResult = super.read(unwrapBuffer);
} finally {
unwrapBuffer.flip();
}
}
unwrapResult = sslEngine.unwrap(dst);
if (unwrapResult == 0 && readResult == -1) {
terminateReads();
return -1;
}
return unwrapResult;
}
@Override
public long read(ByteBuffer[] dsts, int offs, int len) throws IOException {
if (!tls) {
final long res = super.read(dsts, offs, len);
if (res == -1) {
terminateReads();
}
return res;
}
if (offs < 0 || offs > len || len < 0 || offs + len > dsts.length) {
throw new ArrayIndexOutOfBoundsException();
}
if ((!sslEngine.isDataAvailable() && sslEngine.isInboundClosed()) || sslEngine.isClosed()) {
return -1;
}
final int readResult;
final long unwrapResult;
synchronized (sslEngine.getUnwrapLock()) {
final ByteBuffer unwrapBuffer = sslEngine.getUnwrapBuffer().compact();
try {
readResult = super.read(unwrapBuffer);
} finally {
unwrapBuffer.flip();
}
}
unwrapResult = sslEngine.unwrap(dsts, offs, len);
if (unwrapResult == 0 && readResult == -1) {
terminateReads();
return -1;
}
return unwrapResult;
}
@Override
public void resumeReads() {
if (tls && sslEngine.isFirstHandshake()) {
super.wakeupReads();
} else {
super.resumeReads();
}
}
@Override
public void terminateReads() throws IOException {
if (tls) {
try {
sslEngine.closeInbound();
} catch (IOException ex) {
try {
super.terminateReads();
} catch (IOException e2) {
e2.addSuppressed(ex);
throw e2;
}
throw ex;
}
}
super.terminateReads();
}
@Override
public void awaitReadable() throws IOException {
if (tls) {
sslEngine.awaitCanUnwrap();
}
if(sslEngine.isDataAvailable()) {
return;
}
super.awaitReadable();
}
@Override
public void awaitReadable(long time, TimeUnit timeUnit) throws IOException {
if (!tls) {
super.awaitReadable(time, timeUnit);
return;
}
synchronized (sslEngine.getUnwrapLock()) {
if(sslEngine.getUnwrapBuffer().hasRemaining()) {
return;
}
}
long duration = timeUnit.toNanos(time);
long awaited = System.nanoTime();
sslEngine.awaitCanUnwrap(time, timeUnit);
awaited = System.nanoTime() - awaited;
if (awaited > duration) {
return;
}
super.awaitReadable(duration - awaited, TimeUnit.NANOSECONDS);
}
}