package io.undertow.security.impl;
import io.undertow.UndertowLogger;
import io.undertow.UndertowMessages;
import io.undertow.security.api.AuthenticationMechanism;
import io.undertow.security.api.AuthenticationMechanism.AuthenticationMechanismOutcome;
import io.undertow.security.api.AuthenticationMechanism.ChallengeResult;
import io.undertow.security.api.AuthenticationMechanismContext;
import io.undertow.security.api.AuthenticationMode;
import io.undertow.security.idm.Account;
import io.undertow.security.idm.IdentityManager;
import io.undertow.security.idm.PasswordCredential;
import io.undertow.server.HttpServerExchange;
import io.undertow.util.StatusCodes;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
public class SecurityContextImpl extends AbstractSecurityContext implements AuthenticationMechanismContext {
private static final RuntimePermission PERMISSION = new RuntimePermission("MODIFY_UNDERTOW_SECURITY_CONTEXT");
private AuthenticationState authenticationState = AuthenticationState.NOT_ATTEMPTED;
private final AuthenticationMode authenticationMode;
private String programaticMechName = "Programatic";
private Node<AuthenticationMechanism> authMechanisms = null;
private final IdentityManager identityManager;
public SecurityContextImpl(final HttpServerExchange exchange, final IdentityManager identityManager) {
this(exchange, AuthenticationMode.PRO_ACTIVE, identityManager);
}
public SecurityContextImpl(final HttpServerExchange exchange, final AuthenticationMode authenticationMode, final IdentityManager identityManager) {
super(exchange);
this.authenticationMode = authenticationMode;
this.identityManager = identityManager;
if (System.getSecurityManager() != null) {
System.getSecurityManager().checkPermission(PERMISSION);
}
}
@Override
public boolean authenticate() {
UndertowLogger.SECURITY_LOGGER.debugf("Attempting to authenticate %s, authentication required: %s", exchange.getRequestPath(), isAuthenticationRequired());
if(authenticationState == AuthenticationState.ATTEMPTED || (authenticationState == AuthenticationState.CHALLENGE_SENT && !exchange.isResponseStarted())) {
authenticationState = AuthenticationState.NOT_ATTEMPTED;
}
return !authTransition();
}
private boolean authTransition() {
if (authTransitionRequired()) {
switch (authenticationState) {
case NOT_ATTEMPTED:
authenticationState = attemptAuthentication();
break;
case ATTEMPTED:
authenticationState = sendChallenges();
break;
default:
throw new IllegalStateException("It should not be possible to reach this.");
}
return authTransition();
} else {
UndertowLogger.SECURITY_LOGGER.debugf("Authentication result was %s for %s", authenticationState, exchange.getRequestPath());
switch (authenticationState) {
case NOT_ATTEMPTED:
case ATTEMPTED:
case AUTHENTICATED:
return false;
default:
return true;
}
}
}
private AuthenticationState attemptAuthentication() {
return new AuthAttempter(authMechanisms,exchange).transition();
}
private AuthenticationState sendChallenges() {
UndertowLogger.SECURITY_LOGGER.debugf("Sending authentication challenge for %s", exchange);
return new ChallengeSender(authMechanisms, exchange).transition();
}
private boolean authTransitionRequired() {
switch (authenticationState) {
case NOT_ATTEMPTED:
return isAuthenticationRequired() || authenticationMode == AuthenticationMode.PRO_ACTIVE;
case ATTEMPTED:
return isAuthenticationRequired();
default:
return false;
}
}
public void setProgramaticMechName(final String programaticMechName) {
this.programaticMechName = programaticMechName;
}
@Override
public void addAuthenticationMechanism(final AuthenticationMechanism handler) {
if(authMechanisms == null) {
authMechanisms = new Node<>(handler);
} else {
Node<AuthenticationMechanism> cur = authMechanisms;
while (cur.next != null) {
cur = cur.next;
}
cur.next = new Node<>(handler);
}
}
@Override
@Deprecated
public List<AuthenticationMechanism> getAuthenticationMechanisms() {
List<AuthenticationMechanism> ret = new LinkedList<>();
Node<AuthenticationMechanism> cur = authMechanisms;
while (cur != null) {
ret.add(cur.item);
cur = cur.next;
}
return Collections.unmodifiableList(ret);
}
@Override
@Deprecated
public IdentityManager getIdentityManager() {
return identityManager;
}
@Override
public boolean login(final String username, final String password) {
UndertowLogger.SECURITY_LOGGER.debugf("Attempting programatic login for user %s for request %s", username, exchange);
final Account account;
if(System.getSecurityManager() == null) {
account = identityManager.verify(username, new PasswordCredential(password.toCharArray()));
} else {
account = AccessController.doPrivileged(new PrivilegedAction<Account>() {
@Override
public Account run() {
return identityManager.verify(username, new PasswordCredential(password.toCharArray()));
}
});
}
if (account == null) {
return false;
}
authenticationComplete(account, programaticMechName, true);
this.authenticationState = AuthenticationState.AUTHENTICATED;
return true;
}
@Override
public void logout() {
Account authenticatedAccount = getAuthenticatedAccount();
if(authenticatedAccount != null) {
UndertowLogger.SECURITY_LOGGER.debugf("Logging out user %s for %s", authenticatedAccount.getPrincipal().getName(), exchange);
} else {
UndertowLogger.SECURITY_LOGGER.debugf("Logout called with no authenticated user in exchange %s", exchange);
}
super.logout();
this.authenticationState = AuthenticationState.NOT_ATTEMPTED;
}
private class AuthAttempter {
private Node<AuthenticationMechanism> currentMethod;
private final HttpServerExchange exchange;
private AuthAttempter(Node<AuthenticationMechanism> currentMethod, final HttpServerExchange exchange) {
this.exchange = exchange;
this.currentMethod = currentMethod;
}
private AuthenticationState transition() {
if (currentMethod != null) {
final AuthenticationMechanism mechanism = currentMethod.item;
currentMethod = currentMethod.next;
AuthenticationMechanismOutcome outcome = mechanism.authenticate(exchange, SecurityContextImpl.this);
if(UndertowLogger.SECURITY_LOGGER.isDebugEnabled()) {
UndertowLogger.SECURITY_LOGGER.debugf("Authentication outcome was %s with method %s for %s", outcome, mechanism, exchange.getRequestURI());
if(UndertowLogger.SECURITY_LOGGER.isTraceEnabled()) {
UndertowLogger.SECURITY_LOGGER.tracef("Contents of exchange after authentication attempt is %s", exchange);
}
}
if (outcome == null) {
throw UndertowMessages.MESSAGES.authMechanismOutcomeNull();
}
switch (outcome) {
case AUTHENTICATED:
return AuthenticationState.AUTHENTICATED;
case NOT_AUTHENTICATED:
setAuthenticationRequired();
return AuthenticationState.ATTEMPTED;
case NOT_ATTEMPTED:
return transition();
default:
throw new IllegalStateException();
}
} else {
return AuthenticationState.ATTEMPTED;
}
}
}
private class ChallengeSender {
private Node<AuthenticationMechanism> currentMethod;
private final HttpServerExchange exchange;
private Integer chosenStatusCode = null;
private boolean challengeSent = false;
private ChallengeSender(Node<AuthenticationMechanism> currentMethod, final HttpServerExchange exchange) {
this.exchange = exchange;
this.currentMethod = currentMethod;
}
private AuthenticationState transition() {
if (currentMethod != null) {
final AuthenticationMechanism mechanism = currentMethod.item;
currentMethod = currentMethod.next;
ChallengeResult result = mechanism.sendChallenge(exchange, SecurityContextImpl.this);
if(result == null) {
throw UndertowMessages.MESSAGES.sendChallengeReturnedNull(mechanism);
}
if (result.isChallengeSent()) {
challengeSent = true;
Integer desiredCode = result.getDesiredResponseCode();
if (desiredCode != null && (chosenStatusCode == null || chosenStatusCode.equals(StatusCodes.OK))) {
chosenStatusCode = desiredCode;
if (chosenStatusCode.equals(StatusCodes.OK) == false) {
if(!exchange.isResponseStarted()) {
exchange.setStatusCode(chosenStatusCode);
}
}
}
}
return transition();
} else {
if(!exchange.isResponseStarted()) {
if (chosenStatusCode == null) {
if (challengeSent == false) {
exchange.setStatusCode(StatusCodes.FORBIDDEN);
}
} else if (chosenStatusCode.equals(StatusCodes.OK)) {
exchange.setStatusCode(chosenStatusCode);
}
}
return AuthenticationState.CHALLENGE_SENT;
}
}
}
enum AuthenticationState {
NOT_ATTEMPTED,
ATTEMPTED,
AUTHENTICATED,
CHALLENGE_SENT;
}
private static final class Node<T> {
final T item;
Node<T> next;
private Node(T item) {
this.item = item;
}
}
}