自定义Component, Service, Repository, Autowired和Transaction注解并实现
- 首先定义注解类(@Component等),这里参考了Spring
//@Target 表示该注解可以用于什么地方 TYPE:类、接口(包括注解类型)或enum声明
@Target({
ElementType.TYPE})
//@Retention表示需要在什么级别保存该注解信息 RUNTIME:VM将在运行期间保留注解,因此可以通过反射机制读取注解的信息
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Component {
String value() default "";
}
//@Target 表示该注解可以用于什么地方 TYPE:类、接口(包括注解类型)或enum声明
@Target({
ElementType.TYPE})
//@Retention表示需要在什么级别保存该注解信息 RUNTIME:VM将在运行期间保留注解,因此可以通过反射机制读取注解的信息
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Controller {
String value() default "";
}
@Target({
ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Service {
String value() default "";
}
@Target({
ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Repository {
String value() default "";
}
@Target({
ElementType.CONSTRUCTOR, ElementType.METHOD, ElementType.PARAMETER, ElementType.FIELD, ElementType.ANNOTATION_TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Autowired {
boolean requred() default true;
}
//@Target 表示该注解可以用于什么地方 TYPE:类、接口(包括注解类型)或enum声明 METHOD: 方法
@Target({
ElementType.TYPE, ElementType.METHOD})
//@Retention表示需要在什么级别保存该注解信息 RUNTIME:VM将在运行期间保留注解,因此可以通过反射机制读取注解的信息
@Retention(RetentionPolicy.RUNTIME)
//当@InheritedAnno注解加在某个类A上时,假如类B继承了A,则B也会带上该注解
@Inherited
@Documented
public @interface Transactional {
String value() default "";
Class<? extends Throwable>[] rollbackFor() default {
};
}
- 接着,是一些基本工具类和实现类,这里为这些类添加上相应的注解,其中Autowired的依赖需要实现set方法,用户后续的依赖注入。
@Component
public class ConnectionUtils {
/**
* 存储当前线程的连接
*/
private final ThreadLocal<Connection> threadLocal = new ThreadLocal<>();
/**
* 从当前线程获取连接
*/
public Connection getCurrentThreadConn() throws SQLException {
//判断当前线程中是否已经绑定连接,如果没有绑定,需要从连接池获取一个连接绑定到当前线程
Connection connection = threadLocal.get();
if(connection == null) {
// 从连接池拿连接并绑定到线程
connection = DruidUtils.getInstance().getConnection();
// 绑定到当前线程
threadLocal.set(connection);
}
return connection;
}
}
@Component
public class TransactionManager {
@Autowired
private ConnectionUtils connectionUtils;
public void setConnectionUtils(ConnectionUtils connectionUtils) {
this.connectionUtils = connectionUtils;
}
// 开启手动事务控制
public void beginTransaction() throws SQLException {
connectionUtils.getCurrentThreadConn().setAutoCommit(false);
}
// 提交事务
public void commit() throws SQLException {
connectionUtils.getCurrentThreadConn().commit();
}
// 回滚事务
public void rollback() throws SQLException {
connectionUtils.getCurrentThreadConn().rollback();
}
}
@Component
public class ProxyFactory {
@Autowired
private TransactionManager transactionManager;
public void setTransactionManager(TransactionManager transactionManager) {
this.transactionManager = transactionManager;
}
/**
* Jdk动态代理
* @param obj 委托对象
* @return 代理对象
*/
public Object getJdkProxy(Object obj) {
// 获取代理对象
return Proxy.newProxyInstance(obj.getClass().getClassLoader(), obj.getClass().getInterfaces(),
(proxy, method, args) -> {
Object result = null;
try{
// 开启事务(关闭事务的自动提交)
transactionManager.beginTransaction();
result = method.invoke(obj,args);
// 提交事务
transactionManager.commit();
}catch (Exception e) {
e.printStackTrace();
// 回滚事务
transactionManager.rollback();
// 抛出异常便于上层servlet捕获
throw e;
}
return result;
});
}
/**
* 使用cglib动态代理生成代理对象
* @param obj 委托对象
*/
public Object getCglibProxy(Object obj) {
return Enhancer.create(obj.getClass(), (MethodInterceptor) (o, method, objects, methodProxy) -> {
Object result = null;
try{
// 开启事务(关闭事务的自动提交)
transactionManager.beginTransaction();
result = method.invoke(obj,objects);
// 提交事务
transactionManager.commit();
}catch (Exception e) {
e.printStackTrace();
// 回滚事务
transactionManager.rollback();
// 抛出异常便于上层servlet捕获
throw e;
}
return result;
});
}
}
@Service("transferService")
@Transactional(rollbackFor = Exception.class)
public class TransferServiceImpl implements TransferService {
@Autowired
private AccountDao accountDao;
public void setAccountDao(AccountDao accountDao) {
this.accountDao = accountDao;
}
@Override
public void transfer(String fromCardNo, String toCardNo, int money) throws Exception {
Account from = accountDao.queryAccountByCardNo(fromCardNo);
Account to = accountDao.queryAccountByCardNo(toCardNo);
from.setMoney(from.getMoney()-money);
to.setMoney(to.getMoney()+money);
accountDao.updateAccountByCardNo(to);
int c = 1/0; //测试rollback
accountDao.updateAccountByCardNo(from);
}
}
@Repository("accountDao")
public class JdbcAccountDaoImpl implements AccountDao {
@Autowired
private ConnectionUtils connectionUtils;
public void setConnectionUtils(ConnectionUtils connectionUtils) {
this.connectionUtils = connectionUtils;
}
@Override
public int update(Money money) throws Exception {
Connection con = connectionUtils.getCurrentThreadConn();
String sql = "xxxx";
PreparedStatement preparedStatement = con.prepareStatement(sql);
preparedStatement.setInt(1,money.getMoney());
preparedStatement.setString(2,money.getId());
int i = preparedStatement.executeUpdate();
preparedStatement.close();
return i;
}
}
- 最后,也是关键,实现注解的解析及事务,其中事务用到了动态代理,分为jdk和cglib两种,其中jdk动态代理需要有接口类。
public class BeanFactory {
/**
* 存储对象
*/
private static final Map<String,Object> MAP = new HashMap<>();
private static final int COMPONENT = 1;
private static final int CONTROLLER = 2;
private static final int SERVICE = 3;
private static final int REPOSITORY = 4;
static {
try {
System.out.println("-========================================================================================");
//读取文件,获取所有的类
List<Class<?>> classList = getClass("com.lossdate.learning");
//获取及分类所有带注解的类
Set<Class<?>> componentClassSet = getClassByType(COMPONENT, classList);
Set<Class<?>> repositoryClassSet = getClassByType(REPOSITORY, classList);
Set<Class<?>> serviceClassSet = getClassByType(SERVICE, classList);
Set<Class<?>> controllerClassSet = getClassByType(CONTROLLER, classList);
//将对应的类存入MAP(存储对象)中
String[] nameSpilt;
for (Class<?> aClass : componentClassSet) {
Object o = aClass.newInstance();
Component annotation = aClass.getAnnotation(Component.class);
if(StringUtils.isNullOrEmpty(annotation.value())) {
nameSpilt = aClass.getName().split("\\.");
MAP.put(nameSpilt[nameSpilt.length - 1], o);
} else {
//注解有VAULE值(即自定义名称)
MAP.put(originNameWrapper(annotation.value()), o);
}
}
for (Class<?> aClass : repositoryClassSet) {
Object o = aClass.newInstance();
Repository annotation = aClass.getAnnotation(Repository.class);
if(StringUtils.isNullOrEmpty(annotation.value())) {
nameSpilt = aClass.getName().split("\\.");
MAP.put(nameSpilt[nameSpilt.length - 1], o);
} else {
MAP.put(originNameWrapper(annotation.value()), o);
}
}
for (Class<?> aClass : serviceClassSet) {
Object o = aClass.newInstance();
Service annotation = aClass.getAnnotation(Service.class);
if(StringUtils.isNullOrEmpty(annotation.value())) {
nameSpilt = aClass.getName().split("\\.");
MAP.put(nameSpilt[nameSpilt.length - 1], o);
} else {
MAP.put(originNameWrapper(annotation.value()), o);
}
}
for (Class<?> aClass : controllerClassSet) {
Object o = aClass.newInstance();
Controller annotation = aClass.getAnnotation(Controller.class);
if(StringUtils.isNullOrEmpty(annotation.value())) {
nameSpilt = aClass.getName().split("\\.");
MAP.put(nameSpilt[nameSpilt.length - 1], o);
} else {
MAP.put(originNameWrapper(annotation.value()), o);
}
}
//处理依赖和事务
for (Map.Entry<String, Object> classMap : MAP.entrySet()) {
Object aClass = classMap.getValue();
Class<?> c = aClass.getClass();
//处理依赖
Field[] declaredFields = c.getDeclaredFields();
for (Field declaredField : declaredFields) {
declaredField.setAccessible(true);
if(declaredField.isAnnotationPresent(Autowired.class) && declaredField.getAnnotation(Autowired.class).requred()) {
// String name = declaredField.getName(); map.get(name)会有首字母大小写问题
String[] nameArr = declaredField.getType().getName().split("\\.");
String name = nameArr[nameArr.length - 1];
Method[] methods = c.getMethods();
for (Method method : methods) {
//匹配set方法
if(method.getName().equalsIgnoreCase("set" + name)) {
//激活方法注入依赖
method.invoke(aClass, MAP.get(name));
}
}
}
}
//处理事务
if(c.isAnnotationPresent(Transactional.class)) {
//获取接口 jdk动态代理需要有接口才行
Class<?>[] interfaces = c.getInterfaces();
ProxyFactory proxyFactory = (ProxyFactory) BeanFactory.getBean("ProxyFactory");
if(interfaces.length > 0) {
//jdk动态代理
//注意接收返参
aClass = proxyFactory.getJdkProxy(aClass);
} else {
//cglib动态代理
aClass = proxyFactory.getCglibProxy(aClass);
}
}
MAP.put(classMap.getKey(), aClass);
}
} catch (ClassNotFoundException e) {
e.printStackTrace();
} catch (InstantiationException e) {
e.printStackTrace();
} catch (IllegalAccessException e) {
e.printStackTrace();
} catch (InvocationTargetException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 首字母转大写
*/
private static String originNameWrapper(String name) {
return name.substring(0,1).toUpperCase() + name.substring(1);
}
/**
* 根据注解整理类
*/
private static Set<Class<?>> getClassByType(int type, List<Class<?>> classList) {
Set<Class<?>> classSet = new HashSet<>();
classList.forEach(aClass -> {
Object annotation = null;
switch (type) {
case COMPONENT:
annotation = aClass.getAnnotation(Component.class);
break;
case CONTROLLER:
annotation = aClass.getAnnotation(Controller.class);
break;
case SERVICE:
annotation = aClass.getAnnotation(Service.class);
break;
case REPOSITORY:
annotation = aClass.getAnnotation(Repository.class);
break;
default: break;
}
if(annotation != null) {
classSet.add(aClass);
}
});
return classSet;
}
/**
* 获取指定目录下的所有类文件
*/
private static List<Class<?>> getClass(String packageName) throws ClassNotFoundException, IOException {
if(StringUtils.isNullOrEmpty(packageName)) {
throw new RuntimeException("无效的初始化路径");
}
List<Class<?>> classList = new ArrayList<>();
String path = packageName.replace(".", "/");
// 获取当前ClassPath的绝对URI路径用于加载文件
String filePath = Objects.requireNonNull(Thread.currentThread().getContextClassLoader().getResource("")).getPath() + path;
loadClassFactory(packageName, filePath, classList);
return classList;
}
private static void loadClassFactory(String packageName, String path, List<Class<?>> classList) throws ClassNotFoundException {
File currentFile = new File(path);
//筛选文件夹或者后缀为.class的文件
File[] files = currentFile.listFiles(file -> file.isDirectory() || file.getName().endsWith(".class"));
if(files != null) {
String className;
Class<?> aClass;
for (File file : files) {
if(file.isDirectory()) {
//是文件夹,递归
loadClassFactory(packageName + "." + file.getName(), file.getAbsolutePath(), classList);
} else {
//实例化
className = file.getName().replace(".class", "");
aClass = Class.forName(packageName + "." + className);
classList.add(aClass);
}
}
}
}
/**
* 对外提供获取实例对象的接口(根据id获取)
*/
public static Object getBean(String id) {
return MAP.get(originNameWrapper(id));
}
}
- 测试类
public class MyTest {
@Test
public void test() throws Exception {
TransferService transferService = (TransferService) BeanFactory.getBean("transferService");
String fromId = "111";
String toId = "112";
int money = 100;
transferService.transfer(fromId, toId, money);
System.out.println(transferService);
}
}