package io.vertx.ext.auth.impl;
import io.vertx.core.AsyncResult;
import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.json.JsonObject;
import io.vertx.ext.auth.ChainAuth;
import io.vertx.ext.auth.User;
import io.vertx.ext.auth.authentication.AuthenticationProvider;
import io.vertx.ext.auth.authentication.CredentialValidationException;
import io.vertx.ext.auth.authentication.Credentials;
import java.util.ArrayList;
import java.util.List;
public class ChainAuthImpl implements ChainAuth {
private final List<AuthenticationProvider> providers = new ArrayList<>();
private final boolean all;
public ChainAuthImpl(boolean all) {
this.all = all;
}
@Override
public ChainAuth add(AuthenticationProvider other) {
providers.add(other);
return this;
}
@Override
public void authenticate(Credentials credentials, Handler<AsyncResult<User>> resultHandler) {
try {
credentials.checkValid(null);
authenticate(credentials.toJson(), resultHandler);
} catch (CredentialValidationException e) {
resultHandler.handle(Future.failedFuture(e));
}
}
@Override
public void authenticate(final JsonObject authInfo, final Handler<AsyncResult<User>> resultHandler) {
if (providers.size() == 0) {
resultHandler.handle(Future.failedFuture("No providers in the auth chain."));
} else {
iterate(0, authInfo, resultHandler, null);
}
}
private void iterate(final int idx, final JsonObject authInfo, final Handler<AsyncResult<User>> resultHandler, final User previousUser) {
if (idx >= providers.size()) {
if (!all) {
resultHandler.handle(Future.failedFuture("No more providers in the auth chain."));
} else {
resultHandler.handle(Future.succeededFuture(previousUser));
}
return;
}
providers.get(idx).authenticate(authInfo, res -> {
if (res.succeeded()) {
if (!all) {
resultHandler.handle(res);
} else {
iterate(idx + 1, authInfo, resultHandler, previousUser == null ? res.result() : User.create(previousUser.principal().mergeIn(res.result().principal())));
}
} else {
if (!all) {
iterate(idx + 1, authInfo, resultHandler, null);
} else {
resultHandler.handle(res);
}
}
});
}
}