package io.vertx.ext.jwt;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.core.logging.Logger;
import io.vertx.core.logging.LoggerFactory;
import javax.crypto.Mac;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.security.cert.X509Certificate;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
public final class JWT {
private static final Random RND = new Random();
private static final Map<String, String> ALGORITHM_ALIAS = new HashMap<String, String>() {{
put("HS256", "HMacSHA256");
put("HS384", "HMacSHA384");
put("HS512", "HMacSHA512");
put("RS256", "SHA256withRSA");
put("RS384", "SHA384withRSA");
put("RS512", "SHA512withRSA");
put("ES256", "SHA256withECDSA");
put("ES384", "SHA384withECDSA");
put("ES512", "SHA512withECDSA");
}};
private static final Charset UTF8 = StandardCharsets.UTF_8;
private static final Logger log = LoggerFactory.getLogger(JWT.class);
private static final Base64.Encoder encoder = Base64.getUrlEncoder().withoutPadding();
private static final Base64.Decoder decoder = Base64.getUrlDecoder();
private final Map<String, List<Crypto>> cryptoMap = new ConcurrentHashMap<>();
public JWT() {
cryptoMap.put("none", Collections.singletonList(new CryptoNone()));
}
@Deprecated
public JWT(final KeyStore keyStore, final char[] keyStorePassword) {
this();
for (String alg : Arrays.asList("HS256", "HS384", "HS512")) {
try {
Mac mac = getMac(keyStore, keyStorePassword, alg);
if (mac != null) {
List<Crypto> l = cryptoMap.computeIfAbsent(alg, k -> new ArrayList<>());
l.add(new CryptoMac(mac));
} else {
log.info(alg + " not available");
}
} catch (RuntimeException e) {
log.warn(alg + " not supported", e);
}
}
for (String alg : Arrays.asList("RS256", "RS384", "RS512", "ES256", "ES384", "ES512")) {
try {
X509Certificate certificate = getCertificate(keyStore, alg);
PrivateKey privateKey = getPrivateKey(keyStore, keyStorePassword, alg);
if (certificate != null && privateKey != null) {
List<Crypto> l = cryptoMap.computeIfAbsent(alg, k -> new ArrayList<>());
l.add(new CryptoSignature(ALGORITHM_ALIAS.get(alg), certificate, privateKey));
} else {
log.info(alg + " not available");
}
} catch (RuntimeException e) {
e.printStackTrace();
log.warn(alg + " not supported");
}
}
}
@Deprecated
public JWT(String key, boolean keyPrivate) {
this();
if (keyPrivate) {
addSecretKey("RS256", key);
} else {
addPublicKey("RS256", key);
}
}
public JWT addJWK(JWK jwk) {
List<Crypto> current = cryptoMap.computeIfAbsent(jwk.getAlgorithm(), k -> new ArrayList<>());
boolean replaced = false;
for (int i = 0; i < current.size(); i++) {
if (current.get(i).getId().equals(jwk.getId())) {
current.set(i, jwk);
replaced = true;
}
}
if (!replaced) {
current.add(jwk);
}
return this;
}
@Deprecated
public JWT addPublicKey(String algorithm, String key) {
return addJWK(new JWK(algorithm, key, null));
}
@Deprecated
public JWT addKeyPair(String algorithm, String publicKey, String privateKey) {
return addJWK(new JWK(algorithm, publicKey, privateKey));
}
@Deprecated
public JWT addSecretKey(String algorithm, String key) {
return addJWK(new JWK(algorithm, null, key));
}
@Deprecated
public JWT addCertificate(String algorithm, String cert) {
return addJWK(new JWK(algorithm, true, cert, null));
}
@Deprecated
public JWT addSecret(String algorithm, String key) {
return addJWK(new JWK(algorithm, key));
}
private Mac getMac(final KeyStore keyStore, final char[] keyStorePassword, final String alias) {
try {
final Key secretKey = keyStore.getKey(alias, keyStorePassword);
if (secretKey == null) {
return null;
}
Mac mac = Mac.getInstance(secretKey.getAlgorithm());
mac.init(secretKey);
return mac;
} catch (NoSuchAlgorithmException | InvalidKeyException | UnrecoverableKeyException | KeyStoreException e) {
throw new RuntimeException(e);
}
}
private X509Certificate getCertificate(final KeyStore keyStore, final String alias) {
try {
return (X509Certificate) keyStore.getCertificate(alias);
} catch (KeyStoreException e) {
throw new RuntimeException(e);
}
}
private PrivateKey getPrivateKey(final KeyStore keyStore, final char[] keyStorePassword, final String alias) {
try {
return (PrivateKey) keyStore.getKey(alias, keyStorePassword);
} catch (NoSuchAlgorithmException | UnrecoverableKeyException | KeyStoreException e) {
throw new RuntimeException(e);
}
}
public JsonObject decode(final String token) {
String[] segments = token.split("\\.");
if (segments.length != (isUnsecure() ? 2 : 3)) {
throw new RuntimeException("Not enough or too many segments");
}
String headerSeg = segments[0];
String payloadSeg = segments[1];
String signatureSeg = isUnsecure() ? null : segments[2];
if ("".equals(signatureSeg)) {
throw new RuntimeException("Signature is required");
}
JsonObject header = new JsonObject(new String(base64urlDecode(headerSeg), UTF8));
JsonObject payload = new JsonObject(new String(base64urlDecode(payloadSeg), UTF8));
String alg = header.getString("alg");
List<Crypto> cryptos = cryptoMap.get(alg);
if (cryptos == null || cryptos.size() == 0) {
throw new RuntimeException("Algorithm not supported");
}
if (!isUnsecure() && "none".equals(alg)) {
throw new RuntimeException("Algorithm \"none\" not allowed");
}
if (!isUnsecure()) {
byte[] payloadInput = base64urlDecode(signatureSeg);
byte[] signingInput = (headerSeg + "." + payloadSeg).getBytes(UTF8);
for (Crypto c : cryptos) {
if (c.verify(payloadInput, signingInput)) {
return payload;
}
}
throw new RuntimeException("Signature verification failed");
}
return payload;
}
public boolean isExpired(JsonObject jwt, JWTOptions options) {
if (jwt == null) {
return false;
}
final long now = (System.currentTimeMillis() / 1000);
if (jwt.containsKey("exp") && !options.isIgnoreExpiration()) {
if (now - options.getLeeway() >= jwt.getLong("exp")) {
throw new RuntimeException("Expired JWT token: exp <= now");
}
}
if (jwt.containsKey("iat")) {
Long iat = jwt.getLong("iat");
if (iat > now + options.getLeeway()) {
throw new RuntimeException("Invalid JWT token: iat > now");
}
}
if (jwt.containsKey("nbf")) {
Long nbf = jwt.getLong("nbf");
if (nbf > now + options.getLeeway()) {
throw new RuntimeException("Invalid JWT token: nbf > now");
}
}
return false;
}
public String sign(JsonObject payload, JWTOptions options) {
final String algorithm = options.getAlgorithm();
List<Crypto> cryptos = cryptoMap.get(algorithm);
if (cryptos == null || cryptos.size() == 0) {
throw new RuntimeException("Algorithm not supported");
}
JsonObject header = new JsonObject()
.mergeIn(options.getHeader())
.put("typ", "JWT")
.put("alg", algorithm);
long timestamp = System.currentTimeMillis() / 1000;
if (!options.isNoTimestamp()) {
payload.put("iat", payload.getValue("iat", timestamp));
}
if (options.getExpiresInSeconds() > 0) {
payload.put("exp", timestamp + options.getExpiresInSeconds());
}
if (options.getAudience() != null && options.getAudience().size() >= 1) {
if (options.getAudience().size() > 1) {
payload.put("aud", new JsonArray(options.getAudience()));
} else {
payload.put("aud", options.getAudience().get(0));
}
}
if (options.getIssuer() != null) {
payload.put("iss", options.getIssuer());
}
if (options.getSubject() != null) {
payload.put("sub", options.getSubject());
}
String headerSegment = base64urlEncode(header.encode());
String payloadSegment = base64urlEncode(payload.encode());
String signingInput = headerSegment + "." + payloadSegment;
String signSegment = base64urlEncode(cryptos.get(RND.nextInt(cryptos.size())).sign(signingInput.getBytes(UTF8)));
return headerSegment + "." + payloadSegment + "." + signSegment;
}
private static byte[] base64urlDecode(String str) {
return decoder.decode(str.getBytes(UTF8));
}
private static String base64urlEncode(String str) {
return base64urlEncode(str.getBytes(UTF8));
}
private static String base64urlEncode(byte[] bytes) {
return encoder.encodeToString(bytes);
}
public boolean isUnsecure() {
return cryptoMap.size() == 1;
}
public Collection<String> availableAlgorithms() {
return cryptoMap.keySet();
}
}