package org.jooq.impl;
import static org.jooq.DatePart.MONTH;
import static org.jooq.DatePart.SECOND;
import static org.jooq.SQLDialect.CUBRID;
import static org.jooq.SQLDialect.DERBY;
import static org.jooq.SQLDialect.FIREBIRD;
import static org.jooq.SQLDialect.H2;
import static org.jooq.SQLDialect.HSQLDB;
import static org.jooq.SQLDialect.POSTGRES;
import static org.jooq.SQLDialect.SQLITE;
import static org.jooq.impl.DSL.function;
import static org.jooq.impl.DSL.inline;
import static org.jooq.impl.DSL.keyword;
import static org.jooq.impl.DSL.two;
import static org.jooq.impl.DSL.val;
import static org.jooq.impl.ExpressionOperator.ADD;
import static org.jooq.impl.ExpressionOperator.BIT_AND;
import static org.jooq.impl.ExpressionOperator.BIT_NAND;
import static org.jooq.impl.ExpressionOperator.BIT_NOR;
import static org.jooq.impl.ExpressionOperator.BIT_OR;
import static org.jooq.impl.ExpressionOperator.BIT_XNOR;
import static org.jooq.impl.ExpressionOperator.BIT_XOR;
import static org.jooq.impl.ExpressionOperator.MULTIPLY;
import static org.jooq.impl.ExpressionOperator.SHL;
import static org.jooq.impl.ExpressionOperator.SHR;
import static org.jooq.impl.ExpressionOperator.SUBTRACT;
import static org.jooq.impl.Internal.iadd;
import static org.jooq.impl.Internal.idiv;
import static org.jooq.impl.Internal.imul;
import static org.jooq.impl.Internal.isub;
import static org.jooq.impl.Keywords.K_AS;
import static org.jooq.impl.Keywords.K_CAST;
import static org.jooq.impl.Keywords.K_DAY;
import static org.jooq.impl.Keywords.K_DAY_MICROSECOND;
import static org.jooq.impl.Keywords.K_DAY_MILLISECOND;
import static org.jooq.impl.Keywords.K_DAY_TO_SECOND;
import static org.jooq.impl.Keywords.K_INTERVAL;
import static org.jooq.impl.Keywords.K_MILLISECOND;
import static org.jooq.impl.Keywords.K_MONTH;
import static org.jooq.impl.Keywords.K_SECOND;
import static org.jooq.impl.Keywords.K_YEAR_MONTH;
import static org.jooq.impl.Keywords.K_YEAR_TO_MONTH;
import static org.jooq.impl.Names.N_ADD_DAYS;
import static org.jooq.impl.Names.N_ADD_MONTHS;
import static org.jooq.impl.Names.N_ADD_SECONDS;
import static org.jooq.impl.Names.N_DATEADD;
import static org.jooq.impl.Names.N_DATE_ADD;
import static org.jooq.impl.Names.N_SQL_TSI_FRAC_SECOND;
import static org.jooq.impl.Names.N_SQL_TSI_MILLI_SECOND;
import static org.jooq.impl.Names.N_SQL_TSI_MONTH;
import static org.jooq.impl.Names.N_SQL_TSI_SECOND;
import static org.jooq.impl.Names.N_STRFTIME;
import static org.jooq.impl.Names.N_TIMESTAMPADD;
import static org.jooq.impl.Tools.castIfNeeded;
import java.sql.Timestamp;
import java.util.Set;
import java.util.regex.Pattern;
import org.jooq.Context;
import org.jooq.DataType;
import org.jooq.DatePart;
import org.jooq.Field;
import org.jooq.Param;
import org.jooq.SQLDialect;
import org.jooq.conf.TransformUnneededArithmeticExpressions;
import org.jooq.exception.DataTypeException;
import org.jooq.tools.Convert;
import org.jooq.types.DayToSecond;
import org.jooq.types.Interval;
import org.jooq.types.YearToMonth;
import org.jooq.types.YearToSecond;
final class Expression<T> extends AbstractTransformable<T> {
private static final long serialVersionUID = -5522799070693019771L;
private static final Set<SQLDialect> SUPPORT_BIT_AND = SQLDialect.supportedBy(H2, HSQLDB);
private static final Set<SQLDialect> SUPPORT_BIT_OR_XOR = SQLDialect.supportedBy(H2, HSQLDB);
private static final Set<SQLDialect> EMULATE_BIT_XOR = SQLDialect.supportedBy(SQLITE);
private static final Set<SQLDialect> EMULATE_SHR_SHL = SQLDialect.supportedBy(HSQLDB);
private static final Set<SQLDialect> HASH_OP_FOR_BIT_XOR = SQLDialect.supportedBy(POSTGRES);
private static final Set<SQLDialect> SUPPORT_YEAR_TO_SECOND = SQLDialect.supportedBy(POSTGRES);
private final ExpressionOperator operator;
private final boolean internal;
private final Field<T> lhs;
private final Field<?> rhs;
Expression(ExpressionOperator operator, boolean internal, Field<T> lhs, Field<?> rhs) {
super(DSL.name(operator.toSQL()), lhs.getDataType());
this.operator = operator;
this.internal = internal;
this.lhs = lhs;
this.rhs = rhs;
}
@SuppressWarnings("unchecked")
@Override
final void accept0(Context<?> ctx) {
SQLDialect family = ctx.family();
if (BIT_AND == operator && SUPPORT_BIT_AND.contains(ctx.dialect()))
ctx.visit(function("bitand", getDataType(), lhs, rhs));
else if (BIT_AND == operator && FIREBIRD == family)
ctx.visit(function("bin_and", getDataType(), lhs, rhs));
else if (BIT_XOR == operator && SUPPORT_BIT_OR_XOR.contains(ctx.dialect()))
ctx.visit(function("bitxor", getDataType(), lhs, rhs));
else if (BIT_XOR == operator && FIREBIRD == family)
ctx.visit(function("bin_xor", getDataType(), lhs, rhs));
else if (BIT_OR == operator && SUPPORT_BIT_OR_XOR.contains(ctx.dialect()))
ctx.visit(function("bitor", getDataType(), lhs, rhs));
else if (BIT_OR == operator && FIREBIRD == family)
ctx.visit(function("bin_or", getDataType(), lhs, rhs));
else if (BIT_XOR == operator && EMULATE_BIT_XOR.contains(ctx.dialect()))
ctx.visit(DSL.bitAnd(
DSL.bitNot(DSL.bitAnd(lhsAsNumber(), rhsAsNumber())),
DSL.bitOr(lhsAsNumber(), rhsAsNumber())));
else if (operator == SHL || operator == SHR) {
if (family == H2)
ctx.visit(function(SHL == operator ? "lshift" : "rshift", getDataType(), lhs, rhs));
else if (FIREBIRD == family)
ctx.visit(function(SHL == operator ? "bin_shl" : "bin_shr", getDataType(), lhs, rhs));
else if (SHL == operator && EMULATE_SHR_SHL.contains(ctx.dialect()))
ctx.visit(imul(lhs, (Field<? extends Number>) castIfNeeded(DSL.power(two(), rhsAsNumber()), lhs)));
else if (SHR == operator && EMULATE_SHR_SHL.contains(ctx.dialect()))
ctx.visit(idiv(lhs, (Field<? extends Number>) castIfNeeded(DSL.power(two(), rhsAsNumber()), lhs)));
else
ctx.visit(new DefaultExpression<>(lhs, operator, rhs));
}
else if (BIT_NAND == operator)
ctx.visit(DSL.bitNot(DSL.bitAnd(lhsAsNumber(), rhsAsNumber())));
else if (BIT_NOR == operator)
ctx.visit(DSL.bitNot(DSL.bitOr(lhsAsNumber(), rhsAsNumber())));
else if (BIT_XNOR == operator)
ctx.visit(DSL.bitNot(DSL.bitXor(lhsAsNumber(), rhsAsNumber())));
else if ((ADD == operator || SUBTRACT == operator) &&
lhs.getDataType().isDateTime() &&
(rhs.getDataType().isNumeric() ||
rhs.getDataType().isInterval()))
ctx.visit(new DateExpression<>(lhs, operator, rhs));
else
ctx.visit(new DefaultExpression<>(lhs, operator, rhs));
}
@Override
@SuppressWarnings("null")
public final Field<?> transform(TransformUnneededArithmeticExpressions transform) {
return this;
}
@SuppressWarnings("unchecked")
private final Field<Number> lhsAsNumber() {
return (Field<Number>) lhs;
}
@SuppressWarnings("unchecked")
private final Field<Number> rhsAsNumber() {
return (Field<Number>) rhs;
}
private static final Pattern TRUNC_TO_MICROS = Pattern.compile("([^.]*\\.\\d{0,6})\\d{0,3}");
private static class DateExpression<T> extends AbstractField<T> {
private static final long serialVersionUID = 3160679741902222262L;
private final Field<T> lhs;
private final ExpressionOperator operator;
private final Field<?> rhs;
DateExpression(Field<T> lhs, ExpressionOperator operator, Field<?> rhs) {
super(DSL.name(operator.toSQL()), lhs.getDataType());
this.lhs = lhs;
this.operator = operator;
this.rhs = rhs;
}
private final <U> Field<U> p(U u) {
Param<U> result = val(u);
if (((Param<?>) rhs).isInline())
result.setInline(true);
return result;
}
@Override
public final void accept(Context<?> ctx) {
if (rhs.getType() == YearToSecond.class && !SUPPORT_YEAR_TO_SECOND.contains(ctx.dialect()))
acceptYTSExpression(ctx);
else if (rhs.getDataType().isInterval())
acceptIntervalExpression(ctx);
else
acceptNumberExpression(ctx);
}
private final void acceptYTSExpression(Context<?> ctx) {
if (rhs instanceof Param) {
YearToSecond yts = rhsAsYTS();
ctx.visit(new DateExpression<>(
new DateExpression<>(lhs, operator, p(yts.getYearToMonth())),
operator,
p(yts.getDayToSecond())
));
}
else {
acceptIntervalExpression(ctx);
}
}
@SuppressWarnings({ "unchecked", "rawtypes" })
private final void acceptIntervalExpression(Context<?> ctx) {
SQLDialect family = ctx.family();
int sign = (operator == ADD) ? 1 : -1;
switch (family) {
case CUBRID:
case MARIADB:
case MYSQL: {
Interval interval = rhsAsInterval();
if (operator == SUBTRACT)
interval = interval.neg();
if (rhs.getType() == YearToMonth.class)
ctx.visit(N_DATE_ADD).sql('(').visit(lhs).sql(", ").visit(K_INTERVAL).sql(' ')
.visit(Tools.field(interval, SQLDataType.VARCHAR)).sql(' ').visit(K_YEAR_MONTH).sql(')');
else if (family == CUBRID)
ctx.visit(N_DATE_ADD).sql('(').visit(lhs).sql(", ").visit(K_INTERVAL).sql(' ')
.visit(Tools.field(interval, SQLDataType.VARCHAR)).sql(' ').visit(K_DAY_MILLISECOND).sql(')');
else
ctx.visit(N_DATE_ADD).sql('(').visit(lhs).sql(", ").visit(K_INTERVAL).sql(' ')
.visit(Tools.field(TRUNC_TO_MICROS.matcher("" + interval).replaceAll("$1"), SQLDataType.VARCHAR)).sql(' ').visit(K_DAY_MICROSECOND).sql(')');
break;
}
case DERBY: {
boolean needsCast = getDataType().getType() != Timestamp.class;
if (needsCast)
ctx.visit(K_CAST).sql('(');
if (rhs.getType() == YearToMonth.class)
ctx.sql("{fn ").visit(N_TIMESTAMPADD).sql('(').visit(N_SQL_TSI_MONTH).sql(", ")
.visit(p(sign * rhsAsYTM().intValue())).sql(", ").visit(lhs).sql(") }");
else
ctx.sql("{fn ").visit(N_TIMESTAMPADD).sql('(').visit(N_SQL_TSI_SECOND).sql(", ")
.visit(p(sign * (long) rhsAsDTS().getTotalSeconds())).sql(", {fn ")
.visit(N_TIMESTAMPADD).sql('(').visit(ctx.family() == DERBY ? N_SQL_TSI_FRAC_SECOND : N_SQL_TSI_MILLI_SECOND).sql(", ")
.visit(p(sign * (long) rhsAsDTS().getMilli() * (ctx.family() == DERBY ? 1000000L : 1L))).sql(", ").visit(lhs).sql(") }) }");
if (needsCast)
ctx.sql(' ').visit(K_AS).sql(' ').visit(keyword(getDataType().getCastTypeName(ctx.configuration()))).sql(')');
break;
}
case FIREBIRD: {
if (rhs.getType() == YearToMonth.class)
ctx.visit(N_DATEADD).sql('(').visit(K_MONTH).sql(", ").visit(p(sign * rhsAsYTM().intValue())).sql(", ").visit(lhs).sql(')');
else if (rhsAsDTS().getMilli() > 0)
ctx.visit(N_DATEADD).sql('(').visit(K_MILLISECOND).sql(", ").visit(p(sign * (long) rhsAsDTS().getMilli())).sql(", ")
.visit(N_DATEADD).sql('(').visit(K_SECOND).sql(", ").visit(p(sign * (long) rhsAsDTS().getTotalSeconds())).sql(", ").visit(lhs).sql(')')
.sql(')');
else
ctx.visit(N_DATEADD).sql('(').visit(K_SECOND).sql(", ").visit(p(sign * (long) rhsAsDTS().getTotalSeconds())).sql(", ").visit(lhs).sql(')');
break;
}
case SQLITE: {
boolean ytm = rhs.getType() == YearToMonth.class;
Field<?> interval = p(ytm ? rhsAsYTM().intValue() : rhsAsDTS().getTotalSeconds());
if (sign < 0)
interval = interval.neg();
interval = interval.concat(inline(ytm ? " months" : " seconds"));
ctx.visit(N_STRFTIME).sql("('%Y-%m-%d %H:%M:%f', ").visit(lhs).sql(", ").visit(interval).sql(')');
break;
}
case H2:
case HSQLDB:
case POSTGRES:
default:
ctx.visit(new DefaultExpression<>(lhs, operator, rhs));
break;
}
}
@SuppressWarnings({ "unchecked", "rawtypes" })
private final void acceptNumberExpression(Context<?> ctx) {
switch (ctx.family()) {
case FIREBIRD: {
if (operator == ADD)
ctx.visit(N_DATEADD).sql('(').visit(K_DAY).sql(", ").visit(rhsAsNumber()).sql(", ").visit(lhs).sql(')');
else
ctx.visit(N_DATEADD).sql('(').visit(K_DAY).sql(", ").visit(rhsAsNumber().neg()).sql(", ").visit(lhs).sql(')');
break;
}
case HSQLDB: {
if (operator == ADD)
ctx.visit(lhs.add(DSL.field("({0}) day", rhsAsNumber())));
else
ctx.visit(lhs.sub(DSL.field("({0}) day", rhsAsNumber())));
break;
}
case DERBY: {
boolean needsCast = getDataType().getType() != Timestamp.class;
if (needsCast)
ctx.visit(K_CAST).sql('(');
if (operator == ADD)
ctx.sql("{fn ").visit(N_TIMESTAMPADD).sql('(').visit(keyword("sql_tsi_day")).sql(", ").visit(rhsAsNumber()).sql(", ").visit(lhs).sql(") }");
else
ctx.sql("{fn ").visit(N_TIMESTAMPADD).sql('(').visit(keyword("sql_tsi_day")).sql(", ").visit(rhsAsNumber().neg()).sql(", ").visit(lhs).sql(") }");
if (needsCast)
ctx.sql(' ').visit(K_AS).sql(' ').visit(keyword(getDataType().getCastTypeName(ctx.configuration()))).sql(')');
break;
}
case CUBRID:
case MARIADB:
case MYSQL: {
if (operator == ADD)
ctx.visit(N_DATE_ADD).sql('(').visit(lhs).sql(", ").visit(K_INTERVAL).sql(' ').visit(rhsAsNumber()).sql(' ').visit(K_DAY).sql(')');
else
ctx.visit(N_DATE_ADD).sql('(').visit(lhs).sql(", ").visit(K_INTERVAL).sql(' ').visit(rhsAsNumber().neg()).sql(' ').visit(K_DAY).sql(')');
break;
}
case POSTGRES: {
if (operator == ADD)
ctx.visit(new DateAdd(lhs, rhsAsNumber(), DatePart.DAY));
else
ctx.visit(new DateAdd(lhs, rhsAsNumber().neg(), DatePart.DAY));
break;
}
case SQLITE:
if (operator == ADD)
ctx.visit(N_STRFTIME).sql("('%Y-%m-%d %H:%M:%f', ").visit(lhs).sql(", ").visit(rhsAsNumber().concat(inline(" day"))).sql(')');
else
ctx.visit(N_STRFTIME).sql("('%Y-%m-%d %H:%M:%f', ").visit(lhs).sql(", ").visit(rhsAsNumber().neg().concat(inline(" day"))).sql(')');
break;
case H2:
default:
ctx.visit(new DefaultExpression<>(lhs, operator, rhs));
break;
}
}
@SuppressWarnings("unchecked")
private final YearToSecond rhsAsYTS() {
try {
return ((Param<YearToSecond>) rhs).getValue();
}
catch (ClassCastException e) {
throw new DataTypeException("Cannot perform datetime arithmetic with a non-numeric, non-interval data type on the right hand side of the expression: " + rhs, e);
}
}
@SuppressWarnings("unchecked")
private final YearToMonth rhsAsYTM() {
try {
return ((Param<YearToMonth>) rhs).getValue();
}
catch (ClassCastException e) {
throw new DataTypeException("Cannot perform datetime arithmetic with a non-numeric, non-interval data type on the right hand side of the expression: " + rhs, e);
}
}
@SuppressWarnings("unchecked")
private final DayToSecond rhsAsDTS() {
try {
return ((Param<DayToSecond>) rhs).getValue();
}
catch (ClassCastException e) {
throw new DataTypeException("Cannot perform datetime arithmetic with a non-numeric, non-interval data type on the right hand side of the expression: " + rhs, e);
}
}
@SuppressWarnings("unchecked")
private final Interval rhsAsInterval() {
try {
return ((Param<Interval>) rhs).getValue();
}
catch (ClassCastException e) {
throw new DataTypeException("Cannot perform datetime arithmetic with a non-numeric, non-interval data type on the right hand side of the expression: " + rhs, e);
}
}
@SuppressWarnings("unchecked")
private final Field<Number> rhsAsNumber() {
return (Field<Number>) rhs;
}
}
private static class DefaultExpression<T> extends AbstractField<T> {
private static final long serialVersionUID = -5105004317793995419L;
private final Field<T> lhs;
private final ExpressionOperator operator;
private final Field<?> rhs;
DefaultExpression(Field<T> lhs, ExpressionOperator operator, Field<?> rhs) {
super(operator.toName(), lhs.getDataType());
this.lhs = lhs;
this.operator = operator;
this.rhs = rhs;
}
@Override
public final void accept(Context<?> ctx) {
ctx.sql('(');
accept0(ctx, operator, lhs, rhs);
ctx.sql(')');
}
private static final void accept0(Context<?> ctx, ExpressionOperator operator, Field<?> lhs, Field<?> rhs) {
String op = operator.toSQL();
if (operator == BIT_XOR && HASH_OP_FOR_BIT_XOR.contains(ctx.dialect()))
op = "#";
boolean associativity = operator.associative() && lhs.getDataType().equals(rhs.getDataType());
accept1(ctx, operator, lhs, associativity);
ctx.sql(' ')
.sql(op)
.sql(' ');
accept1(ctx, operator, rhs, associativity);
}
private static final void accept1(Context<?> ctx, ExpressionOperator operator, Field<?> field, boolean associativity) {
if (associativity && field instanceof Expression) {
Expression<?> expr = (Expression<?>) field;
if (operator == expr.operator) {
accept0(ctx, expr.operator, expr.lhs, expr.rhs);
return;
}
}
ctx.visit(field);
}
}
}