package io.vertx.ext.web.handler.impl;
import io.vertx.core.AsyncResult;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.json.JsonArray;
import io.vertx.core.json.JsonObject;
import io.vertx.core.logging.Logger;
import io.vertx.core.logging.LoggerFactory;
import io.vertx.ext.auth.AuthProvider;
import io.vertx.ext.auth.oauth2.OAuth2Auth;
import io.vertx.ext.web.Route;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.Session;
import io.vertx.ext.web.handler.AuthHandler;
import io.vertx.ext.web.handler.OAuth2AuthHandler;
import java.net.MalformedURLException;
import java.net.URL;
import java.util.HashSet;
import java.util.Set;
import static io.vertx.ext.auth.oauth2.OAuth2FlowType.AUTH_CODE;
public class OAuth2AuthHandlerImpl extends AuthorizationAuthHandler implements OAuth2AuthHandler {
private static final Logger log = LoggerFactory.getLogger(OAuth2AuthHandlerImpl.class);
private static AuthProvider verifyProvider(AuthProvider provider) {
if (provider instanceof OAuth2Auth) {
if (((OAuth2Auth) provider).getFlowType() != AUTH_CODE) {
throw new IllegalArgumentException("OAuth2Auth + Bearer Auth requires OAuth2 AUTH_CODE flow");
}
}
return provider;
}
private final String host;
private final String callbackPath;
private final Set<String> scopes = new HashSet<>();
private Route callback;
private JsonObject extraParams;
private boolean bearerOnly = true;
public OAuth2AuthHandlerImpl(OAuth2Auth authProvider, String callbackURL) {
super(verifyProvider(authProvider), Type.BEARER);
try {
if (callbackURL != null) {
final URL url = new URL(callbackURL);
this.host = url.getProtocol() + "://" + url.getHost() + (url.getPort() == -1 ? "" : ":" + url.getPort());
this.callbackPath = url.getPath();
} else {
this.host = null;
this.callbackPath = null;
}
} catch (MalformedURLException e) {
throw new RuntimeException(e);
}
}
@Override
public AuthHandler addAuthority(String authority) {
scopes.add(authority);
return this;
}
@Override
public AuthHandler addAuthorities(Set<String> authorities) {
this.scopes.addAll(authorities);
return this;
}
@Override
public void parseCredentials(RoutingContext context, Handler<AsyncResult<JsonObject>> handler) {
parseAuthorization(context, !bearerOnly, parseAuthorization -> {
if (parseAuthorization.failed()) {
handler.handle(Future.failedFuture(parseAuthorization.cause()));
return;
}
final String token = parseAuthorization.result();
if (token == null) {
if (callback == null) {
handler.handle(Future.failedFuture("callback route is not configured."));
return;
}
if (
context.request().method() == HttpMethod.GET &&
context.normalisedPath().equals(callback.getPath())) {
if (log.isWarnEnabled()) {
log.warn("The callback route is shaded by the OAuth2AuthHandler, ensure the callback route is added BEFORE the OAuth2AuthHandler route!");
}
handler.handle(Future.failedFuture(new HttpStatusException(500, "Infinite redirect loop [oauth2 callback]")));
} else {
handler.handle(Future.failedFuture(new HttpStatusException(302, authURI(context.request().uri()))));
}
} else {
((OAuth2Auth) authProvider).decodeToken(token, decodeToken -> {
if (decodeToken.failed()) {
handler.handle(Future.failedFuture(new HttpStatusException(401, decodeToken.cause().getMessage())));
return;
}
context.setUser(decodeToken.result());
handler.handle(Future.succeededFuture());
});
}
});
}
private String authURI(String redirectURL) {
final JsonObject config = new JsonObject()
.put("state", redirectURL);
if (host != null) {
config.put("redirect_uri", host + callback.getPath());
}
if (extraParams != null) {
config.mergeIn(extraParams);
}
if (scopes.size() > 0) {
JsonArray _scopes = new JsonArray();
for (String authority : scopes) {
_scopes.add(authority);
}
config.put("scopes", _scopes);
}
return ((OAuth2Auth) authProvider).authorizeURL(config);
}
@Override
public OAuth2AuthHandler extraParams(JsonObject extraParams) {
this.extraParams = extraParams;
return this;
}
@Override
public OAuth2AuthHandler setupCallback(final Route route) {
if (callbackPath != null && !"".equals(callbackPath)) {
route.path(callbackPath);
}
route.method(HttpMethod.GET);
route.handler(ctx -> {
final String code = ctx.request().getParam("code");
if (code == null) {
ctx.fail(400);
return;
}
final String state = ctx.request().getParam("state");
final JsonObject config = new JsonObject()
.put("code", code);
if (host != null) {
config.put("redirect_uri", host + route.getPath());
}
if (extraParams != null) {
config.mergeIn(extraParams);
}
authProvider.authenticate(config, res -> {
if (res.failed()) {
ctx.fail(res.cause());
} else {
ctx.setUser(res.result());
Session session = ctx.session();
if (session != null) {
session.regenerateId();
ctx.response()
.putHeader(HttpHeaders.CACHE_CONTROL, "no-cache, no-store, must-revalidate")
.putHeader("Pragma", "no-cache")
.putHeader(HttpHeaders.EXPIRES, "0")
.putHeader(HttpHeaders.LOCATION, state != null ? state : "/")
.setStatusCode(302)
.end("Redirecting to " + (state != null ? state : "/") + ".");
} else {
ctx.reroute(state != null ? state : "/");
}
}
});
});
bearerOnly = false;
callback = route;
return this;
}
}