springboot集成shiro自定义登陆过滤器

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第9天,点击查看活动详情

在上一篇博客springboot简单集成shiro权限管理中,用户在登录的过程中,有以下几个问题:

  • 用户在没有登陆的情况下,访问需要权限的接口,服务器自动跳转到登陆页面,前端无法控制;
  • 用户在登录成功后,服务器自动跳转到成功页,前端无法控制;
  • 用户在登录失败后,服务器自动刷新登录页面,前端无法控制;

很显然,这样的交互方式,用户体验上不是很好,并且在某些程度上也无法满足业务上的要求。所以,我们要对默认的FormAuthenticationFilter进行覆盖,实现我们自定义的Filter来解决用户交互的问题。

自定义UsernamePasswordAuthenticationFilter

  • 首先我们需要继承原先的FormAuthenticationFilter

    之所以继承这个FormAuthenticationFilter,有以下几点原因:

    1. FormAuthenticationFilter是默认拦截登录功能的过滤器,我们本身就是要改造登录功能,所以继承它很正常;
    2. 我们自定义的Filter需要复用里面的逻辑;
    public class UsernamePasswordAuthenticationFilter extends FormAuthenticationFilter{}
    复制代码
  • 其次,为了解决第一个问题,我们需要重写saveRequestAndRedirectToLogin方法

    /**
     * 没有登陆的情况下,访问需要权限的接口,需要引导用户登陆
     *
     * @param request
     * @param response
     * @throws IOException
     */
    @Override
    protected void saveRequestAndRedirectToLogin(ServletRequest request, ServletResponse response) throws IOException {
        //  保存当前请求,以便后续登陆成功后重新请求
        this.saveRequest(request);
        // 1. 服务端直接跳转
        //   - 服务端重定向登陆页面
        if (autoRedirectToLogin) {
            this.redirectToLogin(request, response);
        } else {
            // 2. json模式
            //   - json数据格式告知前端需要跳转到登陆页面,前端根据指令跳转登陆页面
            HttpServletRequest req = (HttpServletRequest) request;
            HttpServletResponse res = (HttpServletResponse) response;
            Map<String, String> metaInfo = new HashMap<>();
            // 告知前端需要跳转的登陆页面
            metaInfo.put("loginUrl", getLoginUrl());
            // 告知前端当前请求的url;这个信息也可以保存在前端
            metaInfo.put("currentRequest", req.getRequestURL().toString());
            ResultWrap.failure(802, "请登陆后再操作!", metaInfo)
              .writeToResponse(res);
        }
    }
    复制代码

    在这个方法中,我们通过配置autoRedirectToLogin参数的方式,既保留了原来服务器自动跳转的功能,又增强了服务器返回json给前端,让前端根据返回结果跳转到登陆页面的功能。这样就增强了应用程序的可控性和灵活性了。

  • 重写登陆成功的处理方法onLoginSuccess

    @Override
    protected boolean onLoginSuccess(AuthenticationToken token, Subject subject, ServletRequest request, ServletResponse response) throws Exception {
        // 查询当前用户自定义的登陆成功需要跳转的页面,可以更加灵活控制用户页面跳转
        String successUrl = loginSuccessPageFetch.successUrl(token, subject);
        // 如果没有自定义的成功页面,那么跳转默认成功页
        if (StringUtils.isEmpty(successUrl)) {
            successUrl = this.getSuccessUrl();
        }
        if (loginSuccessAutoRedirect) {
            // 服务端直接重定向到目标页面
            WebUtils.redirectToSavedRequest(request, response, successUrl);
        } else {
            SavedRequest savedRequest = WebUtils.getAndClearSavedRequest(request);
            if (savedRequest != null && savedRequest.getMethod().equalsIgnoreCase("GET")) {
                successUrl = savedRequest.getRequestUrl();
            }
            // 返回json数据格式告知前端跳转目标页面
            HttpServletResponse res = (HttpServletResponse) response;
            Map<String, String> data = new HashMap<>();
            // 登陆成功后跳转的目标页面
            data.put("successUrl", successUrl);
            ResultWrap.success(data).writeToResponse(res);
        }
        return false;
    }
    复制代码
    1. 登陆成功后,我们内置了一个个性化的成功页,用于保证针对不同的用户会有定制化的登陆成功页。
    2. 通过自定义的loginSuccessAutoRedirect属性来决定用户登陆成功后是直接由服务端控制页面跳转还是返回json让前端控制交互行为。
    3. 我们在用户登陆成功后,会获取前面保存的请求,以便用户在登录成功后能直接回到登录前点击的页面。
  • 重写用户登录失败的方法onLoginFailure

    @Override
      protected boolean onLoginFailure(AuthenticationToken token, AuthenticationException e, ServletRequest request, ServletResponse response) {
        if (log.isDebugEnabled()) {
          log.debug("Authentication exception", e);
        }
        this.setFailureAttribute(request, e);
        if (!loginFailureAutoRedirect) {
          // 返回json数据格式告知前端跳转目标页面
          HttpServletResponse res = (HttpServletResponse) response;
          ResultWrap.failure(803, "用户名或密码错误,请核对后无误后重新提交!", null).writeToResponse(res);
        }
        return true;
      }
    复制代码
    1. 登陆失败我们使用自定义属性loginFailureAutoRedirect来控制失败的动作是由服务端直接跳转页面还是返回json由前端控制用户交互。
    2. 在这个方法的逻辑里面没有看到跳转的功能,是因为我们直接把父类的默认实现拿过来了,在原有的逻辑上做了修改。既然默认是服务端跳转的功能,那么我们只需要补充返回json的功能即可。

覆盖默认的FormAuthenticationFilter

现在我们已经写好了自定义的用户名密码登陆过滤器,下面我们就把它加入到shiro的配置中去,这样才能生效:

@Bean
  public ShiroFilterFactoryBean shiroFilterFactoryBean() {
    ShiroFilterFactoryBean shiroFilterFactoryBean = new ShiroFilterFactoryBean();
    shiroFilterFactoryBean.setSecurityManager(securityManager());
    Map<String, String> filterChainDefinitionMap = new LinkedHashMap<>();
    // 设置不需要权限的url
    String[] permitUrls = properties.getPermitUrls();
    if (ArrayUtils.isNotEmpty(permitUrls)) {
      for (String permitUrl : permitUrls) {
        filterChainDefinitionMap.put(permitUrl, "anon");
      }
    }
    // 设置退出的url
    String logoutUrl = properties.getLogoutUrl();
    filterChainDefinitionMap.put(logoutUrl, "logout");
    // 设置需要权限验证的url
    filterChainDefinitionMap.put("/**", "authc");
    shiroFilterFactoryBean.setFilterChainDefinitionMap(filterChainDefinitionMap);
    // 设置提交登陆的url
    String loginUrl = properties.getLoginUrl();
    shiroFilterFactoryBean.setLoginUrl(loginUrl);
    // 设置登陆成功跳转的url
    String successUrl = properties.getSuccessUrl();
    shiroFilterFactoryBean.setSuccessUrl(successUrl);
    // 添加自定义Filter
    shiroFilterFactoryBean.setFilters(customFilters());
    return shiroFilterFactoryBean;
  }
​
/**
   * 自定义过滤器
   *
   * @return
   */
  private Map<String, Filter> customFilters() {
    Map<String, Filter> filters = new LinkedHashMap<>();
    // 自定义FormAuthenticationFilter,用于管理用户登陆的,包括登陆成功后的动作、登陆失败的动作
    // 可查看org.apache.shiro.web.filter.mgt.DefaultFilter,可覆盖里面对应的authc
    UsernamePasswordAuthenticationFilter usernamePasswordAuthenticationFilter = new UsernamePasswordAuthenticationFilter();
    // 不允许服务器自动控制页面跳转
    usernamePasswordAuthenticationFilter.setAutoRedirectToLogin(false);
    usernamePasswordAuthenticationFilter.setLoginSuccessAutoRedirect(false);
    usernamePasswordAuthenticationFilter.setLoginFailureAutoRedirect(false);
    filters.put("authc", usernamePasswordAuthenticationFilter);
    return filters;
  }
复制代码

上面的代码重点看 【添加自定义Filte】 ,其实原理就是把默认的authc过滤器给覆盖掉,换成我们自定义的过滤器,这样的话,我们的过滤器才能生效。

完整UsernamePasswordAuthenticationFilter代码

import com.example.awesomespring.vo.ResultWrap;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.AuthenticationToken;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.web.filter.authc.FormAuthenticationFilter;
import org.apache.shiro.web.util.SavedRequest;
import org.apache.shiro.web.util.WebUtils;
​
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
​
/**
 * @author zouwei
 * @className UsernamePasswordAuthenticationFilter
 * @date: 2022/8/2 上午12:14
 * @description:
 */
@Data
@Slf4j
public class UsernamePasswordAuthenticationFilter extends FormAuthenticationFilter {
  //  如果用户没有登陆的情况下访问需要权限的接口,服务端是否自动调整到登陆页面
  private boolean autoRedirectToLogin = true;
  // 登陆成功后是否自动跳转
  private boolean loginSuccessAutoRedirect = true;
  // 登陆失败后是否跳转
  private boolean loginFailureAutoRedirect = true;
  /**
   * 个性化定制每个登陆成功的账号跳转的url
   */
  private LoginSuccessPageFetch loginSuccessPageFetch = new LoginSuccessPageFetch(){};
​
  public UsernamePasswordAuthenticationFilter() {
  }
​
  /**
   * 没有登陆的情况下,访问需要权限的接口,需要引导用户登陆
   *
   * @param request
   * @param response
   * @throws IOException
   */
  @Override
  protected void saveRequestAndRedirectToLogin(ServletRequest request, ServletResponse response) throws IOException {
    //  保存当前请求,以便后续登陆成功后重新请求
    this.saveRequest(request);
    // 1. 服务端直接跳转
    //   - 服务端重定向登陆页面
    if (autoRedirectToLogin) {
      this.redirectToLogin(request, response);
    } else {
      // 2. json模式
      //   - json数据格式告知前端需要跳转到登陆页面,前端根据指令跳转登陆页面
      HttpServletRequest req = (HttpServletRequest) request;
      HttpServletResponse res = (HttpServletResponse) response;
      Map<String, String> metaInfo = new HashMap<>();
      // 告知前端需要跳转的登陆页面
      metaInfo.put("loginUrl", getLoginUrl());
      // 告知前端当前请求的url;这个信息也可以保存在前端
      metaInfo.put("currentRequest", req.getRequestURL().toString());
      ResultWrap.failure(802, "请登陆后再操作!", metaInfo)
          .writeToResponse(res);
    }
  }
​
  @Override
  protected boolean onLoginSuccess(AuthenticationToken token, Subject subject, ServletRequest request, ServletResponse response) throws Exception {
    // 查询当前用户自定义的登陆成功需要跳转的页面,可以更加灵活控制用户页面跳转
    String successUrl = loginSuccessPageFetch.successUrl(token, subject);
    // 如果没有自定义的成功页面,那么跳转默认成功页
    if (StringUtils.isEmpty(successUrl)) {
      successUrl = this.getSuccessUrl();
    }
    if (loginSuccessAutoRedirect) {
      // 服务端直接重定向到目标页面
      WebUtils.redirectToSavedRequest(request, response, successUrl);
    } else {
      SavedRequest savedRequest = WebUtils.getAndClearSavedRequest(request);
      if (savedRequest != null && savedRequest.getMethod().equalsIgnoreCase("GET")) {
        successUrl = savedRequest.getRequestUrl();
      }
      // 返回json数据格式告知前端跳转目标页面
      HttpServletResponse res = (HttpServletResponse) response;
      Map<String, String> data = new HashMap<>();
      // 登陆成功后跳转的目标页面
      data.put("successUrl", successUrl);
      ResultWrap.success(data).writeToResponse(res);
    }
    return false;
  }
​
  @Override
  protected boolean onLoginFailure(AuthenticationToken token, AuthenticationException e, ServletRequest request, ServletResponse response) {
    if (log.isDebugEnabled()) {
      log.debug("Authentication exception", e);
    }
    this.setFailureAttribute(request, e);
    if (!loginFailureAutoRedirect) {
      // 返回json数据格式告知前端跳转目标页面
      HttpServletResponse res = (HttpServletResponse) response;
      ResultWrap.failure(803, "用户名或密码错误,请核对后无误后重新提交!", null).writeToResponse(res);
    }
    return true;
  }
​
  /**
   * 针对不同的人员登陆成功后有不同的跳转页面而设计
   */
  public interface LoginSuccessPageFetch {
​
    default String successUrl(AuthenticationToken token, Subject subject) {
      return StringUtils.EMPTY;
    }
  }
}
复制代码

ResultWrap.java

import com.example.awesomespring.util.JsonUtil;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.apache.commons.lang3.StringUtils;
import org.springframework.http.HttpStatus;
​
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Objects;
​
/**
 * @author zouwei
 * @className ResultWrap
 * @date: 2022/8/2 下午2:02
 * @description:
 */
@Data
@AllArgsConstructor
public class ResultWrap<T, M> {
  //  方便前端判断当前请求处理结果是否正常
  private int code;
  //  业务处理结果
  private T data;
  //  产生错误的情况下,提示用户信息
  private String message;
  //  产生错误情况下的异常堆栈,提示开发人员
  private String error;
  //  发生错误的时候,返回的附加信息
  private M metaInfo;
​
  /**
   * 成功带处理结果
   *
   * @param data
   * @param <T>
   * @return
   */
  public static <T> ResultWrap success(T data) {
    return new ResultWrap(HttpStatus.OK.value(), data, StringUtils.EMPTY, StringUtils.EMPTY, null);
  }
​
  /**
   * 成功不带处理结果
   *
   * @return
   */
  public static ResultWrap success() {
    return success(HttpStatus.OK.name());
  }
​
  /**
   * 失败
   *
   * @param code
   * @param message
   * @param error
   * @return
   */
  public static <M> ResultWrap failure(int code, String message, String error, M metaInfo) {
    return new ResultWrap(code, null, message, error, metaInfo);
  }
​
  /**
   * 失败
   *
   * @param code
   * @param message
   * @param error
   * @param metaInfo
   * @param <M>
   * @return
   */
  public static <M> ResultWrap failure(int code, String message, Exception error, M metaInfo) {
    return failure(code, message, error.getStackTrace().toString(), metaInfo);
  }
​
  /**
   * 失败
   *
   * @param code
   * @param message
   * @param error
   * @return
   */
  public static ResultWrap failure(int code, String message, Exception error) {
    String errorMessage = StringUtils.EMPTY;
    if (Objects.nonNull(error)) {
      errorMessage = error.getStackTrace().toString();
    }
    return failure(code, message, errorMessage, null);
  }
​
  /**
   * 失败
   *
   * @param code
   * @param message
   * @param metaInfo
   * @param <M>
   * @return
   */
  public static <M> ResultWrap failure(int code, String message, M metaInfo) {
    return failure(code, message, StringUtils.EMPTY, metaInfo);
  }
​
  private static final String APPLICATION_JSON_VALUE = "application/json;charset=UTF-8";
​
  /**
   * 把结果写入响应中
   *
   * @param response
   */
  public void writeToResponse(HttpServletResponse response) {
    int code = this.getCode();
    if (Objects.isNull(HttpStatus.resolve(code))) {
      response.setStatus(HttpStatus.OK.value());
    } else {
      response.setStatus(code);
    }
    response.setContentType(APPLICATION_JSON_VALUE);
    try (PrintWriter writer = response.getWriter()) {
      writer.write(JsonUtil.obj2String(this));
      writer.flush();
    } catch (IOException e) {
      e.printStackTrace();
    }
  }
}
复制代码

JsonUtil.java

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
​
import java.util.Objects;
​
/**
 * @author zouwei
 * @className JsonUtil
 * @date: 2022/8/2 下午3:08
 * @description:
 */
@Slf4j
public final class JsonUtil {
​
  /** 防止使用者直接new JsonUtil() */
  private JsonUtil() {}
​
  private static ObjectMapper objectMapper = new ObjectMapper();
​
  static {
    // 对象所有字段全部列入序列化
    objectMapper.setSerializationInclusion(JsonInclude.Include.ALWAYS);
    /** 所有日期全部格式化成时间戳 因为即使指定了DateFormat,也不一定能满足所有的格式化情况,所以统一为时间戳,让使用者按需转换 */
    objectMapper.configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, true);
    /** 忽略空Bean转json的错误 假设只是new方式创建对象,并且没有对里面的属性赋值,也要保证序列化的时候不报错 */
    objectMapper.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
    /** 忽略反序列化中json字符串中存在,但java对象中不存在的字段 */
    objectMapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
  }
​
  /**
   * 对象转换成json字符串
   *
   * @param obj
   * @param <T>
   * @return
   */
  public static <T> String obj2String(T obj) {
    return obj2String(obj, null);
  }
  /**
   * 对象转换成json字符串
   *
   * @param obj
   * @param <T>
   * @return
   */
  public static <T> String obj2String(T obj, String defaultValue) {
    if (Objects.isNull(obj)) {
      return defaultValue;
    }
    try {
      return obj instanceof String ? (String) obj : objectMapper.writeValueAsString(obj);
    } catch (Exception e) {
      log.warn("Parse object to String error", e);
      // 即使序列化出错,也要保证程序走下去
      return null;
    }
  }
​
  /**
   * 对象转json字符串(带美化效果)
   *
   * @param obj
   * @param <T>
   * @return
   */
  public static <T> String obj2StringPretty(T obj) {
    if (Objects.isNull(obj)) {
      return null;
    }
    try {
      return obj instanceof String
          ? (String) obj
          : objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(obj);
    } catch (Exception e) {
      log.warn("Parse object to String error", e);
      // 即使序列化出错,也要保证程序走下去
      return null;
    }
  }
​
  /**
   * json字符串转简单对象
   *
   * @param <T>
   * @param json
   * @param clazz
   * @return
   */
  public static <T> T string2Obj(String json, Class<T> clazz) {
    if (StringUtils.isEmpty(json) || Objects.isNull(clazz)) {
      return null;
    }
    try {
      return clazz.equals(String.class) ? (T) json : objectMapper.readValue(json, clazz);
    } catch (Exception e) {
      log.warn("Parse String to Object error", e);
      // 即使序列化出错,也要保证程序走下去
      return null;
    }
  }
​
  /**
   * json字符串转复杂对象
   *
   * @param json
   * @param typeReference 例如:new TypeReference<List<User>>(){}
   * @param <T> 例如:List<User>
   * @return
   */
  public static <T> T string2Obj(String json, TypeReference<T> typeReference) {
    if (StringUtils.isEmpty(json) || Objects.isNull(typeReference)) {
      return null;
    }
    try {
      return (T)
          (typeReference.getType().equals(String.class)
              ? (T) json
              : objectMapper.readValue(json, typeReference));
    } catch (Exception e) {
      log.warn("Parse String to Object error", e);
      // 即使序列化出错,也要保证程序走下去
      return null;
    }
  }
​
  /**
   * json字符串转复杂对象
   *
   * @param json
   * @param collectionClass 例如:List.class
   * @param elementClasses 例如:User.class
   * @param <T> 例如:List<User>
   * @return
   */
  public static <T> T string2Obj(
      String json, Class<?> collectionClass, Class<?>... elementClasses) {
    if (StringUtils.isEmpty(json)
        || Objects.isNull(collectionClass)
        || Objects.isNull(elementClasses)) {
      return null;
    }
    JavaType javaType =
        objectMapper
            .getTypeFactory()
            .constructParametricType(collectionClass, elementClasses);
    try {
      return objectMapper.readValue(json, javaType);
    } catch (Exception e) {
      log.warn("Parse String to Object error", e);
      // 即使序列化出错,也要保证程序走下去
      return null;
    }
  }
}
复制代码

至此,在shiro中如何实现更灵活的登陆控制就编写完毕了。后面会陆续讲解我在使用shiro时遇到的其他问题,以及相应的解决方案。

猜你喜欢

转载自juejin.im/post/7127694462062428173