package io.vertx.ext.auth.impl.jose;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.auth.JWTOptions;
import io.vertx.ext.auth.NoSuchKeyIdException;
import io.vertx.ext.auth.impl.CertificateHelper;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public final class JWT {
private final Logger logger = LoggerFactory.getLogger(JWT.class);
private static final Random RND = new Random();
private static final Charset UTF8 = StandardCharsets.UTF_8;
private static final Base64.Encoder urlEncoder = Base64.getUrlEncoder().withoutPadding();
private static final Base64.Decoder urlDecoder = Base64.getUrlDecoder();
private static final Base64.Decoder decoder = Base64.getDecoder();
private boolean allowEmbeddedKey = false;
private final Map<String, List<Crypto>> SIGN = new ConcurrentHashMap<>();
private final Map<String, List<Crypto>> VERIFY = new ConcurrentHashMap<>();
public JWT() {
SIGN.put("none", Collections.singletonList(new CryptoNone()));
VERIFY.put("none", Collections.singletonList(new CryptoNone()));
}
public JWT addJWK(JWK jwk) {
List<Crypto> current = null;
if (jwk.isFor(JWK.USE_ENC)) {
current = VERIFY.computeIfAbsent(jwk.getAlgorithm(), k -> new ArrayList<>());
addJWK(current, jwk);
}
if (jwk.isFor(JWK.USE_SIG)) {
current = SIGN.computeIfAbsent(jwk.getAlgorithm(), k -> new ArrayList<>());
addJWK(current, jwk);
}
if (current == null) {
throw new IllegalStateException("unknown JWK use: " + jwk.getUse());
}
return this;
}
public JWT allowEmbeddedKey(boolean allowEmbeddedKey) {
this.allowEmbeddedKey = allowEmbeddedKey;
return this;
}
private void addJWK(List<Crypto> current, JWK jwk) {
boolean replaced = false;
for (int i = 0; i < current.size(); i++) {
if (current.get(i).getLabel().equals(jwk.getLabel())) {
current.set(i, jwk);
replaced = true;
break;
}
}
if (!replaced) {
current.add(jwk);
}
}
public static JsonObject parse(final byte[] token) {
return parse(new String(token, UTF8));
}
public static JsonObject parse(final String token) {
String[] segments = token.split("\\.");
if (segments.length < 2 || segments.length > 3) {
throw new RuntimeException("Not enough or too many segments [" + segments.length + "]");
}
String headerSeg = segments[0];
String payloadSeg = segments[1];
String signatureSeg = segments.length == 2 ? null : segments[2];
JsonObject header = new JsonObject(new String(base64urlDecode(headerSeg), UTF8));
JsonObject payload = new JsonObject(new String(base64urlDecode(payloadSeg), UTF8));
return new JsonObject()
.put("header", header)
.put("payload", payload)
.put("signatureBase", (headerSeg + "." + payloadSeg))
.put("signature", signatureSeg);
}
public JsonObject decode(final String token) {
return decode(token, false);
}
public JsonObject decode(final String token, boolean full) {
String[] segments = token.split("\\.");
if (segments.length < 2) {
throw new IllegalStateException("Invalid format for JWT");
}
String headerSeg = segments[0];
String payloadSeg = segments[1];
String signatureSeg = segments.length == 3 ? segments[2] : null;
if ("".equals(signatureSeg)) {
throw new IllegalStateException("Signature is required");
}
JsonObject header = new JsonObject(Buffer.buffer(base64urlDecode(headerSeg)));
final boolean unsecure = isUnsecure();
if (unsecure) {
if (!allowEmbeddedKey && segments.length != 2) {
throw new IllegalStateException("JWT is in unsecured mode but token is signed.");
}
} else {
if (!allowEmbeddedKey && segments.length != 3) {
throw new IllegalStateException("JWT is in secure mode but token is not signed.");
}
}
JsonObject payload = new JsonObject(Buffer.buffer(base64urlDecode(payloadSeg)));
String alg = header.getString("alg");
if (!unsecure && "none".equals(alg)) {
throw new IllegalStateException("Algorithm \"none\" not allowed");
}
if (allowEmbeddedKey && header.containsKey("x5c")) {
if (signatureSeg == null) {
throw new IllegalStateException("missing signature segment");
}
try {
JsonArray chain = header.getJsonArray("x5c");
List<X509Certificate> certChain = new ArrayList<>();
if (chain == null || chain.size() == 0) {
throw new IllegalStateException("x5c chain is null or empty");
}
for (int i = 0; i < chain.size(); i++) {
certChain.add(JWS.parseX5c(decoder.decode(chain.getString(i).getBytes(UTF8))));
}
CertificateHelper.checkValidity(certChain, false, null);
if (JWS.verifySignature(alg, certChain.get(0), base64urlDecode(signatureSeg), (headerSeg + "." + payloadSeg).getBytes(UTF8))) {
return full ? new JsonObject().put("header", header).put("payload", payload) : payload;
} else {
throw new RuntimeException("Signature verification failed");
}
} catch (CertificateException | NoSuchAlgorithmException | InvalidKeyException | SignatureException | InvalidAlgorithmParameterException | NoSuchProviderException e) {
throw new RuntimeException("Signature verification failed", e);
}
}
List<Crypto> cryptos = VERIFY.get(alg);
if (cryptos == null || cryptos.size() == 0) {
throw new NoSuchKeyIdException(alg);
}
if (!unsecure) {
if (signatureSeg == null) {
throw new IllegalStateException("missing signature segment");
}
byte[] payloadInput = base64urlDecode(signatureSeg);
byte[] signingInput = (headerSeg + "." + payloadSeg).getBytes(UTF8);
String kid = header.getString("kid");
boolean hasKey = false;
for (Crypto c : cryptos) {
if (kid != null && c.getId() != null && !kid.equals(c.getId())) {
continue;
}
hasKey = true;
if (c.verify(payloadInput, signingInput)) {
return full ? new JsonObject().put("header", header).put("payload", payload) : payload;
}
}
if (hasKey) {
throw new RuntimeException("Signature verification failed");
} else {
throw new NoSuchKeyIdException(alg, kid);
}
}
return full ? new JsonObject().put("header", header).put("payload", payload) : payload;
}
public boolean isScopeGranted(JsonObject jwt, JWTOptions options) {
if(jwt == null) {
return false;
}
if(options.getScopes() == null || options.getScopes().isEmpty()) {
return true;
}
if(jwt.getValue("scope") == null) {
if (logger.isDebugEnabled()) {
logger.debug("Invalid JWT: scope claim is required");
}
return false;
}
JsonArray target;
if (jwt.getValue("scope") instanceof String) {
target = new JsonArray(
Stream.of(jwt.getString("scope")
.split(options.getScopeDelimiter()))
.collect(Collectors.toList())
);
} else {
target = jwt.getJsonArray("scope");
}
if(!target.getList().containsAll(options.getScopes())) {
if (logger.isDebugEnabled()) {
logger.debug(String.format("Invalid JWT scopes expected[%s] actual[%s]", options.getScopes(), target.getList()));
}
return false;
}
return true;
}
public String sign(JsonObject payload, JWTOptions options) {
final String algorithm = options.getAlgorithm();
List<Crypto> cryptos = SIGN.get(algorithm);
if (cryptos == null || cryptos.size() == 0) {
throw new RuntimeException("Algorithm not supported: " + algorithm);
}
final Crypto crypto = cryptos.get(RND.nextInt(cryptos.size()));
JsonObject header = new JsonObject()
.mergeIn(options.getHeader())
.put("typ", "JWT")
.put("alg", algorithm);
if (crypto.getId() != null) {
header.put("kid", crypto.getId());
}
long timestamp = System.currentTimeMillis() / 1000;
if (!options.isNoTimestamp()) {
payload.put("iat", payload.getValue("iat", timestamp));
}
if (options.getExpiresInSeconds() > 0) {
payload.put("exp", timestamp + options.getExpiresInSeconds());
}
if (options.getAudience() != null && options.getAudience().size() >= 1) {
if (options.getAudience().size() > 1) {
payload.put("aud", new JsonArray(options.getAudience()));
} else {
payload.put("aud", options.getAudience().get(0));
}
}
if(options.getScopes() != null && options.getScopes().size() >= 1) {
if(options.hasScopeDelimiter()) {
payload.put("scope", String.join(options.getScopeDelimiter(), options.getScopes()));
} else {
payload.put("scope", new JsonArray(options.getScopes()));
}
}
if (options.getIssuer() != null) {
payload.put("iss", options.getIssuer());
}
if (options.getSubject() != null) {
payload.put("sub", options.getSubject());
}
String headerSegment = base64urlEncode(header.encode());
String payloadSegment = base64urlEncode(payload.encode());
String signingInput = headerSegment + "." + payloadSegment;
String signSegment = base64urlEncode(crypto.sign(signingInput.getBytes(UTF8)));
return headerSegment + "." + payloadSegment + "." + signSegment;
}
private static byte[] base64urlDecode(String str) {
return urlDecoder.decode(str.getBytes(UTF8));
}
private static String base64urlEncode(String str) {
return base64urlEncode(str.getBytes(UTF8));
}
private static String base64urlEncode(byte[] bytes) {
return urlEncoder.encodeToString(bytes);
}
public boolean isUnsecure() {
return VERIFY.size() == 1 && SIGN.size() == 1;
}
public Collection<String> availableAlgorithms() {
Set<String> algorithms = new HashSet<>();
algorithms.addAll(VERIFY.keySet());
algorithms.addAll(SIGN.keySet());
return algorithms;
}
}