package org.apache.cassandra.cql3.functions;
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.cassandra.cql3.CQL3Type;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.db.marshal.AsciiType;
import org.apache.cassandra.db.marshal.BooleanType;
import org.apache.cassandra.db.marshal.ByteType;
import org.apache.cassandra.db.marshal.CounterColumnType;
import org.apache.cassandra.db.marshal.DecimalType;
import org.apache.cassandra.db.marshal.DoubleType;
import org.apache.cassandra.db.marshal.FloatType;
import org.apache.cassandra.db.marshal.InetAddressType;
import org.apache.cassandra.db.marshal.Int32Type;
import org.apache.cassandra.db.marshal.IntegerType;
import org.apache.cassandra.db.marshal.LongType;
import org.apache.cassandra.db.marshal.ShortType;
import org.apache.cassandra.db.marshal.SimpleDateType;
import org.apache.cassandra.db.marshal.TimeType;
import org.apache.cassandra.db.marshal.TimeUUIDType;
import org.apache.cassandra.db.marshal.TimestampType;
import org.apache.cassandra.db.marshal.UTF8Type;
import org.apache.cassandra.db.marshal.UUIDType;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.commons.lang3.text.WordUtils;
public final class CastFcts
{
private static final String FUNCTION_NAME_PREFIX = "castAs";
public static Collection<Function> all()
{
List<Function> functions = new ArrayList<>();
@SuppressWarnings("unchecked")
final AbstractType<? extends Number>[] numericTypes = new AbstractType[] {ByteType.instance,
ShortType.instance,
Int32Type.instance,
LongType.instance,
FloatType.instance,
DoubleType.instance,
DecimalType.instance,
CounterColumnType.instance,
IntegerType.instance};
for (AbstractType<? extends Number> inputType : numericTypes)
{
addFunctionIfNeeded(functions, inputType, ByteType.instance, Number::byteValue);
addFunctionIfNeeded(functions, inputType, ShortType.instance, Number::shortValue);
addFunctionIfNeeded(functions, inputType, Int32Type.instance, Number::intValue);
addFunctionIfNeeded(functions, inputType, LongType.instance, Number::longValue);
addFunctionIfNeeded(functions, inputType, FloatType.instance, Number::floatValue);
addFunctionIfNeeded(functions, inputType, DoubleType.instance, Number::doubleValue);
addFunctionIfNeeded(functions, inputType, DecimalType.instance, getDecimalConversionFunction(inputType));
addFunctionIfNeeded(functions, inputType, IntegerType.instance, p -> BigInteger.valueOf(p.longValue()));
functions.add(CastAsTextFunction.create(inputType, AsciiType.instance));
functions.add(CastAsTextFunction.create(inputType, UTF8Type.instance));
}
functions.add(JavaFunctionWrapper.create(AsciiType.instance, UTF8Type.instance, p -> p));
functions.add(CastAsTextFunction.create(InetAddressType.instance, AsciiType.instance));
functions.add(CastAsTextFunction.create(InetAddressType.instance, UTF8Type.instance));
functions.add(CastAsTextFunction.create(BooleanType.instance, AsciiType.instance));
functions.add(CastAsTextFunction.create(BooleanType.instance, UTF8Type.instance));
functions.add(CassandraFunctionWrapper.create(TimeUUIDType.instance, SimpleDateType.instance, TimeFcts.timeUuidtoDate));
functions.add(CassandraFunctionWrapper.create(TimeUUIDType.instance, TimestampType.instance, TimeFcts.timeUuidToTimestamp));
functions.add(CastAsTextFunction.create(TimeUUIDType.instance, AsciiType.instance));
functions.add(CastAsTextFunction.create(TimeUUIDType.instance, UTF8Type.instance));
functions.add(CassandraFunctionWrapper.create(TimestampType.instance, SimpleDateType.instance, TimeFcts.timestampToDate));
functions.add(CastAsTextFunction.create(TimestampType.instance, AsciiType.instance));
functions.add(CastAsTextFunction.create(TimestampType.instance, UTF8Type.instance));
functions.add(CassandraFunctionWrapper.create(SimpleDateType.instance, TimestampType.instance, TimeFcts.dateToTimestamp));
functions.add(CastAsTextFunction.create(SimpleDateType.instance, AsciiType.instance));
functions.add(CastAsTextFunction.create(SimpleDateType.instance, UTF8Type.instance));
functions.add(CastAsTextFunction.create(TimeType.instance, AsciiType.instance));
functions.add(CastAsTextFunction.create(TimeType.instance, UTF8Type.instance));
functions.add(CastAsTextFunction.create(UUIDType.instance, AsciiType.instance));
functions.add(CastAsTextFunction.create(UUIDType.instance, UTF8Type.instance));
return functions;
}
private static <I extends Number> java.util.function.Function<I, BigDecimal> getDecimalConversionFunction(AbstractType<? extends Number> inputType)
{
if (inputType == FloatType.instance || inputType == DoubleType.instance)
return p -> BigDecimal.valueOf(p.doubleValue());
if (inputType == IntegerType.instance)
return p -> new BigDecimal((BigInteger) p);
return p -> BigDecimal.valueOf(p.longValue());
}
public static String getFunctionName(AbstractType<?> outputType)
{
return getFunctionName(outputType.asCQL3Type());
}
public static String getFunctionName(CQL3Type outputType)
{
return FUNCTION_NAME_PREFIX + WordUtils.capitalize(toLowerCaseString(outputType));
}
private static <I, O> void addFunctionIfNeeded(List<Function> functions,
AbstractType<I> inputType,
AbstractType<O> outputType,
java.util.function.Function<I, O> converter)
{
if (!inputType.equals(outputType))
functions.add(wrapJavaFunction(inputType, outputType, converter));
}
@SuppressWarnings("unchecked")
private static <O, I> Function wrapJavaFunction(AbstractType<I> inputType,
AbstractType<O> outputType,
java.util.function.Function<I, O> converter)
{
return inputType.equals(CounterColumnType.instance)
? JavaCounterFunctionWrapper.create(outputType, (java.util.function.Function<Long, O>) converter)
: JavaFunctionWrapper.create(inputType, outputType, converter);
}
private static String toLowerCaseString(CQL3Type type)
{
return type.toString().toLowerCase();
}
private static abstract class CastFunction<I, O> extends NativeScalarFunction
{
public CastFunction(AbstractType<I> inputType, AbstractType<O> outputType)
{
super(getFunctionName(outputType), outputType, inputType);
}
@Override
public String columnName(List<String> columnNames)
{
return String.format("cast(%s as %s)", columnNames.get(0), toLowerCaseString(outputType().asCQL3Type()));
}
@SuppressWarnings("unchecked")
protected AbstractType<O> outputType()
{
return (AbstractType<O>) returnType;
}
@SuppressWarnings("unchecked")
protected AbstractType<I> inputType()
{
return (AbstractType<I>) argTypes.get(0);
}
}
private static class JavaFunctionWrapper<I, O> extends CastFunction<I, O>
{
private final java.util.function.Function<I, O> converter;
public static <I, O> JavaFunctionWrapper<I, O> create(AbstractType<I> inputType,
AbstractType<O> outputType,
java.util.function.Function<I, O> converter)
{
return new JavaFunctionWrapper<I, O>(inputType, outputType, converter);
}
protected JavaFunctionWrapper(AbstractType<I> inputType,
AbstractType<O> outputType,
java.util.function.Function<I, O> converter)
{
super(inputType, outputType);
this.converter = converter;
}
public final ByteBuffer execute(ProtocolVersion protocolVersion, List<ByteBuffer> parameters)
{
ByteBuffer bb = parameters.get(0);
if (bb == null)
return null;
return outputType().decompose(converter.apply(compose(bb)));
}
protected I compose(ByteBuffer bb)
{
return inputType().compose(bb);
}
}
private static class JavaCounterFunctionWrapper<O> extends JavaFunctionWrapper<Long, O>
{
public static <O> JavaFunctionWrapper<Long, O> create(AbstractType<O> outputType,
java.util.function.Function<Long, O> converter)
{
return new JavaCounterFunctionWrapper<O>(outputType, converter);
}
protected JavaCounterFunctionWrapper(AbstractType<O> outputType,
java.util.function.Function<Long, O> converter)
{
super(CounterColumnType.instance, outputType, converter);
}
protected Long compose(ByteBuffer bb)
{
return LongType.instance.compose(bb);
}
}
private static final class CassandraFunctionWrapper<I, O> extends CastFunction<I, O>
{
private final NativeScalarFunction delegate;
public static <I, O> CassandraFunctionWrapper<I, O> create(AbstractType<I> inputType,
AbstractType<O> outputType,
NativeScalarFunction delegate)
{
return new CassandraFunctionWrapper<I, O>(inputType, outputType, delegate);
}
private CassandraFunctionWrapper(AbstractType<I> inputType,
AbstractType<O> outputType,
NativeScalarFunction delegate)
{
super(inputType, outputType);
assert delegate.argTypes().size() == 1 && inputType.equals(delegate.argTypes().get(0));
assert outputType.equals(delegate.returnType());
this.delegate = delegate;
}
public ByteBuffer execute(ProtocolVersion protocolVersion, List<ByteBuffer> parameters)
{
return delegate.execute(protocolVersion, parameters);
}
}
private static final class CastAsTextFunction<I> extends CastFunction<I, String>
{
public static <I> CastAsTextFunction<I> create(AbstractType<I> inputType,
AbstractType<String> outputType)
{
return new CastAsTextFunction<I>(inputType, outputType);
}
private CastAsTextFunction(AbstractType<I> inputType,
AbstractType<String> outputType)
{
super(inputType, outputType);
}
public ByteBuffer execute(ProtocolVersion protocolVersion, List<ByteBuffer> parameters)
{
ByteBuffer bb = parameters.get(0);
if (bb == null)
return null;
return outputType().decompose(inputType().getSerializer().toCQLLiteral(bb));
}
}
private CastFcts()
{
}
}