package org.graalvm.compiler.replacements;
import static org.graalvm.compiler.nodes.util.ConstantReflectionUtil.loadByteArrayConstant;
import static org.graalvm.compiler.nodes.util.ConstantReflectionUtil.loadCharArrayConstant;
import static org.graalvm.compiler.replacements.ReplacementsUtil.byteArrayBaseOffset;
import static org.graalvm.compiler.replacements.ReplacementsUtil.charArrayBaseOffset;
import static org.graalvm.compiler.replacements.SnippetTemplate.DEFAULT_REPLACER;
import static org.graalvm.compiler.serviceprovider.GraalUnsafeAccess.getUnsafe;
import org.graalvm.compiler.api.replacements.Fold.InjectedParameter;
import org.graalvm.compiler.api.replacements.Snippet;
import org.graalvm.compiler.api.replacements.Snippet.ConstantParameter;
import org.graalvm.compiler.api.replacements.SnippetReflectionProvider;
import org.graalvm.compiler.debug.DebugHandlersFactory;
import org.graalvm.compiler.nodes.StructuredGraph;
import org.graalvm.compiler.nodes.spi.LoweringTool;
import org.graalvm.compiler.options.OptionValues;
import org.graalvm.compiler.phases.util.Providers;
import org.graalvm.compiler.replacements.SnippetTemplate.AbstractTemplates;
import org.graalvm.compiler.replacements.SnippetTemplate.Arguments;
import org.graalvm.compiler.replacements.SnippetTemplate.SnippetInfo;
import org.graalvm.compiler.replacements.nodes.ExplodeLoopNode;
import jdk.vm.ci.code.TargetDescription;
import jdk.vm.ci.meta.JavaConstant;
import jdk.vm.ci.meta.JavaKind;
import jdk.vm.ci.meta.MetaAccessProvider;
import sun.misc.Unsafe;
public class ConstantStringIndexOfSnippets implements Snippets {
private static final Unsafe UNSAFE = getUnsafe();
public static class Templates extends AbstractTemplates {
private final SnippetInfo indexOfConstant = snippet(ConstantStringIndexOfSnippets.class, "indexOfConstant");
private final SnippetInfo latin1IndexOfConstant = snippet(ConstantStringIndexOfSnippets.class, "latin1IndexOfConstant");
private final SnippetInfo utf16IndexOfConstant = snippet(ConstantStringIndexOfSnippets.class, "utf16IndexOfConstant");
public Templates(OptionValues options, Iterable<DebugHandlersFactory> factories, Providers providers, SnippetReflectionProvider snippetReflection, TargetDescription target) {
super(options, factories, providers, snippetReflection, target);
}
public void lower(SnippetLowerableMemoryNode stringIndexOf, LoweringTool tool) {
StructuredGraph graph = stringIndexOf.graph();
Arguments args = new Arguments(indexOfConstant, graph.getGuardsStage(), tool.getLoweringStage());
args.add("source", stringIndexOf.getArgument(0));
args.add("sourceOffset", stringIndexOf.getArgument(1));
args.add("sourceCount", stringIndexOf.getArgument(2));
args.addConst("target", stringIndexOf.getArgument(3));
args.add("targetOffset", stringIndexOf.getArgument(4));
args.add("targetCount", stringIndexOf.getArgument(5));
args.add("origFromIndex", stringIndexOf.getArgument(6));
JavaConstant targetArg = stringIndexOf.getArgument(3).asJavaConstant();
char[] targetCharArray = loadCharArrayConstant(providers.getConstantReflection(), targetArg, Integer.MAX_VALUE);
args.addConst("md2", md2(targetCharArray));
args.addConst("cache", computeCache(targetCharArray));
template(stringIndexOf, args).instantiate(providers.getMetaAccess(), stringIndexOf, DEFAULT_REPLACER, args);
}
public void lowerLatin1(SnippetLowerableMemoryNode latin1IndexOf, LoweringTool tool) {
StructuredGraph graph = latin1IndexOf.graph();
Arguments args = new Arguments(latin1IndexOfConstant, graph.getGuardsStage(), tool.getLoweringStage());
args.add("source", latin1IndexOf.getArgument(0));
args.add("sourceCount", latin1IndexOf.getArgument(1));
args.addConst("target", latin1IndexOf.getArgument(2));
args.add("targetCount", latin1IndexOf.getArgument(3));
args.add("origFromIndex", latin1IndexOf.getArgument(4));
JavaConstant targetArg = latin1IndexOf.getArgument(2).asJavaConstant();
byte[] targetByteArray = loadByteArrayConstant(providers.getConstantReflection(), targetArg, Integer.MAX_VALUE);
args.addConst("md2", md2(targetByteArray));
args.addConst("cache", computeCache(targetByteArray));
template(latin1IndexOf, args).instantiate(providers.getMetaAccess(), latin1IndexOf, DEFAULT_REPLACER, args);
}
public void lowerUTF16(SnippetLowerableMemoryNode utf16IndexOf, LoweringTool tool) {
StructuredGraph graph = utf16IndexOf.graph();
Arguments args = new Arguments(utf16IndexOfConstant, graph.getGuardsStage(), tool.getLoweringStage());
args.add("source", utf16IndexOf.getArgument(0));
args.add("sourceCount", utf16IndexOf.getArgument(1));
args.addConst("target", utf16IndexOf.getArgument(2));
args.add("targetCount", utf16IndexOf.getArgument(3));
args.add("origFromIndex", utf16IndexOf.getArgument(4));
JavaConstant targetArg = utf16IndexOf.getArgument(2).asJavaConstant();
byte[] targetByteArray = loadByteArrayConstant(providers.getConstantReflection(), targetArg, Integer.MAX_VALUE);
args.addConst("md2", md2Utf16(tool.getMetaAccess(), targetByteArray));
args.addConst("cache", computeCacheUtf16(tool.getMetaAccess(), targetByteArray));
template(utf16IndexOf, args).instantiate(providers.getMetaAccess(), utf16IndexOf, DEFAULT_REPLACER, args);
}
}
static int md2(char[] target) {
int c = target.length;
if (c == 0) {
return 0;
}
char lastChar = target[c - 1];
int md2 = c;
for (int i = 0; i < c - 1; i++) {
if (target[i] == lastChar) {
md2 = (c - 1) - i;
}
}
return md2;
}
static long computeCache(char[] s) {
int c = s.length;
int cache = 0;
int i;
for (i = 0; i < c - 1; i++) {
cache |= (1 << (s[i] & 63));
}
return cache;
}
static int md2(byte[] target) {
int c = target.length;
if (c == 0) {
return 0;
}
byte lastByte = target[c - 1];
int md2 = c;
for (int i = 0; i < c - 1; i++) {
if (target[i] == lastByte) {
md2 = (c - 1) - i;
}
}
return md2;
}
static long computeCache(byte[] s) {
int c = s.length;
int cache = 0;
int i;
for (i = 0; i < c - 1; i++) {
cache |= (1 << (s[i] & 63));
}
return cache;
}
static int md2Utf16(MetaAccessProvider metaAccess, byte[] target) {
int c = target.length / 2;
if (c == 0) {
return 0;
}
long base = metaAccess.getArrayBaseOffset(JavaKind.Byte);
char lastChar = UNSAFE.getChar(target, base + (c - 1) * 2);
int md2 = c;
for (int i = 0; i < c - 1; i++) {
char currChar = UNSAFE.getChar(target, base + i * 2);
if (currChar == lastChar) {
md2 = (c - 1) - i;
}
}
return md2;
}
static long computeCacheUtf16(MetaAccessProvider metaAccess, byte[] s) {
int c = s.length / 2;
int cache = 0;
int i;
long base = metaAccess.getArrayBaseOffset(JavaKind.Byte);
for (i = 0; i < c - 1; i++) {
char currChar = UNSAFE.getChar(s, base + i * 2);
cache |= (1 << (currChar & 63));
}
return cache;
}
static final MetaAccessProvider INJECTED = null;
@Snippet
public static int indexOfConstant(char[] source, int sourceOffset, int sourceCount,
@ConstantParameter char[] target, int targetOffset, int targetCount,
int origFromIndex, @ConstantParameter int md2, @ConstantParameter long cache) {
int fromIndex = origFromIndex;
if (fromIndex >= sourceCount) {
return (targetCount == 0 ? sourceCount : -1);
}
if (fromIndex < 0) {
fromIndex = 0;
}
if (targetCount == 0) {
return fromIndex;
}
int targetCountLess1 = targetCount - 1;
int sourceEnd = sourceCount - targetCountLess1;
long base = charArrayBaseOffset(INJECTED);
int lastChar = UNSAFE.getChar(target, base + targetCountLess1 * 2);
outer_loop: for (long i = sourceOffset + fromIndex; i < sourceEnd;) {
int src = UNSAFE.getChar(source, base + (i + targetCountLess1) * 2);
if (src == lastChar) {
if (targetCount <= 8) {
ExplodeLoopNode.explodeLoop();
}
for (long j = 0; j < targetCountLess1; j++) {
char sourceChar = UNSAFE.getChar(source, base + (i + j) * 2);
if (UNSAFE.getChar(target, base + (targetOffset + j) * 2) != sourceChar) {
if ((cache & (1 << sourceChar)) == 0) {
if (md2 < j + 1) {
i += j + 1;
continue outer_loop;
}
}
i += md2;
continue outer_loop;
}
}
return (int) (i - sourceOffset);
}
if ((cache & (1 << src)) == 0) {
i += targetCountLess1;
}
i++;
}
return -1;
}
@Snippet
public static int utf16IndexOfConstant(byte[] source, int sourceCount,
@ConstantParameter byte[] target, int targetCount,
int origFromIndex, @ConstantParameter int md2, @ConstantParameter long cache) {
int fromIndex = origFromIndex;
if (fromIndex >= sourceCount) {
return (targetCount == 0 ? sourceCount : -1);
}
if (fromIndex < 0) {
fromIndex = 0;
}
if (targetCount == 0) {
return fromIndex;
}
int targetCountLess1 = targetCount - 1;
int sourceEnd = sourceCount - targetCountLess1;
long base = byteArrayBaseOffset(INJECTED);
int lastChar = UNSAFE.getChar(target, base + targetCountLess1 * 2);
outer_loop: for (long i = fromIndex; i < sourceEnd;) {
int src = UNSAFE.getChar(source, base + (i + targetCountLess1) * 2);
if (src == lastChar) {
if (targetCount <= 8) {
ExplodeLoopNode.explodeLoop();
}
for (long j = 0; j < targetCountLess1; j++) {
char sourceChar = UNSAFE.getChar(source, base + (i + j) * 2);
if (UNSAFE.getChar(target, base + j * 2) != sourceChar) {
if ((cache & (1 << sourceChar)) == 0) {
if (md2 < j + 1) {
i += j + 1;
continue outer_loop;
}
}
i += md2;
continue outer_loop;
}
}
return (int) i;
}
if ((cache & (1 << src)) == 0) {
i += targetCountLess1;
}
i++;
}
return -1;
}
@Snippet
public static int latin1IndexOfConstant(byte[] source, int sourceCount,
@ConstantParameter byte[] target, int targetCount,
int origFromIndex, @ConstantParameter int md2, @ConstantParameter long cache) {
int fromIndex = origFromIndex;
if (fromIndex >= sourceCount) {
return (targetCount == 0 ? sourceCount : -1);
}
if (fromIndex < 0) {
fromIndex = 0;
}
if (targetCount == 0) {
return fromIndex;
}
int targetCountLess1 = targetCount - 1;
int sourceEnd = sourceCount - targetCountLess1;
long base = byteArrayBaseOffset(INJECTED);
int lastByte = UNSAFE.getByte(target, base + targetCountLess1);
outer_loop: for (long i = fromIndex; i < sourceEnd;) {
int src = UNSAFE.getByte(source, base + i + targetCountLess1);
if (src == lastByte) {
if (targetCount <= 8) {
ExplodeLoopNode.explodeLoop();
}
for (long j = 0; j < targetCountLess1; j++) {
byte sourceByte = UNSAFE.getByte(source, base + i + j);
if (UNSAFE.getByte(target, base + j) != sourceByte) {
if ((cache & (1 << sourceByte)) == 0) {
if (md2 < j + 1) {
i += j + 1;
continue outer_loop;
}
}
i += md2;
continue outer_loop;
}
}
return (int) i;
}
if ((cache & (1 << src)) == 0) {
i += targetCountLess1;
}
i++;
}
return -1;
}
}