笔记3:自定义注解的实现

自定义Component, Service, Repository, Autowired和Transaction注解并实现

  1. 首先定义注解类(@Component等),这里参考了Spring
//@Target 表示该注解可以用于什么地方 TYPE:类、接口(包括注解类型)或enum声明
@Target({
    
    ElementType.TYPE})
//@Retention表示需要在什么级别保存该注解信息 RUNTIME:VM将在运行期间保留注解,因此可以通过反射机制读取注解的信息
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Component {
    
    
    String value() default "";
}
//@Target 表示该注解可以用于什么地方 TYPE:类、接口(包括注解类型)或enum声明
@Target({
    
    ElementType.TYPE})
//@Retention表示需要在什么级别保存该注解信息 RUNTIME:VM将在运行期间保留注解,因此可以通过反射机制读取注解的信息
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Controller {
    
    
    String value() default "";
}
@Target({
    
    ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Service {
    
    
    String value() default "";
}
@Target({
    
    ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Repository {
    
    
    String value() default "";
}
@Target({
    
    ElementType.CONSTRUCTOR, ElementType.METHOD, ElementType.PARAMETER, ElementType.FIELD, ElementType.ANNOTATION_TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Autowired {
    
    
    boolean requred() default true;
}
//@Target 表示该注解可以用于什么地方 TYPE:类、接口(包括注解类型)或enum声明  METHOD: 方法
@Target({
    
    ElementType.TYPE, ElementType.METHOD})
//@Retention表示需要在什么级别保存该注解信息 RUNTIME:VM将在运行期间保留注解,因此可以通过反射机制读取注解的信息
@Retention(RetentionPolicy.RUNTIME)
//当@InheritedAnno注解加在某个类A上时,假如类B继承了A,则B也会带上该注解
@Inherited
@Documented
public @interface Transactional {
    
    
    String value() default "";

    Class<? extends Throwable>[] rollbackFor() default {
    
    };
}
  1. 接着,是一些基本工具类和实现类,这里为这些类添加上相应的注解,其中Autowired的依赖需要实现set方法,用户后续的依赖注入。
@Component
public class ConnectionUtils {
    
    

    /**
     * 存储当前线程的连接
     */
    private final ThreadLocal<Connection> threadLocal = new ThreadLocal<>(); 

    /**
     * 从当前线程获取连接
     */
    public Connection getCurrentThreadConn() throws SQLException {
    
    
        //判断当前线程中是否已经绑定连接,如果没有绑定,需要从连接池获取一个连接绑定到当前线程
        Connection connection = threadLocal.get();
        if(connection == null) {
    
    
            // 从连接池拿连接并绑定到线程
            connection = DruidUtils.getInstance().getConnection();
            // 绑定到当前线程
            threadLocal.set(connection);
        }
        return connection;

    }
}
@Component
public class TransactionManager {
    
    

    @Autowired
    private ConnectionUtils connectionUtils;

    public void setConnectionUtils(ConnectionUtils connectionUtils) {
    
    
        this.connectionUtils = connectionUtils;
    }

    // 开启手动事务控制
    public void beginTransaction() throws SQLException {
    
    
        connectionUtils.getCurrentThreadConn().setAutoCommit(false);
    }


    // 提交事务
    public void commit() throws SQLException {
    
    
        connectionUtils.getCurrentThreadConn().commit();
    }


    // 回滚事务
    public void rollback() throws SQLException {
    
    
        connectionUtils.getCurrentThreadConn().rollback();
    }
}
@Component
public class ProxyFactory {
    
    

    @Autowired
    private TransactionManager transactionManager;

    public void setTransactionManager(TransactionManager transactionManager) {
    
    
        this.transactionManager = transactionManager;
    }

    /**
     * Jdk动态代理
     * @param obj  委托对象
     * @return   代理对象
     */
    public Object getJdkProxy(Object obj) {
    
    

        // 获取代理对象
        return  Proxy.newProxyInstance(obj.getClass().getClassLoader(), obj.getClass().getInterfaces(),
                (proxy, method, args) -> {
    
    
                    Object result = null;
                    try{
    
    
                        // 开启事务(关闭事务的自动提交)
                        transactionManager.beginTransaction();

                        result = method.invoke(obj,args);

                        // 提交事务
                        transactionManager.commit();
                    }catch (Exception e) {
    
    
                        e.printStackTrace();
                        // 回滚事务
                        transactionManager.rollback();

                        // 抛出异常便于上层servlet捕获
                        throw e;
                    }

                    return result;
                });
    }


    /**
     * 使用cglib动态代理生成代理对象
     * @param obj 委托对象
     */
    public Object getCglibProxy(Object obj) {
    
    
        return  Enhancer.create(obj.getClass(), (MethodInterceptor) (o, method, objects, methodProxy) -> {
    
    
            Object result = null;
            try{
    
    
                // 开启事务(关闭事务的自动提交)
                transactionManager.beginTransaction();

                result = method.invoke(obj,objects);

                // 提交事务
                transactionManager.commit();
            }catch (Exception e) {
    
    
                e.printStackTrace();
                // 回滚事务
                transactionManager.rollback();

                // 抛出异常便于上层servlet捕获
                throw e;

            }
            return result;
        });
    }
}
@Service("transferService")
@Transactional(rollbackFor = Exception.class)
public class TransferServiceImpl implements TransferService {
    
    

    @Autowired
    private AccountDao accountDao;
    
    public void setAccountDao(AccountDao accountDao) {
    
    
        this.accountDao = accountDao;
    }



    @Override
    public void transfer(String fromCardNo, String toCardNo, int money) throws Exception {
    
    
        Account from = accountDao.queryAccountByCardNo(fromCardNo);
        Account to = accountDao.queryAccountByCardNo(toCardNo);

        from.setMoney(from.getMoney()-money);
        to.setMoney(to.getMoney()+money);

        accountDao.updateAccountByCardNo(to);
        int c = 1/0; //测试rollback
        accountDao.updateAccountByCardNo(from);
    }
}
@Repository("accountDao")
public class JdbcAccountDaoImpl implements AccountDao {
    
    

    @Autowired
    private ConnectionUtils connectionUtils;

    public void setConnectionUtils(ConnectionUtils connectionUtils) {
    
    
        this.connectionUtils = connectionUtils;
    }

    @Override
    public int update(Money money) throws Exception {
    
    
        Connection con = connectionUtils.getCurrentThreadConn();
        String sql = "xxxx";
        PreparedStatement preparedStatement = con.prepareStatement(sql);
        preparedStatement.setInt(1,money.getMoney());
        preparedStatement.setString(2,money.getId());
        int i = preparedStatement.executeUpdate();

        preparedStatement.close();
        return i;
    }
}
  1. 最后,也是关键,实现注解的解析及事务,其中事务用到了动态代理,分为jdk和cglib两种,其中jdk动态代理需要有接口类。
public class BeanFactory {
    
    

    /**
     * 存储对象
     */
    private static final Map<String,Object> MAP = new HashMap<>();

    private static final int COMPONENT = 1;
    private static final int CONTROLLER = 2;
    private static final int SERVICE = 3;
    private static final int REPOSITORY = 4;


    static {
    
    
        try {
    
    
            System.out.println("-========================================================================================");
            //读取文件,获取所有的类
            List<Class<?>> classList = getClass("com.lossdate.learning");

            //获取及分类所有带注解的类
            Set<Class<?>> componentClassSet = getClassByType(COMPONENT, classList);
            Set<Class<?>> repositoryClassSet = getClassByType(REPOSITORY, classList);
            Set<Class<?>> serviceClassSet = getClassByType(SERVICE, classList);
            Set<Class<?>> controllerClassSet = getClassByType(CONTROLLER, classList);

			//将对应的类存入MAP(存储对象)中
            String[] nameSpilt;
            for (Class<?> aClass : componentClassSet) {
    
    
                Object o = aClass.newInstance();
                Component annotation = aClass.getAnnotation(Component.class);
                if(StringUtils.isNullOrEmpty(annotation.value())) {
    
    
                    nameSpilt = aClass.getName().split("\\.");
                    MAP.put(nameSpilt[nameSpilt.length - 1], o);
                } else {
    
    
                	//注解有VAULE值(即自定义名称)
                    MAP.put(originNameWrapper(annotation.value()), o);
                }
            }
            for (Class<?> aClass : repositoryClassSet) {
    
    
                Object o = aClass.newInstance();
                Repository annotation = aClass.getAnnotation(Repository.class);
                if(StringUtils.isNullOrEmpty(annotation.value())) {
    
    
                    nameSpilt = aClass.getName().split("\\.");
                    MAP.put(nameSpilt[nameSpilt.length - 1], o);
                } else {
    
    
                    MAP.put(originNameWrapper(annotation.value()), o);
                }
            }
            for (Class<?> aClass : serviceClassSet) {
    
    
                Object o = aClass.newInstance();
                Service annotation = aClass.getAnnotation(Service.class);
                if(StringUtils.isNullOrEmpty(annotation.value())) {
    
    
                    nameSpilt = aClass.getName().split("\\.");
                    MAP.put(nameSpilt[nameSpilt.length - 1], o);
                } else {
    
    
                    MAP.put(originNameWrapper(annotation.value()), o);
                }
            }
            for (Class<?> aClass : controllerClassSet) {
    
    
                Object o = aClass.newInstance();
                Controller annotation = aClass.getAnnotation(Controller.class);
                if(StringUtils.isNullOrEmpty(annotation.value())) {
    
    
                    nameSpilt = aClass.getName().split("\\.");
                    MAP.put(nameSpilt[nameSpilt.length - 1], o);
                } else {
    
    
                    MAP.put(originNameWrapper(annotation.value()), o);
                }
            }

            //处理依赖和事务
            for (Map.Entry<String, Object> classMap : MAP.entrySet()) {
    
    
                Object aClass = classMap.getValue();
                Class<?> c = aClass.getClass();
                //处理依赖
                Field[] declaredFields = c.getDeclaredFields();
                for (Field declaredField : declaredFields) {
    
    
                    declaredField.setAccessible(true);
                    if(declaredField.isAnnotationPresent(Autowired.class) && declaredField.getAnnotation(Autowired.class).requred()) {
    
    
//                        String name = declaredField.getName(); map.get(name)会有首字母大小写问题
                        String[] nameArr = declaredField.getType().getName().split("\\.");
                        String name = nameArr[nameArr.length - 1];
                        Method[] methods = c.getMethods();
                        for (Method method : methods) {
    
    
                        	//匹配set方法
                            if(method.getName().equalsIgnoreCase("set" + name)) {
    
    
                                //激活方法注入依赖
                                method.invoke(aClass, MAP.get(name));
                            }
                        }
                    }
                }

                //处理事务
                if(c.isAnnotationPresent(Transactional.class)) {
    
    
                    //获取接口 jdk动态代理需要有接口才行
                    Class<?>[] interfaces = c.getInterfaces();
                    ProxyFactory proxyFactory = (ProxyFactory) BeanFactory.getBean("ProxyFactory");
                    if(interfaces.length > 0) {
    
    
                        //jdk动态代理
                        //注意接收返参
                        aClass = proxyFactory.getJdkProxy(aClass);
                    } else {
    
    
                        //cglib动态代理
                        aClass = proxyFactory.getCglibProxy(aClass);
                    }
                }

                MAP.put(classMap.getKey(), aClass);
            }


        } catch (ClassNotFoundException e) {
    
    
            e.printStackTrace();
        } catch (InstantiationException e) {
    
    
            e.printStackTrace();
        } catch (IllegalAccessException e) {
    
    
            e.printStackTrace();
        } catch (InvocationTargetException e) {
    
    
            e.printStackTrace();
        } catch (IOException e) {
    
    
            e.printStackTrace();
        }
    }

    /**
     * 首字母转大写
     */
    private static String originNameWrapper(String name) {
    
    
        return name.substring(0,1).toUpperCase() + name.substring(1);
    }

    /**
     * 根据注解整理类
     */
    private static Set<Class<?>> getClassByType(int type, List<Class<?>> classList) {
    
    
        Set<Class<?>> classSet = new HashSet<>();
        classList.forEach(aClass -> {
    
    
            Object annotation = null;
            switch (type) {
    
    
                case COMPONENT:
                    annotation = aClass.getAnnotation(Component.class);
                    break;
                case CONTROLLER:
                    annotation = aClass.getAnnotation(Controller.class);
                    break;
                case SERVICE:
                    annotation = aClass.getAnnotation(Service.class);
                    break;
                case REPOSITORY:
                    annotation = aClass.getAnnotation(Repository.class);
                    break;
                default: break;
            }
            if(annotation != null) {
    
    
                classSet.add(aClass);
            }
        });

        return classSet;
    }

    /**
     * 获取指定目录下的所有类文件
     */
    private static List<Class<?>> getClass(String packageName) throws ClassNotFoundException, IOException {
    
    
        if(StringUtils.isNullOrEmpty(packageName)) {
    
    
            throw new RuntimeException("无效的初始化路径");
        }

        List<Class<?>> classList = new ArrayList<>();
        String path = packageName.replace(".", "/");
        // 获取当前ClassPath的绝对URI路径用于加载文件
        String filePath = Objects.requireNonNull(Thread.currentThread().getContextClassLoader().getResource("")).getPath() + path;
        loadClassFactory(packageName, filePath, classList);

        return classList;
    }

    private static void loadClassFactory(String packageName, String path, List<Class<?>> classList) throws ClassNotFoundException {
    
    
        File currentFile = new File(path);
        //筛选文件夹或者后缀为.class的文件
        File[] files = currentFile.listFiles(file -> file.isDirectory() || file.getName().endsWith(".class"));

        if(files != null) {
    
    
            String className;
            Class<?> aClass;
            for (File file : files) {
    
    
                if(file.isDirectory()) {
    
    
                    //是文件夹,递归
                    loadClassFactory(packageName + "." + file.getName(), file.getAbsolutePath(), classList);
                } else {
    
    
                    //实例化
                    className = file.getName().replace(".class", "");
                    aClass = Class.forName(packageName + "." + className);
                    classList.add(aClass);
                }
            }
        }

    }


    /**
     * 对外提供获取实例对象的接口(根据id获取)
     */
    public static  Object getBean(String id) {
    
    
        return MAP.get(originNameWrapper(id));
    }
}
  1. 测试类
public class MyTest {
    
    

    @Test
    public void test() throws Exception {
    
    
        TransferService transferService = (TransferService) BeanFactory.getBean("transferService");
        String fromId = "111";
        String toId = "112";
        int money = 100;
        transferService.transfer(fromId, toId, money);
        System.out.println(transferService);
    }

}

猜你喜欢

转载自blog.csdn.net/Lossdate/article/details/111466286
今日推荐