package io.vertx.ext.auth.oauth2.impl;
import io.vertx.core.AsyncResult;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.core.json.DecodeException;
import io.vertx.core.json.Json;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.auth.JWTOptions;
import io.vertx.ext.auth.NoSuchKeyIdException;
import io.vertx.ext.auth.PubSecKeyOptions;
import io.vertx.ext.auth.User;
import io.vertx.ext.auth.authentication.CredentialValidationException;
import io.vertx.ext.auth.authentication.Credentials;
import io.vertx.ext.auth.authentication.TokenCredentials;
import io.vertx.ext.auth.authentication.UsernamePasswordCredentials;
import io.vertx.ext.auth.impl.jose.JWK;
import io.vertx.ext.auth.impl.jose.JWT;
import io.vertx.ext.auth.oauth2.*;
import java.util.Collections;
public class OAuth2AuthProviderImpl implements OAuth2Auth {
private static final Logger LOG = LoggerFactory.getLogger(OAuth2AuthProviderImpl.class);
private final Vertx vertx;
private final OAuth2Options config;
private final OAuth2API api;
private volatile JWT jwt = new JWT();
private long updateTimerId = -1;
private Handler<String> missingKeyHandler;
public OAuth2AuthProviderImpl(Vertx vertx, OAuth2Options config) {
this.vertx = vertx;
this.config = config;
this.api = new OAuth2API(vertx, config);
this.config.replaceVariables(true);
this.config.validate();
if (config.getPubSecKeys() != null) {
for (PubSecKeyOptions pubSecKey : config.getPubSecKeys()) {
jwt.addJWK(new JWK(pubSecKey));
}
}
}
@Override
public OAuth2Auth jWKSet(Handler<AsyncResult<Void>> handler) {
api.jwkSet(res -> {
if (res.failed()) {
handler.handle(Future.failedFuture(res.cause()));
} else {
if (updateTimerId != -1) {
vertx.cancelTimer(updateTimerId);
}
final JsonObject json = res.result();
JWT jwt = new JWT();
JsonArray keys = json.getJsonArray("keys");
for (Object key : keys) {
try {
jwt.addJWK(new JWK((JsonObject) key));
} catch (RuntimeException e) {
LOG.warn("Skipped unsupported JWK: " + e.getMessage());
}
}
this.jwt = jwt;
if (json.containsKey("maxAge")) {
final long delay = json.getLong("maxAge") * 1000;
if (delay > 0) {
this.updateTimerId = vertx.setPeriodic(delay, t ->
jWKSet(autoUpdateRes -> {
if (autoUpdateRes.failed()) {
LOG.warn("Failed to auto-update JWK Set", autoUpdateRes.cause());
}
}));
} else {
updateTimerId = -1;
}
}
handler.handle(Future.succeededFuture());
}
});
return this;
}
@Override
public OAuth2Auth missingKeyHandler(Handler<String> handler) {
this.missingKeyHandler = handler;
return this;
}
public OAuth2Options getConfig() {
return config;
}
@Override
public void authenticate(JsonObject authInfo, Handler<AsyncResult<User>> handler) {
final OAuth2FlowType flow = config.getFlow();
if (authInfo.containsKey("access_token")) {
if (flow != OAuth2FlowType.AUTH_JWT && flow != OAuth2FlowType.IMPLICIT) {
authenticate(new TokenCredentials(authInfo.getString("access_token")), handler);
} else {
handler.handle(Future.failedFuture("access_token provided but provider is not configured for AUTH_CODE"));
}
return;
}
if (authInfo.containsKey("username") && authInfo.containsKey("password")) {
if (flow == OAuth2FlowType.PASSWORD) {
authenticate(new UsernamePasswordCredentials(authInfo.getString("username"), authInfo.getString("password")), handler);
} else {
handler.handle(Future.failedFuture("username/password provided but provider is not configured for PASSWORD"));
}
return;
}
authenticate(new Oauth2Credentials(authInfo), handler);
}
@Override
public void authenticate(Credentials credentials, Handler<AsyncResult<User>> handler) {
try {
if (credentials instanceof TokenCredentials) {
TokenCredentials tokenCredentials = (TokenCredentials) credentials;
tokenCredentials.checkValid(null);
final User user = createUser(new JsonObject().put("access_token", tokenCredentials.getToken()), false);
if (user.attributes().containsKey("accessToken") && !jwt.isUnsecure()) {
final JWTOptions jwtOptions = config.getJWTOptions();
if (!user.expired(jwtOptions.getLeeway())) {
validateUser(user, handler);
return;
}
}
if (config.getIntrospectionPath() == null) {
if (user.attributes().containsKey("missing-kid")) {
handler.handle(Future.failedFuture(new NoSuchKeyIdException(user.attributes().getString("missing-kid"))));
} else {
handler.handle(Future.failedFuture("Can't authenticate access_token: Provider doesn't support token introspection"));
}
return;
}
api
.tokenIntrospection("access_token", user.principal().getString("access_token"), res -> {
if (res.failed()) {
handler.handle(Future.failedFuture(res.cause()));
return;
}
final JsonObject json = res.result();
if (json.containsKey("active") && !json.getBoolean("active", false)) {
handler.handle(Future.failedFuture("Inactive Token"));
return;
}
if (json.containsKey("client_id")) {
if (!config.getClientID().equals(json.getString("client_id"))) {
handler.handle(Future.failedFuture("Wrong client_id"));
return;
}
}
final User newUser = createUser(json, user.attributes().containsKey("missing-kid"));
if (newUser.expired(config.getJWTOptions().getLeeway())) {
handler.handle(Future.failedFuture("Used is expired."));
} else {
validateUser(newUser, handler);
}
});
} else {
final JsonObject params = new JsonObject();
switch (config.getFlow()) {
case PASSWORD:
UsernamePasswordCredentials usernamePasswordCredentials = (UsernamePasswordCredentials) credentials;
usernamePasswordCredentials.checkValid(config.getFlow());
params
.put("username", usernamePasswordCredentials.getUsername())
.put("password", usernamePasswordCredentials.getPassword());
break;
case AUTH_CODE:
case CLIENT:
Oauth2Credentials oauth2Credentials = (Oauth2Credentials) credentials;
oauth2Credentials.checkValid(config.getFlow());
params.mergeIn(oauth2Credentials.toJson());
break;
case AUTH_JWT:
Oauth2Credentials oauth2OnBehalfOfCredentials = (Oauth2Credentials) credentials;
oauth2OnBehalfOfCredentials.checkValid(config.getFlow());
final JsonObject token = oauth2OnBehalfOfCredentials.toJson();
params
.put("assertion", jwt.sign(token, config.getJWTOptions()));
break;
default:
handler.handle(Future.failedFuture("Current flow does not allow acquiring a token by the replay party"));
return;
}
api.token(config.getFlow().getGrantType(), params, getToken -> {
if (getToken.failed()) {
handler.handle(Future.failedFuture(getToken.cause()));
} else {
final User newUser = createUser(getToken.result(), false);
if (newUser.expired(config.getJWTOptions().getLeeway())) {
handler.handle(Future.failedFuture("Used is expired."));
} else {
validateUser(newUser, handler);
}
}
});
}
} catch (ClassCastException | CredentialValidationException e) {
handler.handle(Future.failedFuture(e));
}
}
@Override
public String authorizeURL(JsonObject params) {
return api.authorizeURL(params);
}
@Override
public OAuth2Auth refresh(User user, Handler<AsyncResult<User>> handler) {
api.token(
"refresh_token",
new JsonObject()
.put("refresh_token", user.principal().getString("refresh_token")),
getToken -> {
if (getToken.failed()) {
handler.handle(Future.failedFuture(getToken.cause()));
} else {
final User newUser = createUser(getToken.result(), false);
if (newUser.expired(config.getJWTOptions().getLeeway())) {
handler.handle(Future.failedFuture("Used is expired."));
} else {
validateUser(newUser, handler);
}
}
});
return this;
}
@Override
public OAuth2Auth revoke(User user, String tokenType, Handler<AsyncResult<Void>> handler) {
api.tokenRevocation(tokenType, user.principal().getString(tokenType), handler);
return this;
}
@Override
public OAuth2Auth userInfo(User user, Handler<AsyncResult<JsonObject>> handler) {
api.userInfo(user.principal().getString("access_token"), jwt, userInfo -> {
if (userInfo.succeeded()) {
JsonObject json = userInfo.result();
String userSub = user.principal().getString("sub", user.attributes().getString("sub"));
String userInfoSub = json.getString("sub");
if (userSub != null || userInfoSub != null) {
if (userSub != null) {
if (userInfoSub != null) {
if (!userSub.equals(userInfoSub)) {
handler.handle(Future.failedFuture("Used 'sub' does not match UserInfo 'sub'."));
return;
}
}
}
}
copyProperties(json, user.attributes(), true, "sub", "name", "email", "picture");
}
handler.handle(userInfo);
});
return this;
}
@Override
public String endSessionURL(User user, JsonObject params) {
return api.endSessionURL(user.principal().getString("id_token"), params);
}
private User createUser(JsonObject json, boolean skipMissingKeyNotify) {
final User user = User.create(json);
final long now = System.currentTimeMillis() / 1000;
String missingKid = null;
if (json.containsKey("expires_in")) {
Long expiresIn;
try {
expiresIn = json.getLong("expires_in");
} catch (ClassCastException e) {
expiresIn = Long.valueOf(json.getString("expires_in"));
}
user.attributes()
.put("iat", now)
.put("exp", now + expiresIn);
}
if (!jwt.isUnsecure()) {
if (json.containsKey("access_token")) {
try {
user.attributes()
.put("accessToken", jwt.decode(json.getString("access_token")));
copyProperties(user.attributes().getJsonObject("accessToken"), user.attributes(), true, "exp", "iat", "nbf", "sub");
user.attributes()
.put("rootClaim", "accessToken");
} catch (NoSuchKeyIdException e) {
if (!skipMissingKeyNotify) {
user.attributes()
.put("missing-kid", e.id());
missingKid = e.id();
if (missingKeyHandler != null) {
missingKeyHandler.handle(e.id());
} else {
LOG.trace("Cannot decode access token:", e);
}
}
} catch (DecodeException | IllegalStateException e) {
LOG.trace("Cannot decode access token:", e);
}
}
if (json.containsKey("id_token")) {
try {
user.attributes()
.put("idToken", jwt.decode(json.getString("id_token")));
copyProperties(user.attributes().getJsonObject("idToken"), user.attributes(), false, "sub", "name", "email", "picture");
} catch (NoSuchKeyIdException e) {
if (!skipMissingKeyNotify) {
if (!e.id().equals(missingKid)) {
user.attributes()
.put("missing-kid", e.id());
if (missingKeyHandler != null) {
missingKeyHandler.handle(e.id());
} else {
LOG.trace("Cannot decode access token:", e);
}
}
}
} catch (DecodeException | IllegalStateException e) {
LOG.trace("Cannot decode id token:", e);
}
}
}
return user;
}
private void validateUser(User user, Handler<AsyncResult<User>> handler) {
if (!user.attributes().containsKey("accessToken")) {
handler.handle(Future.succeededFuture(user));
return;
}
final JWTOptions jwtOptions = config.getJWTOptions();
final JsonObject payload;
try {
payload = user.attributes().getJsonObject("accessToken");
} catch (RuntimeException e) {
handler.handle(Future.failedFuture("User accessToken isn't a JsonObject"));
return;
}
if (jwtOptions.getAudience() != null) {
JsonArray target;
if (payload.getValue("aud") instanceof String) {
target = new JsonArray().add(payload.getValue("aud", ""));
} else {
target = payload.getJsonArray("aud", new JsonArray());
}
if (Collections.disjoint(jwtOptions.getAudience(), target.getList())) {
handler.handle(Future.failedFuture("Invalid JWT audience. expected: " + Json.encode(jwtOptions.getAudience())));
return;
}
}
if (jwtOptions.getIssuer() != null) {
if (!jwtOptions.getIssuer().equals(payload.getString("iss"))) {
handler.handle(Future.failedFuture("Invalid JWT issuer"));
return;
}
}
handler.handle(Future.succeededFuture(user));
}
@Override
@Deprecated
public OAuth2Auth decodeToken(String token, Handler<AsyncResult<AccessToken>> handler) {
try {
JsonObject json = jwt.decode(token);
handler.handle(Future.succeededFuture(createAccessToken(json)));
} catch (RuntimeException e) {
handler.handle(Future.failedFuture(e));
}
return this;
}
@Override
@Deprecated
public OAuth2Auth introspectToken(String token, String tokenType, Handler<AsyncResult<AccessToken>> handler) {
return this;
}
@Override
@Deprecated
public OAuth2FlowType getFlowType() {
return config.getFlow();
}
@Override
@Deprecated
public OAuth2Auth rbacHandler(OAuth2RBAC rbac) {
return this;
}
@Deprecated
private AccessToken createAccessToken(JsonObject json) {
final AccessToken user = new AccessTokenImpl(json, this);
final long now = System.currentTimeMillis() / 1000;
if (json.containsKey("expires_in")) {
Long expiresIn;
try {
expiresIn = json.getLong("expires_in");
} catch (ClassCastException e) {
expiresIn = Long.valueOf(json.getString("expires_in"));
}
user.attributes()
.put("iat", now)
.put("exp", now + expiresIn);
}
if (json.getString("access_token") != null) {
try {
user.attributes()
.put("accessToken", jwt.decode(json.getString("access_token")));
if (!user.attributes().containsKey("exp")) {
Long exp = user.attributes()
.getJsonObject("accessToken").getLong("exp");
if (exp != null) {
user.attributes()
.put("exp", exp);
}
}
user.attributes()
.put("rootClaim", "accessToken");
} catch (NoSuchKeyIdException e) {
if (missingKeyHandler != null) {
missingKeyHandler.handle(e.id());
} else {
LOG.trace("Cannot decode access token:", e);
}
} catch (DecodeException | IllegalStateException e) {
LOG.trace("Cannot decode access token:", e);
}
}
if (json.getString("id_token") != null) {
try {
user.attributes()
.put("idToken", jwt.decode(json.getString("id_token")));
} catch (NoSuchKeyIdException e) {
if (missingKeyHandler != null) {
missingKeyHandler.handle(e.id());
} else {
LOG.trace("Cannot decode access token:", e);
}
} catch (DecodeException | IllegalStateException e) {
LOG.trace("Cannot decode id token:", e);
}
}
return user;
}
private static void copyProperties(JsonObject source, JsonObject target, boolean overwrite, String... keys) {
if (source != null && target != null) {
for (String key : keys) {
if (source.containsKey(key)) {
if (!target.containsKey(key) || overwrite) {
target.put(key, source.getValue(key));
}
}
}
}
}
}