自实现 mvc - DispatcherServlet

package com.xuchen.demo.servlet;

import java.io.File;
import java.io.FileFilter;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;

import javax.servlet.ServletConfig;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;

import org.dom4j.Document;
import org.dom4j.Element;
import org.dom4j.io.SAXReader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.alibaba.fastjson.JSON;
import com.xuchen.demo.annotation.Autowired;
import com.xuchen.demo.annotation.Controller;
import com.xuchen.demo.annotation.RequestMapping;
import com.xuchen.demo.annotation.ResponseBody;
import com.xuchen.demo.annotation.Service;
import com.xuchen.demo.view.Model;
import com.xuchen.demo.view.ModelAndView;
import com.xuchen.demo.view.ModelMap;

@SuppressWarnings("serial")
@WebServlet
public class DispatcherServlet extends HttpServlet {
    
    private static final Logger logger = LoggerFactory.getLogger(DispatcherServlet.class);
    private static final String PACKAGE_PROTOCOL_FILE = "file";
    private static final String PACKAGE_PROTOCOL_JAR = "jar";
    private List<String> classNames = new ArrayList<String>();
    private Map<String, Object> handlerMapper = new ConcurrentHashMap<String, Object>();
    private Map<String, Map<String, Object>> handlerMethodMap = new ConcurrentHashMap<String, Map<String, Object>>();
    private static final String SERVLET_CLAZZ = "DISPATCHERSERVLET.ROOT.CLAZZ";
    private static final String SERVLET_METHOD = "DISPATCHERSERVLET.ROOT.METHOD";
    private static final String SERVLET_CLAZZ_FIELDS = "DISPATCHERSERVLET.ROOT.FIELDS";
    private static final String SERVLET_REQUESTMETHOD_ANNOTION = "DISPATCHERSERVLET.ROOT.REQUESTMETHOD.ANNOTION";
    
    
    @SuppressWarnings("unchecked")
    @Override
    public void init(ServletConfig config) throws ServletException {
        System.out.println("********** init dispatcherServlet start ***********");
        try {
            String filePath = config.getInitParameter("contextConfigLocation");
            String servletXmlPath = filePath.substring(filePath.lastIndexOf(":") + 1);
            Document document = null;
            SAXReader render = new SAXReader();
            System.out.println(this.getClass().getName() + "{}" + servletXmlPath);
            
            InputStream inputStream = this.getClassLoader().getResourceAsStream(servletXmlPath);
            if (null == inputStream) {
                System.out.println(this.getClass().getName() + " not find springmvc-servlet.xml");
                return;
            }
            document = render.read(inputStream);
            String xmlPacakScanPath = "//context-package-scan";
            List<Element> nodelist = document.selectNodes(xmlPacakScanPath);
            for (Element element : nodelist) {
                Object packageName = element.getData();
                contextScanPackage(String.valueOf(packageName));
            }
            
            Object object = config.getServletContext().getAttribute("classNames");
            List<String> claxxlist = (List<String>) object;
            for (String str : claxxlist) {
                classNames.add(str);
            }
            config.getServletContext().setAttribute("classNames", classNames);
            object = config.getServletContext().getAttribute("classNames");
            claxxlist = (List<String>) object;
            System.out.println(getClass().getName() + " " + JSON.toJSONString(object) + " claxxlist size " + claxxlist.size());
            handle(config.getServletContext());
        } catch (Exception e) {
            logger.error(this.getClass().getName() + " -- init dispacherServlet exception {}" , e);
            throw new RuntimeException("init dispacherServlet exception", e);
        }
        System.out.println("********** init dispatcherServlet end  ***********");
    }
    private void contextScanPackage (String packageName) {
        try {
            Enumeration<URL> urlEnum = getClassLoader().getResources(packageName.replace(".", "/"));
            while(urlEnum.hasMoreElements()) {
                URL url = urlEnum.nextElement();
                String protocol = url.getProtocol();
                if (PACKAGE_PROTOCOL_FILE.equals(protocol)) {
                    String clazz = url.getPath().replaceAll("%20", "");
                    File[] files = new File(clazz).listFiles(new FileFilter() {
                        @Override
                        public boolean accept(File file) {
                            return (file.isFile() && file.getName().endsWith(".class") || file.isDirectory());
                        }
                    });
                    for (File file : files) {
                        String fileName = file.getName();
                        if (file.isDirectory()) {
                            contextScanPackage(packageName + "." + fileName);
                        } else {
                            String className = packageName + "." + fileName.substring(0, fileName.lastIndexOf("."));
                            if (!"".equals(className)) {
                                Class<?> claxx = Class.forName(className, false, getClassLoader());
                                if (claxx.isAnnotationPresent(Controller.class)|| claxx.isAnnotationPresent(Service.class)) {
                                    classNames.add(className);
                                }
                            }
                        }
                    }
                } else if (PACKAGE_PROTOCOL_JAR.equals(protocol)){
                    System.out.println("---------- now package is in a jar file.");
                }
            }
        } catch (Exception e) {
            System.out.println(this.getClass().getName() + " -- analasis object instance package exception "  + e);
            throw new RuntimeException("analasis object instance package exception", e);
        }
    } 
    public ClassLoader getClassLoader() {
        return Thread.currentThread().getContextClassLoader();
    }
    
    @SuppressWarnings("unchecked")
    private void handle(ServletContext context) {
        if (null == classNames || classNames.isEmpty()) {
            return;
        }
        
        try {
            Map<String, Object> methodMap = new ConcurrentHashMap<String, Object>();  
            for (String clazz : classNames) {
                Class<?> claxx = Class.forName(clazz);
                if (claxx.isAnnotationPresent(Controller.class)) {
                    Object controllerObj = claxx.newInstance();
                    RequestMapping requestMap = claxx.getAnnotation(RequestMapping.class);
                    String headerPath = requestMap.value();
                    String requestPath = "";
                    Method[] methods = claxx.getDeclaredMethods();
                    
                    Field[] fields = claxx.getDeclaredFields();
                    List<Object> fieldClazz = new ArrayList<Object>();
                    if (null != fields) {
                        for (Field field : fields) {
                            if(field.isAnnotationPresent(Autowired.class)){
                                Map<String, Object> beanContries = (Map<String, Object>) context.getAttribute("beanContries");
                                Object object = null;
                                Class<?> classInterface = Class.forName(field.getType().getName());
                                if (classInterface.isInterface()) {
                                    for (Entry<String, Object> map : beanContries.entrySet()) {
                                        Class<?> subClabb = map.getValue().getClass();
                                        if (classInterface.isAssignableFrom(subClabb)) {
                                            object = subClabb.newInstance();
                                        }
                                    }
                                }
                                fieldClazz.add(object);
                                field.setAccessible(true);
                                field.set(controllerObj, object);
                            }
                        }
                    }
                    for (Method method : methods) {
                        RequestMapping methodReqMap = method.getAnnotation(RequestMapping.class);
                        String methodPath = methodReqMap.value();
                        requestPath = headerPath + methodPath;
                        System.out.println(this.getClass().getName() + "-- " + requestPath);
                        
                        methodMap.put("methodName", method.getName());
                        
                        handlerMapper.put(requestPath, method);
                        Map<String, Object> controller = new HashMap<String, Object>();
                        controller.put(SERVLET_CLAZZ, controllerObj);
                        controller.put(SERVLET_METHOD, method);
                        controller.put(SERVLET_CLAZZ_FIELDS, fieldClazz);
                        controller.put(SERVLET_REQUESTMETHOD_ANNOTION, methodReqMap.method());
                        handlerMethodMap.put(requestPath, controller);
                    }
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
    
    @Override
    public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        this.doDispatch(request, response);
    }
    
    
    @Override
    public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
        this.doGet(request, response);
    }
    
    public void doDispatch (HttpServletRequest request, HttpServletResponse response) throws IOException {
        response.setCharacterEncoding("utf-8");
        PrintWriter outWriter = response.getWriter();
        
        try {
            String requestURL = request.getRequestURI();
            request.setCharacterEncoding("utf-8");    
            String reqMethod = JSON.toJSONString(request.getMethod());
            
            Object object = null;
            Method method = null;
            String requestMethod = null;
            for (Entry<String, Map<String, Object>> entry : handlerMethodMap.entrySet()) {
                if (requestURL.contains(entry.getKey())) {
                    System.out.println(this.getClass().getName() + "-- 当前访问的路径:" + entry.getKey());
                    System.out.println(this.getClass().getName() + "-- 当前访问的路径Controller :" 
                                       + entry.getValue().get(SERVLET_CLAZZ).getClass().getName());
                    object = entry.getValue().get(SERVLET_CLAZZ);
                    method = (Method) entry.getValue().get(SERVLET_METHOD);
                    requestMethod = JSON.toJSONString(entry.getValue().get(SERVLET_REQUESTMETHOD_ANNOTION));
                    break;
                }
            }
            Object ret = null;
            StringBuffer viewName = new StringBuffer();
            if (null != object) {
                if (!requestMethod.contains(reqMethod)) {
                    throw new ServletException(getClass().getName() + object.getClass().getName()+ "." + method.getName()
                            + " " + reqMethod + " not support");
                }
                Class<?>[] parameterTypes =  method.getParameterTypes();
                Object[] parameters = new Object[parameterTypes.length];
                if (null != parameterTypes) {
                    int i = 0;
                    for (Class<?> class1 : parameterTypes) {
                        if (class1.isAssignableFrom(HttpServletRequest.class)) {
                            parameters[i++] = request;
                        } else if (class1.isAssignableFrom(HttpServletResponse.class)) {
                            parameters[i++] = response;
                        } else if (Model.class.isAssignableFrom(class1)) {
                            parameters[i++] = new ModelMap();
                        } else {
                            parameters[i++] = class1.newInstance();
                        }
                    }
                }
                
                ret = method.invoke(object, parameters);
                ResponseBody responseBody = method.getAnnotation(ResponseBody.class);
                
                // model 数据copy到request
                for (Object value : parameters) {
                    if (value instanceof ModelMap) {
                        ModelMap model = (ModelMap) value;
                        Map<String, Object> dataMa = model.getModelMap().get("data");
                        for (Entry<String, Object> reqData : dataMa.entrySet()) {
                            request.setAttribute(reqData.getKey(), reqData.getValue());
                        }
                    }
                }
                
                if (requestMethod.contains("GET")) {
                    response.setHeader("Content-Type","text/html; charset=utf-8");  
                    if (ret instanceof ModelAndView) {
                        ModelAndView view = (ModelAndView) ret;
                        viewName = viewName.append(view.getViewName());
                        System.out.println(this.getClass().getName() + " 方法返回 modelAndView :" + viewName);
                    } else {
                        viewName = viewName.append(ret);
                        System.out.println(this.getClass().getName() + " 方法直接跳转至:" + viewName);
                    }
                    request.getRequestDispatcher(viewName.toString()).forward(request, response);
                } else if(requestMethod.contains("POST")) {
                    if (null != responseBody) {
                        System.out.println("返回JSON " + JSON.toJSONString(ret));
                        response.setContentType("application/json;charset=utf-8");
                        outWriter.write(JSON.toJSONString(ret));
                    } else {
                        response.sendRedirect("<em>" + getClass().getName() + " @ResponseBody not suppond data </em>");
                    } 
                }
            }
        } catch (Exception e) {
            System.out.println(getClass().getName() + " doDispatch exception " + e);
        } finally {
            outWriter.flush();
            outWriter.close();
        }
    }
}
 

猜你喜欢

转载自my.oschina.net/u/2510361/blog/1794034