package io.vertx.pgclient.impl;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.DecoderException;
import io.vertx.core.impl.ContextInternal;
import io.vertx.pgclient.PgException;
import io.vertx.pgclient.impl.codec.PgCodec;
import io.vertx.sqlclient.impl.Connection;
import io.vertx.sqlclient.impl.Notice;
import io.vertx.sqlclient.impl.Notification;
import io.vertx.sqlclient.impl.QueryResultHandler;
import io.vertx.sqlclient.impl.SocketConnectionBase;
import io.vertx.sqlclient.impl.command.CommandBase;
import io.vertx.sqlclient.impl.command.InitCommand;
import io.vertx.core.*;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.net.impl.NetSocketInternal;
import io.vertx.sqlclient.impl.command.QueryCommandBase;
import io.vertx.sqlclient.impl.command.SimpleQueryCommand;
import io.vertx.sqlclient.impl.command.TxCommand;
import io.vertx.sqlclient.spi.DatabaseMetadata;
import java.util.Map;
import java.util.function.Predicate;
public class PgSocketConnection extends SocketConnectionBase {
private PgCodec codec;
public int processId;
public int secretKey;
public PgDatabaseMetadata dbMetaData;
public PgSocketConnection(NetSocketInternal socket,
boolean cachePreparedStatements,
int preparedStatementCacheSize,
Predicate<String> preparedStatementCacheSqlFilter,
int pipeliningLimit,
ContextInternal context) {
super(socket, cachePreparedStatements, preparedStatementCacheSize, preparedStatementCacheSqlFilter, pipeliningLimit, context);
}
@Override
public void init() {
codec = new PgCodec();
ChannelPipeline pipeline = socket.channelHandlerContext().pipeline();
pipeline.addBefore("handler", "codec", codec);
super.init();
}
void sendStartupMessage(String username, String password, String database, Map<String, String> properties, Promise<Connection> completionHandler) {
InitCommand cmd = new InitCommand(this, username, password, database, properties);
schedule(cmd, completionHandler);
}
void sendCancelRequestMessage(int processId, int secretKey, Handler<AsyncResult<Void>> handler) {
Buffer buffer = Buffer.buffer(16);
buffer.appendInt(16);
buffer.appendInt(80877102);
buffer.appendInt(processId);
buffer.appendInt(secretKey);
socket.write(buffer, ar -> {
if (ar.succeeded()) {
if (status == Status.CONNECTED) {
status = Status.CLOSING;
socket.close();
}
handler.handle(Future.succeededFuture());
} else {
handler.handle(Future.failedFuture(ar.cause()));
}
});
}
@Override
protected void handleMessage(Object msg) {
super.handleMessage(msg);
if (msg instanceof Notification) {
handleEvent(msg);
} else if (msg instanceof Notice) {
handleNotice((Notice) msg);
}
}
private void handleNotice(Notice notice) {
notice.log(logger);
}
@Override
public int getProcessId() {
return processId;
}
@Override
public int getSecretKey() {
return secretKey;
}
@Override
public DatabaseMetadata getDatabaseMetaData() {
return dbMetaData;
}
void upgradeToSSLConnection(Handler<AsyncResult<Void>> completionHandler) {
ChannelPipeline pipeline = socket.channelHandlerContext().pipeline();
Promise<Void> upgradePromise = Promise.promise();
upgradePromise.future().onComplete(ar->{
if (ar.succeeded()) {
completionHandler.handle(Future.succeededFuture());
} else {
Throwable cause = ar.cause();
if (cause instanceof DecoderException) {
DecoderException err = (DecoderException) cause;
cause = err.getCause();
}
completionHandler.handle(Future.failedFuture(cause));
}
});
pipeline.addBefore("handler", "initiate-ssl-handler", new InitiateSslHandler(this, upgradePromise));
}
@Override
protected <R> void doSchedule(CommandBase<R> cmd, Handler<AsyncResult<R>> handler) {
if (cmd instanceof TxCommand) {
TxCommand<R> tx = (TxCommand<R>) cmd;
SimpleQueryCommand<Void> cmd2 = new SimpleQueryCommand<>(
tx.kind.sql,
false,
false,
QueryCommandBase.NULL_COLLECTOR,
QueryResultHandler.NOOP_HANDLER);
super.doSchedule(cmd2, ar -> handler.handle(ar.map(tx.result)));
} else {
super.doSchedule(cmd, handler);
}
}
@Override
public boolean isIndeterminatePreparedStatementError(Throwable error) {
if (error instanceof PgException) {
PgException e = (PgException) error;
return "42P18".equals(e.getCode());
}
return false;
}
}