package com.mongodb.internal.connection;
import com.mongodb.AuthenticationMechanism;
import com.mongodb.MongoCredential;
import com.mongodb.MongoException;
import com.mongodb.MongoInternalException;
import com.mongodb.ServerAddress;
import com.mongodb.lang.NonNull;
import com.mongodb.lang.Nullable;
import org.bson.BsonBinary;
import org.bson.BsonBinaryWriter;
import org.bson.BsonDocument;
import org.bson.BsonInt32;
import org.bson.BsonString;
import org.bson.RawBsonDocument;
import org.bson.codecs.BsonDocumentCodec;
import org.bson.codecs.EncoderContext;
import org.bson.io.BasicOutputBuffer;
import javax.security.sasl.SaslClient;
import javax.security.sasl.SaslException;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.security.SecureRandom;
import java.time.Instant;
import java.time.ZoneId;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import static com.mongodb.AuthenticationMechanism.MONGODB_AWS;
import static java.lang.String.format;
public class AwsAuthenticator extends SaslAuthenticator {
private static final String MONGODB_AWS_MECHANISM_NAME = "MONGODB-AWS";
private static final int RANDOM_LENGTH = 32;
public AwsAuthenticator(final MongoCredentialWithCache credential) {
super(credential);
if (getMongoCredential().getAuthenticationMechanism() != MONGODB_AWS) {
throw new MongoException("Incorrect mechanism: " + getMongoCredential().getMechanism());
}
}
@Override
public String getMechanismName() {
return MONGODB_AWS_MECHANISM_NAME;
}
@Override
protected SaslClient createSaslClient(final ServerAddress serverAddress) {
return new AwsSaslClient(getMongoCredential());
}
private static class AwsSaslClient implements SaslClient {
private final MongoCredential credential;
private final byte[] clientNonce = new byte[RANDOM_LENGTH];
private int step = -1;
private String httpResponse;
AwsSaslClient(final MongoCredential credential) {
this.credential = credential;
}
@Override
public String getMechanismName() {
AuthenticationMechanism authMechanism = credential.getAuthenticationMechanism();
if (authMechanism == null) {
throw new IllegalArgumentException("Authentication mechanism cannot be null");
}
return authMechanism.getMechanismName();
}
@Override
public boolean hasInitialResponse() {
return true;
}
@Override
public byte[] evaluateChallenge(final byte[] challenge) throws SaslException {
step++;
if (step == 0) {
return computeClientFirstMessage();
}
if (step == 1) {
return computeClientFinalMessage(challenge);
} else {
throw new SaslException(format("Too many steps involved in the %s negotiation.", getMechanismName()));
}
}
@Override
public boolean isComplete() {
return step == 1;
}
@Override
public byte[] unwrap(final byte[] bytes, final int i, final int i1) {
throw new UnsupportedOperationException("Not implemented yet!");
}
@Override
public byte[] wrap(final byte[] bytes, final int i, final int i1) {
throw new UnsupportedOperationException("Not implemented yet!");
}
@Override
public Object getNegotiatedProperty(final String s) {
throw new UnsupportedOperationException("Not implemented yet!");
}
@Override
public void dispose() {
}
private byte[] computeClientFirstMessage() {
new SecureRandom().nextBytes(this.clientNonce);
BsonDocument document = new BsonDocument()
.append("r", new BsonBinary(this.clientNonce))
.append("p", new BsonInt32('n'));
return toBson(document);
}
private byte[] computeClientFinalMessage(final byte[] serverFirst) throws SaslException {
final BsonDocument document = new RawBsonDocument(serverFirst);
final String host = document.getString("h").getValue();
final byte[] serverNonce = document.getBinary("s").getData();
if (serverNonce.length != (2 * RANDOM_LENGTH)) {
throw new SaslException(String.format("Server nonce must be %d bytes", 2 * RANDOM_LENGTH));
} else if (!Arrays.equals(Arrays.copyOf(serverNonce, RANDOM_LENGTH), this.clientNonce)) {
throw new SaslException(String.format("The first %d bytes of the server nonce must be the client nonce", RANDOM_LENGTH));
}
String timestamp = DateTimeFormatter.ofPattern("yyyyMMdd'T'HHmmss'Z'")
.withZone(ZoneId.of("UTC"))
.format(Instant.now());
String token = getSessionToken();
final AuthorizationHeader authorizationHeader = AuthorizationHeader.builder()
.setAccessKeyID(getUserName())
.setSecretKey(getPassword())
.setSessionToken(token)
.setHost(host)
.setNonce(serverNonce)
.setTimestamp(timestamp)
.build();
BsonDocument ret = new BsonDocument()
.append("a", new BsonString(authorizationHeader.toString()))
.append("d", new BsonString(authorizationHeader.getTimestamp()));
if (token != null) {
ret.append("t", new BsonString(token));
}
return toBson(ret);
}
private byte[] toBson(final BsonDocument document) {
byte[] bytes;
BasicOutputBuffer buffer = new BasicOutputBuffer();
new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build());
bytes = new byte[buffer.size()];
System.arraycopy(buffer.getInternalBuffer(), 0, bytes, 0, buffer.getSize());
return bytes;
}
@NonNull
String getUserName() {
String userName = credential.getUserName();
if (userName == null) {
userName = System.getenv("AWS_ACCESS_KEY_ID");
if (userName == null) {
userName = BsonDocument
.parse(getHttpResponse())
.getString("AccessKeyId")
.getValue();
}
}
return userName;
}
@NonNull
private String getPassword() {
char[] password = credential.getPassword();
if (password == null) {
if (System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
password = System.getenv("AWS_SECRET_ACCESS_KEY").toCharArray();
} else {
password = BsonDocument
.parse(getHttpResponse())
.getString("SecretAccessKey")
.getValue()
.toCharArray();
}
}
return new String(password);
}
@Nullable
private String getSessionToken() {
String token = credential.getMechanismProperty("AWS_SESSION_TOKEN", null);
if (credential.getUserName() != null) {
return token;
}
if (token != null) {
throw new IllegalArgumentException("The connection string contains a session token without credentials");
}
if ((System.getenv("AWS_SECRET_ACCESS_KEY") != null) || (System.getenv("AWS_ACCESS_KEY_ID") != null)
|| (System.getenv("AWS_SESSION_TOKEN") != null)) {
if (System.getenv("AWS_SECRET_ACCESS_KEY") == null || System.getenv("AWS_ACCESS_KEY_ID") == null) {
throw new IllegalArgumentException("The environment variables 'AWS_ACCESS_KEY_ID' and 'AWS_SECRET_ACCESS_KEY' must "
+ "either both be set or both be null");
}
return System.getenv("AWS_SESSION_TOKEN");
}
return BsonDocument
.parse(getHttpResponse())
.getString("Token")
.getValue();
}
@NonNull
private String getHttpResponse() {
if (httpResponse != null) {
return httpResponse;
}
String path = System.getenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI");
httpResponse = (path == null)
? getEc2Response()
: getHttpContents("GET", "http://169.254.170.2" + path, null);
return httpResponse;
}
private String getEc2Response() {
final String endpoint = "http://169.254.169.254";
final String path = "/latest/meta-data/iam/security-credentials/";
Map<String, String> header = new HashMap<>();
header.put("X-aws-ec2-metadata-token-ttl-seconds", "30");
String token = getHttpContents("PUT", endpoint + "/latest/api/token", header);
header.clear();
header.put("X-aws-ec2-metadata-token", token);
String role = getHttpContents("GET", endpoint + path, header);
return getHttpContents("GET", endpoint + path + role, header);
}
@NonNull
private static String getHttpContents(final String method, final String endpoint, final Map<String, String> headers) {
StringBuilder content = new StringBuilder();
HttpURLConnection conn = null;
try {
conn = (HttpURLConnection) new URL(endpoint).openConnection();
conn.setRequestMethod(method);
conn.setReadTimeout(10000);
if (headers != null) {
for (Map.Entry<String, String> kvp : headers.entrySet()) {
conn.setRequestProperty(kvp.getKey(), kvp.getValue());
}
}
int status = conn.getResponseCode();
if (status != HttpURLConnection.HTTP_OK) {
throw new IOException(String.format("%d %s", status, conn.getResponseMessage()));
}
try (BufferedReader in = new BufferedReader(new InputStreamReader(conn.getInputStream(), StandardCharsets.UTF_8))) {
String inputLine;
while ((inputLine = in.readLine()) != null) {
content.append(inputLine);
}
}
} catch (IOException e) {
throw new MongoInternalException("Unexpected IOException", e);
} finally {
if (conn != null) {
conn.disconnect();
}
}
return content.toString();
}
}
}