/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.cassandra.cql3.functions;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.*;
import java.nio.ByteBuffer;
import java.security.*;
import java.security.cert.Certificate;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Pattern;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.ByteStreams;
import com.google.common.reflect.TypeToken;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.datastax.driver.core.TypeCodec;
import org.apache.cassandra.concurrent.NamedThreadFactory;
import org.apache.cassandra.cql3.ColumnIdentifier;
import org.apache.cassandra.db.marshal.AbstractType;
import org.apache.cassandra.exceptions.InvalidRequestException;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.utils.FBUtilities;
import org.apache.cassandra.security.SecurityThreadGroup;
import org.apache.cassandra.security.ThreadAwareSecurityManager;
import org.eclipse.jdt.core.compiler.IProblem;
import org.eclipse.jdt.internal.compiler.*;
import org.eclipse.jdt.internal.compiler.Compiler;
import org.eclipse.jdt.internal.compiler.classfmt.ClassFileReader;
import org.eclipse.jdt.internal.compiler.classfmt.ClassFormatException;
import org.eclipse.jdt.internal.compiler.env.ICompilationUnit;
import org.eclipse.jdt.internal.compiler.env.INameEnvironment;
import org.eclipse.jdt.internal.compiler.env.NameEnvironmentAnswer;
import org.eclipse.jdt.internal.compiler.impl.CompilerOptions;
import org.eclipse.jdt.internal.compiler.problem.DefaultProblemFactory;

public final class JavaBasedUDFunction extends UDFunction
{
    private static final String BASE_PACKAGE = "org.apache.cassandra.cql3.udf.gen";

    private static final Pattern JAVA_LANG_PREFIX = Pattern.compile("\\bjava\\.lang\\.");

    static final Logger logger = LoggerFactory.getLogger(JavaBasedUDFunction.class);

    private static final AtomicInteger classSequence = new AtomicInteger();

    // use a JVM standard ExecutorService as DebuggableThreadPoolExecutor references internal
    // classes, which triggers AccessControlException from the UDF sandbox
    private static final UDFExecutorService executor =
        new UDFExecutorService(new NamedThreadFactory("UserDefinedFunctions",
                                                      Thread.MIN_PRIORITY,
                                                      udfClassLoader,
                                                      new SecurityThreadGroup("UserDefinedFunctions", null, UDFunction::initializeThread)),
                               "userfunction");

    private static final EcjTargetClassLoader targetClassLoader = new EcjTargetClassLoader();

    private static final UDFByteCodeVerifier udfByteCodeVerifier = new UDFByteCodeVerifier();

    private static final ProtectionDomain protectionDomain;

    private static final IErrorHandlingPolicy errorHandlingPolicy = DefaultErrorHandlingPolicies.proceedWithAllProblems();
    private static final IProblemFactory problemFactory = new DefaultProblemFactory(Locale.ENGLISH);
    private static final CompilerOptions compilerOptions;

    
Poor man's template - just a text file splitted at '#' chars. Each string at an even index is a constant string (just copied), each string at an odd index is an 'instruction'.
/** * Poor man's template - just a text file splitted at '#' chars. * Each string at an even index is a constant string (just copied), * each string at an odd index is an 'instruction'. */
private static final String[] javaSourceTemplate; static { udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "forName"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "getClassLoader"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "getResource"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/Class", "getResourceAsStream"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "clearAssertionStatus"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getResource"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getResourceAsStream"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getResources"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemClassLoader"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemResource"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemResourceAsStream"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "getSystemResources"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "loadClass"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "setClassAssertionStatus"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "setDefaultAssertionStatus"); udfByteCodeVerifier.addDisallowedMethodCall("java/lang/ClassLoader", "setPackageAssertionStatus"); udfByteCodeVerifier.addDisallowedMethodCall("java/nio/ByteBuffer", "allocateDirect"); for (String ia : new String[]{"java/net/InetAddress", "java/net/Inet4Address", "java/net/Inet6Address"}) { // static method, probably performing DNS lookups (despite SecurityManager) udfByteCodeVerifier.addDisallowedMethodCall(ia, "getByAddress"); udfByteCodeVerifier.addDisallowedMethodCall(ia, "getAllByName"); udfByteCodeVerifier.addDisallowedMethodCall(ia, "getByName"); udfByteCodeVerifier.addDisallowedMethodCall(ia, "getLocalHost"); // instance methods, probably performing DNS lookups (despite SecurityManager) udfByteCodeVerifier.addDisallowedMethodCall(ia, "getHostName"); udfByteCodeVerifier.addDisallowedMethodCall(ia, "getCanonicalHostName"); // ICMP PING udfByteCodeVerifier.addDisallowedMethodCall(ia, "isReachable"); } udfByteCodeVerifier.addDisallowedClass("java/net/NetworkInterface"); udfByteCodeVerifier.addDisallowedClass("java/net/SocketException"); Map<String, String> settings = new HashMap<>(); settings.put(CompilerOptions.OPTION_LineNumberAttribute, CompilerOptions.GENERATE); settings.put(CompilerOptions.OPTION_SourceFileAttribute, CompilerOptions.DISABLED); settings.put(CompilerOptions.OPTION_ReportDeprecation, CompilerOptions.IGNORE); settings.put(CompilerOptions.OPTION_Source, CompilerOptions.VERSION_1_8); settings.put(CompilerOptions.OPTION_TargetPlatform, CompilerOptions.VERSION_1_8); compilerOptions = new CompilerOptions(settings); compilerOptions.parseLiteralExpressionsAsConstants = true; try (InputStream input = JavaBasedUDFunction.class.getResource("JavaSourceUDF.txt").openConnection().getInputStream()) { ByteArrayOutputStream output = new ByteArrayOutputStream(); FBUtilities.copy(input, output, Long.MAX_VALUE); String template = output.toString(); StringTokenizer st = new StringTokenizer(template, "#"); javaSourceTemplate = new String[st.countTokens()]; for (int i = 0; st.hasMoreElements(); i++) javaSourceTemplate[i] = st.nextToken(); } catch (IOException e) { throw new RuntimeException(e); } CodeSource codeSource; try { codeSource = new CodeSource(new URL("udf", "localhost", 0, "/java", new URLStreamHandler() { protected URLConnection openConnection(URL u) { return null; } }), (Certificate[])null); } catch (MalformedURLException e) { throw new RuntimeException(e); } protectionDomain = new ProtectionDomain(codeSource, ThreadAwareSecurityManager.noPermissions, targetClassLoader, null); } private final JavaUDF javaUDF; JavaBasedUDFunction(FunctionName name, List<ColumnIdentifier> argNames, List<AbstractType<?>> argTypes, AbstractType<?> returnType, boolean calledOnNullInput, String body) { super(name, argNames, argTypes, UDHelper.driverTypes(argTypes), returnType, UDHelper.driverType(returnType), calledOnNullInput, "java", body); // javaParamTypes is just the Java representation for argTypes resp. argDataTypes TypeToken<?>[] javaParamTypes = UDHelper.typeTokens(argCodecs, calledOnNullInput); // javaReturnType is just the Java representation for returnType resp. returnTypeCodec TypeToken<?> javaReturnType = returnCodec.getJavaType(); // put each UDF in a separate package to prevent cross-UDF code access String pkgName = BASE_PACKAGE + '.' + generateClassName(name, 'p'); String clsName = generateClassName(name, 'C'); String executeInternalName = generateClassName(name, 'x'); StringBuilder javaSourceBuilder = new StringBuilder(); int lineOffset = 1; for (int i = 0; i < javaSourceTemplate.length; i++) { String s = javaSourceTemplate[i]; // strings at odd indexes are 'instructions' if ((i & 1) == 1) { switch (s) { case "package_name": s = pkgName; break; case "class_name": s = clsName; break; case "body": lineOffset = countNewlines(javaSourceBuilder); s = body; break; case "arguments": s = generateArguments(javaParamTypes, argNames, false); break; case "arguments_aggregate": s = generateArguments(javaParamTypes, argNames, true); break; case "argument_list": s = generateArgumentList(javaParamTypes, argNames); break; case "return_type": s = javaSourceName(javaReturnType); break; case "execute_internal_name": s = executeInternalName; break; } } javaSourceBuilder.append(s); } String targetClassName = pkgName + '.' + clsName; String javaSource = javaSourceBuilder.toString(); logger.trace("Compiling Java source UDF '{}' as class '{}' using source:\n{}", name, targetClassName, javaSource); try { EcjCompilationUnit compilationUnit = new EcjCompilationUnit(javaSource, targetClassName); Compiler compiler = new Compiler(compilationUnit, errorHandlingPolicy, compilerOptions, compilationUnit, problemFactory); compiler.compile(new ICompilationUnit[]{ compilationUnit }); if (compilationUnit.problemList != null && !compilationUnit.problemList.isEmpty()) { boolean fullSource = false; StringBuilder problems = new StringBuilder(); for (IProblem problem : compilationUnit.problemList) { long ln = problem.getSourceLineNumber() - lineOffset; if (ln < 1L) { if (problem.isError()) { // if generated source around UDF source provided by the user is buggy, // this code is appended. problems.append("GENERATED SOURCE ERROR: line ") .append(problem.getSourceLineNumber()) .append(" (in generated source): ") .append(problem.getMessage()) .append('\n'); fullSource = true; } } else { problems.append("Line ") .append(Long.toString(ln)) .append(": ") .append(problem.getMessage()) .append('\n'); } } if (fullSource) throw new InvalidRequestException("Java source compilation failed:\n" + problems + "\n generated source:\n" + javaSource); else throw new InvalidRequestException("Java source compilation failed:\n" + problems); } // Verify the UDF bytecode against use of probably dangerous code Set<String> errors = udfByteCodeVerifier.verify(targetClassName, targetClassLoader.classData(targetClassName)); String validDeclare = "not allowed method declared: " + executeInternalName + '('; for (Iterator<String> i = errors.iterator(); i.hasNext();) { String error = i.next(); // we generate a random name of the private, internal execute method, which is detected by the byte-code verifier if (error.startsWith(validDeclare)) i.remove(); } if (!errors.isEmpty()) throw new InvalidRequestException("Java UDF validation failed: " + errors); // Load the class and create a new instance of it Thread thread = Thread.currentThread(); ClassLoader orig = thread.getContextClassLoader(); try { thread.setContextClassLoader(UDFunction.udfClassLoader); // Execute UDF intiialization from UDF class loader Class cls = Class.forName(targetClassName, false, targetClassLoader); // Count only non-synthetic methods, so code coverage instrumentation doesn't cause a miscount int nonSyntheticMethodCount = 0; for (Method m : cls.getDeclaredMethods()) { if (!m.isSynthetic()) { nonSyntheticMethodCount += 1; } } if (nonSyntheticMethodCount != 3 || cls.getDeclaredConstructors().length != 1) throw new InvalidRequestException("Check your source to not define additional Java methods or constructors"); MethodType methodType = MethodType.methodType(void.class) .appendParameterTypes(TypeCodec.class, TypeCodec[].class, UDFContext.class); MethodHandle ctor = MethodHandles.lookup().findConstructor(cls, methodType); this.javaUDF = (JavaUDF) ctor.invokeWithArguments(returnCodec, argCodecs, udfContext); } finally { thread.setContextClassLoader(orig); } } catch (InvocationTargetException e) { // in case of an ITE, use the cause throw new InvalidRequestException(String.format("Could not compile function '%s' from Java source: %s", name, e.getCause())); } catch (InvalidRequestException | VirtualMachineError e) { throw e; } catch (Throwable e) { logger.error(String.format("Could not compile function '%s' from Java source:%n%s", name, javaSource), e); throw new InvalidRequestException(String.format("Could not compile function '%s' from Java source: %s", name, e)); } } protected ExecutorService executor() { return executor; } protected ByteBuffer executeUserDefined(ProtocolVersion protocolVersion, List<ByteBuffer> params) { return javaUDF.executeImpl(protocolVersion, params); } protected Object executeAggregateUserDefined(ProtocolVersion protocolVersion, Object firstParam, List<ByteBuffer> params) { return javaUDF.executeAggregateImpl(protocolVersion, firstParam, params); } private static int countNewlines(StringBuilder javaSource) { int ln = 0; for (int i = 0; i < javaSource.length(); i++) if (javaSource.charAt(i) == '\n') ln++; return ln; } private static String generateClassName(FunctionName name, char prefix) { String qualifiedName = name.toString(); StringBuilder sb = new StringBuilder(qualifiedName.length() + 10); sb.append(prefix); for (int i = 0; i < qualifiedName.length(); i++) { char c = qualifiedName.charAt(i); if (Character.isJavaIdentifierPart(c)) sb.append(c); else sb.append(Integer.toHexString(((short)c)&0xffff)); } sb.append('_') .append(ThreadLocalRandom.current().nextInt() & 0xffffff) .append('_') .append(classSequence.incrementAndGet()); return sb.toString(); } @VisibleForTesting public static String javaSourceName(TypeToken<?> type) { String n = type.toString(); return JAVA_LANG_PREFIX.matcher(n).replaceAll(""); } private static String generateArgumentList(TypeToken<?>[] paramTypes, List<ColumnIdentifier> argNames) { // initial builder size can just be a guess (prevent temp object allocations) StringBuilder code = new StringBuilder(32 * paramTypes.length); for (int i = 0; i < paramTypes.length; i++) { if (i > 0) code.append(", "); code.append(javaSourceName(paramTypes[i])) .append(' ') .append(argNames.get(i)); } return code.toString(); }
Generate Java source code snippet for the arguments part to call the UDF implementation function - i.e. the private #return_type# #execute_internal_name#(#argument_list#) function (see JavaSourceUDF.txt template file for details).

This method generates the arguments code snippet for both executeImpl and executeAggregateImpl. General signature for both is the protocolVersion and then all UDF arguments. For aggregation UDF calls the first argument is always unserialized as that is the state variable.

An example output for executeImpl: (double) super.compose_double(protocolVersion, 0, params.get(0)), (double) super.compose_double(protocolVersion, 1, params.get(1))

Similar output for executeAggregateImpl: firstParam, (double) super.compose_double(protocolVersion, 1, params.get(1))

/** * Generate Java source code snippet for the arguments part to call the UDF implementation function - * i.e. the {@code private #return_type# #execute_internal_name#(#argument_list#)} function * (see {@code JavaSourceUDF.txt} template file for details). * <p> * This method generates the arguments code snippet for both {@code executeImpl} and * {@code executeAggregateImpl}. General signature for both is the {@code protocolVersion} and * then all UDF arguments. For aggregation UDF calls the first argument is always unserialized as * that is the state variable. * </p> * <p> * An example output for {@code executeImpl}: * {@code (double) super.compose_double(protocolVersion, 0, params.get(0)), (double) super.compose_double(protocolVersion, 1, params.get(1))} * </p> * <p> * Similar output for {@code executeAggregateImpl}: * {@code firstParam, (double) super.compose_double(protocolVersion, 1, params.get(1))} * </p> */
private static String generateArguments(TypeToken<?>[] paramTypes, List<ColumnIdentifier> argNames, boolean forAggregate) { StringBuilder code = new StringBuilder(64 * paramTypes.length); for (int i = 0; i < paramTypes.length; i++) { if (i > 0) // add separator, if not the first argument code.append(",\n"); // add comment only if trace is enabled if (logger.isTraceEnabled()) code.append(" /* parameter '").append(argNames.get(i)).append("' */\n"); // cast to Java type code.append(" (").append(javaSourceName(paramTypes[i])).append(") "); if (forAggregate && i == 0) // special case for aggregations where the state variable (1st arg to state + final function and // return value from state function) is not re-serialized code.append("firstParam"); else // generate object representation of input parameter (call UDFunction.compose) code.append(composeMethod(paramTypes[i])).append("(protocolVersion, ").append(i).append(", params.get(").append(forAggregate ? i - 1 : i).append("))"); } return code.toString(); } private static String composeMethod(TypeToken<?> type) { return (type.isPrimitive()) ? ("super.compose_" + type.getRawType().getName()) : "super.compose"; } // Java source UDFs are a very simple compilation task, which allows us to let one class implement // all interfaces required by ECJ. static final class EcjCompilationUnit implements ICompilationUnit, ICompilerRequestor, INameEnvironment { List<IProblem> problemList; private final String className; private final char[] sourceCode; EcjCompilationUnit(String sourceCode, String className) { this.className = className; this.sourceCode = sourceCode.toCharArray(); } // ICompilationUnit @Override public char[] getFileName() { return sourceCode; } @Override public char[] getContents() { return sourceCode; } @Override public char[] getMainTypeName() { int dot = className.lastIndexOf('.'); return ((dot > 0) ? className.substring(dot + 1) : className).toCharArray(); } @Override public char[][] getPackageName() { StringTokenizer izer = new StringTokenizer(className, "."); char[][] result = new char[izer.countTokens() - 1][]; for (int i = 0; i < result.length; i++) result[i] = izer.nextToken().toCharArray(); return result; } @Override public boolean ignoreOptionalProblems() { return false; } // ICompilerRequestor @Override public void acceptResult(CompilationResult result) { if (result.hasErrors()) { IProblem[] problems = result.getProblems(); if (problemList == null) problemList = new ArrayList<>(problems.length); Collections.addAll(problemList, problems); } else { ClassFile[] classFiles = result.getClassFiles(); for (ClassFile classFile : classFiles) targetClassLoader.addClass(className, classFile.getBytes()); } } // INameEnvironment @Override public NameEnvironmentAnswer findType(char[][] compoundTypeName) { StringBuilder result = new StringBuilder(); for (int i = 0; i < compoundTypeName.length; i++) { if (i > 0) result.append('.'); result.append(compoundTypeName[i]); } return findType(result.toString()); } @Override public NameEnvironmentAnswer findType(char[] typeName, char[][] packageName) { StringBuilder result = new StringBuilder(); int i = 0; for (; i < packageName.length; i++) { if (i > 0) result.append('.'); result.append(packageName[i]); } if (i > 0) result.append('.'); result.append(typeName); return findType(result.toString()); } private NameEnvironmentAnswer findType(String className) { if (className.equals(this.className)) { return new NameEnvironmentAnswer(this, null); } String resourceName = className.replace('.', '/') + ".class"; try (InputStream is = UDFunction.udfClassLoader.getResourceAsStream(resourceName)) { if (is != null) { byte[] classBytes = ByteStreams.toByteArray(is); char[] fileName = className.toCharArray(); ClassFileReader classFileReader = new ClassFileReader(classBytes, fileName, true); return new NameEnvironmentAnswer(classFileReader, null); } } catch (IOException | ClassFormatException exc) { throw new RuntimeException(exc); } return null; } private boolean isPackage(String result) { if (result.equals(this.className)) return false; String resourceName = result.replace('.', '/') + ".class"; try (InputStream is = UDFunction.udfClassLoader.getResourceAsStream(resourceName)) { return is == null; } catch (IOException e) { // we are here, since close on is failed. That means it was not null return false; } } @Override public boolean isPackage(char[][] parentPackageName, char[] packageName) { StringBuilder result = new StringBuilder(); int i = 0; if (parentPackageName != null) for (; i < parentPackageName.length; i++) { if (i > 0) result.append('.'); result.append(parentPackageName[i]); } if (Character.isUpperCase(packageName[0]) && !isPackage(result.toString())) return false; if (i > 0) result.append('.'); result.append(packageName); return isPackage(result.toString()); } @Override public void cleanup() { } } static final class EcjTargetClassLoader extends SecureClassLoader { EcjTargetClassLoader() { super(UDFunction.udfClassLoader); } // This map is usually empty. // It only contains data *during* UDF compilation but not during runtime. // // addClass() is invoked by ECJ after successful compilation of the generated Java source. // loadClass(targetClassName) is invoked by buildUDF() after ECJ returned from successful compilation. // private final Map<String, byte[]> classes = new ConcurrentHashMap<>(); void addClass(String className, byte[] classData) { classes.put(className, classData); } byte[] classData(String className) { return classes.get(className); } protected Class<?> findClass(String name) throws ClassNotFoundException { // remove the class binary - it's only used once - so it's wasting heap byte[] classData = classes.remove(name); if (classData != null) return defineClass(name, classData, 0, classData.length, protectionDomain); return getParent().loadClass(name); } protected PermissionCollection getPermissions(CodeSource codesource) { return ThreadAwareSecurityManager.noPermissions; } }}