手把手教你写一个简单的IOC容器和DI

在Spring中IOC是个绝佳的解耦合手段,为了更好的理解我就动手自己写了一个

预备知识:

注解,反射,集合类,lambda表达式,流式API

IOC

如何把一个类注册进去呢?首先我们要让容器“发现”它,所以使用注解,声明它应当加入容器

其中的value即对应的是Spring中的Bean name

@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE})
public @interface Part {
    String value() default "";
}

复制代码

扫描包生成类的工具

当然,有人会说hutool的ClassScaner很好用,但是这里为了加深理解,我就自己写一个

思路就是利用文件名利用Class.forName()得到类的反射再生成实例

public static List<Object> find(String packName, ClassFilter classFilter) throws IOException {
        //获取当前路径
        Enumeration<URL> entity = Thread.currentThread().getContextClassLoader().getResources(packName);
        HashSet<String> classPaths = new HashSet<>();
        ArrayList<Object> classes = new ArrayList<>();
        //拿到处理后的路径,处理前为/..../target/classes
        //处理后为/..../target/classes
        if (entity.hasMoreElements()) {
            String path = entity.nextElement().getPath().substring(1);
            classPaths.add(path);
        }
        //这里跳转到我写的一个把路径下的.class文件生成为类名的方法,后面会讲述
        //set的元素为类名 比如Entity.Student
        Set<String> set = loadClassName(classPaths);
        for (String s : set) {
            try {
                Class<?> c = Class.forName(s);
                //利用过滤器判断需不需要生成实例
                if (classFilter.test(c)){
                    //这里为了简单使用无参构造器
                    Constructor<?> constructor = c.getConstructor();
                    constructor.setAccessible(true);
                    //将生成的实例加入返回的list集合中
                    classes.add(constructor.newInstance());
                }
            }catch (ClassNotFoundException| InstantiationException | IllegalAccessException| InvocationTargetException e) {
                throw new RuntimeException(e);
            }catch (NoSuchMethodException e){
                System.err.println(e.getMessage());
            }
        }
        return classes;
    }
复制代码

到来其中的一个核心函数loadClassName

/**
     * @param classPaths 路径名集合
     * @return 类名的集合
     */
    private static Set<String> loadClassName(HashSet<String> classPaths){
        Queue<File> queue = new LinkedList<>();
        HashSet<String> classNames = new HashSet<>();
        //对每一个路径得到对应所有以.class结尾的文件
        classPaths.forEach(p -> {
            //迭代的方法,树的层次遍历
            queue.offer(new File(p));
            while (!queue.isEmpty()){
                File file = queue.poll();
                if (file.isDirectory()) {
                    File[] files = file.listFiles();
                    for (File file1 : files) {
                        queue.offer(file1);
                    }
                }else if(file.getName().endsWith(".class")){
                    //对文件名处理得到类名
                    // ..../target/classes处理完为  \....\target\classes
                    String replace = p.replace("/", "\\");
                    //对于每个.class文件都是以....\target\classes开头,去掉开头,去掉后缀就是类名了
                   String className = file.getPath()
                            .replace(replace, "")
                            .replace(".class", "").replace("\\", ".");
                    classNames.add(className);
                }
            }
        });
        return classNames;
    }
复制代码

好了,现在就可以扫描包了

上面我也提到了不是所有的类都必须放到容器中,现在让我们看看这个 ClassFilter 过滤器是什么东西吧

@FunctionalInterface
public interface ClassFilter{
    boolean test(Class c);
}
复制代码

是个函数式接口,这就意味着使用lambda表达式会很方便

通过这个接口我们就很容易地构造这么一个函数帮我们把所有有@Part注解的类生成好

public static<T> List<Object> findByAnnotation(String packName, Class<T> annotation) throws IOException{
        if (!annotation.isAnnotation()) {
            throw new RuntimeException("it not an annotation"+annotation.getTypeName());
        }
        ClassFilter classFilter =(c) -> c.getAnnotation(annotation) != null;
        return find(packName, classFilter);
    }
复制代码

IOC容器

上面的准备工作做的差不多了

该动手写IOC容器了

思考一下在Spring中我们很容易通过bean name得到java bean,所以使用一个Map<String,Object>可以模拟一下。

这里我们在IOCContainer中添加一个变量

private Map<String,Object> context;
复制代码

构造函数

public IOCContainer(String packName){
        try {
            init(packName);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    public IOCContainer(){
        //默认扫描所有的包
        this("");
    }
复制代码

初始化函数:

 /**
     * @param packName 路径名在ClassScannerUtil中的函数要使用
     * @throws IOException
     * @author dreamlike_ocean
     */
    public void init(String packName) throws IOException {
        //做一个bean name 的映射。如果@Part注解中的值不为空则使用value的值做bean name
        //如果为空就用这个 java bean的类名做bean name
        Function<Object,String> keyMapper = (o) -> {
            Class<?> aClass = o.getClass();
            String s = aClass.getAnnotation(Part.class).value();
            if (s.isBlank()) {
                return o.getClass().getTypeName();
            }
            return s;
        };
        context = new HashMap<String,Object>();
        //获取所有添加@Part注解的类实例
        List<Object> objectList = ClassScannerUtil.findByAnnotation(packName, Part.class);
        //先把自己注入进去
        context.put("IOCContainer", this);
        for (Object o : objectList) {
            //利用上面写好的映射函数接口 获取bean name
            String beanName = keyMapper.apply(o);
            //bean name冲突情况,直接报错
            if (context.containsKey(beanName)) {
                String msg = new StringBuilder().append("duplicate bean name: ")
                        .append(beanName)
                        .append("in")
                        .append(o.getClass())
                        .append(" and ")
                        .append(context.get(beanName).getClass()).toString();
                throw new RuntimeException(msg);
            }
            //加入容器
            context.put(beanName, o);
        }
        //帮助垃圾回收,这个复杂度为O(n),理论上objectList = null也能帮助回收
        objectList.clear();
    }
复制代码

对外暴露的获取Bean的api

    /**
     * 
     * @param beanName
     * @return 记得判断空指针
     * @author dreamlike_ocean
     */
    public Optional<Object> getBean(String beanName){
        return Optional.ofNullable(context.get(beanName));
    }

    /**
     * 
     * @param beanName
     * @param aclass
     * @param <T> 需要返回的类型,类型强转
     * @exception ClassCastException 类型强转可能导致无法转化的异常          
     * @return @author dreamlike_ocean
     */
    public<T> Optional<T> getBean(String beanName,Class<T> aclass){
        return Optional.ofNullable((T)context.get(beanName));
    }

    /**
     *
     * @param interfaceType
     * @param <T>
     * @return 所有继承这个接口的集合
     * @author dreamlike_ocean
     */
    public<T> List<T> getBeanByInterfaceType(Class<T> interfaceType){
        if (!interfaceType.isInterface()) {
            throw new RuntimeException("it is not an interface type:"+interfaceType.getTypeName());
        }
        return context.values().stream()
                .filter(o -> interfaceType.isAssignableFrom(o.getClass()))
                .map(o -> (T)o)
                .collect(Collectors.toList());
    }

    /**
     * 
     * @param type
     * @param <T>
     * @return 所有这个类型的集合
     * @author dreamlike_ocean
     */
    
    public<T> List<T> getBeanByType(Class<T> type){
        return context.values().stream()
                .filter(o -> type.isAssignableFrom(o.getClass()))
                .map(o -> (T)o)
                .collect(Collectors.toList());
    }

    /**
     * 
     * @return 获取所有值
     * @author dreamlike_ocean 
     */
    public Collection<Object> getBeans(){
        return context.values();
    }

    /**
     * 
     * @return 获取容器
     * @author dreamlike_ocean
     */
    public Map<String,Object> getContext(){
        return context;
    }

复制代码

DI

上面我们获取的都是利用无参的构造函数得到的java bean,这和想的差的有点远,我想要的是一幅画,他却给了我一张白纸。这怎么能行!DI模块上,给他整个活!

为了区别通过类型注入还是名称注入,我写了两个注解用于区分

@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.FIELD)
public @interface InjectByName {
    String value();
}
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD,ElementType.METHOD})
public @interface InjectByType {

}

复制代码

首先DI先必须知道到底对哪个容器注入,所以通过构造函数传入一个

 private IOCContainer iocContainer;
    public DI(IOCContainer iocContainer) {
        Objects.requireNonNull(iocContainer);
        this.iocContainer = iocContainer;
    }
复制代码

先是对字段的按类型注入

/**
     * 
     * @param o 需要被注入的类
     * @author dreamlike_ocean          
     */

    private void InjectFieldByType(Object o){
        try {
            //获取内部所有字段
            Field[] declaredFields = o.getClass().getDeclaredFields();
            for (Field field : declaredFields) {
                //判断当前字段是否有注解标识
                if (field.getAnnotation(InjectByType.class) != null) {
                    //防止因为private而抛出异常
                    field.setAccessible(true);
                    List list = iocContainer.getBeanByType(field.getType());
                    //如果找不到,那么注入失败
                    //这里我选择抛出异常,也可给他赋值为null
                    if(list.size() == 0){
                        throw new RuntimeException("not find "+field.getType());
                    }
                    //多于一个也注入失败,和Spring一致
                    if (list.size()!=1){
                        throw new RuntimeException("too many");
                    }
                    //正常注入
                    field.set(o, list.get(0));
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
    }

复制代码

对字段按名称注入

 /**
     *
     * @param o 需要被注入的类
     * @author dreamlike_ocean
     */
    private void InjectFieldByName(Object o){
        try {
            Field[] declaredFields = o.getClass().getDeclaredFields();
            for (Field field : declaredFields) {
                InjectByName annotation = field.getAnnotation(InjectByName.class);
                if (annotation != null) {
                    field.setAccessible(true);
                    //通过注解中的bean name寻找注入的值
                    //这里optional类没有发挥它自己的函数式优势,因为我觉得在lambda表达式里面写异常处理属实不好看
                    //借用在Stack overflow看的一句话,Oracle用受检异常把lambda玩砸了
                    Object v = iocContainer.getBean(annotation.value()).get();
                    if (v != null) {
                        field.set(o, v);
                    }else{
                        //同样找不到就抛异常
                        throw new RuntimeException("not find "+field.getType());
                    }
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        }
    }
复制代码

对函数按类型注入

 /**
     * 这个函数必须是setter函数
     * @param o 要被注入的类
     * @author dreamlike_ocean
     */
    private void InjectMethod(Object o){
        Method[] declaredMethods = o.getClass().getDeclaredMethods();
        try {
            for (Method method : declaredMethods) {
                //获取添加注解的函数
                if (method.getAnnotation(InjectByType.class) != null) {
                    //获取参数列表
                    Class<?>[] parameterTypes = method.getParameterTypes();
                    method.setAccessible(true);
                    int i = method.getParameterCount();
                    //为储存实参做准备
                    Object[] param = new Object[i];
                    //变量重用,现在它代表当前下标了
                    i=0;
                    for (Class<?> parameterType : parameterTypes) {
                        List<?> list = iocContainer.getBeanByType(parameterType);
                        if(list.size() == 0){
                            throw new RuntimeException("not find "+parameterType+"。method :"+method+"class:"+o.getClass());
                        }
                        if (list.size()!=1){
                            throw new RuntimeException("too many");
                        }
                        //暂时存储实参
                        param[i++] = list.get(0);
                    }
                    //调用对应实例的函数
                    method.invoke(o, param);
                }
            }
        } catch (IllegalAccessException e) {
            e.printStackTrace();
        } catch (InvocationTargetException e) {
            e.printStackTrace();
        }
    }
复制代码

你会发现上面都是私有方法,因为我想对外暴露一个简洁的API

  /**
     * 对字段依次进行按类型注入和按名称注入
     * 再对setter方法注入
     * @author dreamlike_ocean
     */
    public void inject(){
        iocContainer.getBeans().forEach(o -> {
            InjectFieldByType(o);
            InjectFieldByName(o);
            InjectMethod(o);
        });
    }
复制代码

测试

做好了,来让我们测一测

@Part("testA")
class A{
    @InjectByType
    private B b;
    public A(){

    }

    public B getB() {
        return b;
    }
}
@Part
class B{
    private UUID uuid;
public B(){
    uuid = UUID.randomUUID();
}

    public UUID getUuid() {
        return uuid;
    }
}
@Part
class C{
    public C(){
    }

}
复制代码

测试方法

@Test
public void test(){
 IOCContainer container = new IOCContainer();
       DI di = new DI(container);
       di.inject();
       System.out.println(container.getBeanByType(A.class).get(0).getB().getUuid());
       System.out.println(container.getBeanByType(B.class).get(0).getUuid());
}
复制代码

好了这就可以了

猜你喜欢

转载自juejin.im/post/5e561077518825492c0504fd