Spring动态创建bean

最近有个项目场景,多垂类支持,大体业务流程相同,只是一些业务规则的校验参数不同。解决思路是将业务参数作为类的属性,然后创建垂类数量个实例,去处理不同垂类的业务。

看了spring ioc部分的代码,个人感觉在spring完成bean创建的过程后,做一个类实现ApplicationContextAware接口,然后克隆多个需要的BeanDefinition,附不同的业务参数属性值的方式比较讨巧。新增加的BeanDefinition会在getBean的过程中,由spring创建。

下面分两部分介绍:
1、动态创建bean的代码实现
2、spring的ioc源码解读,这部分放到另外一篇博客 http://mazhen2010.iteye.com/blog/2283773
<spring.version>4.0.6.RELEASE</spring.version>

【动态创建bean的代码实现】
1、创建一个实现ApplicationContextAware接口的类,然后获取DefaultListableBeanFactory
    private void setSpringFactory(ApplicationContext applicationContext) {

        if (applicationContext instanceof AbstractRefreshableApplicationContext) {
            // suit both XmlWebApplicationContext and ClassPathXmlApplicationContext
            AbstractRefreshableApplicationContext springContext = (AbstractRefreshableApplicationContext) applicationContext;
            if (!(springContext.getBeanFactory() instanceof DefaultListableBeanFactory)) {
                LOGGER.error("No suitable bean factory! The current factory class is {}",
                        springContext.getBeanFactory().getClass());
            }
            springFactory = (DefaultListableBeanFactory) springContext.getBeanFactory();
        } else if (applicationContext instanceof GenericApplicationContext) {
            // suit GenericApplicationContext
            GenericApplicationContext springContext = (GenericApplicationContext) applicationContext;
            springFactory = springContext.getDefaultListableBeanFactory();
        } else {
            LOGGER.error("No suitable application context! The current context class is {}",
                    applicationContext.getClass());
        }
    }


2、定义注解,以找到需要克隆的BeaDefinition和需要赋值的属性
@Target({ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
@Component
public @interface TemplateService {

    //服务名称
    String serviceName();
    //服务实现名称
    String value() default "";
}

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface TemplateBizParam {
}

@TemplateService(serviceName = "demoService", value = "demoServiceImpl")
public class DemoServiceImpl extends AbstractServiceImpl implements DemoService {

    @TemplateBizParam
    private String noVisitDays;

    @Override
    public void doDemo(Long poiId) {
        StringBuilder builder = new StringBuilder("doDemo").append("//").append("poiId:").append(poiId);
        builder.append("//").append(noVisitDays).append("//").append(getExtendFields()).append("//");
        builder.append("abc:").append(getExtendField("abc"));
        System.out.println(builder.toString());
    }

    @Override
    public void doDemos(List<Long> poiIds) {
        System.out.println("poiIds" + poiIds + "; noVisitDays:" + noVisitDays);
    }

}


3、从垂类模板中获取需要动态创建的bean信息,然后注册BeanDefinition
    private void registerBeanDefinition(String templateId, ServiceEntity serviceEntity) {

        try {
            if (springFactory.containsBeanDefinition(serviceEntity.getImplName())) {
                //step1: 注入多个实例
                String beanKey = generateTemplateBeanName(templateId, serviceEntity.getServiceName());
                BeanDefinition beanDefinition = springFactory.getBeanDefinition(serviceEntity.getImplName());
                String className = beanDefinition.getBeanClassName();
                Class c = null;
                try {
                    c = Class.forName(className);
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }

                BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.rootBeanDefinition(className);
                beanDefinitionBuilder.getBeanDefinition().setAttribute("id", beanKey);

                springFactory.registerBeanDefinition(
                        beanKey, beanDefinitionBuilder.getBeanDefinition());
                LOGGER.info("Register bean definition successfully. beanName:{}, implName:{}",
                        generateTemplateBeanName(templateId, serviceEntity.getServiceName()), serviceEntity.getImplName());

                //step2: 为实例自动化注入属性
                Object bean = springFactory.getBean(beanKey, c);
                injectParamVaules(bean, c, serviceEntity);
            }

        } catch (NoSuchBeanDefinitionException ex) {
            LOGGER.info("No bean definition in spring factory. implName:{}", serviceEntity.getImplName());
        } catch (BeanDefinitionStoreException ex) {
            LOGGER.info("Register bean definition wrong. beanName:{}, implName:{}",
                    generateTemplateBeanName(templateId, serviceEntity.getServiceName()), serviceEntity.getImplName());
        }
    }

    private <T> void injectParamVaules(Object bean, Class<T> requiredType, ServiceEntity serviceEntity) {

        if (requiredType.isAnnotationPresent(TemplateService.class)) {
            Field[] fields = requiredType.getDeclaredFields(); //获取类的所有属性
            for (Field field : fields) {
                // 注入业务参数
                if (field.isAnnotationPresent(TemplateBizParam.class)) {
                    field.setAccessible(true);
                    try {
                        if ((serviceEntity.getBizParamMap() != null) && (serviceEntity.getBizParamMap().containsKey(field.getName()))) {
                            field.set(bean, serviceEntity.getBizParamMap().get(field.getName()));
                            LOGGER.info("inject biz param value successfully, paramName = {}, value = {}", field.getName(), serviceEntity.getBizParamMap().get(field.getName()));
                        }
                    } catch (IllegalAccessException e) {
                        LOGGER.error("inject biz param failed. {}", e.getMessage());
                        e.printStackTrace();
                    }
                }
            }

            Class<AbstractService> superClass = getSuperClass(requiredType);
            if(superClass != null) {
                Field[] superFields = superClass.getDeclaredFields(); //获取类的所有属性
                for (Field field : superFields) {
                    // 注入扩展字段
                    if (field.isAnnotationPresent(TemplateExtendFields.class)) {
                        field.setAccessible(true);
                        try {
                            if(serviceEntity.getExtendFields() != null){
                                field.set(bean, serviceEntity.getExtendFields());
                                LOGGER.info("inject extend fields successfully, extendFields = {}", serviceEntity.getExtendFields());
                            }
                        } catch (IllegalAccessException e) {
                            LOGGER.error("inject extend fields failed. {}", e.getMessage());
                            e.printStackTrace();
                        }
                    }
                }
            }


        }
    }


4、定义一个Context继承AbstractServiceContext,实现运行时根据策略,选取所需的业务实例进行处理
@Service("demoService")
public class DemoServiceContext extends AbstractServiceContext implements DemoService {

    @Override
    public void doDemo(Long poiId) {
        getServiceImpl(poiId, DemoService.class).doDemo(poiId);
    }

}

/**
 * 服务上下文抽象类,负责具体服务实现类的策略选择和扩展字段传递.
 * User: mazhen01
 * Date: 2016/3/3
 * Time: 10:14
 */
public abstract class AbstractServiceContext {

    @Resource
    private TemplateBeanFactory templateBeanFactory;

    @Autowired
    public TemplateFunction templateFunction;

    // 当前线程使用的beanName
    private ThreadLocal<String> currentTemplateBeanName = new ThreadLocal<String>();

    private static final Logger LOGGER = LoggerFactory.getLogger(AbstractServiceContext.class);

    /**
     * 根据POI所属行业,获取服务实例
     *
     * @param poiId poiId
     * @param clazz 服务接口
     * @param <T>   实例类型
     * @return
     * @throws AnnotationException
     * @throws BeansException
     */
    protected <T> T getServiceImpl(Long poiId, Class<T> clazz) throws AnnotationException, BeansException {
        String serviceName = getServiceName();
        String templateId = templateFunction.getTemplateId(poiId, serviceName);
        if (templateId == null) {
            LOGGER.error("templateId is null. No templateId id configured for poiId = {}.", poiId);
            throw new TemplateException("templateId is null, can not find templateId.");
        }
        currentTemplateBeanName.set(TemplateBeanFactory.generateTemplateBeanName(templateId, serviceName));
        return templateBeanFactory.getBean(TemplateBeanFactory.generateTemplateBeanName(templateId, serviceName), clazz);
    }

    protected <T> T getServiceImpl(List<Long> poiIds, Class<T> clazz) throws AnnotationException, BeansException {
        if (CollectionUtils.isEmpty(poiIds)) {
            LOGGER.error("poiIds List is null");
            throw new TemplateException("poiIds is null.");
        }
        Long poiId = poiIds.get(0);
        return getServiceImpl(poiId, clazz);
    }

    /**
     * 根据beanName,获取服务实例
     *
     * @param templateBeanName beanName
     * @param clazz            服务接口
     * @param <T>              实例类型
     * @return
     * @throws AnnotationException
     * @throws BeansException
     */
    protected <T> T getServiceImpl(String templateBeanName, Class<T> clazz) throws AnnotationException, BeansException {
        return templateBeanFactory.getBean(templateBeanName, clazz);
    }

    /**
     * 根据POI所属行业,获取服务实例的扩展字段列表
     *
     * @param poiId
     * @return
     */
    public List<String> getExtendFields(Long poiId) {
        AbstractServiceImpl abstractService = getServiceImpl(poiId, AbstractServiceImpl.class);

        if (abstractService == null || CollectionUtils.isEmpty(abstractService.getExtendFields())) {
            Lists.newArrayList();
        }

        return abstractService.getExtendFields();
    }

    /**
     * 根据POI所属行业,设置服务实例所需要的扩展字段的具体值
     *
     * @param poiId   poiId
     * @param request 用户请求
     */
    public void setExtendField(Long poiId, HttpServletRequest request) {

        if (request == null) {
            return;
        }

        AbstractServiceImpl abstractService = getServiceImpl(poiId, AbstractServiceImpl.class);

        if (abstractService == null || CollectionUtils.isEmpty(abstractService.getExtendFields())) {
            return;
        }

        for (String field : abstractService.getExtendFields()) {
            setExtendField(field, request.getAttribute(field));
        }
    }

    /**
     * 对扩展字段进行赋值
     *
     * @param field 字段名
     * @param value 值
     */
    public void setExtendField(String field, Object value) {
        if (currentTemplateBeanName == null || StringUtils.isEmpty(currentTemplateBeanName.get())) {
            return;
        }
        AbstractServiceImpl abstractService = getServiceImpl(currentTemplateBeanName.get(), AbstractServiceImpl.class);
        abstractService.getExtendFieldMap().put(field, value);
    }

    protected String getServiceName() throws AnnotationException {

        Class serviceClass = this.getClass();

        if (serviceClass.isAnnotationPresent(Service.class)) {
            Service service = this.getClass().getAnnotation(Service.class);
            if (service != null) {
                return service.value();
            }
            throwException("Has no Service annotation.");
        }

        if (serviceClass.isAnnotationPresent(Component.class)) {
            Component component = this.getClass().getAnnotation(Component.class);
            if (component != null) {
                return component.value();
            }
            throwException("Has no Component annotation.");
        }

        LOGGER.error("Has no annotation.");
        return null;
    }

    /**
     * 根据品类模板,对poiId进行分组
     *
     * @param poiIds
     * @return
     */
    public Map<Long, List<Long>> groupPoiIds(List<Long> poiIds) {
        Map<Long, List<Long>> map = null;
        map = templateFunction.groupPoiIds(poiIds);
        return map;
    }

    private void throwException(String message) throws AnnotationException {
        message = this.getClass() + "||" + message;
        LOGGER.error(message);
        throw new AnnotationException(message);
    }

}


5、在springContext.xml中声明TemplateBeanFactory
<bean class="com.baidu.nuomi.tpl.spring.TemplateBeanFactory"/>

TemplateBeanFactory的完整代码,包括模板变化时的刷新
/**
 * Bean工厂,创建在模板中定义的服务实例,填充业务参数和扩展字段
 * 定时刷新,如发现模板定义中的服务有变化,则刷新spring上下文中的实例.
 * User: mazhen01
 * Date: 2016/3/1
 * Time: 16:46
 */
public class TemplateBeanFactory implements ApplicationContextAware {

    private DefaultListableBeanFactory springFactory;

    private static final Logger LOGGER = LoggerFactory.getLogger(TemplateBeanFactory.class);

    @Autowired
    TemplateFunction templateFunction;

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        setSpringFactory(applicationContext);
        templateFunction.init();
        loadTemplateBeanDefinitions(templateFunction.getAllTemplateEntity());
    }

    /**
     * 刷新模板bean
     */
    public void refreshTemplateBeans(List<TemplateEntity> changedTemplates) {
        LOGGER.info("Refresh changed template beans start.");
        if(CollectionUtils.isEmpty(changedTemplates)){
            LOGGER.info("no template beans is changed");
            return;
        }
        destroyTemplateBeans(changedTemplates);
        loadTemplateBeanDefinitions(changedTemplates);
        LOGGER.info("Refresh changed template beans end.");
    }

    /**
     * 根据应用使用的不同applicationContext,获取BeanFactory
     *
     * @param applicationContext 应用使用的applicationContext
     */
    private void setSpringFactory(ApplicationContext applicationContext) {

        if (applicationContext instanceof AbstractRefreshableApplicationContext) {
            // suit both XmlWebApplicationContext and ClassPathXmlApplicationContext
            AbstractRefreshableApplicationContext springContext = (AbstractRefreshableApplicationContext) applicationContext;
            if (!(springContext.getBeanFactory() instanceof DefaultListableBeanFactory)) {
                LOGGER.error("No suitable bean factory! The current factory class is {}",
                        springContext.getBeanFactory().getClass());
            }
            springFactory = (DefaultListableBeanFactory) springContext.getBeanFactory();
        } else if (applicationContext instanceof GenericApplicationContext) {
            // suit GenericApplicationContext
            GenericApplicationContext springContext = (GenericApplicationContext) applicationContext;
            springFactory = springContext.getDefaultListableBeanFactory();
        } else {
            LOGGER.error("No suitable application context! The current context class is {}",
                    applicationContext.getClass());
        }
    }

    /**
     * 将模板中定义的service,填充业务参数和扩展字段,添加到BeanFactory的definition中
     */
    private void loadTemplateBeanDefinitions(List<TemplateEntity> templateEntityList) {
        if (CollectionUtils.isEmpty(templateEntityList)) {
            LOGGER.warn("");
            return;
        }
        for (TemplateEntity templateEntity : templateEntityList) {
            if (templateEntity == null || CollectionUtils.isEmpty(templateEntity.getServiceList())) {
                continue;
            }
            Long templateId = templateEntity.getIndustryId();
            for (ServiceEntity serviceEntity : templateEntity.getServiceList()) {
                registerBeanDefinition(templateId.toString(), serviceEntity);
            }
        }
    }

    /**
     * 根据service信息,创建BeanDefinition
     *
     * @param templateId    模板ID
     * @param serviceEntity service信息
     */
    private void registerBeanDefinition(String templateId, ServiceEntity serviceEntity) {

        try {
            if (springFactory.containsBeanDefinition(serviceEntity.getImplName())) {
                //step1: 注入多个实例
                String beanKey = generateTemplateBeanName(templateId, serviceEntity.getServiceName());
                BeanDefinition beanDefinition = springFactory.getBeanDefinition(serviceEntity.getImplName());
                String className = beanDefinition.getBeanClassName();
                Class c = null;
                try {
                    c = Class.forName(className);
                } catch (ClassNotFoundException e) {
                    e.printStackTrace();
                }

                BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.rootBeanDefinition(className);
                beanDefinitionBuilder.getBeanDefinition().setAttribute("id", beanKey);

                springFactory.registerBeanDefinition(
                        beanKey, beanDefinitionBuilder.getBeanDefinition());
                LOGGER.info("Register bean definition successfully. beanName:{}, implName:{}",
                        generateTemplateBeanName(templateId, serviceEntity.getServiceName()), serviceEntity.getImplName());

                //step2: 为实例自动化注入属性
                Object bean = springFactory.getBean(beanKey, c);
                injectParamVaules(bean, c, serviceEntity);
            }

        } catch (NoSuchBeanDefinitionException ex) {
            LOGGER.info("No bean definition in spring factory. implName:{}", serviceEntity.getImplName());
        } catch (BeanDefinitionStoreException ex) {
            LOGGER.info("Register bean definition wrong. beanName:{}, implName:{}",
                    generateTemplateBeanName(templateId, serviceEntity.getServiceName()), serviceEntity.getImplName());
        }
    }

    /**
     * 为bean实例注入业务参数和扩展字段
     *
     * @param bean
     * @param requiredType
     * @param serviceEntity
     * @param <T>
     */
    private <T> void injectParamVaules(Object bean, Class<T> requiredType, ServiceEntity serviceEntity) {

        if (requiredType.isAnnotationPresent(TemplateService.class)) {
            Field[] fields = requiredType.getDeclaredFields(); //获取类的所有属性
            for (Field field : fields) {
                // 注入业务参数
                if (field.isAnnotationPresent(TemplateBizParam.class)) {
                    field.setAccessible(true);
                    try {
                        if ((serviceEntity.getBizParamMap() != null) && (serviceEntity.getBizParamMap().containsKey(field.getName()))) {
                            field.set(bean, serviceEntity.getBizParamMap().get(field.getName()));
                            LOGGER.info("inject biz param value successfully, paramName = {}, value = {}", field.getName(), serviceEntity.getBizParamMap().get(field.getName()));
                        }
                    } catch (IllegalAccessException e) {
                        LOGGER.error("inject biz param failed. {}", e.getMessage());
                        e.printStackTrace();
                    }
                }
            }

            Class<AbstractService> superClass = getSuperClass(requiredType);
            if(superClass != null) {
                Field[] superFields = superClass.getDeclaredFields(); //获取类的所有属性
                for (Field field : superFields) {
                    // 注入扩展字段
                    if (field.isAnnotationPresent(TemplateExtendFields.class)) {
                        field.setAccessible(true);
                        try {
                            if(serviceEntity.getExtendFields() != null){
                                field.set(bean, serviceEntity.getExtendFields());
                                LOGGER.info("inject extend fields successfully, extendFields = {}", serviceEntity.getExtendFields());
                            }
                        } catch (IllegalAccessException e) {
                            LOGGER.error("inject extend fields failed. {}", e.getMessage());
                            e.printStackTrace();
                        }
                    }
                }
            }


        }
    }

    private Class<AbstractService> getSuperClass(Class clazz) {
        if (!AbstractService.class.isAssignableFrom(clazz)) {
            LOGGER.info("super class is null");
            return null;
        }
        Class<? extends AbstractService> superClass = clazz.getSuperclass();
        if (AbstractService.class != superClass) {
            superClass = getSuperClass(superClass);
        }
        return (Class<AbstractService>) superClass;
    }

    /***
     * 销毁模板bean
     */
    private void destroyTemplateBeans(List<TemplateEntity> changedTemplates) {

        if (CollectionUtils.isEmpty(changedTemplates)) {
            LOGGER.warn("");
            return;
        }
        for (TemplateEntity templateEntity : changedTemplates) {
            if (templateEntity == null || CollectionUtils.isEmpty(templateEntity.getServiceList())) {
                continue;
            }
            String templateId = templateEntity.getIndustryId().toString();
            for (ServiceEntity serviceEntity : templateEntity.getServiceList()) {

                if (springFactory.containsSingleton(generateTemplateBeanName(templateId, serviceEntity.getServiceName()))) {
//                    springFactory.destroySingleton(generateTemplateBeanName(templateId, serviceEntity.getServiceName()));  不需要显示的destroy方法,removeBeanDefinition中已调用此方法了
                    springFactory.removeBeanDefinition(generateTemplateBeanName(templateId, serviceEntity.getServiceName()));
                    LOGGER.info("destroy template beans successfully for beanName = {}", generateTemplateBeanName(templateId, serviceEntity.getServiceName()));
                }
            }
        }
    }


    /**
     * 从springFactory中获取bean
     *
     * @param name         bean名称
     * @param requiredType bean类型
     * @param <T>
     * @return
     * @throws BeansException
     */
    public <T> T getBean(String name, Class<T> requiredType) throws BeansException {
        return springFactory.getBean(name, requiredType);
    }

    ;

    /**
     * 生成模板service实例名称
     *
     * @param templateId  模板ID
     * @param serviceName service名称
     * @return
     */
    public static final String generateTemplateBeanName(String templateId, String serviceName) {
        StringBuilder builder = new StringBuilder(serviceName);
        builder.append("_");
        builder.append(templateId);
        return builder.toString();
    }
}

猜你喜欢

转载自mazhen2010.iteye.com/blog/2283592