package org.apache.cassandra.cql3.functions;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.net.InetAddress;
import java.net.URL;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import com.google.common.base.Objects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.datastax.driver.core.DataType;
import com.datastax.driver.core.TypeCodec;
import com.datastax.driver.core.UserType;
import org.apache.cassandra.config.Config;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.config.Schema;
import org.apache.cassandra.cql3.ColumnIdentifier;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.exceptions.FunctionExecutionException;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.schema.Functions;
import org.apache.cassandra.schema.KeyspaceMetadata;
import org.apache.cassandra.service.ClientWarn;
import org.apache.cassandra.service.MigrationManager;
import org.apache.cassandra.tracing.Tracing;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.utils.ByteBufferUtil;
import org.apache.cassandra.utils.JVMStabilityInspector;
public abstract class UDFunction extends AbstractFunction implements ScalarFunction
{
protected static final Logger logger = LoggerFactory.getLogger(UDFunction.class);
static final ThreadMXBean threadMXBean = ManagementFactory.getThreadMXBean();
protected final List<ColumnIdentifier> argNames;
protected final String language;
protected final String body;
protected final TypeCodec<Object>[] argCodecs;
protected final TypeCodec<Object> returnCodec;
protected final boolean calledOnNullInput;
protected final UDFContext udfContext;
private static final String[] whitelistedPatterns =
{
"com/datastax/driver/core/",
"com/google/common/reflect/TypeToken",
"java/io/IOException.class",
"java/io/Serializable.class",
"java/lang/",
"java/math/",
"java/net/InetAddress.class",
"java/net/Inet4Address.class",
"java/net/Inet6Address.class",
"java/net/UnknownHostException.class",
"java/net/NetworkInterface.class",
"java/net/SocketException.class",
"java/nio/Buffer.class",
"java/nio/ByteBuffer.class",
"java/text/",
"java/time/",
"java/util/",
"org/apache/cassandra/cql3/functions/JavaUDF.class",
"org/apache/cassandra/cql3/functions/UDFContext.class",
"org/apache/cassandra/exceptions/",
"org/apache/cassandra/transport/ProtocolVersion.class"
};
private static final String[] blacklistedPatterns =
{
"com/datastax/driver/core/Cluster.class",
"com/datastax/driver/core/Metrics.class",
"com/datastax/driver/core/NettyOptions.class",
"com/datastax/driver/core/Session.class",
"com/datastax/driver/core/Statement.class",
"com/datastax/driver/core/TimestampGenerator.class",
"java/lang/Compiler.class",
"java/lang/InheritableThreadLocal.class",
"java/lang/Package.class",
"java/lang/Process.class",
"java/lang/ProcessBuilder.class",
"java/lang/ProcessEnvironment.class",
"java/lang/ProcessImpl.class",
"java/lang/Runnable.class",
"java/lang/Runtime.class",
"java/lang/Shutdown.class",
"java/lang/Thread.class",
"java/lang/ThreadGroup.class",
"java/lang/ThreadLocal.class",
"java/lang/instrument/",
"java/lang/invoke/",
"java/lang/management/",
"java/lang/ref/",
"java/lang/reflect/",
"java/util/ServiceLoader.class",
"java/util/Timer.class",
"java/util/concurrent/",
"java/util/function/",
"java/util/jar/",
"java/util/logging/",
"java/util/prefs/",
"java/util/spi/",
"java/util/stream/",
"java/util/zip/",
};
static boolean secureResource(String resource)
{
while (resource.startsWith("/"))
resource = resource.substring(1);
for (String white : whitelistedPatterns)
if (resource.startsWith(white))
{
for (String black : blacklistedPatterns)
if (resource.startsWith(black))
{
logger.trace("access denied: resource {}", resource);
return false;
}
return true;
}
logger.trace("access denied: resource {}", resource);
return false;
}
static final ClassLoader udfClassLoader = new UDFClassLoader();
protected UDFunction(FunctionName name,
List<ColumnIdentifier> argNames,
List<AbstractType<?>> argTypes,
AbstractType<?> returnType,
boolean calledOnNullInput,
String language,
String body)
{
this(name, argNames, argTypes, UDHelper.driverTypes(argTypes), returnType,
UDHelper.driverType(returnType), calledOnNullInput, language, body);
}
protected UDFunction(FunctionName name,
List<ColumnIdentifier> argNames,
List<AbstractType<?>> argTypes,
DataType[] argDataTypes,
AbstractType<?> returnType,
DataType returnDataType,
boolean calledOnNullInput,
String language,
String body)
{
super(name, argTypes, returnType);
assert new HashSet<>(argNames).size() == argNames.size() : "duplicate argument names";
this.argNames = argNames;
this.language = language;
this.body = body;
this.argCodecs = UDHelper.codecsFor(argDataTypes);
this.returnCodec = UDHelper.codecFor(returnDataType);
this.calledOnNullInput = calledOnNullInput;
KeyspaceMetadata keyspaceMetadata = Schema.instance.getKSMetaData(name.keyspace);
this.udfContext = new UDFContextImpl(argNames, argCodecs, returnCodec,
keyspaceMetadata);
}
public static UDFunction create(FunctionName name,
List<ColumnIdentifier> argNames,
List<AbstractType<?>> argTypes,
AbstractType<?> returnType,
boolean calledOnNullInput,
String language,
String body)
{
UDFunction.assertUdfsEnabled(language);
switch (language)
{
case "java":
return new JavaBasedUDFunction(name, argNames, argTypes, returnType, calledOnNullInput, body);
default:
return new ScriptBasedUDFunction(name, argNames, argTypes, returnType, calledOnNullInput, language, body);
}
}
public static UDFunction createBrokenFunction(FunctionName name,
List<ColumnIdentifier> argNames,
List<AbstractType<?>> argTypes,
AbstractType<?> returnType,
boolean calledOnNullInput,
String language,
String body,
InvalidRequestException reason)
{
return new UDFunction(name, argNames, argTypes, returnType, calledOnNullInput, language, body)
{
protected ExecutorService executor()
{
return Executors.newSingleThreadExecutor();
}
protected Object executeAggregateUserDefined(ProtocolVersion protocolVersion, Object firstParam, List<ByteBuffer> parameters)
{
throw broken();
}
public ByteBuffer executeUserDefined(ProtocolVersion protocolVersion, List<ByteBuffer> parameters)
{
throw broken();
}
private InvalidRequestException broken()
{
return new InvalidRequestException(String.format("Function '%s' exists but hasn't been loaded successfully "
+ "for the following reason: %s. Please see the server log for details",
this,
reason.getMessage()));
}
};
}
public final ByteBuffer execute(ProtocolVersion protocolVersion, List<ByteBuffer> parameters)
{
assertUdfsEnabled(language);
if (!isCallableWrtNullable(parameters))
return null;
long tStart = System.nanoTime();
parameters = makeEmptyParametersNull(parameters);
try
{
ByteBuffer result = DatabaseDescriptor.enableUserDefinedFunctionsThreads()
? executeAsync(protocolVersion, parameters)
: executeUserDefined(protocolVersion, parameters);
Tracing.trace("Executed UDF {} in {}\u03bcs", name(), (System.nanoTime() - tStart) / 1000);
return result;
}
catch (InvalidRequestException e)
{
throw e;
}
catch (Throwable t)
{
logger.trace("Invocation of user-defined function '{}' failed", this, t);
if (t instanceof VirtualMachineError)
throw (VirtualMachineError) t;
throw FunctionExecutionException.create(this, t);
}
}
public final Object executeForAggregate(ProtocolVersion protocolVersion, Object firstParam, List<ByteBuffer> parameters)
{
assertUdfsEnabled(language);
if (!calledOnNullInput && firstParam == null || !isCallableWrtNullable(parameters))
return null;
long tStart = System.nanoTime();
parameters = makeEmptyParametersNull(parameters);
try
{
Object result = DatabaseDescriptor.enableUserDefinedFunctionsThreads()
? executeAggregateAsync(protocolVersion, firstParam, parameters)
: executeAggregateUserDefined(protocolVersion, firstParam, parameters);
Tracing.trace("Executed UDF {} in {}\u03bcs", name(), (System.nanoTime() - tStart) / 1000);
return result;
}
catch (InvalidRequestException e)
{
throw e;
}
catch (Throwable t)
{
logger.debug("Invocation of user-defined function '{}' failed", this, t);
if (t instanceof VirtualMachineError)
throw (VirtualMachineError) t;
throw FunctionExecutionException.create(this, t);
}
}
public static void assertUdfsEnabled(String language)
{
if (!DatabaseDescriptor.enableUserDefinedFunctions())
throw new InvalidRequestException("User-defined functions are disabled in cassandra.yaml - set enable_user_defined_functions=true to enable");
if (!"java".equalsIgnoreCase(language) && !DatabaseDescriptor.enableScriptedUserDefinedFunctions())
throw new InvalidRequestException("Scripted user-defined functions are disabled in cassandra.yaml - set enable_scripted_user_defined_functions=true to enable if you are aware of the security risks");
}
static void initializeThread()
{
TypeCodec.inet().format(InetAddress.getLoopbackAddress());
TypeCodec.ascii().format("");
}
private static final class ThreadIdAndCpuTime extends CompletableFuture<Object>
{
long threadId;
long cpuTime;
ThreadIdAndCpuTime()
{
threadMXBean.getCurrentThreadCpuTime();
}
void setup()
{
this.threadId = Thread.currentThread().getId();
this.cpuTime = threadMXBean.getCurrentThreadCpuTime();
complete(null);
}
}
private ByteBuffer executeAsync(ProtocolVersion protocolVersion, List<ByteBuffer> parameters)
{
ThreadIdAndCpuTime threadIdAndCpuTime = new ThreadIdAndCpuTime();
return async(threadIdAndCpuTime, () -> {
threadIdAndCpuTime.setup();
return executeUserDefined(protocolVersion, parameters);
});
}
private Object executeAggregateAsync(ProtocolVersion protocolVersion, Object firstParam, List<ByteBuffer> parameters)
{
ThreadIdAndCpuTime threadIdAndCpuTime = new ThreadIdAndCpuTime();
return async(threadIdAndCpuTime, () -> {
threadIdAndCpuTime.setup();
return executeAggregateUserDefined(protocolVersion, firstParam, parameters);
});
}
private <T> T async(ThreadIdAndCpuTime threadIdAndCpuTime, Callable<T> callable)
{
Future<T> future = executor().submit(callable);
try
{
if (DatabaseDescriptor.getUserDefinedFunctionWarnTimeout() > 0)
try
{
return future.get(DatabaseDescriptor.getUserDefinedFunctionWarnTimeout(), TimeUnit.MILLISECONDS);
}
catch (TimeoutException e)
{
String warn = String.format("User defined function %s ran longer than %dms", this, DatabaseDescriptor.getUserDefinedFunctionWarnTimeout());
logger.warn(warn);
ClientWarn.instance.warn(warn);
}
return future.get(DatabaseDescriptor.getUserDefinedFunctionFailTimeout() - DatabaseDescriptor.getUserDefinedFunctionWarnTimeout(), TimeUnit.MILLISECONDS);
}
catch (InterruptedException e)
{
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
catch (ExecutionException e)
{
Throwable c = e.getCause();
if (c instanceof RuntimeException)
throw (RuntimeException) c;
throw new RuntimeException(c);
}
catch (TimeoutException e)
{
try
{
threadIdAndCpuTime.get(1, TimeUnit.SECONDS);
long cpuTimeMillis = threadMXBean.getThreadCpuTime(threadIdAndCpuTime.threadId) - threadIdAndCpuTime.cpuTime;
cpuTimeMillis /= 1000000L;
return future.get(Math.max(DatabaseDescriptor.getUserDefinedFunctionFailTimeout() - cpuTimeMillis, 0L),
TimeUnit.MILLISECONDS);
}
catch (InterruptedException e1)
{
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
catch (ExecutionException e1)
{
Throwable c = e.getCause();
if (c instanceof RuntimeException)
throw (RuntimeException) c;
throw new RuntimeException(c);
}
catch (TimeoutException e1)
{
TimeoutException cause = new TimeoutException(String.format("User defined function %s ran longer than %dms%s",
this,
DatabaseDescriptor.getUserDefinedFunctionFailTimeout(),
DatabaseDescriptor.getUserFunctionTimeoutPolicy() == Config.UserFunctionTimeoutPolicy.ignore
? "" : " - will stop Cassandra VM"));
FunctionExecutionException fe = FunctionExecutionException.create(this, cause);
JVMStabilityInspector.userFunctionTimeout(cause);
throw fe;
}
}
}
private List<ByteBuffer> makeEmptyParametersNull(List<ByteBuffer> parameters)
{
List<ByteBuffer> r = new ArrayList<>(parameters.size());
for (int i = 0; i < parameters.size(); i++)
{
ByteBuffer param = parameters.get(i);
r.add(UDHelper.isNullOrEmpty(argTypes.get(i), param)
? null : param);
}
return r;
}
protected abstract ExecutorService executor();
public boolean isCallableWrtNullable(List<ByteBuffer> parameters)
{
if (!calledOnNullInput)
for (int i = 0; i < parameters.size(); i++)
if (UDHelper.isNullOrEmpty(argTypes.get(i), parameters.get(i)))
return false;
return true;
}
protected abstract ByteBuffer executeUserDefined(ProtocolVersion protocolVersion, List<ByteBuffer> parameters);
protected abstract Object executeAggregateUserDefined(ProtocolVersion protocolVersion, Object firstParam, List<ByteBuffer> parameters);
public boolean isAggregate()
{
return false;
}
public boolean isNative()
{
return false;
}
public boolean isCalledOnNullInput()
{
return calledOnNullInput;
}
public List<ColumnIdentifier> argNames()
{
return argNames;
}
public String body()
{
return body;
}
public String language()
{
return language;
}
protected Object compose(ProtocolVersion protocolVersion, int argIndex, ByteBuffer value)
{
return compose(argCodecs, protocolVersion, argIndex, value);
}
protected static Object compose(TypeCodec<Object>[] codecs, ProtocolVersion protocolVersion, int argIndex, ByteBuffer value)
{
return value == null ? null : UDHelper.deserialize(codecs[argIndex], protocolVersion, value);
}
protected ByteBuffer decompose(ProtocolVersion protocolVersion, Object value)
{
return decompose(returnCodec, protocolVersion, value);
}
protected static ByteBuffer decompose(TypeCodec<Object> codec, ProtocolVersion protocolVersion, Object value)
{
return value == null ? null : UDHelper.serialize(codec, protocolVersion, value);
}
@Override
public boolean equals(Object o)
{
if (!(o instanceof UDFunction))
return false;
UDFunction that = (UDFunction)o;
return Objects.equal(name, that.name)
&& Objects.equal(argNames, that.argNames)
&& Functions.typesMatch(argTypes, that.argTypes)
&& Functions.typesMatch(returnType, that.returnType)
&& Objects.equal(language, that.language)
&& Objects.equal(body, that.body);
}
@Override
public int hashCode()
{
return Objects.hashCode(name, Functions.typeHashCode(argTypes), Functions.typeHashCode(returnType), returnType, language, body);
}
public void userTypeUpdated(String ksName, String typeName)
{
boolean updated = false;
for (int i = 0; i < argCodecs.length; i++)
{
DataType dataType = argCodecs[i].getCqlType();
if (dataType instanceof UserType)
{
UserType userType = (UserType) dataType;
if (userType.getKeyspace().equals(ksName) && userType.getTypeName().equals(typeName))
{
KeyspaceMetadata ksm = Schema.instance.getKSMetaData(ksName);
assert ksm != null;
org.apache.cassandra.db.marshal.UserType ut = ksm.types.get(ByteBufferUtil.bytes(typeName)).get();
DataType newUserType = UDHelper.driverType(ut);
argCodecs[i] = UDHelper.codecFor(newUserType);
argTypes.set(i, ut);
updated = true;
}
}
}
if (updated)
MigrationManager.announceNewFunction(this, true);
}
private static class UDFClassLoader extends ClassLoader
{
static final ClassLoader insecureClassLoader = Thread.currentThread().getContextClassLoader();
public URL getResource(String name)
{
if (!secureResource(name))
return null;
return insecureClassLoader.getResource(name);
}
protected URL findResource(String name)
{
return getResource(name);
}
public Enumeration<URL> getResources(String name)
{
return Collections.emptyEnumeration();
}
protected Class<?> findClass(String name) throws ClassNotFoundException
{
if (!secureResource(name.replace('.', '/') + ".class"))
throw new ClassNotFoundException(name);
return insecureClassLoader.loadClass(name);
}
public Class<?> loadClass(String name) throws ClassNotFoundException
{
if (!secureResource(name.replace('.', '/') + ".class"))
throw new ClassNotFoundException(name);
return super.loadClass(name);
}
}
}