package sun.security.ssl;
import sun.security.action.GetPropertyAction;
import sun.security.ssl.SSLExtension.ExtensionConsumer;
import sun.security.ssl.SSLExtension.SSLExtensionSpec;
import sun.security.ssl.SSLHandshake.HandshakeMessage;
import sun.security.ssl.SupportedGroupsExtension.SupportedGroups;
import sun.security.util.HexDumpEncoder;
import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import javax.net.ssl.SSLProtocolException;
import static sun.security.ssl.SSLExtension.CH_SESSION_TICKET;
import static sun.security.ssl.SSLExtension.SH_SESSION_TICKET;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.text.MessageFormat;
import java.util.Locale;
final class SessionTicketExtension {
static final HandshakeProducer chNetworkProducer =
new T12CHSessionTicketProducer();
static final ExtensionConsumer chOnLoadConsumer =
new T12CHSessionTicketConsumer();
static final HandshakeProducer shNetworkProducer =
new T12SHSessionTicketProducer();
static final ExtensionConsumer shOnLoadConsumer =
new T12SHSessionTicketConsumer();
static final SSLStringizer steStringizer = new SessionTicketStringizer();
private static final int TIMEOUT_DEFAULT = 3600 * 1000;
private static final int keyTimeout;
private static int currentKeyID = new SecureRandom().nextInt();
private static final int KEYLEN = 256;
static {
String s = GetPropertyAction.privilegedGetProperty(
"jdk.tls.server.statelessKeyTimeout");
if (s != null) {
int kt;
try {
kt = Integer.parseInt(s) * 1000;
if (kt < 0 ||
kt > NewSessionTicket.MAX_TICKET_LIFETIME) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
SSLLogger.warning("Invalid timeout for " +
"jdk.tls.server.statelessKeyTimeout: " +
kt + ". Set to default value " +
TIMEOUT_DEFAULT + "sec");
}
kt = TIMEOUT_DEFAULT;
}
} catch (NumberFormatException e) {
kt = TIMEOUT_DEFAULT;
if (SSLLogger.isOn && SSLLogger.isOn("ssl")) {
SSLLogger.warning("Invalid timeout for " +
"jdk.tls.server.statelessKeyTimeout: " + s +
". Set to default value " + TIMEOUT_DEFAULT +
"sec");
}
}
keyTimeout = kt;
} else {
keyTimeout = TIMEOUT_DEFAULT;
}
}
final static class StatelessKey {
final long timeout;
final SecretKey key;
final int num;
StatelessKey(HandshakeContext hc, int newNum) {
SecretKey k = null;
try {
KeyGenerator kg = KeyGenerator.getInstance("AES");
kg.init(KEYLEN, hc.sslContext.getSecureRandom());
k = kg.generateKey();
} catch (NoSuchAlgorithmException e) {
}
key = k;
timeout = System.currentTimeMillis() + keyTimeout;
num = newNum;
hc.sslContext.keyHashMap.put(Integer.valueOf(num), this);
}
boolean isExpired() {
return ((System.currentTimeMillis()) > timeout);
}
boolean isInvalid(long sessionTimeout) {
return ((System.currentTimeMillis()) > (timeout + sessionTimeout));
}
}
private static final class KeyState {
static StatelessKey getKey(HandshakeContext hc, int num) {
StatelessKey ssk = hc.sslContext.keyHashMap.get(num);
if (ssk == null || ssk.isInvalid(getSessionTimeout(hc))) {
return null;
}
return ssk;
}
static StatelessKey getCurrentKey(HandshakeContext hc) {
StatelessKey ssk = hc.sslContext.keyHashMap.get(currentKeyID);
if (ssk != null && !ssk.isExpired()) {
return ssk;
}
return nextKey(hc);
}
private static StatelessKey nextKey(HandshakeContext hc) {
StatelessKey ssk;
synchronized (hc.sslContext.keyHashMap) {
ssk = hc.sslContext.keyHashMap.get(currentKeyID);
if (ssk != null && !ssk.isExpired()) {
return ssk;
}
int newNum;
if (currentKeyID == Integer.MAX_VALUE) {
newNum = 0;
} else {
newNum = currentKeyID + 1;
}
ssk = new StatelessKey(hc, newNum);
currentKeyID = newNum;
}
cleanup(hc);
return ssk;
}
static void cleanup(HandshakeContext hc) {
int sessionTimeout = getSessionTimeout(hc);
StatelessKey ks;
for (Object o : hc.sslContext.keyHashMap.keySet().toArray()) {
Integer i = (Integer)o;
ks = hc.sslContext.keyHashMap.get(i);
if (ks.isInvalid(sessionTimeout)) {
try {
ks.key.destroy();
} catch (Exception e) {
}
hc.sslContext.keyHashMap.remove(i);
}
}
}
static int getSessionTimeout(HandshakeContext hc) {
return hc.sslContext.engineGetServerSessionContext().
getSessionTimeout() * 1000;
}
}
static final class SessionTicketSpec implements SSLExtensionSpec {
private static final int GCM_TAG_LEN = 128;
ByteBuffer data;
static final ByteBuffer zero = ByteBuffer.wrap(new byte[0]);
SessionTicketSpec() {
data = zero;
}
SessionTicketSpec(HandshakeContext hc, byte[] b) throws IOException {
this(hc, ByteBuffer.wrap(b));
}
SessionTicketSpec(HandshakeContext hc,
ByteBuffer buf) throws IOException {
if (buf == null) {
throw hc.conContext.fatal(Alert.DECODE_ERROR,
new SSLProtocolException(
"SessionTicket buffer too small"));
}
if (buf.remaining() > 65536) {
throw hc.conContext.fatal(Alert.DECODE_ERROR,
new SSLProtocolException(
"SessionTicket buffer too large. " + buf.remaining()));
}
data = buf;
}
public byte[] encrypt(HandshakeContext hc, SSLSessionImpl session) {
byte[] encrypted;
if (!hc.statelessResumption ||
!hc.handshakeSession.isStatelessable()) {
return new byte[0];
}
try {
StatelessKey key = KeyState.getCurrentKey(hc);
byte[] iv = new byte[16];
SecureRandom random = hc.sslContext.getSecureRandom();
random.nextBytes(iv);
Cipher c = Cipher.getInstance("AES/GCM/NoPadding");
c.init(Cipher.ENCRYPT_MODE, key.key,
new GCMParameterSpec(GCM_TAG_LEN, iv));
c.updateAAD(new byte[] {
(byte)(key.num >>> 24),
(byte)(key.num >>> 16),
(byte)(key.num >>> 8),
(byte)(key.num)}
);
byte[] data = session.write();
if (data.length == 0) {
return data;
}
encrypted = c.doFinal(data);
byte[] result = new byte[encrypted.length + Integer.BYTES +
iv.length];
result[0] = (byte)(key.num >>> 24);
result[1] = (byte)(key.num >>> 16);
result[2] = (byte)(key.num >>> 8);
result[3] = (byte)(key.num);
System.arraycopy(iv, 0, result, Integer.BYTES, iv.length);
System.arraycopy(encrypted, 0, result,
Integer.BYTES + iv.length, encrypted.length);
return result;
} catch (Exception e) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Encryption failed." + e);
}
return new byte[0];
}
}
ByteBuffer decrypt(HandshakeContext hc) {
int keyID;
byte[] iv;
try {
keyID = data.getInt();
StatelessKey key = KeyState.getKey(hc, keyID);
if (key == null) {
return null;
}
iv = new byte[16];
data.get(iv);
Cipher c = Cipher.getInstance("AES/GCM/NoPadding");
c.init(Cipher.DECRYPT_MODE, key.key,
new GCMParameterSpec(GCM_TAG_LEN, iv));
c.updateAAD(new byte[] {
(byte)(keyID >>> 24),
(byte)(keyID >>> 16),
(byte)(keyID >>> 8),
(byte)(keyID)}
);
ByteBuffer out;
out = ByteBuffer.allocate(data.remaining() - GCM_TAG_LEN / 8);
c.doFinal(data, out);
out.flip();
return out;
} catch (Exception e) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Decryption failed." + e.getMessage());
}
}
return null;
}
byte[] getEncoded() {
byte[] out = new byte[data.capacity()];
data.duplicate().get(out);
return out;
}
@Override
public String toString() {
if (data == null) {
return "<null>";
}
if (data.capacity() == 0) {
return "<empty>";
}
MessageFormat messageFormat = new MessageFormat(
" \"ticket\" : '{'\n" +
"{0}\n" +
" '}'",
Locale.ENGLISH);
HexDumpEncoder hexEncoder = new HexDumpEncoder();
Object[] messageFields = {
Utilities.indent(hexEncoder.encode(data.duplicate()),
" "),
};
return messageFormat.format(messageFields);
}
}
static final class SessionTicketStringizer implements SSLStringizer {
@Override
public String toString(HandshakeContext hc, ByteBuffer buffer) {
try {
return new SessionTicketSpec(hc, buffer).toString();
} catch (IOException e) {
return e.getMessage();
}
}
}
private static final class T12CHSessionTicketProducer
extends SupportedGroups implements HandshakeProducer {
T12CHSessionTicketProducer() {
}
@Override
public byte[] produce(ConnectionContext context,
HandshakeMessage message) throws IOException {
ClientHandshakeContext chc = (ClientHandshakeContext)context;
if (!((SSLSessionContextImpl)chc.sslContext.
engineGetClientSessionContext()).statelessEnabled()) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Stateless resumption not supported");
}
return null;
}
chc.statelessResumption = true;
if (!chc.isResumption || chc.resumingSession == null) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Stateless resumption supported");
}
return new SessionTicketSpec().getEncoded();
}
if (chc.localSupportedSignAlgs == null) {
chc.localSupportedSignAlgs =
SignatureScheme.getSupportedAlgorithms(
chc.sslConfig,
chc.algorithmConstraints, chc.activeProtocols);
}
return chc.resumingSession.getPskIdentity();
}
}
private static final class T12CHSessionTicketConsumer
implements ExtensionConsumer {
T12CHSessionTicketConsumer() {
}
@Override
public void consume(ConnectionContext context,
HandshakeMessage message, ByteBuffer buffer)
throws IOException {
ServerHandshakeContext shc = (ServerHandshakeContext) context;
if (!shc.sslConfig.isAvailable(CH_SESSION_TICKET)) {
return;
}
if (shc.statelessResumption) {
return;
}
SSLSessionContextImpl cache = (SSLSessionContextImpl)shc.sslContext
.engineGetServerSessionContext();
if (!cache.statelessEnabled()) {
return;
}
shc.statelessResumption = true;
if (buffer.remaining() == 0) {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Client accepts session tickets.");
}
return;
}
SessionTicketSpec spec = new SessionTicketSpec(shc, buffer);
ByteBuffer b = spec.decrypt(shc);
if (b != null) {
shc.resumingSession = new SSLSessionImpl(shc, b);
shc.isResumption = true;
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Valid stateless session ticket found");
}
} else {
if (SSLLogger.isOn && SSLLogger.isOn("ssl,handshake")) {
SSLLogger.fine("Invalid stateless session ticket found");
}
}
}
}
private static final class T12SHSessionTicketProducer
extends SupportedGroups implements HandshakeProducer {
T12SHSessionTicketProducer() {
}
@Override
public byte[] produce(ConnectionContext context,
HandshakeMessage message) {
ServerHandshakeContext shc = (ServerHandshakeContext)context;
if (!shc.statelessResumption) {
return null;
}
SSLSessionContextImpl cache = (SSLSessionContextImpl)shc.sslContext
.engineGetServerSessionContext();
if (cache.statelessEnabled()) {
return new byte[0];
}
shc.statelessResumption = false;
return null;
}
}
private static final class T12SHSessionTicketConsumer
implements ExtensionConsumer {
T12SHSessionTicketConsumer() {
}
@Override
public void consume(ConnectionContext context,
HandshakeMessage message, ByteBuffer buffer)
throws IOException {
ClientHandshakeContext chc = (ClientHandshakeContext) context;
if (!chc.sslConfig.isAvailable(SH_SESSION_TICKET)) {
chc.statelessResumption = false;
return;
}
if (!((SSLSessionContextImpl)chc.sslContext.
engineGetClientSessionContext()).statelessEnabled()) {
chc.statelessResumption = false;
return;
}
SessionTicketSpec spec = new SessionTicketSpec(chc, buffer);
chc.statelessResumption = true;
}
}
}