package org.glassfish.pfl.tf.spi;
import java.io.PrintWriter;
import org.glassfish.pfl.basic.func.UnaryFunction;
import org.glassfish.pfl.objectweb.asm.ClassAdapter;
import org.glassfish.pfl.objectweb.asm.ClassReader;
import org.glassfish.pfl.objectweb.asm.ClassVisitor;
import org.glassfish.pfl.objectweb.asm.ClassWriter;
import org.glassfish.pfl.objectweb.asm.MethodVisitor;
import org.glassfish.pfl.objectweb.asm.Opcodes;
import org.glassfish.pfl.objectweb.asm.Type;
import org.glassfish.pfl.objectweb.asm.tree.LocalVariableNode;
import org.glassfish.pfl.objectweb.asm.tree.MethodInsnNode;
import org.glassfish.pfl.objectweb.asm.tree.MethodNode;
import org.glassfish.pfl.objectweb.asm.util.AbstractVisitor;
import org.glassfish.pfl.objectweb.asm.util.CheckClassAdapter;
public class Util {
private final boolean debug ;
private final int verbose ;
public Util( final boolean debug, final int verbose ) {
this.debug = debug ;
if (debug && (verbose < 1)) {
this.verbose = 1 ;
} else {
this.verbose = verbose ;
}
}
public boolean getDebug() {
return debug ;
}
public void info( int level, String str ) {
if (verbose >= level) {
final String format = level>1 ? "%" + (4*(level-1) + 1) + "s"
: "%s" ;
final String pad = String.format( format, ">" ) ;
msg( pad + str ) ;
}
}
public void msg( String str ) {
System.out.println( str ) ;
}
public void error( String str ) {
throw new RuntimeException( str ) ;
}
public void initLocal( MethodVisitor mv, LocalVariableNode var ) {
info( 2, "Initializing variable " + var ) ;
Type type = Type.getType( var.desc ) ;
switch (type.getSort()) {
case Type.BYTE :
case Type.BOOLEAN :
case Type.CHAR :
case Type.SHORT :
case Type.INT :
mv.visitInsn( Opcodes.ICONST_0 ) ;
mv.visitVarInsn( Opcodes.ISTORE, var.index ) ;
break ;
case Type.LONG :
mv.visitInsn( Opcodes.LCONST_0 ) ;
mv.visitVarInsn( Opcodes.LSTORE, var.index ) ;
break ;
case Type.FLOAT :
mv.visitInsn( Opcodes.FCONST_0 ) ;
mv.visitVarInsn( Opcodes.FSTORE, var.index ) ;
break ;
case Type.DOUBLE :
mv.visitInsn( Opcodes.DCONST_0 ) ;
mv.visitVarInsn( Opcodes.DSTORE, var.index ) ;
break ;
default :
mv.visitInsn( Opcodes.ACONST_NULL ) ;
mv.visitVarInsn( Opcodes.ASTORE, var.index ) ;
}
}
public String getFullMethodDescriptor( String name, String desc ) {
return name + desc ;
}
public String getFullMethodDescriptor( MethodNode mn ) {
return mn.name + mn.desc ;
}
public String getFullMethodDescriptor( MethodInsnNode mn ) {
return mn.name + mn.desc ;
}
public String getFullMethodDescriptor( java.lang.reflect.Method method ) {
final String desc = Type.getMethodDescriptor(method) ;
return method.getName() + desc ;
}
public void newWithSimpleConstructor( MethodVisitor mv, Class<?> cls ) {
info( 2, "generating new for class " + cls ) ;
Type type = Type.getType( cls ) ;
mv.visitTypeInsn( Opcodes.NEW, type.getInternalName() );
mv.visitInsn( Opcodes.DUP ) ;
mv.visitMethodInsn( Opcodes.INVOKESPECIAL,
type.getInternalName(), "<init>", "()V" );
}
public String augmentInfoMethodDescriptor( String desc ) {
info( 2, "Augmenting infoMethod descriptor " + desc ) ;
Type[] oldArgTypes = Type.getArgumentTypes( desc ) ;
Type retType = Type.getReturnType( desc ) ;
int oldlen = oldArgTypes.length ;
Type[] argTypes = new Type[ oldlen + 2 ] ;
System.arraycopy(oldArgTypes, 0, argTypes, 0, oldlen);
argTypes[oldlen] = Type.getType( MethodMonitor.class ) ;
argTypes[oldlen+1] = Type.INT_TYPE ;
String newDesc = Type.getMethodDescriptor(retType, argTypes) ;
info( 3, "result is " + newDesc ) ;
return newDesc ;
}
public void emitIntConstant( MethodVisitor mv, int val ) {
info( 2, "Emitting constant " + val ) ;
if (val <= 5) {
switch (val) {
case 0:
mv.visitInsn( Opcodes.ICONST_0 ) ;
break ;
case 1:
mv.visitInsn( Opcodes.ICONST_1 ) ;
break ;
case 2:
mv.visitInsn( Opcodes.ICONST_2 ) ;
break ;
case 3:
mv.visitInsn( Opcodes.ICONST_3 ) ;
break ;
case 4:
mv.visitInsn( Opcodes.ICONST_4 ) ;
break ;
case 5:
mv.visitInsn( Opcodes.ICONST_5 ) ;
break ;
}
} else {
mv.visitLdcInsn( val );
}
}
public int wrapArg( MethodVisitor mv, int argIndex, Type atype ) {
info( 2, "Emitting code to wrap argument at " + argIndex
+ " of type " + atype ) ;
switch (atype.getSort() ) {
case Type.BOOLEAN :
mv.visitVarInsn( Opcodes.ILOAD, argIndex ) ;
mv.visitMethodInsn( Opcodes.INVOKESTATIC,
Type.getInternalName( Boolean.class ), "valueOf",
"(Z)Ljava/lang/Boolean;" );
break ;
case Type.BYTE :
mv.visitVarInsn( Opcodes.ILOAD, argIndex ) ;
mv.visitMethodInsn( Opcodes.INVOKESTATIC,
Type.getInternalName( Byte.class ), "valueOf",
"(B)Ljava/lang/Byte;" );
break ;
case Type.CHAR :
mv.visitVarInsn( Opcodes.ILOAD, argIndex ) ;
mv.visitMethodInsn( Opcodes.INVOKESTATIC,
Type.getInternalName( Character.class ), "valueOf",
"(C)Ljava/lang/Character;" );
break ;
case Type.SHORT :
mv.visitVarInsn( Opcodes.ILOAD, argIndex ) ;
mv.visitMethodInsn( Opcodes.INVOKESTATIC,
Type.getInternalName( Short.class ), "valueOf",
"(S)Ljava/lang/Short;" );
break ;
case Type.INT :
mv.visitVarInsn( Opcodes.ILOAD, argIndex ) ;
mv.visitMethodInsn( Opcodes.INVOKESTATIC,
Type.getInternalName( Integer.class ), "valueOf",
"(I)Ljava/lang/Integer;" );
break ;
case Type.LONG :
mv.visitVarInsn( Opcodes.LLOAD, argIndex ) ;
mv.visitMethodInsn( Opcodes.INVOKESTATIC,
Type.getInternalName( Long.class ), "valueOf",
"(J)Ljava/lang/Long;" );
break ;
case Type.DOUBLE :
mv.visitVarInsn( Opcodes.DLOAD, argIndex ) ;
mv.visitMethodInsn( Opcodes.INVOKESTATIC,
Type.getInternalName( Double.class ), "valueOf",
"(D)Ljava/lang/Double;" );
break ;
case Type.FLOAT :
mv.visitVarInsn( Opcodes.FLOAD, argIndex ) ;
mv.visitMethodInsn( Opcodes.INVOKESTATIC,
Type.getInternalName( Float.class ), "valueOf",
"(F)Ljava/lang/Float;" );
break ;
default :
mv.visitVarInsn( Opcodes.ALOAD, argIndex ) ;
break ;
}
return argIndex + atype.getSize() ;
}
public void wrapArgs( MethodVisitor mv, int access, String desc ) {
info( 2, "Wrapping args for descriptor " + desc ) ;
Type[] atypes = Type.getArgumentTypes( desc ) ;
emitIntConstant( mv, atypes.length ) ;
mv.visitTypeInsn( Opcodes.ANEWARRAY, "java/lang/Object" ) ;
int argIndex ;
if ((access & Opcodes.ACC_STATIC) == Opcodes.ACC_STATIC) {
argIndex = 0 ;
} else {
argIndex = 1 ;
}
for (int ctr=0; ctr<atypes.length; ctr++) {
mv.visitInsn( Opcodes.DUP ) ;
emitIntConstant( mv, ctr );
argIndex = wrapArg( mv, argIndex, atypes[ctr] ) ;
mv.visitInsn( Opcodes.AASTORE ) ;
}
}
public void storeFromXReturn( MethodVisitor mv, int returnOpcode,
LocalVariableNode holder ) {
switch (returnOpcode) {
case Opcodes.RETURN :
break ;
case Opcodes.ARETURN :
mv.visitVarInsn( Opcodes.ASTORE, holder.index ) ;
break ;
case Opcodes.IRETURN :
mv.visitVarInsn( Opcodes.ISTORE, holder.index ) ;
break ;
case Opcodes.LRETURN :
mv.visitVarInsn( Opcodes.LSTORE, holder.index ) ;
break ;
case Opcodes.FRETURN :
mv.visitVarInsn( Opcodes.FSTORE, holder.index ) ;
break ;
case Opcodes.DRETURN :
mv.visitVarInsn( Opcodes.DSTORE, holder.index ) ;
break ;
}
}
public void loadFromXReturn( MethodVisitor mv, int returnOpcode,
LocalVariableNode holder ) {
switch (returnOpcode) {
case Opcodes.RETURN :
break ;
case Opcodes.ARETURN :
mv.visitVarInsn( Opcodes.ALOAD, holder.index ) ;
break ;
case Opcodes.IRETURN :
mv.visitVarInsn( Opcodes.ILOAD, holder.index ) ;
break ;
case Opcodes.LRETURN :
mv.visitVarInsn( Opcodes.LLOAD, holder.index ) ;
break ;
case Opcodes.FRETURN :
mv.visitVarInsn( Opcodes.FLOAD, holder.index ) ;
break ;
case Opcodes.DRETURN :
mv.visitVarInsn( Opcodes.DLOAD, holder.index ) ;
break ;
}
}
private void verify( byte[] cls ) {
if (getDebug()) {
info( 2, "Verifying enhanced class") ;
ClassReader cr = new ClassReader( cls ) ;
PrintWriter pw = new PrintWriter( System.out ) ;
CheckClassAdapter.verify( cr, true, pw ) ;
}
}
public boolean hasAccess( int access, int flag ) {
return (access & flag) == flag ;
}
public static String opcodeToString( int opcode ) {
String[] opcodes = AbstractVisitor.OPCODES ;
if ((opcode < 0) || (opcode > opcodes.length)) {
return "ILLEGAL[" + opcode + "]" ;
} else {
return opcodes[opcode] ;
}
}
public byte[] transform( final boolean debug, final byte[] cls,
final UnaryFunction<ClassVisitor,ClassAdapter> factory ) {
final ClassReader cr = new ClassReader(cls) ;
final ClassWriter cw = new ClassWriter(
ClassWriter.COMPUTE_MAXS ) ;
PrintWriter pw = null ;
ClassVisitor cv = cw ;
if (debug) {
pw = new PrintWriter( System.out ) ;
}
ClassAdapter xform = factory.evaluate( cv ) ;
try {
cr.accept( xform, ClassReader.SKIP_FRAMES ) ;
} catch (TraceEnhancementException exc) {
throw exc ;
} catch (Exception exc) {
info( 1, "Exception: " + exc ) ;
if (debug) {
exc.printStackTrace() ;
}
} finally {
if (pw != null) {
pw.flush() ;
pw.close() ;
}
}
byte[] enhancedClass = cw.toByteArray() ;
verify( enhancedClass ) ;
return enhancedClass ;
}
}