package io.reactiverse.pgclient.impl;
import io.reactiverse.pgclient.PgConnectOptions;
import io.reactiverse.pgclient.SslMode;
import io.vertx.core.*;
import io.vertx.core.impl.NetSocketInternal;
import io.vertx.core.net.*;
public class PgConnectionFactory {
private final NetClient client;
private final Context ctx;
private final boolean registerCloseHook;
private final String host;
private final int port;
private final SslMode sslMode;
private final TrustOptions trustOptions;
private final String hostnameVerificationAlgorithm;
private final String database;
private final String username;
private final String password;
private final boolean cachePreparedStatements;
private final int pipeliningLimit;
private final boolean isUsingDomainSocket;
private final Closeable hook;
public PgConnectionFactory(Context context,
boolean registerCloseHook,
PgConnectOptions options) {
hook = this::close;
this.registerCloseHook = registerCloseHook;
ctx = context;
if (registerCloseHook) {
ctx.addCloseHook(hook);
}
NetClientOptions netClientOptions = new NetClientOptions(options);
netClientOptions.setSsl(false);
this.sslMode = options.getSslMode();
this.hostnameVerificationAlgorithm = netClientOptions.getHostnameVerificationAlgorithm();
this.trustOptions = netClientOptions.getTrustOptions();
this.host = options.getHost();
this.port = options.getPort();
this.database = options.getDatabase();
this.username = options.getUser();
this.password = options.getPassword();
this.cachePreparedStatements = options.getCachePreparedStatements();
this.pipeliningLimit = options.getPipeliningLimit();
this.isUsingDomainSocket = options.isUsingDomainSocket();
this.client = context.owner().createNetClient(netClientOptions);
}
private void close(Handler<AsyncResult<Void>> completionHandler) {
client.close();
completionHandler.handle(Future.succeededFuture());
}
public void close() {
if (registerCloseHook) {
ctx.removeCloseHook(hook);
}
client.close();
}
public void create(Handler<? super CommandResponse<Connection>> completionHandler) {
connect(ar -> {
if (ar.succeeded()) {
SocketConnection conn = ar.result();
conn.initializeCodec();
conn.sendStartupMessage(username, password, database, completionHandler);
} else {
completionHandler.handle(CommandResponse.failure(ar.cause()));
}
});
}
public void connect(Handler<AsyncResult<SocketConnection>> handler) {
switch (sslMode) {
case DISABLE:
doConnect(false, handler);
break;
case ALLOW:
doConnect(false, ar -> {
if (ar.succeeded()) {
handler.handle(Future.succeededFuture(ar.result()));
} else {
doConnect(true, handler);
}
});
break;
case PREFER:
doConnect(true, ar -> {
if (ar.succeeded()) {
handler.handle(Future.succeededFuture(ar.result()));
} else {
doConnect(false, handler);
}
});
break;
case VERIFY_FULL:
if (hostnameVerificationAlgorithm == null || hostnameVerificationAlgorithm.isEmpty()) {
handler.handle(Future.failedFuture(new IllegalArgumentException("Host verification algorithm must be specified under verify-full sslmode")));
return;
}
case VERIFY_CA:
if (trustOptions == null) {
handler.handle(Future.failedFuture(new IllegalArgumentException("Trust options must be specified under verify-full or verify-ca sslmode")));
return;
}
case REQUIRE:
doConnect(true, handler);
break;
default:
throw new IllegalArgumentException("Unsupported SSL mode");
}
}
private void doConnect(boolean ssl, Handler<AsyncResult<SocketConnection>> handler) {
if (Vertx.currentContext() != ctx) {
throw new IllegalStateException();
}
SocketAddress socketAddress;
if (!isUsingDomainSocket) {
socketAddress = SocketAddress.inetSocketAddress(port, host);
} else {
socketAddress = SocketAddress.domainSocketAddress(host + "/.s.PGSQL." + port);
}
Future<NetSocket> future = Future.<NetSocket>future().setHandler(ar -> {
if (ar.succeeded()) {
NetSocketInternal socket = (NetSocketInternal) ar.result();
SocketConnection conn = newSocketConnection(socket);
if (ssl && !isUsingDomainSocket) {
conn.upgradeToSSLConnection(ar2 -> {
if (ar2.succeeded()) {
handler.handle(Future.succeededFuture(conn));
} else {
handler.handle(Future.failedFuture(ar2.cause()));
}
});
} else {
handler.handle(Future.succeededFuture(conn));
}
} else {
handler.handle(Future.failedFuture(ar.cause()));
}
});
try {
client.connect(socketAddress, null, future);
} catch (Exception e) {
future.fail(e);
}
}
private SocketConnection newSocketConnection(NetSocketInternal socket) {
return new SocketConnection(socket, cachePreparedStatements, pipeliningLimit, ctx);
}
}