使用ASM实现AOP

使用ASM5.2版本,其它版本可能有差异,请自行调整!

AopMethodAdapter.java

package com.wly.aop.adapter;

import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;

import com.wly.aop.Aop;
import com.wly.aop.inv.BaseInvocation;

public class AopMethodAdapter extends MethodVisitor implements Opcodes{
	Label l0 = new Label();
	Label l1 = new Label();
	Label l2 = new Label();
	private Class<?> aopClass;
	
	public AopMethodAdapter(MethodVisitor mv, Class<?> aopClass) {
		super(ASM5, mv);
		this.aopClass = aopClass;
	}
	
	public AopMethodAdapter(MethodVisitor mv, Class<?> aopClass, int loadCount){
		super(ASM5, mv);
		this.aopClass = aopClass;
	}

	public void visitCode() {
		super.visitCode();
		mv.visitTryCatchBlock(l0, l1, l2, "java/lang/Exception");
		mv.visitLabel(l0);
		Class<? extends BaseInvocation>[] values = aopClass.getAnnotation(Aop.class).values();
		for(Class<? extends BaseInvocation> c : values){
			mv.visitTypeInsn(NEW, c.getName().replace(".", "/"));
			mv.visitInsn(DUP);
			mv.visitMethodInsn(INVOKESPECIAL, c.getName().replace(".", "/"), "<init>", "()V", false);
			mv.visitMethodInsn(INVOKEVIRTUAL, c.getName().replace(".", "/"), "before", "()V", false);
		}
		
	}
	
	@Override
	public void visitInsn(int opcode) {
		if (opcode >= IRETURN && opcode <= RETURN){
			Class<? extends BaseInvocation>[] values = aopClass.getAnnotation(Aop.class).values();
			// 遍历有注解的所有AOP类
			for(Class<? extends BaseInvocation> c : values){
				mv.visitTypeInsn(NEW, c.getName().replace(".", "/"));
				mv.visitInsn(DUP);
				mv.visitMethodInsn(INVOKESPECIAL, c.getName().replace(".", "/"), "<init>", "()V", false);
				mv.visitMethodInsn(INVOKEVIRTUAL, c.getName().replace(".", "/"), "after", "()V", false);
			}
		}
		mv.visitLabel(l1);
		
		// ----------------------------------------------------------
		// PS:两种输出异常方式(可以对比看下指令)
		
		// --------------- 调用e.printStackTrace打印异常
		Label l5 = new Label();
		mv.visitJumpInsn(GOTO, l5);
		mv.visitLabel(l2);
		mv.visitFrame(Opcodes.F_SAME1, 0, null, 1, new Object[] {"java/lang/Exception"});
		mv.visitVarInsn(ASTORE, 2);
		Label l6 = new Label();
		mv.visitLabel(l6);
		mv.visitVarInsn(ALOAD, 2);
		mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Exception", "printStackTrace", "()V", false);
		mv.visitLabel(l5);
		mv.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
		mv.visitInsn(RETURN);
		
		// --------------- 通过System.out.print打印异常信息
//		Label l5 = new Label();
//		mv.visitJumpInsn(GOTO, l5);
//		mv.visitLabel(l2);
//		mv.visitFrame(Opcodes.F_SAME1, 0, null, 1, new Object[] {"java/lang/Exception"});
//		mv.visitVarInsn(ASTORE, 2);
//		Label l6 = new Label();
//		mv.visitLabel(l6);
//		mv.visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
//		mv.visitVarInsn(ALOAD, 2);
//		mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Exception", "getMessage", "()Ljava/lang/String;", false);
//		mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
//		Label l7 = new Label();
//		mv.visitLabel(l7);
//		mv.visitTypeInsn(NEW, "java/lang/StringBuffer");
//		mv.visitInsn(DUP);
//		mv.visitMethodInsn(INVOKESPECIAL, "java/lang/StringBuffer", "<init>", "()V", false);
//		mv.visitVarInsn(ASTORE, 3);
//		Label l8 = new Label();
//		mv.visitLabel(l8);
//		mv.visitVarInsn(ALOAD, 2);
//		mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/Exception", "getStackTrace", "()[Ljava/lang/StackTraceElement;", false);
//		mv.visitVarInsn(ASTORE, 4);
//		Label l9 = new Label();
//		mv.visitLabel(l9);
//		mv.visitVarInsn(ALOAD, 4);
//		mv.visitInsn(DUP);
//		mv.visitVarInsn(ASTORE, 8);
//		mv.visitInsn(ARRAYLENGTH);
//		mv.visitVarInsn(ISTORE, 7);
//		mv.visitInsn(ICONST_0);
//		mv.visitVarInsn(ISTORE, 6);
//		Label l10 = new Label();
//		mv.visitJumpInsn(GOTO, l10);
//		Label l11 = new Label();
//		mv.visitLabel(l11);
//		mv.visitFrame(Opcodes.F_FULL, 9, new Object[] {aopClass.getName().replace(".", "/"), "java/lang/String", "java/lang/Exception", "java/lang/StringBuffer", "[Ljava/lang/StackTraceElement;", Opcodes.TOP, Opcodes.INTEGER, Opcodes.INTEGER, "[Ljava/lang/StackTraceElement;"}, 0, new Object[] {});
//		mv.visitVarInsn(ALOAD, 8);
//		mv.visitVarInsn(ILOAD, 6);
//		mv.visitInsn(AALOAD);
//		mv.visitVarInsn(ASTORE, 5);
//		Label l12 = new Label();
//		mv.visitLabel(l12);
//		mv.visitVarInsn(ALOAD, 3);
//		mv.visitTypeInsn(NEW, "java/lang/StringBuilder");
//		mv.visitInsn(DUP);
//		mv.visitVarInsn(ALOAD, 5);
//		mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StackTraceElement", "toString", "()Ljava/lang/String;", false);
//		mv.visitMethodInsn(INVOKESTATIC, "java/lang/String", "valueOf", "(Ljava/lang/Object;)Ljava/lang/String;", false);
//		mv.visitMethodInsn(INVOKESPECIAL, "java/lang/StringBuilder", "<init>", "(Ljava/lang/String;)V", false);
//		mv.visitLdcInsn("\n");
//		mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "append", "(Ljava/lang/String;)Ljava/lang/StringBuilder;", false);
//		mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuilder", "toString", "()Ljava/lang/String;", false);
//		mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuffer", "append", "(Ljava/lang/String;)Ljava/lang/StringBuffer;", false);
//		mv.visitInsn(POP);
//		Label l13 = new Label();
//		mv.visitLabel(l13);
//		mv.visitIincInsn(6, 1);
//		mv.visitLabel(l10);
//		mv.visitFrame(Opcodes.F_SAME, 0, null, 0, null);
//		mv.visitVarInsn(ILOAD, 6);
//		mv.visitVarInsn(ILOAD, 7);
//		mv.visitJumpInsn(IF_ICMPLT, l11);
//		Label l14 = new Label();
//		mv.visitLabel(l14);
//		mv.visitFieldInsn(GETSTATIC, "java/lang/System", "out", "Ljava/io/PrintStream;");
//		mv.visitVarInsn(ALOAD, 3);
//		mv.visitMethodInsn(INVOKEVIRTUAL, "java/lang/StringBuffer", "toString", "()Ljava/lang/String;", false);
//		mv.visitMethodInsn(INVOKEVIRTUAL, "java/io/PrintStream", "println", "(Ljava/lang/String;)V", false);
//		mv.visitLabel(l5);
//		mv.visitFrame(Opcodes.F_FULL, 2, new Object[] {aopClass.getName().replace(".", "/"), "java/lang/String"}, 0, new Object[] {});
//		mv.visitInsn(RETURN);
		
		super.visitInsn(opcode);
	}
	
}

AopClassAdapter.java(最后一个函数自己懒得写了,所以是直接拿的哈库纳大佬Hasor里面某个函数)

package com.wly.aop.adapter;

import java.io.IOException;
import java.io.StringReader;
import java.util.ArrayList;

import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;


public class AopClassAdapter extends ClassVisitor implements Opcodes{
	
	private Class<?> aopClass;

	public AopClassAdapter(ClassWriter cv, Class<?> className) {
		super(ASM5, cv);
		this.aopClass = className;
	}
	
	@Override
	public MethodVisitor visitMethod(final int access, final String name, final String desc, final String signature,
			final String[] exceptions) {
		MethodVisitor mv = cv.visitMethod(access, name, desc, signature,  exceptions);
		MethodVisitor wrappedMv = mv;
		if (mv != null) {
			if (!name.equals("<init>")) {
				wrappedMv = new AopMethodAdapter(mv, aopClass);
			}else{
				wrappedMv = new ChangeConstructorMethodAdapter(mv, aopClass.getName().replace(".", "/"));
			}
		}
		return wrappedMv;
	}
	
	public void visit(final int version, final int access, final String name, final String signature,
			final String superName, final String[] interfaces) {
		String enhancedName = name + "$AOP";
		super.visit(version, access, enhancedName, signature, name, interfaces);
	}
	
	/** 
	 * ***** 本函数是直接从哈库纳大佬Hasor框架中copy来的 *****
	 * 将IIIILjava/lang/Integer;F形式的ASM类型表述分解为数组。测试字符串IIIILjava/lang/Integer;F[[[ILjava/lang.Boolean; 
	 */
    public static String[] splitAsmType(final String asmTypes) {
        class AsmTypeRead {
            StringReader sread = null;
            public AsmTypeRead(final String sr) {
                this.sread = new StringReader(sr);
            }
            /** 读取到下一个分号为止或者结束为止。*/
            private String readToSemicolon() throws IOException {
                String res = "";
                while (true) {
                    int strInt = this.sread.read();
                    if (strInt == -1) {
                        return res;
                    } else if ((char) strInt == ';') {
                        return res + ';';
                    } else {
                        res += (char) strInt;
                    }
                }
            }
            /** 读取一个类型 */
            private String readType() throws IOException {
                int strInt = this.sread.read();
                if (strInt == -1) {
                    return "";
                }
                switch ((char) strInt) {
                case '['://array
                    return '[' + this.readType();
                case 'L'://Object
                    return 'L' + this.readToSemicolon();
                default:
                    return String.valueOf((char) strInt);
                }
            }
            /** 读取所有类型 */
            public String[] readTypes() throws IOException {
                ArrayList<String> ss = new ArrayList<String>(0);
                while (true) {
                    String s = this.readType();
                    if (s.equals("") == true) {
                        break;
                    } else {
                        ss.add(s);
                    }
                }
                String[] res = new String[ss.size()];
                ss.toArray(res);
                return res;
            }
        }
        try {
            return new AsmTypeRead(asmTypes).readTypes();//     IIIILjava/lang/Integer;F[[[Ljava/util/Date;
        } catch (Exception e) {
            throw new RuntimeException("不合法的ASM类型desc。");
        }
    }
}

ChangeConstructorMethodAdapter.java

package com.wly.aop.adapter;

import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;

public class ChangeConstructorMethodAdapter extends MethodVisitor{
	private String superClassName;

	public ChangeConstructorMethodAdapter(MethodVisitor mv, String superClassName) {
		super(Opcodes.ASM5, mv);
		this.superClassName = superClassName;
	}

	public void visitMethodInsn(int opcode, String owner, String name, String desc, boolean b) {
		// 调用父类的构造函数时
		if (opcode == Opcodes.INVOKESPECIAL && name.equals("<init>")) {
			owner = superClassName;
		}
		super.visitMethodInsn(opcode, owner, name, desc, b);// 改写父类为 superClassName
	}
}

AopUtil.java

package com.wly.aop.util;

import java.util.HashMap;
import java.util.Map;

import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassVisitor;
import org.objectweb.asm.ClassWriter;

import com.wly.aop.adapter.AopClassAdapter;

public class AopUtil {
	private static AccountGeneratorClassLoader classLoader = new AccountGeneratorClassLoader();
	private static Map<String, Class<?>> secureClassMap = new HashMap<>();
	@SuppressWarnings("unchecked")
	public static <T> T getClassInstance(Class<?> obj){
		try{
			Class<?> secureAccountClass = null;
			if (secureClassMap.containsKey(obj.getName())) {
				secureAccountClass = secureClassMap.get(obj.getName());
			}
			if (null == secureAccountClass) {
				ClassReader cr = new ClassReader(obj.getName());
				ClassWriter cw = new ClassWriter(ClassWriter.COMPUTE_FRAMES);
				ClassVisitor classAdapter = new AopClassAdapter(cw, obj);
				cr.accept(classAdapter, ClassReader.EXPAND_FRAMES);
				byte[] data = cw.toByteArray();
				secureAccountClass = classLoader.defineClassFromClassFile(obj.getName() + "$AOP", data);
				secureClassMap.put(obj.getName(), secureAccountClass);
			}
			return (T)secureAccountClass.newInstance();
		}catch(Exception e){
			e.printStackTrace();
		}
		return null;
	}
	
	private static class AccountGeneratorClassLoader extends ClassLoader {
		public Class<?> defineClassFromClassFile(String className, byte[] classFile) throws Exception {
			return defineClass(className, classFile, 0, classFile.length);
		}
	}
}

所有AOP类需要继承的类:BaseInvocation.java

package com.wly.aop.inv;

public abstract class BaseInvocation {
	public abstract void after();
	public abstract void before();
}

AOP注解接口:Aop.java

package com.wly.aop;

import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;

import com.wly.aop.inv.BaseInvocation;

@Retention(value = RetentionPolicy.RUNTIME)
public @interface Aop{
	
	Class<? extends BaseInvocation>[] values();
	
}

AOP测试类:AopTest.java

package com.wly.aop.test;

import com.wly.aop.Aop;

@Aop(values = {Invo.class, Invo2.class})
public class AopTest {
	public void m(String str){
		System.out.println(str);
	}
}

实现了BaseInvocation的代理类Invo:

package com.wly.aop.test;

import com.wly.aop.inv.BaseInvocation;

public class Invo extends BaseInvocation{
	public void after(){
		System.out.println("after");
	}
	public void before(){
		System.out.println("before");
	}
}

实现了BaseInvocation的代理类Invo2(故意添加了一个异常):

package com.wly.aop.test;

import com.wly.aop.inv.BaseInvocation;

public class Invo2 extends BaseInvocation{
	public void after(){
		System.out.println("aop2 ---- after");
		System.out.println(1/0);
	}
	public void before(){
		System.out.println("aop2 ---- before");
	}
}

测试Main函数:

package com.wly.aop.test;

import com.wly.aop.util.AopUtil;

public class Main {
	public static void main(String[] args) {
		AopTest aop = AopUtil.getClassInstance(AopTest.class);
		aop.m("test");
	}
}

测试输出结果:

猜你喜欢

转载自my.oschina.net/u/3269106/blog/1154508