package com.mongodb.client.internal;
import com.mongodb.MongoClientException;
import com.mongodb.MongoException;
import com.mongodb.MongoInternalException;
import com.mongodb.MongoSocketReadException;
import com.mongodb.ServerAddress;
import com.mongodb.client.model.vault.DataKeyOptions;
import com.mongodb.client.model.vault.EncryptOptions;
import com.mongodb.crypt.capi.MongoCrypt;
import com.mongodb.crypt.capi.MongoCryptContext;
import com.mongodb.crypt.capi.MongoCryptException;
import com.mongodb.crypt.capi.MongoDataKeyOptions;
import com.mongodb.crypt.capi.MongoExplicitEncryptOptions;
import com.mongodb.crypt.capi.MongoKeyDecryptor;
import com.mongodb.lang.Nullable;
import org.bson.BsonBinary;
import org.bson.BsonDocument;
import org.bson.BsonValue;
import org.bson.RawBsonDocument;
import java.io.Closeable;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import static com.mongodb.assertions.Assertions.notNull;
import static com.mongodb.crypt.capi.MongoCryptContext.State;
class Crypt implements Closeable {
private final MongoCrypt mongoCrypt;
private final CollectionInfoRetriever collectionInfoRetriever;
private final CommandMarker commandMarker;
private final KeyRetriever keyRetriever;
private final KeyManagementService keyManagementService;
private final boolean bypassAutoEncryption;
Crypt(final MongoCrypt mongoCrypt, final KeyRetriever keyRetriever, final KeyManagementService keyManagementService) {
this(mongoCrypt, null, null, keyRetriever, keyManagementService, false);
}
Crypt(final MongoCrypt mongoCrypt, @Nullable final CollectionInfoRetriever collectionInfoRetriever,
@Nullable final CommandMarker commandMarker, final KeyRetriever keyRetriever, final KeyManagementService keyManagementService,
final boolean bypassAutoEncryption) {
this.mongoCrypt = mongoCrypt;
this.collectionInfoRetriever = collectionInfoRetriever;
this.commandMarker = commandMarker;
this.keyRetriever = keyRetriever;
this.keyManagementService = keyManagementService;
this.bypassAutoEncryption = bypassAutoEncryption;
}
public RawBsonDocument encrypt(final String databaseName, final RawBsonDocument command) {
notNull("databaseName", databaseName);
notNull("command", command);
if (bypassAutoEncryption) {
return command;
}
try {
MongoCryptContext encryptionContext = mongoCrypt.createEncryptionContext(databaseName, command);
try {
return executeStateMachine(encryptionContext, databaseName);
} finally {
encryptionContext.close();
}
} catch (MongoCryptException e) {
throw wrapInClientException(e);
}
}
RawBsonDocument decrypt(final RawBsonDocument commandResponse) {
notNull("commandResponse", commandResponse);
try {
MongoCryptContext decryptionContext = mongoCrypt.createDecryptionContext(commandResponse);
try {
return executeStateMachine(decryptionContext, null);
} finally {
decryptionContext.close();
}
} catch (MongoCryptException e) {
throw wrapInClientException(e);
}
}
BsonDocument createDataKey(final String kmsProvider, final DataKeyOptions options) {
notNull("kmsProvider", kmsProvider);
notNull("options", options);
try {
MongoCryptContext dataKeyCreationContext = mongoCrypt.createDataKeyContext(kmsProvider,
MongoDataKeyOptions.builder()
.keyAltNames(options.getKeyAltNames())
.masterKey(options.getMasterKey())
.build());
try {
return executeStateMachine(dataKeyCreationContext, null);
} finally {
dataKeyCreationContext.close();
}
} catch (MongoCryptException e) {
throw wrapInClientException(e);
}
}
BsonBinary encryptExplicitly(final BsonValue value, final EncryptOptions options) {
notNull("value", value);
notNull("options", options);
try {
MongoExplicitEncryptOptions.Builder encryptOptionsBuilder = MongoExplicitEncryptOptions.builder()
.algorithm(options.getAlgorithm());
if (options.getKeyId() != null) {
encryptOptionsBuilder.keyId(options.getKeyId());
}
if (options.getKeyAltName() != null) {
encryptOptionsBuilder.keyAltName(options.getKeyAltName());
}
MongoCryptContext encryptionContext = mongoCrypt.createExplicitEncryptionContext(
new BsonDocument("v", value), encryptOptionsBuilder.build());
try {
return executeStateMachine(encryptionContext, null).getBinary("v");
} finally {
encryptionContext.close();
}
} catch (MongoCryptException e) {
throw wrapInClientException(e);
}
}
BsonValue decryptExplicitly(final BsonBinary value) {
notNull("value", value);
try {
MongoCryptContext decryptionContext = mongoCrypt.createExplicitDecryptionContext(new BsonDocument("v", value));
try {
return executeStateMachine(decryptionContext, null).get("v");
} finally {
decryptionContext.close();
}
} catch (MongoCryptException e) {
throw wrapInClientException(e);
}
}
@Override
public void close() {
mongoCrypt.close();
commandMarker.close();
keyRetriever.close();
}
private RawBsonDocument executeStateMachine(final MongoCryptContext cryptContext, final String databaseName) {
while (true) {
State state = cryptContext.getState();
switch (state) {
case NEED_MONGO_COLLINFO:
collInfo(cryptContext, databaseName);
break;
case NEED_MONGO_MARKINGS:
mark(cryptContext, databaseName);
break;
case NEED_MONGO_KEYS:
fetchKeys(cryptContext);
break;
case NEED_KMS:
decryptKeys(cryptContext);
break;
case READY:
return cryptContext.finish();
default:
throw new MongoInternalException("Unsupported encryptor state + " + state);
}
}
}
private void collInfo(final MongoCryptContext cryptContext, final String databaseName) {
try {
BsonDocument collectionInfo = collectionInfoRetriever.filter(databaseName, cryptContext.getMongoOperation());
if (collectionInfo != null) {
cryptContext.addMongoOperationResult(collectionInfo);
}
cryptContext.completeMongoOperation();
} catch (Throwable t) {
throw MongoException.fromThrowableNonNull(t);
}
}
private void mark(final MongoCryptContext cryptContext, final String databaseName) {
try {
RawBsonDocument markedCommand = commandMarker.mark(databaseName, cryptContext.getMongoOperation());
cryptContext.addMongoOperationResult(markedCommand);
cryptContext.completeMongoOperation();
} catch (Throwable t) {
throw MongoException.fromThrowableNonNull(t);
}
}
private void fetchKeys(final MongoCryptContext keyBroker) {
try {
for (BsonDocument bsonDocument : keyRetriever.find(keyBroker.getMongoOperation())) {
keyBroker.addMongoOperationResult(bsonDocument);
}
keyBroker.completeMongoOperation();
} catch (Throwable t) {
throw MongoException.fromThrowableNonNull(t);
}
}
private void decryptKeys(final MongoCryptContext cryptContext) {
try {
MongoKeyDecryptor keyDecryptor = cryptContext.nextKeyDecryptor();
while (keyDecryptor != null) {
decryptKey(keyDecryptor);
keyDecryptor = cryptContext.nextKeyDecryptor();
}
cryptContext.completeKeyDecryptors();
} catch (Throwable t) {
throw MongoException.fromThrowableNonNull(t);
}
}
private void decryptKey(final MongoKeyDecryptor keyDecryptor) {
InputStream inputStream = keyManagementService.stream(keyDecryptor.getHostName(), keyDecryptor.getMessage());
try {
int bytesNeeded = keyDecryptor.bytesNeeded();
while (bytesNeeded > 0) {
byte[] bytes = new byte[bytesNeeded];
int bytesRead = inputStream.read(bytes, 0, bytes.length);
keyDecryptor.feed(ByteBuffer.wrap(bytes, 0, bytesRead));
bytesNeeded = keyDecryptor.bytesNeeded();
}
} catch (IOException e) {
throw new MongoSocketReadException("Exception receiving message from key management service",
new ServerAddress(keyDecryptor.getHostName(), keyManagementService.getPort()), e);
} finally {
try {
inputStream.close();
} catch (IOException e) {
}
}
}
private MongoClientException wrapInClientException(final MongoCryptException e) {
return new MongoClientException("Exception in encryption library: " + e.getMessage(), e);
}
}