package com.mongodb.internal.connection;
import com.mongodb.MongoException;
import com.mongodb.MongoInterruptedException;
import com.mongodb.MongoSecurityException;
import com.mongodb.ServerAddress;
import com.mongodb.async.SingleResultCallback;
import com.mongodb.lang.Nullable;
import com.mongodb.connection.ConnectionDescription;
import org.bson.BsonBinary;
import org.bson.BsonDocument;
import org.bson.BsonInt32;
import org.bson.BsonString;
import javax.security.auth.Subject;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import java.security.PrivilegedAction;
import static com.mongodb.MongoCredential.JAVA_SUBJECT_KEY;
import static com.mongodb.internal.connection.CommandHelper.executeCommand;
import static com.mongodb.internal.connection.CommandHelper.executeCommandAsync;
abstract class SaslAuthenticator extends Authenticator {
SaslAuthenticator(final MongoCredentialWithCache credential) {
super(credential);
}
public void authenticate(final InternalConnection connection, final ConnectionDescription connectionDescription) {
doAsSubject(new PrivilegedAction<Void>() {
@Override
public Void run() {
SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress());
throwIfSaslClientIsNull(saslClient);
try {
byte[] response = (saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null);
BsonDocument res = sendSaslStart(response, connection);
BsonInt32 conversationId = res.getInt32("conversationId");
while (!(res.getBoolean("done")).getValue()) {
response = saslClient.evaluateChallenge((res.getBinary("payload")).getData());
if (response == null) {
throw new MongoSecurityException(getMongoCredential(),
"SASL protocol error: no client response to challenge for credential "
+ getMongoCredential());
}
res = sendSaslContinue(conversationId, response, connection);
}
} catch (Exception e) {
throw wrapException(e);
} finally {
disposeOfSaslClient(saslClient);
}
return null;
}
});
}
@Override
void authenticateAsync(final InternalConnection connection, final ConnectionDescription connectionDescription,
final SingleResultCallback<Void> callback) {
try {
doAsSubject(new PrivilegedAction<Void>() {
@Override
public Void run() {
final SaslClient saslClient = createSaslClient(connection.getDescription().getServerAddress());
throwIfSaslClientIsNull(saslClient);
try {
byte[] response = (saslClient.hasInitialResponse() ? saslClient.evaluateChallenge(new byte[0]) : null);
sendSaslStartAsync(response, connection, new SingleResultCallback<BsonDocument>() {
@Override
public void onResult(final BsonDocument result, final Throwable t) {
if (t != null) {
callback.onResult(null, wrapException(t));
} else if (result.getBoolean("done").getValue()) {
callback.onResult(null, null);
} else {
new Continuator(saslClient, result, connection, callback).start();
}
}
});
} catch (SaslException e) {
throw wrapException(e);
}
return null;
}
});
} catch (Throwable t) {
callback.onResult(null, t);
}
}
public abstract String getMechanismName();
protected abstract SaslClient createSaslClient(ServerAddress serverAddress);
private void throwIfSaslClientIsNull(final SaslClient saslClient) {
if (saslClient == null) {
throw new MongoSecurityException(getMongoCredential(),
String.format("This JDK does not support the %s SASL mechanism", getMechanismName()));
}
}
@Nullable
private Subject getSubject() {
return getMongoCredential().getMechanismProperty(JAVA_SUBJECT_KEY, null);
}
private BsonDocument sendSaslStart(final byte[] outToken, final InternalConnection connection) {
return executeCommand(getMongoCredential().getSource(), createSaslStartCommandDocument(outToken), connection);
}
private BsonDocument sendSaslContinue(final BsonInt32 conversationId, final byte[] outToken, final InternalConnection connection) {
return executeCommand(getMongoCredential().getSource(), createSaslContinueDocument(conversationId, outToken), connection);
}
private void sendSaslStartAsync(final byte[] outToken, final InternalConnection connection,
final SingleResultCallback<BsonDocument> callback) {
executeCommandAsync(getMongoCredential().getSource(), createSaslStartCommandDocument(outToken), connection,
callback);
}
private void sendSaslContinueAsync(final BsonInt32 conversationId, final byte[] outToken, final InternalConnection connection,
final SingleResultCallback<BsonDocument> callback) {
executeCommandAsync(getMongoCredential().getSource(), createSaslContinueDocument(conversationId, outToken), connection,
callback);
}
private BsonDocument createSaslStartCommandDocument(final byte[] outToken) {
return new BsonDocument("saslStart", new BsonInt32(1)).append("mechanism", new BsonString(getMechanismName()))
.append("payload", new BsonBinary(outToken != null ? outToken : new byte[0]));
}
private BsonDocument createSaslContinueDocument(final BsonInt32 conversationId, final byte[] outToken) {
return new BsonDocument("saslContinue", new BsonInt32(1)).append("conversationId", conversationId)
.append("payload", new BsonBinary(outToken));
}
private void disposeOfSaslClient(final SaslClient saslClient) {
try {
saslClient.dispose();
} catch (SaslException e) {
}
}
private MongoException wrapException(final Throwable t) {
if (t instanceof MongoInterruptedException) {
return (MongoInterruptedException) t;
} else if (t instanceof MongoSecurityException) {
return (MongoSecurityException) t;
} else {
return new MongoSecurityException(getMongoCredential(), "Exception authenticating " + getMongoCredential(), t);
}
}
void doAsSubject(final java.security.PrivilegedAction<Void> action) {
if (getSubject() == null) {
action.run();
} else {
Subject.doAs(getSubject(), action);
}
}
private final class Continuator implements SingleResultCallback<BsonDocument> {
private final SaslClient saslClient;
private final BsonDocument saslStartDocument;
private final InternalConnection connection;
private final SingleResultCallback<Void> callback;
Continuator(final SaslClient saslClient, final BsonDocument saslStartDocument, final InternalConnection connection,
final SingleResultCallback<Void> callback) {
this.saslClient = saslClient;
this.saslStartDocument = saslStartDocument;
this.connection = connection;
this.callback = callback;
}
@Override
public void onResult(final BsonDocument result, final Throwable t) {
if (t != null) {
callback.onResult(null, wrapException(t));
disposeOfSaslClient(saslClient);
} else if (result.getBoolean("done").getValue()) {
callback.onResult(null, null);
disposeOfSaslClient(saslClient);
} else {
continueConversation(result);
}
}
public void start() {
continueConversation(saslStartDocument);
}
private void continueConversation(final BsonDocument result) {
try {
doAsSubject(new PrivilegedAction<Void>() {
@Override
public Void run() {
try {
sendSaslContinueAsync(saslStartDocument.getInt32("conversationId"),
saslClient.evaluateChallenge((result.getBinary("payload")).getData()), connection, Continuator.this);
} catch (SaslException e) {
throw wrapException(e);
}
return null;
}
});
} catch (Throwable t) {
callback.onResult(null, t);
disposeOfSaslClient(saslClient);
}
}
}
}