字节码 & ASM-MehotdVisitor实践

使用ASM几乎用户全部的精力都是对MethodVisitor的处理,方法code的处理都需要使用这个类进行操作。还是之前文章说过的,ASM单独学习意义并不大,难以达到触类旁通,先行掌握字节码基础后再玩起ASM才能体会真正的乐趣,不然真的蛮折磨人的。

场景一:需要在方法开始插入代码

这个应该非常简单的,相信对于熟练掌握MethodVisitor方法调用顺序及方法含义的你来说应当不是问题的,visitCode方法是访问code的开始,所以我们最优的选择就是重写该方法,在该方法中开始书写需要插入的代码

import org.objectweb.asm.*;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

public class HelloWorld {
    public static void hello() {
        System.out.println("hello world");
    }

    public static void main(String[] args) throws IOException, NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        ClassReader cr = new ClassReader("HelloWorld");
        ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
        cr.accept(new ClassVisitor(Opcodes.ASM6, cw) {
            @Override
            public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
                MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
                if (mv != null) {
                    //这里判断需要处理的方法,真实场景一般需要对类名、方法名、方法签名一起进行校验
                    if (name.equals("hello") && descriptor.equals("()V")) {
                        mv = new MethodVisitor(Opcodes.ASM6, mv) {
                            @Override
                            public void visitCode() {
                                super.visitCode();
                                //插入代码,此处插入的代码为 System.out.println("code from asm insert");
                                mv.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                                mv.visitLdcInsn("code from asm insert");
                                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
                            }
                        };
                    }
                }
                return mv;
            }
        }, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
        
        //通过ClassLoader加载bytes后执行hello方法
        ASMGenClassLoader classLoader = new ASMGenClassLoader(cw.toByteArray());
        Class<?> helloWorld = classLoader.findClass("HelloWorld");
        Method hello = helloWorld.getDeclaredMethod("hello");
        hello.invoke(null);
    }

    static class ASMGenClassLoader extends ClassLoader {
        private byte[] bytes;

        public ASMGenClassLoader(byte[] bytes) {
            this.bytes = bytes;
        }

        @Override
        protected Class<?> findClass(String name) {
            return defineClass(name, bytes, 0, bytes.length);
        }
    }
}

复制代码

执行代码后不出意外可以看到控制台输出code from asm insert和hello world。在方法首部插入代码的模板非常简单,往往我们需要在ClassVisitor#visit方法记录类信息后判断出是否是需要处理的类,在ClassVisitor#visitMethod方法中判断是否是需要处理的方法,然后通过一个MethodVisitor的子类中重写visitCode方法插入相应的代码。

场景二:在方法尾部插入代码

在方法尾部插入代码相比在方法首部插入相对来说更困难些,因为方法首部我们无需更多的考虑,而方法的尾部却需要考虑到正常方法退出、异常退出相关的指令, 所以我们可以通过对这些指令进行监控,并且在这些指令入栈前插入我们相应的代码便可以达到在方法尾部插入代码的目的。

import org.objectweb.asm.*;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

public class HelloWorld {
    public static void hello() {
        System.out.println("hello world");
    }

    public static void main(String[] args) throws IOException, NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        ClassReader cr = new ClassReader("HelloWorld");
        ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
        cr.accept(new ClassVisitor(Opcodes.ASM6, cw) {
            @Override
            public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
                MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
                if (mv != null) {
                    //这里判断需要处理的方法,真实场景一般需要对类名、方法名、方法签名一起进行校验
                    if (name.equals("hello") && descriptor.equals("()V")) {
                        mv = new MethodVisitor(Opcodes.ASM6, mv) {
                            @Override
                            public void visitInsn(int opcode) {
                                //opcode >= Opcodes.IRETURN && opcode <= Opcodes.RETURN条件为正常方法退出操作码
                                //opcode == Opcodes.ATHROW条件为异常时方法退出操作码
                                if (opcode >= Opcodes.IRETURN && opcode <= Opcodes.RETURN || opcode == Opcodes.ATHROW) {
                                    //插入代码,此处插入的代码为 System.out.println("code from asm insert");
                                    mv.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                                    mv.visitLdcInsn("code from asm insert");
                                    mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
                                }
                                super.visitInsn(opcode);
                            }
                        };
                    }
                }
                return mv;
            }
        }, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);

        //通过ClassLoader加载bytes后执行hello方法
        ASMGenClassLoader classLoader = new ASMGenClassLoader(cw.toByteArray());
        Class<?> helloWorld = classLoader.findClass("HelloWorld");
        Method hello = helloWorld.getDeclaredMethod("hello");
        hello.invoke(null);
    }

    static class ASMGenClassLoader extends ClassLoader {
        private byte[] bytes;

        public ASMGenClassLoader(byte[] bytes) {
            this.bytes = bytes;
        }

        @Override
        protected Class<?> findClass(String name) {
            return defineClass(name, bytes, 0, bytes.length);
        }
    }
}

复制代码

执行代码后不出意外可以看到控制台输出hello world和code from asm insert。方法尾部插入代码的套路和方法首部插入的套路很相似,难点在于如何拦截到方法退出操作码,而这么多退出操作码书写也非常的麻烦,我们可以通过封装抽象出一个类帮助我们屏蔽掉内部的逻辑,幸运的时ASM也考虑到了这一点为我们提供了封装好的AdviceAdapter

public abstract class AdviceAdapter extends GeneratorAdapter implements Opcodes {
	protected void onMethodEnter()
    
    protected void onMethodExit(int opcode)
}
复制代码

关于这个类的使用非常简单,我们着重关注的两个方法onMethodEnteronMethodExit,命名不难看出一个方法为方法进入时的回调,一个为方法退出的回调及相关的退出操作码。我们通过使用这个类替换上面两个例子

import org.objectweb.asm.*;
import org.objectweb.asm.commons.AdviceAdapter;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

public class HelloWorld {
    public static void hello() {
        System.out.println("hello world");
    }

    public static void main(String[] args) throws IOException, NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        ClassReader cr = new ClassReader("HelloWorld");
        ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
        cr.accept(new ClassVisitor(Opcodes.ASM6, cw) {
            @Override
            public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
                MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
                if (mv != null) {
                    //这里判断需要处理的方法,真实场景一般需要对类名、方法名、方法签名一起进行校验
                    if (name.equals("hello") && descriptor.equals("()V")) {
                        mv = new AdviceAdapter(Opcodes.ASM6, mv, access, name, descriptor) {
                            @Override
                            protected void onMethodEnter() {
                                super.onMethodEnter();
                                //插入代码,此处插入的代码为 System.out.println("code from asm onMethodEnter insert");
                                mv.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                                mv.visitLdcInsn("code from asm onMethodEnter insert");
                                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
                            }

                            @Override
                            protected void onMethodExit(int opcode) {
                                super.onMethodExit(opcode);
                                //插入代码,此处插入的代码为 System.out.println("code from asm onMethodExit insert");
                                mv.visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                                mv.visitLdcInsn("code from asm onMethodExit insert");
                                mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
                            }
                        };
                    }
                }
                return mv;
            }
        }, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);

        //通过ClassLoader加载bytes后执行hello方法
        ASMGenClassLoader classLoader = new ASMGenClassLoader(cw.toByteArray());
        Class<?> helloWorld = classLoader.findClass("HelloWorld");
        Method hello = helloWorld.getDeclaredMethod("hello");
        hello.invoke(null);
    }

    static class ASMGenClassLoader extends ClassLoader {
        private byte[] bytes;

        public ASMGenClassLoader(byte[] bytes) {
            this.bytes = bytes;
        }

        @Override
        protected Class<?> findClass(String name) {
            return defineClass(name, bytes, 0, bytes.length);
        }
    }
}

复制代码

执行代码后不出意外可以看到控制台输出code from asm onMethodEnter insert、hello world和code from asm onMethodExit insert。这些例子中我们使用的字符串都是直接通过字面量的形式直接传入方法的,但是根据我们书写代码的经验来看声明局部变量更加高频,那么我们如何在ASM中在方法中生命变量存储数据并且如何在后续中使用该变量呢?为了编写出相关的代码,首先需要你对ClassFile中的method_info的结构有一定的了解,每个方法都有这自己对应的局部变量池,而这个变量池在编译期间确定长度,变量池的长度由方法结束时依旧存在的所有变量的和确定,需要注意的时变量池存在复用的机制,所以并不是说变量池的长度为方法中出现变量最高次数的值。对应到ASM中确定变量池大小所对应的方法为MethodVistor#visitMaxs(int maxStack, int maxLocals)。在ASM中我们可以通过自己计算变量池大小并且在方法结束前调用该方法,或者通过参数配置让ASM自行计算,个人推荐交权给ASM,毕竟人家是专业的么,当然如果你是大牛觉得性能更为重要,那么手动计算更适合你。在确定掌握前置相关知识后对于聪明的我们来说声明局部变量应该不是什么问题了就,首先我们知道变量池是一个类似于数组的数据结构,而数据项都有相关的数据类型和值,在新建一个局部变量时我们肯定需要存储新建变量在变量池的位置和变量的类型,后续我们便可以通过这两个数据关系进行操作码的存储和读取操作,所以核心就是通过一定数据结构维护我们新建的变量在变量池的索引及对应的变量类型。当然这里StackFrame由于过于复杂暂不考虑它带来的影响

import org.objectweb.asm.*;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.HashMap;

/**
 * @Author: River
 * @Emial: [email protected]
 * @Create: 2022/3/10
 **/
public class HelloWorld {
    public static void hello() throws InterruptedException {
        System.out.println("hello world");
        Thread.sleep(300);
    }

    public static void main(String[] args) throws IOException, NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        ClassReader cr = new ClassReader("HelloWorld");
        ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
        cr.accept(new ClassVisitor(Opcodes.ASM6, cw) {
            @Override
            public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
                MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
                if (mv != null) {
                    //这里判断需要处理的方法,真实场景一般需要对类名、方法名、方法签名一起进行校验
                    if (name.equals("hello") && descriptor.equals("()V")) {
                        mv = new MethodVisitor(Opcodes.ASM6, mv) {
                            private int argSize = Type.getArgumentTypes(descriptor).length;
                            private HashMap<Integer, Type> argIndexTypes = new HashMap<>();

                            private int timeArgIndex;

                            @Override
                            public void visitCode() {
                                super.visitCode();
                                timeArgIndex = newLocal(Type.LONG_TYPE);
                                visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/System", "currentTimeMillis", "()J", false);
                                visitVarInsn(Opcodes.LSTORE, timeArgIndex);
                            }

                            @Override
                            public void visitInsn(int opcode) {
                                if (opcode >= Opcodes.IRETURN && opcode <= Opcodes.RETURN || opcode == Opcodes.ATHROW) {
                                    visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                                    visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/System", "currentTimeMillis", "()J", false);
                                    visitVarInsn(Opcodes.LLOAD, timeArgIndex);
                                    visitInsn(Opcodes.LSUB);
                                    visitVarInsn(Opcodes.LSTORE, timeArgIndex);

                                    visitLdcInsn("method %s cost %d ms");

                                    visitInsn(Opcodes.ICONST_2);
                                    visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object");
                                    visitInsn(Opcodes.DUP);
                                    visitInsn(Opcodes.ICONST_0);
                                    visitLdcInsn(name);
                                    visitInsn(Opcodes.AASTORE);

                                    visitInsn(Opcodes.DUP);
                                    visitInsn(Opcodes.ICONST_1);
                                    visitVarInsn(Opcodes.LLOAD, timeArgIndex);
                                    visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Long", "valueOf", "(J)Ljava/lang/Long;", false);
                                    visitInsn(Opcodes.AASTORE);
                                    visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/String", "format", "(Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;", false);

                                    visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
                                }
                                super.visitInsn(opcode);
                            }

                            protected int newLocal(Type type) {
                                argIndexTypes.put(++argSize, type);
                                return argSize;
                            }
                        };
                    }
                }
                return mv;
            }
        }, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);

        //通过ClassLoader加载bytes后执行hello方法
        ASMGenClassLoader classLoader = new ASMGenClassLoader(cw.toByteArray());
        Class<?> helloWorld = classLoader.findClass("HelloWorld");
        Method hello = helloWorld.getDeclaredMethod("hello");
        hello.invoke(null);
    }

    static class ASMGenClassLoader extends ClassLoader {
        private byte[] bytes;

        public ASMGenClassLoader(byte[] bytes) {
            this.bytes = bytes;
        }

        @Override
        protected Class<?> findClass(String name) {
            return defineClass(name, bytes, 0, bytes.length);
        }
    }
}

复制代码

执行代码后不出意外可以看到控制台输出hello world和method hello cost 315 ms。如果单从声明局部变量起始实现非常简单,这里通过初始化时先确定下 argSize的值,即当前函数参数占用的变量池大小,Type类时ASM提供给开发者使用的工具类,这个类提供很多有用的转换操作,比如常见的Class到Type的转换后可以便捷的取到全限定名等信息,当然这里我们通过利用它来拿到方法签名中方法参数个数值。然后维护一个为名argIndexTypesHashMap将变量索引和类型一一对应。最后添加一个newLocal方法用于新建变量并返回这个变量所在变量池中的索引。上方代码完成的功能就是一个统计函数耗时的功能,执行完后输出函数耗时。有一个点需要注意到的是程序利用String.format格式化字符串的时候第二个参数为数组,而对数组初始化后在将第二个参数为long类型添加进数组时需要将J转为Long,而我们平时说的自动装箱自动拆箱其实就是这么做的,只不过真正在我们接触字节码时才会更有体会。当然ASM也想到了局部变量自己维护不容易所以也提供了相应的类LocalVariablesSorter帮助我们快速起飞。使用上和我们现在的这个类一样通过调用int newLocal(final Type type)方法新建变量。让我们通过该类实现上方一样性质的代码

import org.objectweb.asm.*;
import org.objectweb.asm.commons.LocalVariablesSorter;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

public class HelloWorld {
    public static void hello() throws InterruptedException {
        System.out.println("hello world");
        Thread.sleep(300);
    }

    public static void main(String[] args) throws IOException, NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        ClassReader cr = new ClassReader("HelloWorld");
        ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
        cr.accept(new ClassVisitor(Opcodes.ASM6, cw) {
            @Override
            public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
                MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
                if (mv != null) {
                    //这里判断需要处理的方法,真实场景一般需要对类名、方法名、方法签名一起进行校验
                    if (name.equals("hello") && descriptor.equals("()V")) {
                        mv = new LocalVariablesSorter(Opcodes.ASM6, access, descriptor, mv) {
                            private int timeArgIndex;

                            @Override
                            public void visitCode() {
                                super.visitCode();
                                timeArgIndex = newLocal(Type.LONG_TYPE);
                                visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/System", "currentTimeMillis", "()J", false);
                                visitVarInsn(Opcodes.LSTORE, timeArgIndex);
                            }

                            @Override
                            public void visitInsn(int opcode) {
                                if (opcode >= Opcodes.IRETURN && opcode <= Opcodes.RETURN || opcode == Opcodes.ATHROW) {
                                    visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                                    visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/System", "currentTimeMillis", "()J", false);
                                    visitVarInsn(Opcodes.LLOAD, timeArgIndex);
                                    visitInsn(Opcodes.LSUB);
                                    visitVarInsn(Opcodes.LSTORE, timeArgIndex);

                                    visitLdcInsn("method %s cost %d ms");

                                    visitInsn(Opcodes.ICONST_2);
                                    visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object");
                                    visitInsn(Opcodes.DUP);
                                    visitInsn(Opcodes.ICONST_0);
                                    visitLdcInsn(name);
                                    visitInsn(Opcodes.AASTORE);

                                    visitInsn(Opcodes.DUP);
                                    visitInsn(Opcodes.ICONST_1);
                                    visitVarInsn(Opcodes.LLOAD, timeArgIndex);
                                    visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Long", "valueOf", "(J)Ljava/lang/Long;", false);
                                    visitInsn(Opcodes.AASTORE);
                                    visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/String", "format", "(Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;", false);

                                    visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
                                }
                                super.visitInsn(opcode);
                            }
                        };
                    }
                }
                return mv;
            }
        }, ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);

        //通过ClassLoader加载bytes后执行hello方法
        ASMGenClassLoader classLoader = new ASMGenClassLoader(cw.toByteArray());
        Class<?> helloWorld = classLoader.findClass("HelloWorld");
        Method hello = helloWorld.getDeclaredMethod("hello");
        hello.invoke(null);
    }

    static class ASMGenClassLoader extends ClassLoader {
        private byte[] bytes;

        public ASMGenClassLoader(byte[] bytes) {
            this.bytes = bytes;
        }

        @Override
        protected Class<?> findClass(String name) {
            return defineClass(name, bytes, 0, bytes.length);
        }
    }
}

复制代码

AdviceAdapterLocalVariablesSorter的子类,因此可以通过AdviceAdapter继续优化我们的代码,这里就不再贴出代码,留给聪明且勤奋的你来敲敲代码了。事实上很多统计相关的工作我们便可以已当前的模板再加优化便可以完成相应的很好。以上的代码我们都是通过写死的对hello方法进行处理,但是真实场景我们不可能把所有的方法都写进判断语句中,比较常见的是依赖注解来处理方法,所以这里我们已通过注解的标记来处理方法

import lombok.Data;
import org.objectweb.asm.*;
import org.objectweb.asm.commons.AdviceAdapter;

import java.io.IOException;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;

public class HelloWorld {
    @Event(eventName = "event_hello")
    public static void hello() {
        System.out.println("hello world");
    }

    public static void main(String[] args) throws IOException, NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        final int readerOption = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
        ClassReader cr = new ClassReader("HelloWorld");
        CollectClassVisitor collectClassVisitor = new CollectClassVisitor();
        cr.accept(collectClassVisitor, readerOption);

        ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
        cr.accept(new ClassVisitor(Opcodes.ASM6, cw) {
            private String clzName;

            @Override
            public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
                super.visit(version, access, name, signature, superName, interfaces);
                clzName = name;
            }

            @Override
            public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
                MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
                if (mv != null) {
                    EventBean eventBean = collectClassVisitor.matchMethod(clzName, name, descriptor);
                    if (eventBean != null) {
                        mv = new AdviceAdapter(Opcodes.ASM6, mv, access, name, descriptor) {
                            @Override
                            protected void onMethodEnter() {
                                super.onMethodEnter();
                                visitFieldInsn(Opcodes.GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
                                visitLdcInsn("event name: " + eventBean.eventName);
                                visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
                            }
                        };
                    }
                }
                return mv;
            }
        }, readerOption);

        //通过ClassLoader加载bytes后执行hello方法
        ASMGenClassLoader classLoader = new ASMGenClassLoader(cw.toByteArray());
        Class<?> helloWorld = classLoader.findClass("HelloWorld");
        Method hello = helloWorld.getDeclaredMethod("hello");
        hello.invoke(null);
    }

    static class CollectClassVisitor extends ClassVisitor {
        private String clzName;
        private ArrayList<EventBean> eventBeans = new ArrayList<>();

        public CollectClassVisitor() {
            super(Opcodes.ASM6);
        }

        @Override
        public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
            super.visit(version, access, name, signature, superName, interfaces);
            clzName = name;
        }

        @Override
        public MethodVisitor visitMethod(int access, String name, String methodDescriptor, String signature, String[] exceptions) {
            MethodVisitor mv = super.visitMethod(access, name, methodDescriptor, signature, exceptions);
            mv = new MethodVisitor(Opcodes.ASM6, mv) {
                @Override
                public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
                    AnnotationVisitor an = super.visitAnnotation(descriptor, visible);
                    if (descriptor.equals("LHelloWorld$Event;")) {
                        final EventBean eventBean = new EventBean();
                        eventBean.clzName = clzName;
                        eventBean.methodName = name;
                        eventBean.methodDesc = methodDescriptor;
                        an = new AnnotationVisitor(Opcodes.ASM6, an) {
                            @Override
                            public void visit(String name, Object value) {
                                super.visit(name, value);
                                eventBean.eventName = (String) value;
                            }

                            @Override
                            public void visitEnd() {
                                super.visitEnd();
                                eventBeans.add(eventBean);
                            }
                        };
                    }
                    return an;
                }
            };
            return mv;
        }

        public EventBean matchMethod(String clzName, String name, String descriptor) {
            for (EventBean eventBean : eventBeans) {
                if (eventBean.clzName.equals(clzName) && eventBean.methodName.equals(name) && eventBean.methodDesc.equals(descriptor)) {
                    return eventBean;
                }
            }
            return null;
        }
    }

    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    public @interface Event {
        String eventName();
    }

    @Data
    static class EventBean {
        private String clzName;
        private String methodName;
        private String methodDesc;
        private String eventName;
    }

    static class ASMGenClassLoader extends ClassLoader {
        private byte[] bytes;

        public ASMGenClassLoader(byte[] bytes) {
            this.bytes = bytes;
        }

        @Override
        protected Class<?> findClass(String name) {
            return defineClass(name, bytes, 0, bytes.length);
        }
    }
}

复制代码

通过注解处理方法首先我们需要手机被注解标记的类的方法,收集完成后再次通过ASM处理类进行匹配,匹配成功后进行相应代码的处理

场景三:偷梁换柱

在很多情况我们可以通过偷梁换柱达到我们的目的,而偷梁换柱的主要精力时对关心的字节操作码匹配,匹配成功后再进行偷梁换柱,首先我们通过替换hello方法中字面量字符串开启第一个偷梁换柱的例子吧

import org.objectweb.asm.*;

import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

public class HelloWorld {
    public static void hello() {
        System.out.println("hello world");
    }

    public static void main(String[] args) throws IOException, NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        final int readerOption = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
        ClassReader cr = new ClassReader("HelloWorld");

        ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
        cr.accept(new ClassVisitor(Opcodes.ASM6, cw) {
            private String clzName;

            @Override
            public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
                super.visit(version, access, name, signature, superName, interfaces);
                clzName = name;
            }

            @Override
            public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
                MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
                if (mv != null) {
                    if (name.equals("hello") && descriptor.equals("()V")) {
                        mv = new MethodVisitor(Opcodes.ASM6, mv) {
                            @Override
                            public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) {
                                if (opcode == Opcodes.INVOKEVIRTUAL && owner.equals("java/io/PrintStream") && name.equals("println") && descriptor.equals("(Ljava/lang/String;)V")) {
                                    visitInsn(Opcodes.POP);
                                    visitLdcInsn("hello river!");
                                }
                                super.visitMethodInsn(opcode, owner, name, descriptor, isInterface);
                            }
                        };
                    }
                }
                return mv;
            }
        }, readerOption);

        //通过ClassLoader加载bytes后执行hello方法
        ASMGenClassLoader classLoader = new ASMGenClassLoader(cw.toByteArray());
        Class<?> helloWorld = classLoader.findClass("HelloWorld");
        Method hello = helloWorld.getDeclaredMethod("hello");
        hello.invoke(null);
    }

    static class ASMGenClassLoader extends ClassLoader {
        private byte[] bytes;

        public ASMGenClassLoader(byte[] bytes) {
            this.bytes = bytes;
        }

        @Override
        protected Class<?> findClass(String name) {
            return defineClass(name, bytes, 0, bytes.length);
        }
    }
}

复制代码

执行代码后不出意外可以看到控制台输出的是hello river!而并不是hello world。方法其实很多,这里经过分析没法方法执行前应该将所需要的参数压入操作栈内,而我们需要替换的话很简单将操作栈内的原有字符串数据弹出并且压入我们需要的字符串那么是不是就达到了偷梁换柱的目的了。这里我们通过拦截PrintStream#println的方法将操作栈内数据替换便可以看到运行结果为hello river!。而利用这种我们可以通过对各种校验相关的函数进行劫持已达到绕过验证的机制。这里我们已Charler为例子,看看我们如何真实的破解该软件,这里已4.6.2版本为例

import org.objectweb.asm.*;

import java.io.*;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.Map;
import java.util.TreeMap;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.jar.JarOutputStream;
import java.util.zip.ZipEntry;

/**
 * @Author: River
 * @Emial: [email protected]
 * @Create: 2022/3/24
 **/
public class CharlesCrack {
    public static void main(String[] args) throws Exception {
        //This is a 30 day trial version. If you continue using Charles you must\npurchase a license. Please see the Help menu for details.
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        System.out.println("请输入charles安装的目录:");

        String charlesPath = br.readLine();
        String charlesJarPath = charlesPath + "\\lib\\charles.jar";
        File charlesJarFile = new File(charlesJarPath);
        if (!charlesJarFile.exists()) {
            System.out.println("charles目录(" + charlesJarPath + ")不存在!");
            return;
        }
        System.out.println("charles检测通过");
        System.out.println("正在处理文件...");

        final int readerOption = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
        ClassReader cr;
        ClassWriter cw = null;

        JarFile jarFile = new JarFile(charlesJarPath);
        Enumeration<JarEntry> entries = jarFile.entries();
        while (entries.hasMoreElements()) {
            JarEntry jarEntry = entries.nextElement();
            String entryName = jarEntry.getName();
            if (entryName.equals("com/xk72/charles/p.class")) {
                ZipEntry zipEntry = new ZipEntry(entryName);
                InputStream inputStream = jarFile.getInputStream(zipEntry);
                cr = new ClassReader(inputStream);
                cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
                cr.accept(new ClassVisitor(Opcodes.ASM7, cw) {
                    @Override
                    public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
                        MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
                        if (mv != null) {
                            if (name.equals("c") && descriptor.equals("()Ljava/lang/String;")) {
                                mv = new MethodVisitor(Opcodes.ASM7, mv) {
                                    @Override
                                    public void visitMethodInsn(int opcode, String owner, String name, String descriptor, boolean isInterface) {
                                    }

                                    @Override
                                    public void visitInsn(int opcode) {
                                        if (opcode == Opcodes.ARETURN) {
                                            visitInsn(Opcodes.POP);
                                            visitLdcInsn("River破解");
                                        }
                                        super.visitInsn(opcode);
                                    }
                                };
                            }

                            if (name.equals("a") && descriptor.equals("()Z")) {
                                mv = new MethodVisitor(Opcodes.ASM6, mv) {
                                    @Override
                                    public void visitInsn(int opcode) {
                                        if (opcode == Opcodes.IRETURN) {
                                            visitInsn(Opcodes.POP);
                                            visitInsn(Opcodes.ICONST_1);
                                        }
                                        super.visitInsn(opcode);
                                    }
                                };
                            }
                        }
                        return mv;
                    }
                }, readerOption);
            }
        }

        writeJarFile(charlesJarPath, "com/xk72/charles/p.class", cw.toByteArray());

        System.out.println("破解完成!");
    }


    public static void writeJarFile(String jarFilePath, String entryName, byte[] data) throws Exception {

        //1、首先将原Jar包里的所有内容读取到内存里,用TreeMap保存
        JarFile jarFile = new JarFile(jarFilePath);
        //可以保持排列的顺序,所以用TreeMap 而不用HashMap
        TreeMap tm = new TreeMap();
        Enumeration es = jarFile.entries();
        while (es.hasMoreElements()) {
            JarEntry je = (JarEntry) es.nextElement();
            byte[] b = readStream(jarFile.getInputStream(je));
            tm.put(je.getName(), b);
        }

        JarOutputStream jos = new JarOutputStream(new FileOutputStream(jarFilePath));
        Iterator it = tm.entrySet().iterator();
        boolean has = false;

        //2、将TreeMap重新写到原jar里,如果TreeMap里已经有entryName文件那么覆盖,否则在最后添加
        while (it.hasNext()) {
            Map.Entry item = (Map.Entry) it.next();
            String name = (String) item.getKey();
            JarEntry entry = new JarEntry(name);
            jos.putNextEntry(entry);
            byte[] temp;
            if (name.equals(entryName)) {
                //覆盖
                temp = data;
                has = true;
            } else {
                temp = (byte[]) item.getValue();
            }
            jos.write(temp, 0, temp.length);
        }

        if (!has) {
            //最后添加
            JarEntry newEntry = new JarEntry(entryName);
            jos.putNextEntry(newEntry);
            jos.write(data, 0, data.length);
        }
        jos.finish();
        jos.close();

    }

    public static byte[] readStream(InputStream inStream) throws Exception {
        ByteArrayOutputStream outSteam = new ByteArrayOutputStream();
        byte[] buffer = new byte[1024];
        int len = -1;
        while ((len = inStream.read(buffer)) != -1) {
            outSteam.write(buffer, 0, len);
        }
        outSteam.close();
        inStream.close();
        return outSteam.toByteArray();
    }
}


复制代码

通过反编译软件查看Charles的源码,而通过分析This is a 30 day trial version. If you continue using Charles you must\npurchase a license. Please see the Help menu for details这段字符串的位置调用可以分析出关键类及校验方法,通过ASM将该核心类的方法进行修改后使得永远校验通过,当然为了保险我们将方法内部所有的字节码全部删除值保留我们需要的return true字节码。在我们自己开发时破解应该注意并且最好做些机制防止使用这种办法就简单的给到攻击者可乘之机,最好将核心代码通过native书写。再次运行Charles将看到破解后的相关信息

CDB4AC51-9A5E-4a25-A056-44215815FA46.png

场景四:AOP

AOP没有使用过至少你也应该听说过,这里不再赘述相关内容,我们将实现一个非常简单的aspectj版本

import lombok.Data;
import lombok.SneakyThrows;
import org.objectweb.asm.*;
import org.objectweb.asm.commons.AdviceAdapter;

import java.io.IOException;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;

public class HelloWorld {
    public static void hello() {
        System.out.println("hello world");
    }

    public static void main(String[] args) throws IOException, NoSuchMethodException, InvocationTargetException, IllegalAccessException {
        final int readerOption = ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES;
        ClassReader cr = new ClassReader("HelloWorld$AopTest");
        AopCollectClassVisitor aopCollectClassVisitor = new AopCollectClassVisitor();
        cr.accept(aopCollectClassVisitor, readerOption);

        cr = new ClassReader("HelloWorld");
        ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS | ClassWriter.COMPUTE_FRAMES);
        cr.accept(new ClassVisitor(Opcodes.ASM6, cw) {
            private String clzName;

            @Override
            public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
                super.visit(version, access, name, signature, superName, interfaces);
                clzName = name;
            }

            @Override
            public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
                MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
                if (mv != null) {

                    mv = new AdviceAdapter(Opcodes.ASM6, mv, access, name, descriptor) {
                        @Override
                        protected void onMethodEnter() {
                            super.onMethodEnter();
                            AopBean aopBean = aopCollectClassVisitor.matchAop(clzName, name, descriptor, Before.class);

                            if (aopBean != null) {
                                visitMethodInsn(Opcodes.INVOKESTATIC, aopBean.targetClzName, aopBean.targetMethodName, aopBean.targetMethodDesc, false);
                            }
                        }

                        @Override
                        protected void onMethodExit(int opcode) {
                            super.onMethodExit(opcode);
                            AopBean aopBean = aopCollectClassVisitor.matchAop(clzName, name, descriptor, Exit.class);
                            if (aopBean != null) {
                                visitMethodInsn(Opcodes.INVOKESTATIC, aopBean.targetClzName, aopBean.targetMethodName, aopBean.targetMethodDesc, false);
                            }
                        }
                    };
                }
                return mv;
            }
        }, readerOption);

        //通过ClassLoader加载bytes后执行hello方法
        ASMGenClassLoader classLoader = new ASMGenClassLoader(cw.toByteArray());
        Class<?> helloWorld = classLoader.findClass("HelloWorld");
        Method hello = helloWorld.getDeclaredMethod("hello");
        hello.invoke(null);
    }

    static class AopCollectClassVisitor extends ClassVisitor {
        private String clzName;
        private ArrayList<AopBean> aopBeans = new ArrayList<>();

        public AopCollectClassVisitor() {
            super(Opcodes.ASM6);
        }

        @Override
        public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
            super.visit(version, access, name, signature, superName, interfaces);
            clzName = name;
        }

        @Override
        public MethodVisitor visitMethod(int access, String name, String methodDescriptor, String signature, String[] exceptions) {
            MethodVisitor mv = super.visitMethod(access, name, methodDescriptor, signature, exceptions);
            mv = new MethodVisitor(Opcodes.ASM6, mv) {
                @Override
                public AnnotationVisitor visitAnnotation(String descriptor, boolean visible) {
                    AnnotationVisitor av = super.visitAnnotation(descriptor, visible);
                    if (descriptor.equals("LHelloWorld$Before;") || descriptor.equals("LHelloWorld$Exit;")) {
                        AopBean aopBean = new AopBean();
                        aopBean.targetClzName = clzName;
                        aopBean.targetMethodName = name;
                        aopBean.targetMethodDesc = methodDescriptor;
                        aopBean.aopAnnClzName = descriptor;

                        av = new AnnotationVisitor(Opcodes.ASM6, av) {
                            @SneakyThrows
                            @Override
                            public void visit(String name, Object value) {
                                super.visit(name, value);

                                String setMethodName = "set" + name.substring(0, 1).toUpperCase() + name.substring(1);
                                Method setMethod = aopBean.getClass().getDeclaredMethod(setMethodName, String.class);
                                setMethod.invoke(aopBean, (String) value);
                            }

                            @Override
                            public void visitEnd() {
                                super.visitEnd();
                                aopBeans.add(aopBean);
                            }
                        };
                    }
                    return av;
                }
            };
            return mv;
        }

        public AopBean matchAop(String clzName, String methodName, String methodDesc, Class annoClz) {
            for (AopBean aopBean : aopBeans) {
                boolean clzMatch = aopBean.clzName.equals(clzName) || aopBean.clzName.equals("*");
                boolean methodNameMatch = aopBean.methodName.equals(methodName) || aopBean.methodName.equals("*");
                boolean methodDescMatch = aopBean.methodDesc.equals(methodDesc) || aopBean.methodDesc.equals("*");
                if (clzMatch && methodNameMatch && methodDescMatch && aopBean.aopAnnClzName.equals(Type.getDescriptor(annoClz))) {
                    return aopBean;
                }
            }
            return null;
        }
    }

    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    public @interface Before {
        String clzName();

        String methodName();

        String methodDesc();
    }

    @Target(ElementType.METHOD)
    @Retention(RetentionPolicy.RUNTIME)
    public @interface Exit {
        String clzName();

        String methodName();

        String methodDesc();
    }

    public static class AopTest {
        @Before(clzName = "*", methodName = "hello", methodDesc = "()V")
        public static void sayBefore() {
            System.out.println("aop code insert from AopTest.sayBefore");
        }

        @Exit(clzName = "*", methodName = "hello", methodDesc = "()V")
        public static void sayAfter() {
            System.out.println("aop code insert from AopTest.sayAfter");
        }
    }

    @Data
    static class AopBean {
        private String clzName;
        private String methodName;
        private String methodDesc;
        private String aopAnnClzName;

        private String targetClzName;
        private String targetMethodName;
        private String targetMethodDesc;
    }

    static class ASMGenClassLoader extends ClassLoader {
        private byte[] bytes;

        public ASMGenClassLoader(byte[] bytes) {
            this.bytes = bytes;
        }

        @Override
        protected Class<?> findClass(String name) {
            return defineClass(name, bytes, 0, bytes.length);
        }
    }
}

复制代码

代码实现的AOP非常简单,相对来说实现成熟一些的AOP无非就是额外添加些注解,额外添加更复杂的匹配机制,但作为一个AOP的实现例子最好不过了

猜你喜欢

转载自juejin.im/post/7078598490544144391