package com.oracle.truffle.regex.tregex.nodes.input;
import com.oracle.truffle.api.ArrayUtils;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Specialization;
import com.oracle.truffle.api.interop.TruffleObject;
import com.oracle.truffle.api.nodes.Node;
public abstract class InputRegionMatchesNode extends Node {
public static InputRegionMatchesNode create() {
return InputRegionMatchesNodeGen.create();
}
public abstract boolean execute(Object input, int fromIndex1, Object match, int fromIndex2, int length, Object mask);
@Specialization(guards = "mask == null")
public boolean doBytes(byte[] input, int fromIndex1, byte[] match, int fromIndex2, int length, @SuppressWarnings("unused") Object mask) {
return ArrayUtils.regionEqualsWithOrMask(input, fromIndex1, match, fromIndex2, length, null);
}
@Specialization(guards = "mask != null")
public boolean doBytesMask(byte[] input, int fromIndex1, byte[] match, int fromIndex2, int length, byte[] mask) {
return ArrayUtils.regionEqualsWithOrMask(input, fromIndex1, match, fromIndex2, length, mask);
}
@Specialization(guards = "mask == null")
public boolean doString(String input, int fromIndex1, String match, int fromIndex2, int length, @SuppressWarnings("unused") Object mask) {
return input.regionMatches(fromIndex1, match, fromIndex2, length);
}
@Specialization(guards = "mask != null")
public boolean doJavaStringMask(String input, int fromIndex1, String match, int fromIndex2, int length, String mask) {
return ArrayUtils.regionEqualsWithOrMask(input, fromIndex1, match, fromIndex2, length, mask);
}
@Specialization(guards = "mask == null")
public boolean doTruffleObjBytes(TruffleObject input, int fromIndex1, byte[] match, int fromIndex2, int length, @SuppressWarnings("unused") Object mask,
@Cached InputLengthNode lengthNode,
@Cached InputReadNode charAtNode) {
return regionMatchesTruffleObj(input, fromIndex1, match, fromIndex2, length, null, lengthNode, charAtNode);
}
@Specialization(guards = "mask != null")
public boolean doTruffleObjBytesMask(TruffleObject input, int fromIndex1, byte[] match, int fromIndex2, int length, byte[] mask,
@Cached InputLengthNode lengthNode,
@Cached InputReadNode charAtNode) {
assert match.length == mask.length;
return regionMatchesTruffleObj(input, fromIndex1, match, fromIndex2, length, mask, lengthNode, charAtNode);
}
@Specialization(guards = "mask == null")
public boolean doTruffleObjString(TruffleObject input, int fromIndex1, String match, int fromIndex2, int length, @SuppressWarnings("unused") Object mask,
@Cached InputLengthNode lengthNode,
@Cached InputReadNode charAtNode) {
return regionMatchesTruffleObj(input, fromIndex1, match, fromIndex2, length, null, lengthNode, charAtNode);
}
@Specialization(guards = "mask != null")
public boolean doTruffleObjStringMask(TruffleObject input, int fromIndex1, String match, int fromIndex2, int length, String mask,
@Cached InputLengthNode lengthNode,
@Cached InputReadNode charAtNode) {
assert match.length() == mask.length();
return regionMatchesTruffleObj(input, fromIndex1, match, fromIndex2, length, mask, lengthNode, charAtNode);
}
@Specialization(guards = "mask == null")
public boolean doTruffleObjTruffleObj(TruffleObject input, int fromIndex1, TruffleObject match, int fromIndex2, int length, @SuppressWarnings("unused") Object mask,
@Cached InputLengthNode lengthNode1,
@Cached InputReadNode charAtNode1,
@Cached InputLengthNode lengthNode2,
@Cached InputReadNode charAtNode2) {
if (fromIndex1 + length > lengthNode1.execute(input) || fromIndex2 + length > lengthNode2.execute(match)) {
return false;
}
for (int i = 0; i < length; i++) {
if (charAtNode1.execute(input, fromIndex1 + i) != charAtNode2.execute(match, fromIndex2 + i)) {
return false;
}
}
return true;
}
private static boolean regionMatchesTruffleObj(TruffleObject input, int fromIndex1, byte[] match, int fromIndex2, int length, byte[] mask,
InputLengthNode lengthNode,
InputReadNode charAtNode) {
if (fromIndex1 + length > lengthNode.execute(input) || fromIndex2 + length > match.length) {
return false;
}
for (int i = 0; i < length; i++) {
if (InputReadNode.readWithMask(input, fromIndex1 + i, mask, i, charAtNode) != Byte.toUnsignedInt(match[fromIndex2 + i])) {
return false;
}
}
return true;
}
private static boolean regionMatchesTruffleObj(TruffleObject input, int fromIndex1, String match, int fromIndex2, int length, String mask,
InputLengthNode lengthNode,
InputReadNode charAtNode) {
if (fromIndex1 + length > lengthNode.execute(input) || fromIndex2 + length > match.length()) {
return false;
}
for (int i = 0; i < length; i++) {
if (InputReadNode.readWithMask(input, fromIndex1 + i, mask, i, charAtNode) != match.charAt(fromIndex2 + i)) {
return false;
}
}
return true;
}
}