package io.vertx.ext.web.handler.impl;
import io.vertx.core.AsyncResult;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.auth.VertxContextPRNG;
import io.vertx.ext.auth.authentication.Credentials;
import io.vertx.ext.auth.authentication.TokenCredentials;
import io.vertx.ext.auth.oauth2.OAuth2Auth;
import io.vertx.ext.auth.oauth2.Oauth2Credentials;
import io.vertx.ext.web.Route;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.Session;
import io.vertx.ext.web.handler.OAuth2AuthHandler;
import io.vertx.ext.web.impl.Origin;
import java.nio.charset.StandardCharsets;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
public class OAuth2AuthHandlerImpl extends HTTPAuthorizationHandler<OAuth2Auth> implements OAuth2AuthHandler {
private static final Logger LOG = LoggerFactory.getLogger(OAuth2AuthHandlerImpl.class);
private final VertxContextPRNG prng;
private final String host;
private final String callbackPath;
private final MessageDigest sha256;
private Route callback;
private JsonObject extraParams;
private final List<String> scopes = new ArrayList<>();
private String prompt;
private int pkce = -1;
private boolean bearerOnly = true;
public OAuth2AuthHandlerImpl(Vertx vertx, OAuth2Auth authProvider, String callbackURL) {
super(authProvider, Type.BEARER);
this.prng = VertxContextPRNG.current(vertx);
try {
sha256 = MessageDigest.getInstance("SHA-256");
} catch (NoSuchAlgorithmException e) {
throw new IllegalStateException("Cannot get instance of SHA-256 MessageDigest", e);
}
if (callbackURL != null) {
final Origin origin = Origin.parse(callbackURL);
this.host = origin.toString();
this.callbackPath = origin.resource();
} else {
this.host = null;
this.callbackPath = null;
}
}
@Override
public void parseCredentials(RoutingContext context, Handler<AsyncResult<Credentials>> handler) {
parseAuthorization(context, !bearerOnly, parseAuthorization -> {
if (parseAuthorization.failed()) {
handler.handle(Future.failedFuture(parseAuthorization.cause()));
return;
}
final String token = parseAuthorization.result();
if (token == null) {
if (bearerOnly || callback == null) {
handler.handle(Future.failedFuture("callback route is not configured."));
return;
}
if (context.request().method() == HttpMethod.GET && context.normalizedPath().equals(callback.getPath())) {
LOG.warn("The callback route is shaded by the OAuth2AuthHandler, ensure the callback route is added BEFORE the OAuth2AuthHandler route!");
handler.handle(Future.failedFuture(new HttpStatusException(500, "Infinite redirect loop [oauth2 callback]")));
} else {
if (context.request().method() != HttpMethod.GET) {
LOG.error("OAuth2 redirect attempt to non GET resource");
context.fail(400);
return;
}
String redirectUri = context.request().uri();
String state = null;
String codeVerifier = null;
if (context.session() == null) {
if (pkce > 0) {
LOG.error("OAuth2 PKCE requires a session to be present");
context.fail(500);
return;
}
} else {
context.session()
.put("redirect_uri", context.request().uri());
state = prng.nextString(6);
context.session()
.put("state", state);
if (pkce > 0) {
codeVerifier = prng.nextString(pkce);
context.session()
.put("pkce", codeVerifier);
}
}
handler.handle(Future.failedFuture(new HttpStatusException(302, authURI(redirectUri, state, codeVerifier))));
}
} else {
handler.handle(Future.succeededFuture(new TokenCredentials(token)));
}
});
}
private String authURI(String redirectURL, String state, String codeVerifier) {
final JsonObject config = new JsonObject()
.put("state", state != null ? state : redirectURL);
if (host != null) {
config.put("redirect_uri", host + callback.getPath());
}
if (scopes.size() > 0) {
config.put("scopes", scopes);
}
if (prompt != null) {
config.put("prompt", prompt);
}
if (codeVerifier != null) {
synchronized (sha256) {
sha256.update(codeVerifier.getBytes(StandardCharsets.US_ASCII));
config
.put("code_challenge", sha256.digest())
.put("code_challenge_method", "S256");
}
}
if (extraParams != null) {
config.mergeIn(extraParams);
}
return authProvider.authorizeURL(config);
}
@Override
public OAuth2AuthHandler extraParams(JsonObject extraParams) {
this.extraParams = extraParams;
return this;
}
@Override
public OAuth2AuthHandler withScope(String scope) {
this.scopes.add(scope);
return this;
}
@Override
public OAuth2AuthHandler prompt(String prompt) {
this.prompt = prompt;
return this;
}
@Override
public OAuth2AuthHandler pkceVerifierLength(int length) {
if (length >= 0) {
if (length < 43 || length > 128) {
throw new IllegalArgumentException("Length must be between 34 and 128");
}
}
this.pkce = length;
return this;
}
@Override
public OAuth2AuthHandler setupCallback(final Route route) {
if (callbackPath != null && !"".equals(callbackPath)) {
if (!callbackPath.equals(route.getPath())) {
if (LOG.isWarnEnabled()) {
LOG.warn("route path changed to match callback URL");
}
route.path(callbackPath);
}
}
route.method(HttpMethod.GET);
route.handler(ctx -> {
String error = ctx.request().getParam("error");
if (error != null) {
String errorDescription = ctx.request().getParam("error_description");
if (errorDescription != null) {
ctx.response()
.setStatusMessage(error + ": " + errorDescription);
} else {
ctx.response()
.setStatusMessage(error);
}
ctx.fail(400);
return;
}
final String code = ctx.request().getParam("code");
if (code == null) {
ctx.response()
.setStatusMessage("Missing code parameter");
ctx.fail(400);
return;
}
final Oauth2Credentials credentials = new Oauth2Credentials()
.setCode(code)
.setExtra(extraParams);
final String state = ctx.request().getParam("state");
if (state == null) {
LOG.error("Missing IdP state parameter to the callback endpoint");
ctx.fail(400);
return;
}
final String resource;
if (ctx.session() != null) {
String ctxState = ctx.session().remove("state");
if (!state.equals(ctxState)) {
ctx.fail(401);
return;
}
String codeVerifier = ctx.session().remove("pkce");
if (codeVerifier != null) {
JsonObject extras = credentials.getExtra();
if (extras != null) {
credentials
.setExtra(new JsonObject()
.mergeIn(extras)
.put("code_verifier", codeVerifier));
} else {
credentials
.setExtra(new JsonObject()
.put("code_verifier", codeVerifier));
}
}
resource = ctx.session().get("redirect_uri");
} else {
resource = state;
}
if (host == null) {
if (LOG.isWarnEnabled()) {
LOG.warn("Cannot compute: 'redirect_uri' variable. OAuth2AuthHandler was created without a origin/callback URL.");
}
} else {
credentials.setRedirectUri(host + route.getPath());
}
authProvider.authenticate(credentials, res -> {
if (res.failed()) {
ctx.fail(res.cause());
} else {
ctx.setUser(res.result());
Session session = ctx.session();
String location = resource != null ? resource : "/";
if (session != null) {
session.regenerateId();
} else {
if (location.length() != 0 && location.charAt(0) == '/') {
ctx.reroute(location);
return;
}
}
ctx.response()
.putHeader(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, must-revalidate")
.putHeader("Pragma", "no-cache")
.putHeader(HttpHeaders.EXPIRES, "0")
.putHeader(HttpHeaders.LOCATION, location)
.setStatusCode(302)
.end("Redirecting to " + location + ".");
}
});
});
bearerOnly = false;
callback = route;
return this;
}
}