package org.skife.jdbi.v2;
import org.antlr.runtime.ANTLRInputStream;
import org.antlr.runtime.Token;
import org.skife.jdbi.v2.exceptions.UnableToCreateStatementException;
import org.skife.jdbi.v2.tweak.StatementLocator;
import edu.umd.cs.findbugs.annotations.SuppressFBWarnings;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.regex.Matcher;
public class ClasspathStatementLocator implements StatementLocator
{
private static class CacheKey {
final String name;
final Class<?> sqlObjectType;
final Method sqlObjectMethod;
CacheKey(String name, Class<?> sqlObjectType, Method sqlObjectMethod) {
this.name = name;
this.sqlObjectType = sqlObjectType;
this.sqlObjectMethod = sqlObjectMethod;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
CacheKey that = (CacheKey) o;
return eq(this.name, that.name)
&& eq(this.sqlObjectType, that.sqlObjectType)
&& eq(this.sqlObjectMethod, that.sqlObjectMethod);
}
private boolean eq(Object left, Object right) {
return left == null ? right == null : left.equals(right);
}
@Override
public int hashCode() {
int result = name == null ? 0 : name.hashCode();
result = 31 * result + (sqlObjectType != null ? sqlObjectType.hashCode() : 0);
result = 31 * result + (sqlObjectMethod != null ? sqlObjectMethod.hashCode() : 0);
return result;
}
}
private final Map<CacheKey, String> found = Collections.synchronizedMap(new WeakHashMap<CacheKey, String>());
public static boolean looksLikeSql(String sql)
{
final String local = left(stripStart(sql), 8).toLowerCase();
return local.startsWith("insert ")
|| local.startsWith("update ")
|| local.startsWith("select ")
|| local.startsWith("call ")
|| local.startsWith("delete ")
|| local.startsWith("create ")
|| local.startsWith("alter ")
|| local.startsWith("merge ")
|| local.startsWith("replace ")
|| local.startsWith("drop ");
}
@Override
@SuppressWarnings("PMD.EmptyCatchBlock")
@SuppressFBWarnings("DM_STRING_CTOR")
public String locate(String name, StatementContext ctx)
{
final CacheKey cache_key = new CacheKey(name, ctx.getSqlObjectType(), ctx.getSqlObjectMethod());
boolean isSqlObjectMethod = ctx.getSqlObjectType() != null && ctx.getSqlObjectMethod() != null;
String cached = found.get(cache_key);
if (cached != null) {
return cached;
}
if (looksLikeSql(name)) {
if (isSqlObjectMethod) {
found.put(cache_key, name);
}
return name;
}
final ClassLoader loader = selectClassLoader();
InputStream in_stream = null;
try {
in_stream = loader.getResourceAsStream(name);
if (in_stream == null) {
in_stream = loader.getResourceAsStream(name + ".sql");
}
if (in_stream == null && ctx.getSqlObjectType() != null) {
String filename = '/' + mungify(ctx.getSqlObjectType().getName() + '.' + name) + ".sql";
in_stream = loader.getResourceAsStream(filename);
if (in_stream == null) {
in_stream = ctx.getSqlObjectType().getResourceAsStream(filename);
}
}
if (in_stream == null) {
found.put(cache_key, isSqlObjectMethod ? name : new String(name));
return name;
}
String sql;
try {
sql = SQL_SCRIPT_PARSER.parse(new ANTLRInputStream(in_stream));
} catch (IOException e) {
throw new UnableToCreateStatementException(e.getMessage(), e, ctx);
}
found.put(cache_key, sql);
return sql;
}
finally {
try {
if (in_stream != null) {
in_stream.close();
}
}
catch (IOException e) {
e.printStackTrace();
}
}
}
private static ClassLoader selectClassLoader()
{
ClassLoader loader;
if (Thread.currentThread().getContextClassLoader() != null) {
loader = Thread.currentThread().getContextClassLoader();
}
else {
loader = ClasspathStatementLocator.class.getClassLoader();
}
return loader;
}
private static boolean (final String line)
{
return line.startsWith("#") || line.startsWith("--") || line.startsWith("//");
}
private static final String SEP = "/";
private static String mungify(String path)
{
return path.replaceAll("\\.", Matcher.quoteReplacement(SEP));
}
private static String stripStart(String str)
{
int strLen;
if (str == null || (strLen = str.length()) == 0) {
return "";
}
int start = 0;
while (start != strLen && Character.isWhitespace(str.charAt(start))) {
start++;
}
return str.substring(start);
}
private static String left(String str, int len)
{
if (str == null || len < 0) {
return "";
}
if (str.length() <= len) {
return str;
}
return str.substring(0, len);
}
private static final SqlScriptParser SQL_SCRIPT_PARSER = new SqlScriptParser(new SqlScriptParser.TokenHandler() {
@Override
public void handle(Token t, StringBuilder sb) {
sb.append(t.getText());
}
});
}