package io.vertx.pgclient.impl;
import io.vertx.pgclient.PgConnectOptions;
import io.vertx.pgclient.SslMode;
import io.vertx.sqlclient.impl.Connection;
import io.vertx.sqlclient.impl.command.CommandResponse;
import io.vertx.core.*;
import io.vertx.core.impl.NetSocketInternal;
import io.vertx.core.net.*;
import java.util.HashMap;
import java.util.Map;
class PgConnectionFactory {
private final NetClient client;
private final Context ctx;
private final boolean registerCloseHook;
private final String host;
private final int port;
private final SslMode sslMode;
private final TrustOptions trustOptions;
private final String hostnameVerificationAlgorithm;
private final String database;
private final String username;
private final String password;
private final Map<String, String> properties;
private final boolean cachePreparedStatements;
private final int preparedStatementCacheSize;
private final int preparedStatementCacheSqlLimit;
private final int pipeliningLimit;
private final boolean isUsingDomainSocket;
private final Closeable hook;
PgConnectionFactory(Context context,
boolean registerCloseHook,
PgConnectOptions options) {
hook = this::close;
this.registerCloseHook = registerCloseHook;
ctx = context;
if (registerCloseHook) {
ctx.addCloseHook(hook);
}
NetClientOptions netClientOptions = new NetClientOptions(options);
netClientOptions.setSsl(false);
this.sslMode = options.getSslMode();
this.hostnameVerificationAlgorithm = netClientOptions.getHostnameVerificationAlgorithm();
this.trustOptions = netClientOptions.getTrustOptions();
this.host = options.getHost();
this.port = options.getPort();
this.database = options.getDatabase();
this.username = options.getUser();
this.password = options.getPassword();
this.properties = new HashMap<>(options.getProperties());
this.cachePreparedStatements = options.getCachePreparedStatements();
this.pipeliningLimit = options.getPipeliningLimit();
this.preparedStatementCacheSize = options.getPreparedStatementCacheMaxSize();
this.preparedStatementCacheSqlLimit = options.getPreparedStatementCacheSqlLimit();
this.isUsingDomainSocket = options.isUsingDomainSocket();
this.client = context.owner().createNetClient(netClientOptions);
}
private void close(Handler<AsyncResult<Void>> completionHandler) {
client.close();
completionHandler.handle(Future.succeededFuture());
}
void close() {
if (registerCloseHook) {
ctx.removeCloseHook(hook);
}
client.close();
}
void connectAndInit(Handler<AsyncResult<Connection>> completionHandler) {
connect(ar -> {
if (ar.succeeded()) {
PgSocketConnection conn = ar.result();
conn.init();
conn.sendStartupMessage(username, password, database, properties, completionHandler);
} else {
completionHandler.handle(CommandResponse.failure(ar.cause()));
}
});
}
void connect(Handler<AsyncResult<PgSocketConnection>> handler) {
switch (sslMode) {
case DISABLE:
doConnect(false, handler);
break;
case ALLOW:
doConnect(false, ar -> {
if (ar.succeeded()) {
handler.handle(Future.succeededFuture(ar.result()));
} else {
doConnect(true, handler);
}
});
break;
case PREFER:
doConnect(true, ar -> {
if (ar.succeeded()) {
handler.handle(Future.succeededFuture(ar.result()));
} else {
doConnect(false, handler);
}
});
break;
case VERIFY_FULL:
if (hostnameVerificationAlgorithm == null || hostnameVerificationAlgorithm.isEmpty()) {
handler.handle(Future.failedFuture(new IllegalArgumentException("Host verification algorithm must be specified under verify-full sslmode")));
return;
}
case VERIFY_CA:
if (trustOptions == null) {
handler.handle(Future.failedFuture(new IllegalArgumentException("Trust options must be specified under verify-full or verify-ca sslmode")));
return;
}
case REQUIRE:
doConnect(true, handler);
break;
default:
throw new IllegalArgumentException("Unsupported SSL mode");
}
}
private void doConnect(boolean ssl, Handler<AsyncResult<PgSocketConnection>> handler) {
if (Vertx.currentContext() != ctx) {
throw new IllegalStateException();
}
SocketAddress socketAddress;
if (!isUsingDomainSocket) {
socketAddress = SocketAddress.inetSocketAddress(port, host);
} else {
socketAddress = SocketAddress.domainSocketAddress(host + "/.s.PGSQL." + port);
}
Promise<NetSocket> promise = Promise.promise();
promise.future().setHandler(ar -> {
if (ar.succeeded()) {
NetSocketInternal socket = (NetSocketInternal) ar.result();
PgSocketConnection conn = newSocketConnection(socket);
if (ssl && !isUsingDomainSocket) {
conn.upgradeToSSLConnection(ar2 -> {
if (ar2.succeeded()) {
handler.handle(Future.succeededFuture(conn));
} else {
handler.handle(Future.failedFuture(ar2.cause()));
}
});
} else {
handler.handle(Future.succeededFuture(conn));
}
} else {
handler.handle(Future.failedFuture(ar.cause()));
}
});
try {
client.connect(socketAddress, null, promise);
} catch (Exception e) {
promise.fail(e);
}
}
private PgSocketConnection newSocketConnection(NetSocketInternal socket) {
return new PgSocketConnection(socket, cachePreparedStatements, preparedStatementCacheSize, preparedStatementCacheSqlLimit, pipeliningLimit, ctx);
}
}