package io.vertx.ext.web.handler.impl;
import io.vertx.core.Vertx;
import io.vertx.core.http.Cookie;
import io.vertx.core.http.CookieSameSite;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.ext.auth.VertxContextPRNG;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.Session;
import io.vertx.ext.web.handler.CSRFHandler;
import io.vertx.ext.web.handler.SessionHandler;
import io.vertx.ext.web.impl.Origin;
import javax.crypto.Mac;
import javax.crypto.spec.SecretKeySpec;
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
import java.security.NoSuchAlgorithmException;
import java.util.Base64;
public class CSRFHandlerImpl implements CSRFHandler {
private static final Logger log = LoggerFactory.getLogger(CSRFHandlerImpl.class);
private static final Base64.Encoder BASE64 = Base64.getMimeEncoder();
private final VertxContextPRNG random;
private final Mac mac;
private boolean nagHttps;
private String cookieName = DEFAULT_COOKIE_NAME;
private String cookiePath = DEFAULT_COOKIE_PATH;
private String headerName = DEFAULT_HEADER_NAME;
private long timeout = SessionHandler.DEFAULT_SESSION_TIMEOUT;
private Origin origin;
private boolean httpOnly;
public CSRFHandlerImpl(final Vertx vertx, final String secret) {
try {
random = VertxContextPRNG.current(vertx);
mac = Mac.getInstance("HmacSHA256");
mac.init(new SecretKeySpec(secret.getBytes(StandardCharsets.UTF_8), "HmacSHA256"));
} catch (NoSuchAlgorithmException | InvalidKeyException e) {
throw new RuntimeException(e);
}
}
@Override
public CSRFHandler setOrigin(String origin) {
this.origin = Origin.parse(origin);
return this;
}
@Override
public CSRFHandler setCookieName(String cookieName) {
this.cookieName = cookieName;
return this;
}
@Override
public CSRFHandler setCookiePath(String cookiePath) {
this.cookiePath = cookiePath;
return this;
}
@Override
public CSRFHandler setCookieHttpOnly(boolean httpOnly) {
this.httpOnly = httpOnly;
return this;
}
@Override
public CSRFHandler setHeaderName(String headerName) {
this.headerName = headerName;
return this;
}
@Override
public CSRFHandler setTimeout(long timeout) {
this.timeout = timeout;
return this;
}
@Override
public CSRFHandler setNagHttps(boolean nag) {
this.nagHttps = nag;
return this;
}
private String generateAndStoreToken(RoutingContext ctx) {
byte[] salt = new byte[32];
random.nextBytes(salt);
String saltPlusToken = BASE64.encodeToString(salt) + "." + System.currentTimeMillis();
String signature = BASE64.encodeToString(mac.doFinal(saltPlusToken.getBytes(StandardCharsets.US_ASCII)));
final String token = saltPlusToken + "." + signature;
ctx.addCookie(
Cookie.cookie(cookieName, token)
.setPath(cookiePath)
.setHttpOnly(httpOnly)
.setSameSite(CookieSameSite.STRICT));
return token;
}
private String getTokenFromSession(RoutingContext ctx) {
Session session = ctx.session();
if (session == null) {
return null;
}
String sessionToken = session.get(headerName);
if (sessionToken != null) {
int idx = sessionToken.indexOf('/');
if (idx != -1 && session.id() != null && session.id().equals(sessionToken.substring(0, idx))) {
return sessionToken.substring(idx + 1);
}
}
return null;
}
private static boolean isBlank(String s) {
return s == null || s.trim().isEmpty();
}
private static long parseLong(String s) {
if (isBlank(s)) {
return -1;
}
try {
return Long.parseLong(s);
} catch (NumberFormatException e) {
log.trace("Invalid Token format", e);
return -1;
}
}
private boolean isValidOrigin(RoutingContext ctx) {
if (origin != null) {
String source = ctx.request().getHeader(HttpHeaders.ORIGIN);
if (isBlank(source)) {
source = ctx.request().getHeader(HttpHeaders.REFERER);
if (isBlank(source)) {
log.trace("ORIGIN and REFERER request headers are both absent/empty");
return false;
}
}
if (!origin.sameOrigin(source)) {
log.trace("Protocol/Host/Port do not fully match");
return false;
}
}
return true;
}
private boolean isValidRequest(RoutingContext ctx) {
final Cookie cookie = ctx.getCookie(cookieName);
String header = ctx.request().getHeader(headerName);
if (header == null) {
header = ctx.request().getFormAttribute(headerName);
}
if (header == null || cookie == null || isBlank(header) || isBlank(cookie.getValue())) {
log.trace("Token provided via HTTP Header/Form is absent/empty");
return false;
}
if (!header.equals(cookie.getValue())) {
log.trace("Token provided via HTTP Header and via Cookie are not equal");
return false;
}
if (ctx.session() != null) {
Session session = ctx.session();
String sessionToken = session.get(headerName);
if (sessionToken != null) {
int idx = sessionToken.indexOf('/');
if (idx != -1 && session.id() != null && session.id().equals(sessionToken.substring(0, idx))) {
String challenge = sessionToken.substring(idx + 1);
if (!challenge.equals(header)) {
log.trace("Token has been used or is outdated");
return false;
}
} else {
log.trace("Token has been issued for a different session");
return false;
}
} else {
log.trace("No Token has been added to the session");
return false;
}
}
String[] tokens = header.split("\\.");
if (tokens.length != 3) {
return false;
}
byte[] saltPlusToken = (tokens[0] + "." + tokens[1]).getBytes(StandardCharsets.US_ASCII);
synchronized (mac) {
saltPlusToken = mac.doFinal(saltPlusToken);
}
String signature = BASE64.encodeToString(saltPlusToken);
if(!signature.equals(tokens[2])) {
log.trace("Token signature does not match");
return false;
}
if (ctx.session() != null) {
ctx.session().remove(headerName);
}
final long ts = parseLong(tokens[1]);
if (ts == -1) {
return false;
}
return !(System.currentTimeMillis() > ts + timeout);
}
@Override
public void handle(RoutingContext ctx) {
if (nagHttps) {
String uri = ctx.request().absoluteURI();
if (uri != null && !uri.startsWith("https:")) {
log.trace("Using session cookies without https could make you susceptible to session hijacking: " + uri);
}
}
HttpMethod method = ctx.request().method();
Session session = ctx.session();
if (!isValidOrigin(ctx)) {
ctx.fail(403);
return;
}
switch (method.name()) {
case "GET":
final String token;
if (session == null) {
token = generateAndStoreToken(ctx);
} else {
String sessionToken = getTokenFromSession(ctx);
if (sessionToken == null) {
token = generateAndStoreToken(ctx);
session.put(headerName, session.id() + "/" + token);
} else {
String[] parts = sessionToken.split("\\.");
final long ts = parseLong(parts[1]);
if (ts == -1) {
token = generateAndStoreToken(ctx);
} else {
if (!(System.currentTimeMillis() > ts + timeout)) {
token = sessionToken;
} else {
token = generateAndStoreToken(ctx);
}
}
}
}
ctx.put(headerName, token);
ctx.next();
break;
case "POST":
case "PUT":
case "DELETE":
case "PATCH":
if (isValidRequest(ctx)) {
token = generateAndStoreToken(ctx);
ctx.put(headerName, token);
ctx.next();
} else {
ctx.fail(403);
}
break;
default:
ctx.next();
break;
}
}
}