package com.microsoft.sqlserver.jdbc;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.TimeUnit.SECONDS;
import java.text.MessageFormat;
import java.util.Base64;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
class CacheClear implements Runnable {
private String keylookupValue;
static private java.util.logging.Logger aeLogger = java.util.logging.Logger
.getLogger("com.microsoft.sqlserver.jdbc.CacheClear");
CacheClear(String keylookupValue) {
this.keylookupValue = keylookupValue;
}
@Override
public void run() {
synchronized (SQLServerSymmetricKeyCache.lock) {
SQLServerSymmetricKeyCache instance = SQLServerSymmetricKeyCache.getInstance();
if (instance.getCache().containsKey(keylookupValue)) {
instance.getCache().get(keylookupValue).zeroOutKey();
instance.getCache().remove(keylookupValue);
if (aeLogger.isLoggable(java.util.logging.Level.FINE)) {
aeLogger.fine("Removed encryption key from cache...");
}
}
}
}
}
final class SQLServerSymmetricKeyCache {
static final Object lock = new Object();
private final ConcurrentHashMap<String, SQLServerSymmetricKey> cache;
private static final SQLServerSymmetricKeyCache instance = new SQLServerSymmetricKeyCache();
private static ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1, new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r);
t.setDaemon(true);
return t;
}
});
static final private java.util.logging.Logger aeLogger = java.util.logging.Logger
.getLogger("com.microsoft.sqlserver.jdbc.SQLServerSymmetricKeyCache");
private SQLServerSymmetricKeyCache() {
cache = new ConcurrentHashMap<>();
}
static SQLServerSymmetricKeyCache getInstance() {
return instance;
}
ConcurrentHashMap<String, SQLServerSymmetricKey> getCache() {
return cache;
}
SQLServerSymmetricKey getKey(EncryptionKeyInfo keyInfo, SQLServerConnection connection) throws SQLServerException {
SQLServerSymmetricKey encryptionKey = null;
synchronized (lock) {
String serverName = connection.getTrustedServerNameAE();
assert null != serverName : "serverName should not be null in getKey.";
StringBuilder keyLookupValuebuffer = new StringBuilder(serverName);
String keyLookupValue;
keyLookupValuebuffer.append(":");
keyLookupValuebuffer
.append(Base64.getEncoder().encodeToString((new String(keyInfo.encryptedKey, UTF_8)).getBytes()));
keyLookupValuebuffer.append(":");
keyLookupValuebuffer.append(keyInfo.keyStoreName);
keyLookupValue = keyLookupValuebuffer.toString();
keyLookupValuebuffer.setLength(0);
if (aeLogger.isLoggable(java.util.logging.Level.FINE)) {
aeLogger.fine("Checking trusted master key path...");
}
Boolean[] hasEntry = new Boolean[1];
List<String> trustedKeyPaths = SQLServerConnection.getColumnEncryptionTrustedMasterKeyPaths(serverName,
hasEntry);
if (hasEntry[0]) {
if ((null == trustedKeyPaths) || (0 == trustedKeyPaths.size())
|| (!trustedKeyPaths.contains(keyInfo.keyPath))) {
MessageFormat form = new MessageFormat(SQLServerException.getErrString("R_UntrustedKeyPath"));
Object[] msgArgs = {keyInfo.keyPath, serverName};
throw new SQLServerException(this, form.format(msgArgs), null, 0, false);
}
}
if (aeLogger.isLoggable(java.util.logging.Level.FINE)) {
aeLogger.fine("Checking Symmetric key cache...");
}
if (!cache.containsKey(keyLookupValue)) {
byte[] plaintextKey;
plaintextKey = connection.getColumnEncryptionKeyStoreProvider(keyInfo.keyStoreName)
.decryptColumnEncryptionKey(keyInfo.keyPath, keyInfo.algorithmName, keyInfo.encryptedKey);
encryptionKey = new SQLServerSymmetricKey(plaintextKey);
long columnEncryptionKeyCacheTtl = SQLServerConnection.getColumnEncryptionKeyCacheTtl();
if (0 != columnEncryptionKeyCacheTtl) {
cache.putIfAbsent(keyLookupValue, encryptionKey);
if (aeLogger.isLoggable(java.util.logging.Level.FINE)) {
aeLogger.fine("Adding encryption key to cache...");
}
scheduler.schedule(new CacheClear(keyLookupValue), columnEncryptionKeyCacheTtl, SECONDS);
}
} else {
encryptionKey = cache.get(keyLookupValue);
}
}
return encryptionKey;
}
}