/*
 * JBoss, Home of Professional Open Source
 *
 * Copyright 2013 Red Hat, Inc. and/or its affiliates.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.xnio.ssl;

import static org.xnio._private.Messages.msg;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.FileChannel;
import java.util.concurrent.TimeUnit;

import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.AbstractStreamSinkConduit;
import org.xnio.conduits.ConduitWritableByteChannel;
import org.xnio.conduits.Conduits;
import org.xnio.conduits.StreamSinkConduit;

Jsse SSL source conduit implementation based on JsseSslConduitEngine.
Author:Flavia Rainone
/** * Jsse SSL source conduit implementation based on {@link JsseSslConduitEngine}. * * @author <a href="mailto:frainone@redhat.com">Flavia Rainone</a> * */
final class JsseSslStreamSinkConduit extends AbstractStreamSinkConduit<StreamSinkConduit> { private final JsseSslConduitEngine sslEngine; private volatile boolean tls; protected JsseSslStreamSinkConduit(StreamSinkConduit next, JsseSslConduitEngine sslEngine, boolean tls) { super(next); if (sslEngine == null) { throw msg.nullParameter("sslEngine"); } this.sslEngine = sslEngine; this.tls = tls; } public void enableTls() { tls = true; if (isWriteResumed()) { wakeupWrites(); } } @Override public long transferFrom(final FileChannel src, final long position, final long count) throws IOException { return src.transferTo(position, count, new ConduitWritableByteChannel(this)); } @Override public long transferFrom(final StreamSourceChannel source, final long count, final ByteBuffer throughBuffer) throws IOException { return Conduits.transfer(source, count, throughBuffer, this); } @Override public int write(ByteBuffer src) throws IOException { return write(src, false); } @Override public long write(ByteBuffer[] srcs, int offs, int len) throws IOException { return write(srcs, offs, len, false); } @Override public int writeFinal(ByteBuffer src) throws IOException { return write(src, true); } @Override public long writeFinal(ByteBuffer[] srcs, int offset, int length) throws IOException { return write(srcs, offset, length, true); } private int write(ByteBuffer src, boolean writeFinal) throws IOException { if (!tls) { if(writeFinal) { return next.writeFinal(src); } else { return next.write(src); } } final int wrappedBytes = sslEngine.wrap(src); if (wrappedBytes > 0) { writeWrappedBuffer(writeFinal); } return wrappedBytes; } private long write(ByteBuffer[] srcs, int offs, int len, boolean writeFinal) throws IOException { if (!tls) { if(writeFinal) { return super.writeFinal(srcs, offs, len); } else { return super.write(srcs, offs, len); } } final long wrappedBytes = sslEngine.wrap(srcs, offs, len); if (wrappedBytes > 0) { writeWrappedBuffer(writeFinal); } return wrappedBytes; } @Override public void resumeWrites() { if (tls && sslEngine.isFirstHandshake()) { super.wakeupWrites(); } else { super.resumeWrites(); } } @Override public void terminateWrites() throws IOException { if (!tls) { super.terminateWrites(); return; } try { sslEngine.closeOutbound(); flush(); } catch (IOException e) { try { super.truncateWrites(); } catch (IOException e2) { e2.addSuppressed(e); throw e2; } throw e; } } @Override public void awaitWritable() throws IOException { if (tls) { sslEngine.awaitCanWrap(); } super.awaitWritable(); } @Override public void awaitWritable(long time, TimeUnit timeUnit) throws IOException { if (!tls) { super.awaitWritable(time, timeUnit); return; } long duration = timeUnit.toNanos(time); long awaited = System.nanoTime(); sslEngine.awaitCanWrap(time, timeUnit); awaited = System.nanoTime() - awaited; if (awaited > duration) { return; } super.awaitWritable(duration - awaited, TimeUnit.NANOSECONDS); } @Override public void truncateWrites() throws IOException { if (tls) try { sslEngine.closeOutbound(); } finally { try { super.truncateWrites(); } catch (IOException ignored) { } } super.truncateWrites(); } @Override public boolean flush() throws IOException { if (!tls) { return super.flush(); } if (sslEngine.isOutboundClosed()) { if (sslEngine.flush() && writeWrappedBuffer(false) && super.flush()) { super.terminateWrites(); return true; } else { return false; } } return sslEngine.flush() && writeWrappedBuffer(false) && super.flush(); } private boolean writeWrappedBuffer(boolean writeFinal) throws IOException { synchronized (sslEngine.getWrapLock()) { final ByteBuffer wrapBuffer = sslEngine.getWrappedBuffer(); for (;;) { try { if (!wrapBuffer.flip().hasRemaining()) { if (writeFinal) { terminateWrites(); } return true; } if(writeFinal) { if (super.writeFinal(wrapBuffer) == 0) { return false; } } else { if (super.write(wrapBuffer) == 0) { return false; } } } finally { wrapBuffer.compact(); } } } } }