package org.hibernate.sql.ordering.antlr;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.hibernate.dialect.function.SQLFunction;
import org.hibernate.internal.util.StringHelper;
import org.hibernate.sql.Template;
import org.jboss.logging.Logger;
import antlr.CommonAST;
import antlr.TokenStream;
import antlr.collections.AST;
public class OrderByFragmentParser extends GeneratedOrderByFragmentParser {
private static final Logger LOG = Logger.getLogger( OrderByFragmentParser.class.getName() );
private final TranslationContext context;
private Set<String> columnReferences = new HashSet<String>();
public OrderByFragmentParser(TokenStream lexer, TranslationContext context) {
super( lexer );
super.setASTFactory( new Factory() );
this.context = context;
}
public Set<String> getColumnReferences() {
return columnReferences;
}
@Override
protected AST quotedIdentifier(AST ident) {
final String columnName = context.getDialect().quote( '`' + ident.getText() + '`' );
columnReferences.add( columnName );
final String marker = '{' + columnName + '}';
return getASTFactory().create( OrderByTemplateTokenTypes.IDENT, marker );
}
@Override
protected AST quotedString(AST ident) {
return getASTFactory().create( OrderByTemplateTokenTypes.IDENT, context.getDialect().quote( ident.getText() ) );
}
@Override
@SuppressWarnings("SimplifiableIfStatement")
protected boolean isFunctionName(AST ast) {
AST child = ast.getFirstChild();
if ( child != null && "{param list}".equals( child.getText() ) ) {
return true;
}
final SQLFunction function = context.getSqlFunctionRegistry().findSQLFunction( ast.getText() );
if ( function == null ) {
return false;
}
else {
return !function.hasParenthesesIfNoArguments();
}
}
@SuppressWarnings("unchecked")
@Override
protected AST resolveFunction(AST ast) {
AST child = ast.getFirstChild();
if ( child != null ) {
assert "{param list}".equals( child.getText() );
child = child.getFirstChild();
}
final String functionName = ast.getText();
final SQLFunction function = context.getSqlFunctionRegistry().findSQLFunction( functionName );
if ( function == null ) {
String text = functionName;
if ( child != null ) {
text += '(';
while ( child != null ) {
text += resolveFunctionArgument( child );
child = child.getNextSibling();
if ( child != null ) {
text += ", ";
}
}
text += ')';
}
return getASTFactory().create( OrderByTemplateTokenTypes.IDENT, text );
}
else {
ArrayList expressions = new ArrayList();
while ( child != null ) {
expressions.add( resolveFunctionArgument( child ) );
child = child.getNextSibling();
}
final String text = function.render( null, expressions, context.getSessionFactory() );
return getASTFactory().create( OrderByTemplateTokenTypes.IDENT, text );
}
}
private String resolveFunctionArgument(AST argumentNode) {
final String nodeText = argumentNode.getText();
final String adjustedText;
if ( nodeText.contains( Template.TEMPLATE ) ) {
adjustedText = adjustTemplateReferences( nodeText );
}
else if ( nodeText.startsWith( "{" ) && nodeText.endsWith( "}" ) ) {
columnReferences.add( nodeText.substring( 1, nodeText.length() - 1 ) );
return nodeText;
}
else {
adjustedText = nodeText;
Pattern pattern = Pattern.compile( "\\{(.*)\\}" );
Matcher matcher = pattern.matcher( adjustedText );
while ( matcher.find() ) {
columnReferences.add( matcher.group( 1 ) );
}
}
return adjustedText;
}
@Override
protected AST resolveIdent(AST ident) {
String text = ident.getText();
SqlValueReference[] sqlValueReferences;
try {
sqlValueReferences = context.getColumnMapper().map( text );
}
catch (Throwable t) {
sqlValueReferences = null;
}
if ( sqlValueReferences == null || sqlValueReferences.length == 0 ) {
return getASTFactory().create( OrderByTemplateTokenTypes.IDENT, makeColumnReference( text ) );
}
else if ( sqlValueReferences.length == 1 ) {
return processSqlValueReference( sqlValueReferences[0] );
}
else {
final AST root = getASTFactory().create( OrderByTemplateTokenTypes.IDENT_LIST, "{ident list}" );
for ( SqlValueReference sqlValueReference : sqlValueReferences ) {
root.addChild( processSqlValueReference( sqlValueReference ) );
}
return root;
}
}
private AST processSqlValueReference(SqlValueReference sqlValueReference) {
if ( ColumnReference.class.isInstance( sqlValueReference ) ) {
final String columnName = ( (ColumnReference) sqlValueReference ).getColumnName();
return getASTFactory().create( OrderByTemplateTokenTypes.IDENT, makeColumnReference( columnName ) );
}
else {
final String formulaFragment = ( (FormulaReference) sqlValueReference ).getFormulaFragment();
final String adjustedText = adjustTemplateReferences( formulaFragment );
return getASTFactory().create( OrderByTemplateTokenTypes.IDENT, adjustedText );
}
}
private String makeColumnReference(String text) {
columnReferences.add( text );
return "{" + text + "}";
}
private static final int TEMPLATE_MARKER_LENGTH = Template.TEMPLATE.length();
private String adjustTemplateReferences(String template) {
int templateLength = template.length();
int startPos = template.indexOf( Template.TEMPLATE );
while ( startPos != -1 && startPos < templateLength ) {
int dotPos = startPos + TEMPLATE_MARKER_LENGTH;
int pos = dotPos + 1;
while ( pos < templateLength && isValidIdentifierCharacter( template.charAt( pos ) ) ) {
pos++;
}
final String columnReference = template.substring( dotPos + 1, pos );
final String replacement = "{" + columnReference + "}";
template = template.replace( template.substring( startPos, pos ), replacement );
columnReferences.add( columnReference );
startPos = template.indexOf( Template.TEMPLATE, ( pos - TEMPLATE_MARKER_LENGTH ) + 1 );
templateLength = template.length();
}
return template;
}
private static boolean isValidIdentifierCharacter(char c) {
return Character.isLetter( c )
|| Character.isDigit( c )
|| '_' == c
|| '\"' == c;
}
@Override
protected AST postProcessSortSpecification(AST sortSpec) {
assert SORT_SPEC == sortSpec.getType();
SortSpecification sortSpecification = (SortSpecification) sortSpec;
AST sortKey = sortSpecification.getSortKey();
if ( IDENT_LIST == sortKey.getFirstChild().getType() ) {
AST identList = sortKey.getFirstChild();
AST ident = identList.getFirstChild();
AST holder = new CommonAST();
do {
holder.addChild(
createSortSpecification(
ident,
sortSpecification.getCollation(),
sortSpecification.getOrdering()
)
);
ident = ident.getNextSibling();
} while ( ident != null );
sortSpec = holder.getFirstChild();
}
return sortSpec;
}
private SortSpecification createSortSpecification(
AST ident,
CollationSpecification collationSpecification,
OrderingSpecification orderingSpecification) {
AST sortSpecification = getASTFactory().create( SORT_SPEC, "{{sort specification}}" );
AST sortKey = getASTFactory().create( SORT_KEY, "{{sort key}}" );
AST newIdent = getASTFactory().create( ident.getType(), ident.getText() );
sortKey.setFirstChild( newIdent );
sortSpecification.setFirstChild( sortKey );
if ( collationSpecification != null ) {
sortSpecification.addChild( collationSpecification );
}
if ( orderingSpecification != null ) {
sortSpecification.addChild( orderingSpecification );
}
return (SortSpecification) sortSpecification;
}
private int traceDepth = 0;
@Override
public void traceIn(String ruleName) {
if ( inputState.guessing > 0 ) {
return;
}
String prefix = StringHelper.repeat( '-', ( traceDepth++ * 2 ) ) + "-> ";
LOG.trace( prefix + ruleName );
}
@Override
public void traceOut(String ruleName) {
if ( inputState.guessing > 0 ) {
return;
}
String prefix = "<-" + StringHelper.repeat( '-', ( --traceDepth * 2 ) ) + " ";
LOG.trace( prefix + ruleName );
}
@Override
protected void trace(String msg) {
LOG.trace( msg );
}
}