SpringCloud利用网关拦截做Token验证(JWT方式)

背景:前后分离

关于JWT的内容,请看以下链接,主要看它的原理,以及缺点!

https://blog.csdn.net/memmsc/article/details/78122931

步骤1:前端传userName+password到后端,后端为springcloud架构,经过网关的拦截器拦截请求,拦截器在项目启动的时候@Component进行加载。

步骤2:如果是第一次登陆,放行,进入JWT的加密生成Token阶段(还可以写入登陆用户的其他信息放在JWTmap中,之后可以利用Token来获取该用户信息),加密token需要一个随机数作为加密字段,将token的失效时间设置为一天,并且放到reids里面,设置该redis里面的token过期时间为30分钟,最后将Token返回给前端。

步骤3:以后任何的请求都带Token到后端去请求。

步骤4:拦截到非登陆请求,进行解密,鉴权,如果鉴权通过,更新redis里面token字段的失效时间,如果还有5分钟失效,再设置还有30分钟,目的就是让密码的过期时间变的活跃。

大致就是以上的过程,核心代码主要在网关拦截器解密鉴权和登陆接口的加密两部分

0,controller层的将得到的token做保存redis和设置过期时间的操作

 compactJws = authService.generateJwt(username, password, userBean);
 //将token存在redis里
 stringRedisTemplate.opsForValue().set("token", compactJws);
 //设置redis里面的数据失效时间为半小时
 stringRedisTemplate.expire("token",1800,TimeUnit.SECONDS);

1,登陆接口的加密:

package com.movitech.user.service.imp;

import com.movitech.commons.entity.UserBean;
import com.movitech.commons.utils.CommonConstants;
import com.movitech.user.service.AuthService;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
import org.joda.time.DateTime;
import org.springframework.stereotype.Service;

import java.util.Base64;
import java.util.HashMap;
import java.util.Map;

/**
 * 用户身份验证Service
 */
@Service(value = "authService")
public class AuthServiceImpl implements AuthService {
    @Override
        public String generateJwt(String userName, String userPassword, UserBean userBean) {
        // Base64编码后的secretKey
        byte[] secretKey = Base64.getEncoder().encode(CommonConstants.SECURITY_KEY.getBytes());
        // 设置失效时间
        DateTime expirationDate = new DateTime().plusDays(1);
        //DateTime expirationDate = new DateTime().plusMinutes(30);
        // Claims是需要保存到token中的信息,可以自定义,需要存什么就放什么,会保存到token的payload中
        Map<String, Object> claims = new HashMap<>();
        // 用户角色
        claims.put("role", "user");
        // 用户名
        claims.put("userName", userName);
        claims.put(CommonConstants.USER_ID, userBean.getId());
        claims.put("uuid",UUID.randomUUID().toString());
        String compactJws = Jwts.builder()
                // 设置subject,一般是用户的唯一标识,比如用户对象的ID,用户名等,目前设置的是userCode
                .setSubject(userName)
                // 设置失效时间
                .setExpiration(expirationDate.toDate())
                .addClaims(claims)
                // 加密算法是HS512,加密解密统一就可以
                .signWith(SignatureAlgorithm.HS512, secretKey)
                .compact();
        return compactJws;
    }

}

以上常量类和pojo此处省略。。。。

2,网关拦截器解密鉴权:

package com.movitech.gateway.filter;

import com.movitech.commons.dto.ErrorResponseMap;
import com.movitech.commons.enums.ErrorCode;
import com.movitech.commons.utils.CommonConstants;
import com.movitech.commons.utils.JsonUtil;
import com.movitech.commons.utils.ResponseUtil;
import com.netflix.zuul.ZuulFilter;
import com.netflix.zuul.context.RequestContext;
import io.jsonwebtoken.*;
import org.springframework.cloud.netflix.zuul.filters.support.FilterConstants;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import javax.servlet.http.HttpServletRequest;
import java.util.Base64;

@Component
public class SecurityFilter extends ZuulFilter {
    @Override
    public String filterType() {
        return FilterConstants.PRE_TYPE;
    }

    @Override
    public int filterOrder() {
        return FilterConstants.PRE_DECORATION_FILTER_ORDER - 1;
    }

    @Override
    public boolean shouldFilter() {
        RequestContext ctx = RequestContext.getCurrentContext();
        HttpServletRequest request = ctx.getRequest();
        if (request.getRequestURL().toString().contains("loginInfo") || request.getRequestURL().toString().contains("info")) {
            return false;
        }
        // TODO
        return true;
    }

    @Override
    public Object run() {
        RequestContext ctx = RequestContext.getCurrentContext();
        HttpServletRequest request = ctx.getRequest();
        final String authorizationHeader = request.getHeader(HttpHeaders.AUTHORIZATION);

        if (HttpMethod.OPTIONS.name().equals(request.getMethod())) {
            return null;
        } else {
            if (StringUtils.isEmpty(authorizationHeader) || !authorizationHeader.startsWith(CommonConstants.BEARER)) {
                // Missing or invalid Authorization header
                ErrorResponseMap errorResponseMap = ResponseUtil.createErrorResponse(null, "Missing or invalid Authorization header!",
                        ErrorCode.INVALID_AUTHORIZATION_HEADER, request, null);
                denyAccess(ctx,errorResponseMap);
                return JsonUtil.serializeToString(errorResponseMap);
            }
            final String token = authorizationHeader.substring(7);
            try {
                byte[] secretKey = Base64.getEncoder().encode(CommonConstants.SECURITY_KEY.getBytes());
                Claims claims = Jwts.parser().setSigningKey(secretKey).parseClaimsJws(token).getBody();
                if (claims != null) {
//获取redis里面数据的存活时间
                    Long expirationDate = stringRedisTemplate.getExpire("token",TimeUnit.SECONDS);
                    //如果还剩余5分钟,重置redis里面数据的存活时间
                    if(expirationDate > 300){
                        stringRedisTemplate.expire("token",1800,TimeUnit.SECONDS);
                    }else {
                        ErrorResponseMap errorResponseMap = new ErrorResponseMap();
                        Error error = new Error(null, "Token expired!",  "", 1003,"");
                        errorResponseMap.setSuccess(false);
                        errorResponseMap.setMessage(null);
                        errorResponseMap.setError(error);
                        denyAccess(ctx,errorResponseMap);
                        return JsonUtil.serializeToString(errorResponseMap);
                    }
                    String userName = (String) claims.get(CommonConstants.USER_CODE);
                    Integer userId = (Integer) claims.get(CommonConstants.USER_ID);
                    ctx.addZuulRequestHeader(CommonConstants.USER_CODE, userName);
                    ctx.addZuulRequestHeader(CommonConstants.USER_ID, String.valueOf(userId));
                }
            } catch (MalformedJwtException ex) {
                ErrorResponseMap errorResponseMap = ResponseUtil.createErrorResponse(null, "Invalid token!",
                        ErrorCode.INVALID_AUTHORIZATION_HEADER, request, ex);
                denyAccess(ctx,errorResponseMap);
                return JsonUtil.serializeToString(errorResponseMap);
            } catch (SignatureException ex) {
                ErrorResponseMap errorResponseMap = ResponseUtil.createErrorResponse(null, "Token Signature error!",
                        ErrorCode.SIGNATURE_EXCEPTION, request, ex);
                denyAccess(ctx,errorResponseMap);
                return JsonUtil.serializeToString(errorResponseMap);
            } catch (ExpiredJwtException ex) {
                ErrorResponseMap errorResponseMap = ResponseUtil.createErrorResponse(null, "Token expired!",
                        ErrorCode.EXPIRED_JWT_EXCEPTION, request, ex);
                denyAccess(ctx,errorResponseMap);
                return JsonUtil.serializeToString(errorResponseMap);
            }
        }
        return null;
    }

    private void denyAccess(RequestContext ctx, ErrorResponseMap authResult) {
        String result = JsonUtil.serializeToString(authResult);
        ctx.setSendZuulResponse(false);
        ctx.setResponseStatusCode(401);
        try {
            ctx.getResponse().getWriter().write(result);
        }catch (Exception e){}
    }
}

以上代码是核心代码,下面是自己项目中涉及到的异常包装类,以及util类,可以不管下面的,直接去封装

以上涉及到的类:

(1)ErrorResponseMap

package com.movitech.commons.dto;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.movitech.commons.exception.Error;
import lombok.Getter;
import lombok.Setter;

@Getter
@Setter
public class ErrorResponseMap extends ResponseMap {
    @JsonProperty(value = "error")
    private Error error;
    @JsonProperty(value = "stackTrace")
    private String stackTrace;
}

(1.1)Error

package com.movitech.commons.exception;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
public class Error {
    // 标准的 Http status code
    private Integer httpStatusCode;
    // 自定义的错误说明
    private String errorMsg;
    // 异常信息
    private String exceptionMsg;
    // 自定义的错误代码
    private Integer errorCode;
    // 异常的类名
    private String exceptionClassName;
}

(2)ErrorCode

package com.movitech.commons.enums;

/**
 * 自定义的错误代码的枚举
 */
public enum ErrorCode {
    // Token 签名错误
    SIGNATURE_EXCEPTION(1000),
    // Token 过期
    EXPIRED_JWT_EXCEPTION(1001),
    // 无效的Authorization header
    INVALID_AUTHORIZATION_HEADER(1002);
    private Integer errorCode;
    ErrorCode(Integer errorCode) {
        this.errorCode = errorCode;
    }

    @Override
    public String toString() {
        return errorCode.toString();
    }

    public Integer value() {
        return errorCode;
    }
}

(3)JsonUtil

package com.movitech.commons.utils;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.*;
import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import java.io.IOException;

/**
 * Json序列化和反序列化
 */
@Component
public class JsonUtil
{
    public static ObjectMapper mapper;
/*    static {
        dao = new ObjectMapper();
        dao.configure(SerializationFeature.WRAP_ROOT_VALUE,true);
        dao.configure(DeserializationFeature.UNWRAP_ROOT_VALUE, true);
    }*/

    public JsonUtil(Jackson2ObjectMapperBuilder jackson2ObjectMapperBuilder) {
        mapper = jackson2ObjectMapperBuilder.build();
    }

    public static String serializeToString(Object object) {
        return serializeToString(object,false);
    }

    public static String serializeToString(Object object,Boolean rootValueState) {
        SerializationConfig serializationConfig = mapper.getSerializationConfig();
        mapper.configure(SerializationFeature.WRAP_ROOT_VALUE,rootValueState);
        if (object != null) {
            try {
                return mapper.writeValueAsString(object);
            } catch (JsonProcessingException e) {
                e.printStackTrace();
            } finally {
                mapper.setConfig(serializationConfig);
            }
        }
        return "";
    }

    public static byte[] serializeToBytes(Object object) {
        if (object != null) {
            try {
                return mapper.writeValueAsBytes(object);
            } catch (JsonProcessingException e) {
                e.printStackTrace();
            }
        }
        return null;
    }

    /**
     * 反序列化Json数据
     * @param jsonData Json数据字符串
     * @param valueType 反序列化的类型
     * @param rootValueState 是否解析Json root name
     * @param <T>
     * @return 反序列化后的POJO
     */
    public static <T> T deserialize(String jsonData,Class<T> valueType,Boolean rootValueState) {
        if (StringUtils.isEmpty(jsonData) || rootValueState == null) {
            return null;
        }
        DeserializationConfig deserializationConfig = mapper.getDeserializationConfig();
        mapper.configure(DeserializationFeature.UNWRAP_ROOT_VALUE,rootValueState);
        try {
            return mapper.readValue(jsonData, valueType);
        } catch (IOException e) {
            e.printStackTrace();
        } finally {
            mapper.setConfig(deserializationConfig);
        }
        return null;
    }

    /**
     * 反序列化Json数据,默认解析Json root name
     * @param jsonData Json数据字符串
     * @param valueType 反序列化的类型
     * @param <T>
     * @return 反序列化后的POJO
     */
    public static <T> T deserialize(String jsonData,Class<T> valueType) {
        return deserialize(jsonData, valueType, true);
    }

    public static JavaType getCollectionType(Class<?> collectionClass, Class<?>... elementClasses) {
        return mapper.getTypeFactory().constructParametricType(collectionClass, elementClasses);
    }

    /**
     * 用Json数据中的key获取对应的value
     * @param jsonString json数据字符串
     * @param field 需要取值的字段
     * @param rootValueState 是否解析Json root name
     * @return 字段对应的值
     */
    public static String getValue(String jsonString,String field,Boolean rootValueState) {
        JsonNode node = getJsonNode(jsonString,field,rootValueState);
        return node == null ? "" : node.toString();
    }

    /**
     * 用Json数据中的key获取对应的value, 默认解析Json root name
     * @param jsonString json数据字符串
     * @param field 需要取值的字段
     * @return 字段对应的值
     */
    public static String getValue(String jsonString,String field) {
        return getValue(jsonString,field,true);
    }

    /**
     * 用Json数据中的key获取对应的value
     * @param jsonString json数据字符串
     * @param field 需要取值的字段
     * @param rootValueState 是否解析Json root name
     * @return 字段对应的值
     */
    public static JsonNode getJsonNode(String jsonString, String field, Boolean rootValueState) {
        if (StringUtils.isEmpty(jsonString) || StringUtils.isEmpty(field) || rootValueState == null) {
            return null;
        }
        // 准备工作 传入vo请参照第一篇里面的实体。此处不再重新贴上代码 浪费大家时间
        JsonNode node = null;// 这里的JsonNode和XML里面的Node很像
        // 默认的反序列化配置
        DeserializationConfig deserializationConfig = mapper.getDeserializationConfig();
        mapper.configure(DeserializationFeature.UNWRAP_ROOT_VALUE,rootValueState);
        try {
            node = mapper.readTree(jsonString);
        } catch (IOException e) {
            e.printStackTrace();
            return null;
        } finally {
            mapper.setConfig(deserializationConfig);
        }
        return node.get(field) == null ? null : node.get(field);
    }
}

(4)ResponseUtil

package com.movitech.commons.utils;

import com.movitech.commons.dto.ErrorResponseMap;
import com.movitech.commons.enums.ErrorCode;
import com.movitech.commons.exception.Error;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;

import javax.servlet.http.HttpServletRequest;

public class ResponseUtil {
    public static ResponseEntity createResponseEntity(String errorMsg, ErrorCode errorCode, HttpServletRequest request, Throwable ex) {
        HttpStatus status = getStatus(request);
        ErrorResponseMap errorResponseMap = createErrorResponse(status, errorMsg, errorCode, request, ex);
        return new ResponseEntity<>(errorResponseMap, status);
    }

    public static ErrorResponseMap createErrorResponse(HttpStatus status, String errorMsg, ErrorCode errorCode, HttpServletRequest request, Throwable ex) {
        if (status == null) {
            status = getStatus(request);
        }
        ErrorResponseMap errorResponseMap = new ErrorResponseMap();
        Error error = new Error(status.value(), errorMsg, ex == null ? "" : ex.getMessage(),
                errorCode == null ? HttpStatus.INTERNAL_SERVER_ERROR.value() : errorCode.value(), ex == null ? "" : ex.getClass().getCanonicalName());
        errorResponseMap.setSuccess(false);
        errorResponseMap.setMessage(ex.getMessage());
        errorResponseMap.setError(error);
//        String stackTrace = StringUtils.arrayToDelimitedString(ex.getStackTrace(), "hahaha");
//        errorResponseMap.setStackTrace(stackTrace);
        return errorResponseMap;
    }

    public static HttpStatus getStatus(HttpServletRequest request) {
        Integer statusCode = (Integer) request.getAttribute("javax.servlet.error.status_code");
        if (statusCode == null) {
            return HttpStatus.INTERNAL_SERVER_ERROR;
        }
        return HttpStatus.valueOf(statusCode);
    }
}

猜你喜欢

转载自blog.csdn.net/qq_34707991/article/details/82898187