package com.oracle.svm.hosted;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.security.Provider;
import java.security.Provider.Service;
import java.security.cert.CertificateException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.graalvm.compiler.options.Option;
import org.graalvm.compiler.serviceprovider.JavaVersionUtil;
import org.graalvm.nativeimage.ImageSingletons;
import org.graalvm.nativeimage.hosted.Feature;
import org.graalvm.nativeimage.hosted.RuntimeReflection;
import org.graalvm.nativeimage.impl.RuntimeClassInitializationSupport;
import com.oracle.svm.core.SubstrateOptions;
import com.oracle.svm.core.annotate.AutomaticFeature;
import com.oracle.svm.core.jdk.JNIRegistrationUtil;
import com.oracle.svm.core.jdk.NativeLibrarySupport;
import com.oracle.svm.core.jdk.PlatformNativeLibrarySupport;
import com.oracle.svm.core.jni.JNIRuntimeAccess;
import com.oracle.svm.core.option.HostedOptionKey;
import com.oracle.svm.core.util.VMError;
import com.oracle.svm.hosted.c.NativeLibraries;
import com.oracle.svm.util.ReflectionUtil;
import sun.security.jca.Providers;
import sun.security.provider.NativePRNG;
import sun.security.x509.OIDMap;
@AutomaticFeature
public class SecurityServicesFeature extends JNIRegistrationUtil implements Feature {
static class Options {
@Option(help = "Enable the feature that provides support for security services.")
public static final HostedOptionKey<Boolean> EnableSecurityServicesFeature = new HostedOptionKey<>(true);
@Option(help = "Enable trace logging for the security services feature.")
static final HostedOptionKey<Boolean> TraceSecurityServices = new HostedOptionKey<>(false);
}
private static final String SUN_PROVIDER = "SUN";
private static final String SECURE_RANDOM_SERVICE = "SecureRandom";
private static final String MESSAGE_DIGEST_SERVICE = "MessageDigest";
private static final String SIGNATURE_SERVICE = "Signature";
private static final String CIPHER_SERVICE = "Cipher";
private static final String KEY_AGREEMENT_SERVICE = "KeyAgreement";
@Override
public boolean isInConfiguration(IsInConfigurationAccess access) {
return Options.EnableSecurityServicesFeature.getValue();
}
@Override
public void duringSetup(DuringSetupAccess access) {
RuntimeClassInitializationSupport rci = ImageSingletons.lookup(RuntimeClassInitializationSupport.class);
rci.rerunInitialization(NativePRNG.class, "for substitutions");
rci.rerunInitialization(NativePRNG.Blocking.class, "for substitutions");
rci.rerunInitialization(NativePRNG.NonBlocking.class, "for substitutions");
rci.rerunInitialization(clazz(access, "sun.security.provider.SeedGenerator"), "for substitutions");
rci.rerunInitialization(clazz(access, "sun.security.provider.SecureRandom$SeederHolder"), "for substitutions");
if (JavaVersionUtil.JAVA_SPEC >= 11) {
rci.rerunInitialization(clazz(access, "sun.security.provider.AbstractDrbg$SeederHolder"), "for substitutions");
}
if (JavaVersionUtil.JAVA_SPEC > 8) {
rci.rerunInitialization(clazz(access, "sun.security.provider.FileInputStreamPool"), "for substitutions");
}
rci.rerunInitialization(clazz(access, "java.util.UUID$Holder"), "for substitutions");
rci.rerunInitialization(clazz(access, "sun.security.jca.JCAUtil$CachedSecureRandomHolder"), "for substitutions");
rci.rerunInitialization(clazz(access, "com.sun.crypto.provider.SunJCE$SecureRandomHolder"), "for substitutions");
rci.rerunInitialization(clazz(access, "sun.security.krb5.Confounder"), "for substitutions");
rci.rerunInitialization(clazz(access, "sun.security.ssl.SSLContextImpl$DefaultManagersHolder"), "for reading properties at run time");
optionalClazz(access, "sun.security.ssl.Debug").ifPresent(c -> rci.rerunInitialization(c, "for reading properties at run time"));
optionalClazz(access, "sun.security.ssl.SSLLogger").ifPresent(c -> rci.rerunInitialization(c, "for reading properties at run time"));
prepareSunEC();
}
private static void prepareSunEC() {
JNIRuntimeAccess.register(byte[].class);
}
private static List<Provider> getProviders(boolean enableAllSecurityServices) {
if (enableAllSecurityServices) {
return Providers.getProviderList().providers();
} else {
Provider sunProvider = Providers.getSunProvider();
assert isSunProvider(sunProvider);
return Collections.singletonList(sunProvider);
}
}
@Override
public void beforeAnalysis(BeforeAnalysisAccess access) {
access.registerReachabilityHandler(SecurityServicesFeature::registerServicesForReflection, method(access, "java.security.Provider$Service", "newInstance", Object.class));
access.registerReachabilityHandler(SecurityServicesFeature::linkSunEC,
method(access, "sun.security.ec.ECDSASignature", "signDigest", byte[].class, byte[].class, byte[].class, byte[].class, int.class),
method(access, "sun.security.ec.ECDSASignature", "verifySignedDigest", byte[].class, byte[].class, byte[].class, byte[].class));
if (isPosix()) {
access.registerReachabilityHandler(SecurityServicesFeature::linkJaas, method(access, "com.sun.security.auth.module.UnixSystem", "getUnixInfo"));
}
}
@SuppressWarnings("unchecked")
private static void registerServicesForReflection(BeforeAnalysisAccess access) {
boolean enableAllSecurityServices = SubstrateOptions.EnableAllSecurityServices.getValue();
Function<String, Class<?>> consParamClassAccessor = getConsParamClassAccessor(access);
trace("Registering security services...");
for (Provider provider : getProviders(enableAllSecurityServices)) {
register(provider);
for (Service service : provider.getServices()) {
if (enableAllSecurityServices || isMessageDigest(service) || isSecureRandom(service)) {
register(access, service, consParamClassAccessor);
}
}
}
if (enableAllSecurityServices) {
Class<?> javaKeyStoreJks = access.findClassByName("sun.security.provider.JavaKeyStore$JKS");
registerForReflection(javaKeyStoreJks);
trace("Class registered for reflection: " + javaKeyStoreJks);
Map<String, Object> map = ReflectionUtil.readStaticField(OIDMap.class, "nameMap");
for (String name : map.keySet()) {
try {
Class<?> extensionClass = OIDMap.getClass(name);
assert sun.security.x509.Extension.class.isAssignableFrom(extensionClass);
registerForReflection(extensionClass);
trace("Class registered for reflection: " + extensionClass);
} catch (CertificateException e) {
throw VMError.shouldNotReachHere(e);
}
}
}
}
private static void linkSunEC(DuringAnalysisAccess duringAnalysisAccess) {
FeatureImpl.DuringAnalysisAccessImpl a = (FeatureImpl.DuringAnalysisAccessImpl) duringAnalysisAccess;
NativeLibraries nativeLibraries = a.getNativeLibraries();
PlatformNativeLibrarySupport.singleton();
NativeLibrarySupport.singleton().preregisterUninitializedBuiltinLibrary("sunec");
PlatformNativeLibrarySupport.singleton().addBuiltinPkgNativePrefix("sun_security_ec");
nativeLibraries.addStaticJniLibrary("sunec");
if (isPosix()) {
nativeLibraries.addDynamicNonJniLibrary("stdc++");
}
}
private static void linkJaas(DuringAnalysisAccess duringAnalysisAccess) {
JNIRuntimeAccess.register(fields(duringAnalysisAccess, "com.sun.security.auth.module.UnixSystem", "username", "uid", "gid", "groups"));
NativeLibraries nativeLibraries = ((FeatureImpl.DuringAnalysisAccessImpl) duringAnalysisAccess).getNativeLibraries();
NativeLibrarySupport.singleton().preregisterUninitializedBuiltinLibrary(JavaVersionUtil.JAVA_SPEC >= 11 ? "jaas" : "jaas_unix");
PlatformNativeLibrarySupport.singleton().addBuiltinPkgNativePrefix("com_sun_security_auth_module_UnixSystem");
nativeLibraries.addStaticJniLibrary("jaas");
}
@SuppressWarnings("unchecked")
private static Function<String, Class<?>> getConsParamClassAccessor(BeforeAnalysisAccess access) {
Map<String, Object> knownEngines = ReflectionUtil.readStaticField(Provider.class, "knownEngines");
Field consParamClassNameField = ReflectionUtil.lookupField(access.findClassByName("java.security.Provider$EngineDescription"), "constructorParameterClassName");
return (serviceType) -> {
try {
Object engineDescription = knownEngines.get(serviceType);
if (engineDescription == null) {
return null;
}
String constrParamClassName = (String) consParamClassNameField.get(engineDescription);
if (constrParamClassName != null) {
return access.findClassByName(constrParamClassName);
}
} catch (IllegalAccessException e) {
VMError.shouldNotReachHere(e);
}
return null;
};
}
private static void register(Provider provider) {
registerForReflection(provider.getClass());
try {
Method getVerificationResult = ReflectionUtil.lookupMethod(Class.forName("javax.crypto.JceSecurity"), "getVerificationResult", Provider.class);
getVerificationResult.invoke(null, provider);
} catch (ReflectiveOperationException ex) {
throw VMError.shouldNotReachHere(ex);
}
}
private static void register(BeforeAnalysisAccess access, Service service, Function<String, Class<?>> consParamClassAccessor) {
Class<?> serviceClass = access.findClassByName(service.getClassName());
if (serviceClass != null) {
registerForReflection(serviceClass);
Class<?> consParamClass = consParamClassAccessor.apply(service.getType());
if (consParamClass != null) {
registerForReflection(consParamClass);
trace("Parameter class registered: " + consParamClass);
}
if (isSignature(service) || isCipher(service) || isKeyAgreement(service)) {
for (String keyClassName : getSupportedKeyClasses(service)) {
Class<?> keyClass = access.findClassByName(keyClassName);
if (keyClass != null) {
registerForReflection(keyClass);
}
}
}
trace("Service registered: " + asString(service));
} else {
trace("Service registration failed: " + asString(service) + ". Cause: class not found " + service.getClassName());
}
}
private static void registerForReflection(Class<?> clazz) {
RuntimeReflection.register(clazz);
RuntimeReflection.register(clazz.getConstructors());
}
private static boolean isSunProvider(Provider provider) {
return provider.getName().equals(SUN_PROVIDER);
}
private static boolean isSecureRandom(Service s) {
return s.getType().equals(SECURE_RANDOM_SERVICE);
}
private static boolean isMessageDigest(Service s) {
return s.getType().equals(MESSAGE_DIGEST_SERVICE);
}
private static boolean isSignature(Service s) {
return s.getType().equals(SIGNATURE_SERVICE);
}
private static boolean isCipher(Service s) {
return s.getType().equals(CIPHER_SERVICE);
}
private static boolean isKeyAgreement(Service s) {
return s.getType().equals(KEY_AGREEMENT_SERVICE);
}
private static final String[] emptyStringArray = new String[0];
private static String[] getSupportedKeyClasses(Service s) {
assert isSignature(s) || isCipher(s) || isKeyAgreement(s);
String supportedKeyClasses = s.getAttribute("SupportedKeyClasses");
if (supportedKeyClasses != null) {
return supportedKeyClasses.split("\\|");
}
return emptyStringArray;
}
private static final String SEP = " , ";
private static String asString(Service s) {
String str = "Provider = " + s.getProvider().getName() + SEP;
str += "Type = " + s.getType() + SEP;
str += "Algorithm = " + s.getAlgorithm() + SEP;
str += "Class = " + s.getClassName();
if (isSignature(s) || isCipher(s) || isKeyAgreement(s)) {
str += SEP + "SupportedKeyClasses = " + Arrays.toString(getSupportedKeyClasses(s));
}
return str;
}
private static void trace(String trace) {
if (Options.TraceSecurityServices.getValue()) {
System.out.println(trace);
}
}
}