package io.vertx.jdbcclient.impl.actions;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.json.JsonArray;
import io.vertx.ext.jdbc.impl.actions.JDBCStatementHelper;
import io.vertx.ext.sql.SQLOptions;
import io.vertx.jdbcclient.SqlOutParam;
import io.vertx.sqlclient.Row;
import io.vertx.sqlclient.Tuple;
import io.vertx.sqlclient.impl.command.ExtendedQueryCommand;
import java.sql.*;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalTime;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collector;
public class JDBCPreparedQuery<C, R> extends JDBCQueryAction<C, R> {
private final ExtendedQueryCommand<R> query;
private final Tuple params;
private final List<Integer> outParams;
private static List<Integer> countOut(Tuple tuple) {
List<Integer> total = new ArrayList<>();
if (tuple != null) {
for (int i = 0; i < tuple.size(); i++) {
if (tuple.getValue(i) instanceof SqlOutParam) {
total.add(i + 1);
}
}
}
return total;
}
public JDBCPreparedQuery(JDBCStatementHelper helper, SQLOptions options, ExtendedQueryCommand<R> query, Collector<Row, C, R> collector, Tuple params) {
super(helper, options, collector);
this.query = query;
this.params = params;
this.outParams = countOut(params);
}
@Override
public JDBCResponse<R> execute(Connection conn) throws SQLException {
try (PreparedStatement ps = prepare(conn)) {
fillStatement(ps, conn);
return decode(ps, ps.execute(), true, outParams);
}
}
private PreparedStatement prepare(Connection conn) throws SQLException {
final String sql = query.sql();
if (outParams.size() > 0) {
return conn.prepareCall(sql);
} else {
boolean autoGeneratedKeys = options == null || options.isAutoGeneratedKeys();
boolean autoGeneratedIndexes = options != null && options.getAutoGeneratedKeysIndexes() != null;
if (autoGeneratedKeys && !autoGeneratedIndexes) {
return conn.prepareStatement(sql, Statement.RETURN_GENERATED_KEYS);
} else if (autoGeneratedIndexes) {
JsonArray indexes = options.getAutoGeneratedKeysIndexes();
try {
if (indexes.getValue(0) instanceof Number) {
int[] keys = new int[indexes.size()];
for (int i = 0; i < keys.length; i++) {
keys[i] = indexes.getInteger(i);
}
return conn.prepareStatement(sql, keys);
} else if (indexes.getValue(0) instanceof String) {
String[] keys = new String[indexes.size()];
for (int i = 0; i < keys.length; i++) {
keys[i] = indexes.getString(i);
}
return conn.prepareStatement(sql, keys);
} else {
throw new SQLException("Invalid type of index, only [int, String] allowed");
}
} catch (RuntimeException e) {
throw new SQLException(e);
}
} else {
return conn.prepareStatement(sql);
}
}
}
private void fillStatement(PreparedStatement ps, Connection conn) throws SQLException {
for (int i = 0; i < params.size(); i++) {
Object value = adaptType(conn, params.getValue(i));
if (value instanceof SqlOutParam) {
SqlOutParam outValue = (SqlOutParam) value;
if (outValue.in()) {
ps.setObject(i + 1, adaptType(conn, outValue.value()));
}
((CallableStatement) ps)
.registerOutParameter(i + 1, outValue.type());
} else {
ps.setObject(i + 1, value);
}
}
}
private Object adaptType(Connection conn, Object value) throws SQLException {
if (value instanceof LocalTime) {
LocalTime time = (LocalTime) value;
return Time.valueOf(time);
} else if (value instanceof LocalDate) {
LocalDate date = (LocalDate) value;
return Date.valueOf(date);
} else if (value instanceof Instant) {
Instant timestamp = (Instant) value;
return Timestamp.from(timestamp);
} else if (value instanceof Buffer) {
Buffer blob = (Buffer) value;
return conn.createBlob().setBytes(0, blob.getBytes());
}
return value;
}
}