package org.jooq.impl;
import static java.lang.Boolean.FALSE;
import static java.lang.Boolean.TRUE;
import static org.jooq.SQLDialect.DERBY;
import static org.jooq.SQLDialect.FIREBIRD;
import static org.jooq.SQLDialect.H2;
import static org.jooq.SQLDialect.HSQLDB;
import static org.jooq.SQLDialect.MARIADB;
import static org.jooq.SQLDialect.MYSQL;
import static org.jooq.SQLDialect.POSTGRES;
import static org.jooq.conf.SettingsTools.renderLocale;
import static org.jooq.impl.DSL.name;
import static org.jooq.impl.DSL.select;
import static org.jooq.impl.DSL.unquotedName;
import static org.jooq.impl.Keywords.K_BEGIN;
import static org.jooq.impl.Keywords.K_BULK_COLLECT_INTO;
import static org.jooq.impl.Keywords.K_DECLARE;
import static org.jooq.impl.Keywords.K_END;
import static org.jooq.impl.Keywords.K_FOR;
import static org.jooq.impl.Keywords.K_FORALL;
import static org.jooq.impl.Keywords.K_FROM;
import static org.jooq.impl.Keywords.K_IN;
import static org.jooq.impl.Keywords.K_INTO;
import static org.jooq.impl.Keywords.K_OPEN;
import static org.jooq.impl.Keywords.K_OUTPUT;
import static org.jooq.impl.Keywords.K_RETURNING;
import static org.jooq.impl.Keywords.K_ROWCOUNT;
import static org.jooq.impl.Keywords.K_SELECT;
import static org.jooq.impl.Keywords.K_SQL;
import static org.jooq.impl.Keywords.K_TABLE;
import static org.jooq.impl.Tools.EMPTY_FIELD;
import static org.jooq.impl.Tools.EMPTY_STRING;
import static org.jooq.impl.Tools.flattenCollection;
import static org.jooq.impl.Tools.qualify;
import static org.jooq.impl.Tools.BooleanDataKey.DATA_EMULATE_BULK_INSERT_RETURNING;
import static org.jooq.impl.Tools.BooleanDataKey.DATA_UNALIAS_ALIASED_EXPRESSIONS;
import static org.jooq.impl.Tools.DataKey.DATA_DML_TARGET_TABLE;
import static org.jooq.tools.StringUtils.defaultIfNull;
import static org.jooq.util.sqlite.SQLiteDSL.rowid;
import java.sql.CallableStatement;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.time.Instant;
import java.time.OffsetDateTime;
import java.time.OffsetTime;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.jooq.Asterisk;
import org.jooq.Binding;
import org.jooq.CommonTableExpression;
import org.jooq.Condition;
import org.jooq.Configuration;
import org.jooq.Context;
import org.jooq.DSLContext;
import org.jooq.DataType;
import org.jooq.Delete;
import org.jooq.ExecuteContext;
import org.jooq.ExecuteListener;
import org.jooq.Field;
import org.jooq.Identity;
import org.jooq.Insert;
import org.jooq.Name;
import org.jooq.Param;
import org.jooq.QualifiedAsterisk;
import org.jooq.Record;
import org.jooq.Result;
import org.jooq.SQLDialect;
import org.jooq.Scope;
import org.jooq.Select;
import org.jooq.SelectFieldOrAsterisk;
import org.jooq.Table;
import org.jooq.UniqueKey;
import org.jooq.Update;
import org.jooq.conf.ExecuteWithoutWhere;
import org.jooq.conf.RenderNameCase;
import org.jooq.conf.SettingsTools;
import org.jooq.exception.DataAccessException;
import org.jooq.impl.DefaultUnwrapperProvider.DefaultUnwrapper;
import org.jooq.impl.Tools.BooleanDataKey;
import org.jooq.impl.Tools.DataKey;
import org.jooq.tools.JooqLogger;
import org.jooq.tools.jdbc.BatchedPreparedStatement;
import org.jooq.tools.jdbc.JDBCUtils;
abstract class AbstractDMLQuery<R extends Record> extends AbstractRowCountQuery {
private static final long serialVersionUID = -7438014075226919192L;
private static final JooqLogger log = JooqLogger.getLogger(AbstractQuery.class);
private static final Set<SQLDialect> NO_SUPPORT_INSERT_ALIASED_TABLE = SQLDialect.supportedBy(DERBY, FIREBIRD, H2, MARIADB, MYSQL);
private static final Set<SQLDialect> NATIVE_SUPPORT_INSERT_RETURNING = SQLDialect.supportedBy(FIREBIRD, MARIADB, POSTGRES);
private static final Set<SQLDialect> NATIVE_SUPPORT_UPDATE_RETURNING = SQLDialect.supportedBy(FIREBIRD, POSTGRES);
private static final Set<SQLDialect> NATIVE_SUPPORT_DELETE_RETURNING = SQLDialect.supportedBy(FIREBIRD, MARIADB, POSTGRES);
private final WithImpl with;
private final Table<R> table;
final SelectFieldList<SelectFieldOrAsterisk> returning;
final List<Field<?>> returningResolvedAsterisks;
Result<Record> returnedResult;
Result<R> returned;
AbstractDMLQuery(Configuration configuration, WithImpl with, Table<R> table) {
super(configuration);
this.with = with;
this.table = table;
this.returning = new SelectFieldList<>();
this.returningResolvedAsterisks = new ArrayList<>();
}
public final void setReturning() {
setReturning(table.fields());
}
public final void setReturning(Identity<R, ?> identity) {
if (identity != null)
setReturning(identity.getField());
}
public final void setReturning(SelectFieldOrAsterisk... fields) {
setReturning(Arrays.asList(fields));
}
public final void setReturning(Collection<? extends SelectFieldOrAsterisk> fields) {
returning.clear();
returning.addAll(fields);
returningResolvedAsterisks.clear();
for (SelectFieldOrAsterisk f : fields)
if (f instanceof Field<?>)
returningResolvedAsterisks.add((Field<?>) f);
else if (f instanceof QualifiedAsterisk)
returningResolvedAsterisks.addAll(Arrays.asList(((QualifiedAsterisk) f).qualifier().fields()));
else if (f instanceof Asterisk)
returningResolvedAsterisks.addAll(Arrays.asList(table.fields()));
else
throw new AssertionError("Type not supported: " + f);
}
public final R getReturnedRecord() {
if (getReturnedRecords().isEmpty())
return null;
return getReturnedRecords().get(0);
}
@SuppressWarnings("unchecked")
public final Result<R> getReturnedRecords() {
if (returned == null) {
if (table.fields().length > 0) {
warnOnAPIMisuse();
returned = getResult().into(table);
}
else {
returned = (Result<R>) getResult();
}
}
return returned;
}
private final void warnOnAPIMisuse() {
for (Field<?> field : getResult().fields())
if (table.field(field) == null)
log.warn("API misuse", "Column " + field + " has been requested through the returning() clause, which is not present in table " + table + ". Use StoreQuery.getResult() or the returningResult() clause instead.");
}
final Table<R> table() {
return table;
}
final Table<?> table(Context<?> ctx) {
if (NO_SUPPORT_INSERT_ALIASED_TABLE.contains(ctx.dialect()) && this instanceof Insert)
return defaultIfNull(Tools.aliased(table()), table());
else
return table();
}
public final Result<?> getResult() {
if (returnedResult == null)
returnedResult = new ResultImpl<>(configuration(), returningResolvedAsterisks);
return returnedResult;
}
@Override
public final void accept(Context<?> ctx) {
WithImpl w = with;
ctx.data(DATA_DML_TARGET_TABLE, table);
if (w != null)
ctx.visit(w);
boolean previousDeclareFields = ctx.declareFields();
{
accept0(ctx);
}
ctx.data().remove(DATA_DML_TARGET_TABLE);
}
abstract void accept0(Context<?> ctx);
void executeWithoutWhere(String message, ExecuteWithoutWhere executeWithoutWhere) {
switch (executeWithoutWhere) {
case IGNORE:
break;
case LOG_DEBUG:
if (log.isDebugEnabled())
log.debug(message, "A statement is executed without WHERE clause");
break;
case LOG_INFO:
if (log.isInfoEnabled())
log.info(message, "A statement is executed without WHERE clause");
break;
case LOG_WARN:
log.warn(message, "A statement is executed without WHERE clause");
break;
case THROW:
throw new DataAccessException("A statement is executed without WHERE clause");
}
}
final void toSQLReturning(Context<?> ctx) {
if (!returning.isEmpty()) {
if (nativeSupportReturning(ctx)) {
boolean declareFields = ctx.declareFields();
boolean qualify = ctx.qualify();
boolean unqualify = ctx.family() == MARIADB;
if (unqualify)
ctx.qualify(false);
ctx.formatSeparator()
.visit(K_RETURNING)
.sql(' ')
.declareFields(true)
.visit(
ctx.family() == FIREBIRD || ctx.family() == MARIADB
? new SelectFieldList<>(returningResolvedAsterisks)
: returning
)
.declareFields(declareFields);
if (unqualify)
ctx.qualify(qualify);
}
}
}
private final boolean nativeSupportReturning(Scope ctx) {
return this instanceof Insert && NATIVE_SUPPORT_INSERT_RETURNING.contains(ctx.dialect())
|| this instanceof Update && NATIVE_SUPPORT_UPDATE_RETURNING.contains(ctx.dialect())
|| this instanceof Delete && NATIVE_SUPPORT_DELETE_RETURNING.contains(ctx.dialect());
}
@Override
protected final void prepare(ExecuteContext ctx) throws SQLException {
prepare0(ctx);
Tools.setFetchSize(ctx, 0);
}
private final void prepare0(ExecuteContext ctx) throws SQLException {
Connection connection = ctx.connection();
if (returning.isEmpty()) {
super.prepare(ctx);
}
else if (nativeSupportReturning(ctx)) {
super.prepare(ctx);
}
else {
switch (ctx.family()) {
case SQLITE:
case CUBRID:
super.prepare(ctx);
break;
case DERBY:
case H2:
case MARIADB:
case MYSQL:
if (ctx.statement() == null)
ctx.statement(connection.prepareStatement(ctx.sql(), Statement.RETURN_GENERATED_KEYS));
break;
case HSQLDB:
default: {
if (ctx.statement() == null) {
List<String> names = new ArrayList<>(returningResolvedAsterisks.size());
RenderNameCase style = SettingsTools.getRenderNameCase(configuration().settings());
if (style == RenderNameCase.UPPER)
for (Field<?> f : Tools.flattenCollection(returningResolvedAsterisks, false))
names.add(f.getName().toUpperCase(renderLocale(configuration().settings())));
else if (style == RenderNameCase.LOWER)
for (Field<?> f : Tools.flattenCollection(returningResolvedAsterisks, false))
names.add(f.getName().toLowerCase(renderLocale(configuration().settings())));
else
for (Field<?> f : Tools.flattenCollection(returningResolvedAsterisks, false))
names.add(f.getName());
ctx.statement(connection.prepareStatement(ctx.sql(), names.toArray(EMPTY_STRING)));
}
break;
}
}
}
}
@SuppressWarnings({ "rawtypes", "unchecked" })
@Override
protected final int execute(ExecuteContext ctx, ExecuteListener listener) throws SQLException {
returned = null;
returnedResult = null;
if (returning.isEmpty()) {
return super.execute(ctx, listener);
}
else {
int result = 1;
ResultSet rs;
switch (ctx.family()) {
case SQLITE: {
listener.executeStart(ctx);
result = executeImmediate(ctx.statement()).executeUpdate();
ctx.rows(result);
listener.executeEnd(ctx);
DSLContext create = ctx.dsl();
returnedResult =
create.select(returning)
.from(table)
.where(rowid().equal(rowid().getDataType().convert(create.lastID())))
.fetch();
returnedResult.attach(((DefaultExecuteContext) ctx).originalConfiguration());
return result;
}
case CUBRID: {
listener.executeStart(ctx);
result = executeImmediate(ctx.statement()).executeUpdate();
ctx.rows(result);
listener.executeEnd(ctx);
selectReturning(
((DefaultExecuteContext) ctx).originalConfiguration(),
ctx.configuration(),
ctx.dsl().lastID()
);
return result;
}
case DERBY:
case H2:
case MYSQL: {
return executeReturningGeneratedKeysFetchAdditionalRows(ctx, listener);
}
case MARIADB: {
if (!nativeSupportReturning(ctx))
return executeReturningGeneratedKeysFetchAdditionalRows(ctx, listener);
rs = executeReturningQuery(ctx, listener);
break;
}
case FIREBIRD:
case POSTGRES: {
rs = executeReturningQuery(ctx, listener);
break;
}
case HSQLDB:
default: {
rs = executeReturningGeneratedKeys(ctx, listener);
break;
}
}
ExecuteContext ctx2 = new DefaultExecuteContext(((DefaultExecuteContext) ctx).originalConfiguration());
ExecuteListener listener2 = ExecuteListeners.getAndStart(ctx2);
ctx2.resultSet(rs);
returnedResult = new CursorImpl<>(ctx2, listener2, returningResolvedAsterisks.toArray(EMPTY_FIELD), null, false, true).fetch();
if (!returnedResult.isEmpty() || ctx.family() != HSQLDB) {
result = returnedResult.size();
ctx.rows(result);
}
return result;
}
}
private final PreparedStatement executeImmediate(PreparedStatement s) throws SQLException {
if (DefaultUnwrapper.isWrapperFor(s, BatchedPreparedStatement.class))
s.unwrap(BatchedPreparedStatement.class).setExecuteImmediate(true);
return s;
}
private final ResultSet executeReturningGeneratedKeys(ExecuteContext ctx, ExecuteListener listener) throws SQLException {
listener.executeStart(ctx);
int result = executeImmediate(ctx.statement()).executeUpdate();
ctx.rows(result);
listener.executeEnd(ctx);
return ctx.statement().getGeneratedKeys();
}
private final int executeReturningGeneratedKeysFetchAdditionalRows(ExecuteContext ctx, ExecuteListener listener) throws SQLException {
ResultSet rs;
listener.executeStart(ctx);
int result = executeImmediate(ctx.statement()).executeUpdate();
ctx.rows(result);
listener.executeEnd(ctx);
try {
rs = ctx.statement().getGeneratedKeys();
}
catch (SQLException e) {
throw e;
}
try {
List<Object> list = new ArrayList<>();
if (rs != null)
while (rs.next())
list.add(rs.getObject(1));
selectReturning(
((DefaultExecuteContext) ctx).originalConfiguration(),
ctx.configuration(),
list.toArray()
);
return result;
}
finally {
JDBCUtils.safeClose(rs);
}
}
private final ResultSet executeReturningQuery(ExecuteContext ctx, ExecuteListener listener) throws SQLException {
listener.executeStart(ctx);
ResultSet rs = ctx.statement().executeQuery();
listener.executeEnd(ctx);
return rs;
}
@SuppressWarnings({ "unchecked", "rawtypes" })
private final void selectReturning(
Configuration originalConfiguration,
Configuration derivedConfiguration,
Object... values
) {
if (values != null && values.length > 0) {
final Field<Object> returnIdentity = (Field<Object>) returnedIdentity();
if (returnIdentity != null) {
Object[] ids = new Object[values.length];
for (int i = 0; i < values.length; i++)
ids[i] = returnIdentity.getDataType().convert(values[i]);
if (returningResolvedAsterisks.size() == 1 && new Fields<>(returningResolvedAsterisks).field(returnIdentity) != null) {
AbstractRow fields = Tools.row0(returningResolvedAsterisks.toArray(EMPTY_FIELD));
for (final Object id : ids) {
((Result) getResult()).add(
Tools.newRecord(
true,
AbstractRecord.class,
fields,
originalConfiguration)
.operate(new RecordOperation<AbstractRecord, RuntimeException>() {
@Override
public AbstractRecord operate(AbstractRecord record) throws RuntimeException {
record.values[0] = id;
record.originals[0] = id;
return record;
}
}));
}
}
else {
returnedResult =
derivedConfiguration.dsl()
.select(returning)
.from(table)
.where(table.field(returnIdentity).in(ids))
.fetch();
returnedResult.attach(originalConfiguration);
}
}
}
}
private final Field<?> returnedIdentity() {
if (table.getIdentity() != null)
return table.getIdentity().getField();
else
for (Field<?> field : returningResolvedAsterisks)
if (field.getDataType().identity())
return field;
return null;
}
}