package sun.security.ssl;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ReadOnlyBufferException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLKeyException;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLProtocolException;
import javax.net.ssl.SSLSession;
final class SSLEngineImpl extends SSLEngine implements SSLTransport {
private final SSLContextImpl sslContext;
final TransportContext conContext;
SSLEngineImpl(SSLContextImpl sslContext) {
this(sslContext, null, -1);
}
SSLEngineImpl(SSLContextImpl sslContext,
String host, int port) {
super(host, port);
this.sslContext = sslContext;
HandshakeHash handshakeHash = new HandshakeHash();
this.conContext = new TransportContext(sslContext, this,
new SSLEngineInputRecord(handshakeHash),
new SSLEngineOutputRecord(handshakeHash));
if (host != null) {
this.conContext.sslConfig.serverNames =
Utilities.addToSNIServerNameList(
conContext.sslConfig.serverNames, host);
}
}
@Override
public synchronized void beginHandshake() throws SSLException {
if (conContext.isUnsureMode) {
throw new IllegalStateException(
"Client/Server mode has not yet been set.");
}
try {
conContext.kickstart();
} catch (IOException ioe) {
throw conContext.fatal(Alert.HANDSHAKE_FAILURE,
"Couldn't kickstart handshaking", ioe);
} catch (Exception ex) {
throw conContext.fatal(Alert.INTERNAL_ERROR,
"Fail to begin handshake", ex);
}
}
@Override
public synchronized SSLEngineResult wrap(ByteBuffer[] appData,
int offset, int length, ByteBuffer netData) throws SSLException {
return wrap(appData, offset, length, new ByteBuffer[]{ netData }, 0, 1);
}
public synchronized SSLEngineResult wrap(
ByteBuffer[] srcs, int srcsOffset, int srcsLength,
ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws SSLException {
if (conContext.isUnsureMode) {
throw new IllegalStateException(
"Client/Server mode has not yet been set.");
}
checkTaskThrown();
checkParams(srcs, srcsOffset, srcsLength, dsts, dstsOffset, dstsLength);
try {
return writeRecord(
srcs, srcsOffset, srcsLength, dsts, dstsOffset, dstsLength);
} catch (SSLProtocolException spe) {
throw conContext.fatal(Alert.UNEXPECTED_MESSAGE, spe);
} catch (IOException ioe) {
throw conContext.fatal(Alert.INTERNAL_ERROR,
"problem wrapping app data", ioe);
} catch (Exception ex) {
throw conContext.fatal(Alert.INTERNAL_ERROR,
"Fail to wrap application data", ex);
}
}
private SSLEngineResult writeRecord(
ByteBuffer[] srcs, int srcsOffset, int srcsLength,
ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws IOException {
if (isOutboundDone()) {
return new SSLEngineResult(
Status.CLOSED, getHandshakeStatus(), 0, 0);
}
HandshakeContext hc = conContext.handshakeContext;
HandshakeStatus hsStatus = null;
if (!conContext.isNegotiated && !conContext.isBroken &&
!conContext.isInboundClosed() &&
!conContext.isOutboundClosed()) {
conContext.kickstart();
hsStatus = getHandshakeStatus();
if (hsStatus == HandshakeStatus.NEED_UNWRAP) {
return new SSLEngineResult(Status.OK, hsStatus, 0, 0);
}
}
if (hsStatus == null) {
hsStatus = getHandshakeStatus();
}
if (hsStatus == HandshakeStatus.NEED_TASK) {
return new SSLEngineResult(Status.OK, hsStatus, 0, 0);
}
int dstsRemains = 0;
for (int i = dstsOffset; i < dstsOffset + dstsLength; i++) {
dstsRemains += dsts[i].remaining();
}
if (dstsRemains < conContext.conSession.getPacketBufferSize()) {
return new SSLEngineResult(
Status.BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0);
}
int srcsRemains = 0;
for (int i = srcsOffset; i < srcsOffset + srcsLength; i++) {
srcsRemains += srcs[i].remaining();
}
Ciphertext ciphertext = null;
try {
if (!conContext.outputRecord.isEmpty()) {
ciphertext = encode(null, 0, 0,
dsts, dstsOffset, dstsLength);
}
if (ciphertext == null && srcsRemains != 0) {
ciphertext = encode(srcs, srcsOffset, srcsLength,
dsts, dstsOffset, dstsLength);
}
} catch (IOException ioe) {
if (ioe instanceof SSLException) {
throw ioe;
} else {
throw new SSLException("Write problems", ioe);
}
}
Status status = (isOutboundDone() ? Status.CLOSED : Status.OK);
if (ciphertext != null && ciphertext.handshakeStatus != null) {
hsStatus = ciphertext.handshakeStatus;
} else {
hsStatus = getHandshakeStatus();
}
int deltaSrcs = srcsRemains;
for (int i = srcsOffset; i < srcsOffset + srcsLength; i++) {
deltaSrcs -= srcs[i].remaining();
}
int deltaDsts = dstsRemains;
for (int i = dstsOffset; i < dstsOffset + dstsLength; i++) {
deltaDsts -= dsts[i].remaining();
}
return new SSLEngineResult(status, hsStatus, deltaSrcs, deltaDsts);
}
private Ciphertext encode(
ByteBuffer[] srcs, int srcsOffset, int srcsLength,
ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws IOException {
Ciphertext ciphertext = null;
try {
ciphertext = conContext.outputRecord.encode(
srcs, srcsOffset, srcsLength, dsts, dstsOffset, dstsLength);
} catch (SSLHandshakeException she) {
throw conContext.fatal(Alert.HANDSHAKE_FAILURE, she);
} catch (IOException e) {
throw conContext.fatal(Alert.UNEXPECTED_MESSAGE, e);
}
if (ciphertext == null) {
return Ciphertext.CIPHERTEXT_NULL;
}
HandshakeStatus hsStatus =
tryToFinishHandshake(ciphertext.contentType);
if (hsStatus == null) {
hsStatus = conContext.getHandshakeStatus();
}
if (conContext.outputRecord.seqNumIsHuge() ||
conContext.outputRecord.writeCipher.atKeyLimit()) {
hsStatus = tryKeyUpdate(hsStatus);
}
ciphertext.handshakeStatus = hsStatus;
return ciphertext;
}
private HandshakeStatus tryToFinishHandshake(byte contentType) {
HandshakeStatus hsStatus = null;
if ((contentType == ContentType.HANDSHAKE.id) &&
conContext.outputRecord.isEmpty()) {
if (conContext.handshakeContext == null) {
hsStatus = HandshakeStatus.FINISHED;
} else if (conContext.isPostHandshakeContext()) {
hsStatus = conContext.finishPostHandshake();
} else if (conContext.handshakeContext.handshakeFinished) {
hsStatus = conContext.finishHandshake();
}
}
return hsStatus;
}
private HandshakeStatus tryKeyUpdate(
HandshakeStatus currentHandshakeStatus) throws IOException {
if ((conContext.handshakeContext == null) &&
!conContext.isOutboundClosed() &&
!conContext.isInboundClosed() &&
!conContext.isBroken) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
SSLLogger.finest("trigger key update");
}
beginHandshake();
return conContext.getHandshakeStatus();
}
return currentHandshakeStatus;
}
private static void checkParams(
ByteBuffer[] srcs, int srcsOffset, int srcsLength,
ByteBuffer[] dsts, int dstsOffset, int dstsLength) {
if ((srcs == null) || (dsts == null)) {
throw new IllegalArgumentException(
"source or destination buffer is null");
}
if ((dstsOffset < 0) || (dstsLength < 0) ||
(dstsOffset > dsts.length - dstsLength)) {
throw new IndexOutOfBoundsException(
"index out of bound of the destination buffers");
}
if ((srcsOffset < 0) || (srcsLength < 0) ||
(srcsOffset > srcs.length - srcsLength)) {
throw new IndexOutOfBoundsException(
"index out of bound of the source buffers");
}
for (int i = dstsOffset; i < dstsOffset + dstsLength; i++) {
if (dsts[i] == null) {
throw new IllegalArgumentException(
"destination buffer[" + i + "] == null");
}
if (dsts[i].isReadOnly()) {
throw new ReadOnlyBufferException();
}
}
for (int i = srcsOffset; i < srcsOffset + srcsLength; i++) {
if (srcs[i] == null) {
throw new IllegalArgumentException(
"source buffer[" + i + "] == null");
}
}
}
@Override
public synchronized SSLEngineResult unwrap(ByteBuffer src,
ByteBuffer[] dsts, int offset, int length) throws SSLException {
return unwrap(
new ByteBuffer[]{src}, 0, 1, dsts, offset, length);
}
public synchronized SSLEngineResult unwrap(
ByteBuffer[] srcs, int srcsOffset, int srcsLength,
ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws SSLException {
if (conContext.isUnsureMode) {
throw new IllegalStateException(
"Client/Server mode has not yet been set.");
}
checkTaskThrown();
checkParams(srcs, srcsOffset, srcsLength, dsts, dstsOffset, dstsLength);
try {
return readRecord(
srcs, srcsOffset, srcsLength, dsts, dstsOffset, dstsLength);
} catch (SSLProtocolException spe) {
throw conContext.fatal(Alert.UNEXPECTED_MESSAGE,
spe.getMessage(), spe);
} catch (IOException ioe) {
throw conContext.fatal(Alert.INTERNAL_ERROR,
"problem unwrapping net record", ioe);
} catch (Exception ex) {
throw conContext.fatal(Alert.INTERNAL_ERROR,
"Fail to unwrap network record", ex);
}
}
private SSLEngineResult readRecord(
ByteBuffer[] srcs, int srcsOffset, int srcsLength,
ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws IOException {
if (isInboundDone()) {
return new SSLEngineResult(
Status.CLOSED, getHandshakeStatus(), 0, 0);
}
HandshakeStatus hsStatus = null;
if (!conContext.isNegotiated && !conContext.isBroken &&
!conContext.isInboundClosed() &&
!conContext.isOutboundClosed()) {
conContext.kickstart();
hsStatus = getHandshakeStatus();
if (hsStatus == HandshakeStatus.NEED_WRAP) {
return new SSLEngineResult(Status.OK, hsStatus, 0, 0);
}
}
if (hsStatus == null) {
hsStatus = getHandshakeStatus();
}
if (hsStatus == HandshakeStatus.NEED_TASK) {
return new SSLEngineResult(Status.OK, hsStatus, 0, 0);
}
int srcsRemains = 0;
for (int i = srcsOffset; i < srcsOffset + srcsLength; i++) {
srcsRemains += srcs[i].remaining();
}
if (srcsRemains == 0) {
return new SSLEngineResult(
Status.BUFFER_UNDERFLOW, hsStatus, 0, 0);
}
int packetLen = conContext.inputRecord.bytesInCompletePacket(
srcs, srcsOffset, srcsLength);
if (packetLen > conContext.conSession.getPacketBufferSize()) {
int largestRecordSize = SSLRecord.maxLargeRecordSize;
if (packetLen <= largestRecordSize) {
conContext.conSession.expandBufferSizes();
}
largestRecordSize = conContext.conSession.getPacketBufferSize();
if (packetLen > largestRecordSize) {
throw new SSLProtocolException(
"Input record too big: max = " +
largestRecordSize + " len = " + packetLen);
}
}
int dstsRemains = 0;
for (int i = dstsOffset; i < dstsOffset + dstsLength; i++) {
dstsRemains += dsts[i].remaining();
}
if (conContext.isNegotiated) {
int FragLen =
conContext.inputRecord.estimateFragmentSize(packetLen);
if (FragLen > dstsRemains) {
return new SSLEngineResult(
Status.BUFFER_OVERFLOW, hsStatus, 0, 0);
}
}
if ((packetLen == -1) || (srcsRemains < packetLen)) {
return new SSLEngineResult(Status.BUFFER_UNDERFLOW, hsStatus, 0, 0);
}
Plaintext plainText = null;
try {
plainText = decode(srcs, srcsOffset, srcsLength,
dsts, dstsOffset, dstsLength);
} catch (IOException ioe) {
if (ioe instanceof SSLException) {
throw ioe;
} else {
throw new SSLException("readRecord", ioe);
}
}
Status status = (isInboundDone() ? Status.CLOSED : Status.OK);
if (plainText.handshakeStatus != null) {
hsStatus = plainText.handshakeStatus;
} else {
hsStatus = getHandshakeStatus();
}
int deltaNet = srcsRemains;
for (int i = srcsOffset; i < srcsOffset + srcsLength; i++) {
deltaNet -= srcs[i].remaining();
}
int deltaApp = dstsRemains;
for (int i = dstsOffset; i < dstsOffset + dstsLength; i++) {
deltaApp -= dsts[i].remaining();
}
return new SSLEngineResult(status, hsStatus, deltaNet, deltaApp);
}
private Plaintext decode(
ByteBuffer[] srcs, int srcsOffset, int srcsLength,
ByteBuffer[] dsts, int dstsOffset, int dstsLength) throws IOException {
Plaintext pt = SSLTransport.decode(conContext,
srcs, srcsOffset, srcsLength,
dsts, dstsOffset, dstsLength);
if (pt != Plaintext.PLAINTEXT_NULL) {
HandshakeStatus hsStatus = tryToFinishHandshake(pt.contentType);
if (hsStatus == null) {
pt.handshakeStatus = conContext.getHandshakeStatus();
} else {
pt.handshakeStatus = hsStatus;
}
if (conContext.inputRecord.seqNumIsHuge() ||
conContext.inputRecord.readCipher.atKeyLimit()) {
pt.handshakeStatus =
tryKeyUpdate(pt.handshakeStatus);
}
}
return pt;
}
@Override
public synchronized Runnable getDelegatedTask() {
if (conContext.handshakeContext != null &&
!conContext.handshakeContext.taskDelegated &&
!conContext.handshakeContext.delegatedActions.isEmpty()) {
conContext.handshakeContext.taskDelegated = true;
return new DelegatedTask(this);
}
return null;
}
@Override
public synchronized void closeInbound() throws SSLException {
if (isInboundDone()) {
return;
}
if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
SSLLogger.finest("Closing inbound of SSLEngine");
}
if (!conContext.isInputCloseNotified &&
(conContext.isNegotiated || conContext.handshakeContext != null)) {
throw conContext.fatal(Alert.INTERNAL_ERROR,
"closing inbound before receiving peer's close_notify");
}
conContext.closeInbound();
}
@Override
public synchronized boolean isInboundDone() {
return conContext.isInboundClosed();
}
@Override
public synchronized void closeOutbound() {
if (conContext.isOutboundClosed()) {
return;
}
if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
SSLLogger.finest("Closing outbound of SSLEngine");
}
conContext.closeOutbound();
}
@Override
public synchronized boolean isOutboundDone() {
return conContext.isOutboundDone();
}
@Override
public String[] getSupportedCipherSuites() {
return CipherSuite.namesOf(sslContext.getSupportedCipherSuites());
}
@Override
public synchronized String[] getEnabledCipherSuites() {
return CipherSuite.namesOf(conContext.sslConfig.enabledCipherSuites);
}
@Override
public synchronized void setEnabledCipherSuites(String[] suites) {
conContext.sslConfig.enabledCipherSuites =
CipherSuite.validValuesOf(suites);
}
@Override
public String[] getSupportedProtocols() {
return ProtocolVersion.toStringArray(
sslContext.getSupportedProtocolVersions());
}
@Override
public synchronized String[] getEnabledProtocols() {
return ProtocolVersion.toStringArray(
conContext.sslConfig.enabledProtocols);
}
@Override
public synchronized void setEnabledProtocols(String[] protocols) {
if (protocols == null) {
throw new IllegalArgumentException("Protocols cannot be null");
}
conContext.sslConfig.enabledProtocols =
ProtocolVersion.namesOf(protocols);
}
@Override
public synchronized SSLSession getSession() {
return conContext.conSession;
}
@Override
public synchronized SSLSession getHandshakeSession() {
return conContext.handshakeContext == null ?
null : conContext.handshakeContext.handshakeSession;
}
@Override
public synchronized SSLEngineResult.HandshakeStatus getHandshakeStatus() {
return conContext.getHandshakeStatus();
}
@Override
public synchronized void setUseClientMode(boolean mode) {
conContext.setUseClientMode(mode);
}
@Override
public synchronized boolean getUseClientMode() {
return conContext.sslConfig.isClientMode;
}
@Override
public synchronized void setNeedClientAuth(boolean need) {
conContext.sslConfig.clientAuthType =
(need ? ClientAuthType.CLIENT_AUTH_REQUIRED :
ClientAuthType.CLIENT_AUTH_NONE);
}
@Override
public synchronized boolean getNeedClientAuth() {
return (conContext.sslConfig.clientAuthType ==
ClientAuthType.CLIENT_AUTH_REQUIRED);
}
@Override
public synchronized void setWantClientAuth(boolean want) {
conContext.sslConfig.clientAuthType =
(want ? ClientAuthType.CLIENT_AUTH_REQUESTED :
ClientAuthType.CLIENT_AUTH_NONE);
}
@Override
public synchronized boolean getWantClientAuth() {
return (conContext.sslConfig.clientAuthType ==
ClientAuthType.CLIENT_AUTH_REQUESTED);
}
@Override
public synchronized void setEnableSessionCreation(boolean flag) {
conContext.sslConfig.enableSessionCreation = flag;
}
@Override
public synchronized boolean getEnableSessionCreation() {
return conContext.sslConfig.enableSessionCreation;
}
@Override
public synchronized SSLParameters getSSLParameters() {
return conContext.sslConfig.getSSLParameters();
}
@Override
public synchronized void setSSLParameters(SSLParameters params) {
conContext.sslConfig.setSSLParameters(params);
if (conContext.sslConfig.maximumPacketSize != 0) {
conContext.outputRecord.changePacketSize(
conContext.sslConfig.maximumPacketSize);
}
}
@Override
public synchronized String getApplicationProtocol() {
return conContext.applicationProtocol;
}
@Override
public synchronized String getHandshakeApplicationProtocol() {
return conContext.handshakeContext == null ?
null : conContext.handshakeContext.applicationProtocol;
}
@Override
public synchronized void setHandshakeApplicationProtocolSelector(
BiFunction<SSLEngine, List<String>, String> selector) {
conContext.sslConfig.engineAPSelector = selector;
}
@Override
public synchronized BiFunction<SSLEngine, List<String>, String>
getHandshakeApplicationProtocolSelector() {
return conContext.sslConfig.engineAPSelector;
}
@Override
public boolean useDelegatedTask() {
return true;
}
private synchronized void checkTaskThrown() throws SSLException {
Exception exc = null;
HandshakeContext hc = conContext.handshakeContext;
if ((hc != null) && (hc.delegatedThrown != null)) {
exc = hc.delegatedThrown;
hc.delegatedThrown = null;
}
if (conContext.delegatedThrown != null) {
if (exc != null) {
if (conContext.delegatedThrown == exc) {
conContext.delegatedThrown = null;
}
} else {
exc = conContext.delegatedThrown;
conContext.delegatedThrown = null;
}
}
if (exc == null) {
return;
}
if (exc instanceof SSLException) {
throw (SSLException)exc;
} else if (exc instanceof RuntimeException) {
throw (RuntimeException)exc;
} else {
throw getTaskThrown(exc);
}
}
private static SSLException getTaskThrown(Exception taskThrown) {
String msg = taskThrown.getMessage();
if (msg == null) {
msg = "Delegated task threw Exception or Error";
}
if (taskThrown instanceof RuntimeException) {
throw new RuntimeException(msg, taskThrown);
} else if (taskThrown instanceof SSLHandshakeException) {
return (SSLHandshakeException)
new SSLHandshakeException(msg).initCause(taskThrown);
} else if (taskThrown instanceof SSLKeyException) {
return (SSLKeyException)
new SSLKeyException(msg).initCause(taskThrown);
} else if (taskThrown instanceof SSLPeerUnverifiedException) {
return (SSLPeerUnverifiedException)
new SSLPeerUnverifiedException(msg).initCause(taskThrown);
} else if (taskThrown instanceof SSLProtocolException) {
return (SSLProtocolException)
new SSLProtocolException(msg).initCause(taskThrown);
} else if (taskThrown instanceof SSLException) {
return (SSLException)taskThrown;
} else {
return new SSLException(msg, taskThrown);
}
}
private static class DelegatedTask implements Runnable {
private final SSLEngineImpl engine;
DelegatedTask(SSLEngineImpl engineInstance) {
this.engine = engineInstance;
}
@Override
public void run() {
synchronized (engine) {
HandshakeContext hc = engine.conContext.handshakeContext;
if (hc == null || hc.delegatedActions.isEmpty()) {
return;
}
try {
AccessController.doPrivileged(
new DelegatedAction(hc), engine.conContext.acc);
} catch (PrivilegedActionException pae) {
Exception reportedException = pae.getException();
if (engine.conContext.delegatedThrown == null) {
engine.conContext.delegatedThrown = reportedException;
}
hc = engine.conContext.handshakeContext;
if (hc != null) {
hc.delegatedThrown = reportedException;
} else if (engine.conContext.closeReason != null) {
engine.conContext.closeReason =
getTaskThrown(reportedException);
}
} catch (RuntimeException rte) {
if (engine.conContext.delegatedThrown == null) {
engine.conContext.delegatedThrown = rte;
}
hc = engine.conContext.handshakeContext;
if (hc != null) {
hc.delegatedThrown = rte;
} else if (engine.conContext.closeReason != null) {
engine.conContext.closeReason = rte;
}
}
hc = engine.conContext.handshakeContext;
if (hc != null) {
hc.taskDelegated = false;
}
}
}
private static class DelegatedAction
implements PrivilegedExceptionAction<Void> {
final HandshakeContext context;
DelegatedAction(HandshakeContext context) {
this.context = context;
}
@Override
public Void run() throws Exception {
while (!context.delegatedActions.isEmpty()) {
Map.Entry<Byte, ByteBuffer> me =
context.delegatedActions.poll();
if (me != null) {
context.dispatch(me.getKey(), me.getValue());
}
}
return null;
}
}
}
}