springboot+shiro+redis实现session共享和cache共享

在分布式应用中,若是使用了负载均衡,用户第一次访问,连接的A服务器,进行了登录操作进入了系统,当用户再次操作时,请求被转发到了B服务器,用户并没有在B进行登录,此时用户又来到了登录页面,这是难以理解和接受的,这就引出了session共享。
对于shiro安全框架如何实现session共享?shiro共享分为两方面,一个是session共享,一个是cache共享。下面聊聊在springboot工程中整合shiro框架,并通过redis实现session共享和cache共享。

1、shiro

Apache Shiro 是 Java 的一个安全框架,提供认证、授权、加密、会话管理、与 Web 集成、缓存等。
在这里插入图片描述

  • SecurityManager即安全管理器,对全部的subject进行安全管理,它是shiro的核心,负责对所有的subject进行安全管理。通过SecurityManager可以完成subject的认证、授权等,实质上SecurityManager是通过Authenticator进行认证,通过Authorizer进行授权,通过SessionManager进行会话管理等。
    SecurityManager是一个接口,继承了Authenticator, Authorizer, SessionManager这三个接口。
  • SessionManager即会话管理,shiro框架定义了一套会话管理,它不依赖web容器的session,所以shiro可以使用在非web应用上,也可以将分布式应用的会话集中在一点管理,此特性可使它实现单点登录。
  • CacheManager即缓存管理,将用户权限数据存储在缓存,这样可以提高性能。
  • SessionDAO即会话dao,是对session会话操作的一套接口,比如要将session存储到数据库,可以通过jdbc将会话存储到数据库。

其中各个插件的依赖关系如下:
shiro-spring包:shiro-core子包和shiro-web子包
shiro-redis包:shiro-core子包
AuthorizationAttributeSourceAdvisor-》shiro-spring插件
ShiroFilterFactoryBean-》shiro-spring插件
DefaultWebSecurityManager-》shiro-web插件
DefaultWebSessionManager-》shiro-web插件
SecurityManager-》shiro-core插件
SessionManager-》shiro-core插件
CacheManager-》shiro-core插件
AbstractSessionDAO-》shiro-core插件
SessionDAO-》shiro-core插件
Cache-》shiro-core插件
RedisCache-》shiro-redis插件
RedisCacheManager-》shiro-redis插件
RedisSessionDAO-》shiro-redis插件
RedisClusterManager-》shiro-redis插件
RedisManager-》shiro-redis插件

通过上面可知,通过redis重写Cache、CacheManager、SessionDAO的实现类实现session共享和cache共享。其中shiro-redis插件也实现了session共享和cache共享,可参考自己实现一遍。

2、springboot整合redis

新建springboot工程,版本为2.7.x版本,引用redis依赖。

使用的是RedisTemplate来集中存储session和cache,一个列化工具继承RedisSerializer

public class SerializeUtils implements RedisSerializer {

    private static Logger logger = LoggerFactory.getLogger(SerializeUtils.class);

    public static boolean isEmpty(byte[] data) {
        return (data == null || data.length == 0);
    }

    /**
     * 序列化
     * @param object
     * @return
     * @throws SerializationException
     */
    @Override
    public byte[] serialize(Object object) throws SerializationException {
        byte[] result = null;

        if (object == null) {
            return new byte[0];
        }
        try (
                ByteArrayOutputStream byteStream = new ByteArrayOutputStream(128);
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteStream)
        ){

            if (!(object instanceof Serializable)) {
                throw new IllegalArgumentException(SerializeUtils.class.getSimpleName() + " requires a Serializable payload " +
                        "but received an object of type [" + object.getClass().getName() + "]");
            }

            objectOutputStream.writeObject(object);
            objectOutputStream.flush();
            result =  byteStream.toByteArray();
        } catch (Exception ex) {
            logger.error("Failed to serialize",ex);
        }
        return result;
    }

    /**
     * 反序列化
     * @param bytes
     * @return
     * @throws SerializationException
     */
    @Override
    public Object deserialize(byte[] bytes) throws SerializationException {

        Object result = null;

        if (isEmpty(bytes)) {
            return null;
        }
        try (
                ByteArrayInputStream byteStream = new ByteArrayInputStream(bytes);
                ObjectInputStream objectInputStream = new ObjectInputStream(byteStream)
        ){
            result = objectInputStream.readObject();
        } catch (Exception e) {
            logger.error("Failed to deserialize",e);
        }
        return result;
    }
}

RedisConfig配置类

@Bean
    public RedisTemplate<String, Object> redisTemplate(LettuceConnectionFactory lettuceConnectionFactory) {
        // 1.创建RedisTemplate对象
        RedisTemplate<String, Object> redisTemplate = new RedisTemplate<String, Object>();
        // 2.加载Redis配置
        redisTemplate.setConnectionFactory(lettuceConnectionFactory);
        // 3.配置key序列化
        RedisSerializer<?> stringRedisSerializer = new StringRedisSerializer();
        redisTemplate.setKeySerializer(stringRedisSerializer);
        redisTemplate.setHashKeySerializer(stringRedisSerializer);
        // 4.配置Value序列化
//        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<Object>(Object.class);
//        ObjectMapper objMapper = new ObjectMapper();
//        objMapper.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
//        objMapper.activateDefaultTyping(objMapper.getPolymorphicTypeValidator(), ObjectMapper.DefaultTyping.NON_FINAL);
//        jackson2JsonRedisSerializer.setObjectMapper(objMapper);
        SerializeUtils serializeUtils = new SerializeUtils();
        redisTemplate.setValueSerializer(serializeUtils);
        redisTemplate.setHashValueSerializer(serializeUtils);
        // 5.初始化RedisTemplate
        redisTemplate.afterPropertiesSet();
        return redisTemplate;
    }

RedisService工具类

public class RedisService {
    @Autowired
    private RedisTemplate<String, Object> redisTemplate;

    //=============================common============================
    /**
     * 指定缓存失效时间
     * @param key 键
     * @param time 时间(秒)
     */
    public void expire(String key,long time){
        redisTemplate.expire(key, time, TimeUnit.SECONDS);
    }

    /**
     * 判断key是否存在
     * @param key 键
     * @return true 存在 false不存在
     */
    public Boolean hasKey(String key){
        return redisTemplate.hasKey(key);
    }

    /**
     * 删除缓存
     * @param key 可以传一个值 或多个
     */
    @SuppressWarnings("unchecked")
    public void del(String ... key){
        if(key!=null&&key.length>0){
            if(key.length==1){
                redisTemplate.delete(key[0]);
            }else{
                redisTemplate.delete((Collection<String>) CollectionUtils.arrayToList(key));
            }
        }
    }

    /**
     * 批量删除key
     * @param keys
     */
    public void del(Collection keys){
        redisTemplate.delete(keys);
    }

    //============================String=============================
    /**
     * 普通缓存获取
     * @param key 键
     * @return 值
     */
    public Object get(String key){
        return redisTemplate.opsForValue().get(key);
    }

    /**
     * 普通缓存放入
     * @param key 键
     * @param value 值
     */
    public void set(String key,Object value) {
        redisTemplate.opsForValue().set(key, value);
    }

    /**
     * 普通缓存放入并设置时间
     * @param key 键
     * @param value 值
     * @param time 时间(秒) time要大于0 如果time小于等于0 将设置无限期
     */
    public void set(String key,Object value,long time){
        if(time>0){
            redisTemplate.opsForValue().set(key, value, time, TimeUnit.SECONDS);
        }else{
            set(key, value);
        }
    }

    /**
     * 使用scan命令 查询某些前缀的key
     * @param key
     * @return
     */
    public Set<String> scan(String key){
        Set<String> keys = this.redisTemplate.keys(key);
        return keys;
    }

    /**
     * 使用scan命令 查询某些前缀的key 有多少个
     * 用来获取当前session数量,也就是在线用户
     * @param key
     * @return
     */
    public Long scanSize(String key){
        long dbSize = this.redisTemplate.execute(new RedisCallback<Long>() {

            @Override
            public Long doInRedis(RedisConnection connection) throws DataAccessException {
                long count = 0L;
                Cursor<byte[]> cursor = connection.scan(ScanOptions.scanOptions().match(key).count(1000).build());
                while (cursor.hasNext()) {
                    cursor.next();
                    count++;
                }
                return count;
            }
        });
        return dbSize;
    }
}

3、session共享

shiro提供了自己的会话管理器sessionManager,其中有个属性叫sessionDao,它来处理所有的会话信息。对于sessionDao,shiro也提供了自己的实现。若是通过redis来实现session共享,则需要重写sessionDao。

新建实现类RedisSessionDao,并继承AbstractSessionDAO

@Component
public class ShiroRedisSessionDAO extends AbstractSessionDAO {

    private static Logger logger = LoggerFactory.getLogger(ShiroRedisSessionDAO.class);

    private static final String DEFAULT_SESSION_KEY_PREFIX = "shiro:session:";
    private String keyPrefix = DEFAULT_SESSION_KEY_PREFIX;
    private static final long DEFAULT_SESSION_IN_MEMORY_TIMEOUT = 1000L;
    /**
     * doReadSession be called about 10 times when login.
     * Save Session in ThreadLocal to resolve this problem. sessionInMemoryTimeout is expiration of Session in ThreadLocal.
     * The default value is 1000 milliseconds (1s).
     * Most of time, you don't need to change it.
     */
    private long sessionInMemoryTimeout = DEFAULT_SESSION_IN_MEMORY_TIMEOUT;
    /**
     * expire time in seconds
     */
    private static final int DEFAULT_EXPIRE = -2;
    private static final int NO_EXPIRE = -1;
    /**
     * Please make sure expire is longer than sesion.getTimeout()
     */
    private int expire = DEFAULT_EXPIRE;
    private static final int MILLISECONDS_IN_A_SECOND = 1000;
    @Autowired
    private RedisService redisService;
    @Autowired
    private ShiroRedisSessionIdGenerator shiroRedisSessionIdGenerator;
    private static ThreadLocal sessionsInThread = new ThreadLocal();

    @Override
    public void update(Session session) throws UnknownSessionException {
        //如果会话过期/停止 没必要再更新了
        try {
            if (session instanceof ValidatingSession && !((ValidatingSession) session).isValid()) {
                return;
            }

            this.saveSession(session);
        } catch (Exception e) {
            logger.warn("update Session is failed", e);
        }
    }

    private void saveSession(Session session) throws UnknownSessionException {
        if (session == null || session.getId() == null) {
            logger.error("session or session id is null");
            throw new UnknownSessionException("session or session id is null");
        }
        String key = getRedisSessionKey(session.getId());
        if (expire == DEFAULT_EXPIRE) {
            this.redisService.set(key, session, (int) (session.getTimeout() / MILLISECONDS_IN_A_SECOND));
            return;
        }
        if (expire != NO_EXPIRE && expire * MILLISECONDS_IN_A_SECOND < session.getTimeout()) {
            logger.warn("Redis session expire time: "
                    + (expire * MILLISECONDS_IN_A_SECOND)
                    + " is less than Session timeout: "
                    + session.getTimeout()
                    + " . It may cause some problems.");
        }
        this.redisService.set(key, session, expire);
    }

    @Override
    public void delete(Session session) {
        if (session == null || session.getId() == null) {
            logger.error("session or session id is null");
            return;
        }
        try {
            redisService.del(getRedisSessionKey(session.getId()));
        } catch (Exception e) {
            logger.error("delete session error. session id= {}",session.getId());
        }
    }

    @Override
    public Collection<Session> getActiveSessions() {
        Set<Session> sessions = new HashSet<Session>();
        try {
            Set<String> keys = redisService.scan(this.keyPrefix + "*");
            if (keys != null && keys.size() > 0) {
                for (String key:keys) {
                    Session s = (Session) redisService.get(key);
                    sessions.add(s);
                }
            }
        } catch (Exception e) {
            logger.error("get active sessions error.");
        }
        return sessions;
    }

    @Override
    protected Serializable doCreate(Session session) {
        if (session == null) {
            logger.error("session is null");
            throw new UnknownSessionException("session is null");
        }
        Serializable sessionId = this.generateSessionId(session);
        this.assignSessionId(session, sessionId);
        this.saveSession(session);
        return sessionId;
    }

    @Override
    protected Session doReadSession(Serializable sessionId) {
        if (sessionId == null) {
            logger.warn("session id is null");
            return null;
        }
        Session s = getSessionFromThreadLocal(sessionId);

        if (s != null) {
            return s;
        }

        logger.debug("read session from redis");
        try {
            s = (Session) redisService.get(getRedisSessionKey(sessionId));
            setSessionToThreadLocal(sessionId, s);
        } catch (Exception e) {
            logger.error("read session error. settionId= {}",sessionId);
        }
        return s;
    }

    private void setSessionToThreadLocal(Serializable sessionId, Session s) {
        Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();
        if (sessionMap == null) {
            sessionMap = new HashMap<Serializable, SessionInMemory>();
            sessionsInThread.set(sessionMap);
        }
        SessionInMemory sessionInMemory = new SessionInMemory();
        sessionInMemory.setCreateTime(new Date());
        sessionInMemory.setSession(s);
        sessionMap.put(sessionId, sessionInMemory);
    }

    private Session getSessionFromThreadLocal(Serializable sessionId) {
        Session s = null;

        if (sessionsInThread.get() == null) {
            return null;
        }

        Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();
        SessionInMemory sessionInMemory = sessionMap.get(sessionId);
        if (sessionInMemory == null) {
            return null;
        }
        Date now = new Date();
        long duration = now.getTime() - sessionInMemory.getCreateTime().getTime();
        if (duration < sessionInMemoryTimeout) {
            s = sessionInMemory.getSession();
            logger.debug("read session from memory");
        } else {
            sessionMap.remove(sessionId);
        }

        return s;
    }

    @Override
    public void setSessionIdGenerator(SessionIdGenerator sessionIdGenerator) {
        super.setSessionIdGenerator(shiroRedisSessionIdGenerator);
    }

    private String getRedisSessionKey(Serializable sessionId) {
        return this.keyPrefix + sessionId;
    }

    public RedisService getRedisService() {
        return redisService;
    }

    public void setRedisService(RedisService redisService) {
        this.redisService = redisService;
    }

    public String getKeyPrefix() {
        return keyPrefix;
    }

    public void setKeyPrefix(String keyPrefix) {
        this.keyPrefix = keyPrefix;
    }

    public long getSessionInMemoryTimeout() {
        return sessionInMemoryTimeout;
    }

    public void setSessionInMemoryTimeout(long sessionInMemoryTimeout) {
        this.sessionInMemoryTimeout = sessionInMemoryTimeout;
    }

    public int getExpire() {
        return expire;
    }

    public void setExpire(int expire) {
        this.expire = expire;
    }
}
public class ShiroRedisSessionManager extends DefaultWebSessionManager {
    /**
     * 注入自己的sessionDAO
     */
    @Autowired
    private ShiroRedisSessionDAO shiroRedisSessionDAO;

    /**
     * @description: 重写构造器
     * @author: ldc
     * @date: 2021/5/22 15:20
     */
    public ShiroRedisSessionManager() {
        super();
    }

    /**
     * @description: 重写方法实现从请求头获取Token便于接口统一,每次请求进来,Shiro会去从请求头找REQUEST_HEADER这个key对应的Value(Token)
     * @author: ldc
     * @date: 2021/5/22 15:20
     */
    @Override
    public Serializable getSessionId(ServletRequest request, ServletResponse response) {
        String token = WebUtils.toHttp(request).getHeader("Authorization");
        // 如果请求头中存在token 则从请求头中获取token
        if (!ObjectUtils.isEmpty(token)) {
            // 去掉URL中的JSESSIONID
            request.setAttribute(ShiroHttpServletRequest.REFERENCED_SESSION_ID_SOURCE, "Stateless request");
            request.setAttribute(ShiroHttpServletRequest.REFERENCED_SESSION_ID, token);
            request.setAttribute(ShiroHttpServletRequest.REFERENCED_SESSION_ID_IS_VALID, Boolean.TRUE);
            return token;
        } else {
            // 否则按默认规则从cookie取token
            return super.getSessionId(request, response);
        }
    }

    @Override
    public void setSessionDAO(SessionDAO sessionDAO) {
        super.setSessionDAO(shiroRedisSessionDAO);
    }
}

4、cache共享

shijro提供了自己的缓存管理器cacheManager,此时只需实现shiro的CacheManager和Cache接口即可。前者暴漏了获取Cache实例的方法,后者提供了Cache实例的增删改查操作。若是通过redis来实现cache共享,则需要重写cache和cacheManager。

public class ShiroRedisCache<K, V> implements Cache<K, V> {

    private static Logger logger = LoggerFactory.getLogger(ShiroRedisCache.class);

    private RedisService redisService;
    private String keyPrefix = "shiro:cache:";
    private int expire;
    private String principalIdFieldName = "id";

    public ShiroRedisCache(RedisService redisService, String prefix, int expire, String principalIdFieldName) {
        if (redisService == null) {
            throw new IllegalArgumentException("redisService cannot be null.");
        }
        this.redisService = redisService;
        if (prefix != null && !"".equals(prefix)) {
            this.keyPrefix = prefix;
        }
        if (expire != -1) {
            this.expire = expire;
        }
        if (principalIdFieldName != null && !"".equals(principalIdFieldName)) {
            this.principalIdFieldName = principalIdFieldName;
        }
    }

    @Override
    public V get(K key) throws CacheException {
        logger.debug("get key [{}]",key);
        if (key == null) {
            return null;
        }
        try {
            String redisCacheKey = getRedisCacheKey(key);
            Object rawValue = redisService.get(redisCacheKey);
            if (rawValue == null) {
                return null;
            }
            V value = (V) rawValue;
            return value;
        } catch (Exception e) {
            throw new CacheException(e);
        }
    }

    @Override
    public V put(K key, V value) throws CacheException {
        logger.debug("put key [{}]",key);
        if (key == null) {
            logger.warn("Saving a null key is meaningless, return value directly without call Redis.");
            return value;
        }
        try {
            String redisCacheKey = getRedisCacheKey(key);
            redisService.set(redisCacheKey, value, expire);
            return value;
        } catch (Exception e) {
            throw new CacheException(e);
        }
    }

    @Override
    public V remove(K key) throws CacheException {
        logger.debug("remove key [{}]",key);
        if (key == null) {
            return null;
        }
        try {
            String redisCacheKey = getRedisCacheKey(key);
            Object rawValue = redisService.get(redisCacheKey);
            V previous = (V) rawValue;
            redisService.del(redisCacheKey);
            return previous;
        } catch (Exception e) {
            throw new CacheException(e);
        }
    }

    private String getRedisCacheKey(K key) {
        if (key == null) {
            return null;
        }
        return this.keyPrefix + getStringRedisKey(key);
    }

    private String getStringRedisKey(K key) {
        String redisKey;
        if (key instanceof PrincipalCollection) {
            redisKey = getRedisKeyFromPrincipalIdField((PrincipalCollection) key);
        } else {
            redisKey = key.toString();
        }
        return redisKey;
    }

    private String getRedisKeyFromPrincipalIdField(PrincipalCollection key) {
        Object principalObject = key.getPrimaryPrincipal();
        if (principalObject instanceof String) {
            return principalObject.toString();
        } else {
            Method pincipalIdGetter = this.getPrincipalIdGetter(principalObject);
            return this.getIdObj(principalObject, pincipalIdGetter);
        }
    }

    private String getIdObj(Object principalObject, Method pincipalIdGetter) {
        try {
            Object idObj = pincipalIdGetter.invoke(principalObject);
            if (idObj == null) {
                throw new PrincipalIdNullException(principalObject.getClass(), this.principalIdFieldName);
            } else {
                String redisKey = idObj.toString();
                return redisKey;
            }
        } catch (IllegalAccessException var5) {
            throw new org.crazycake.shiro.exception.PrincipalInstanceException(principalObject.getClass(), this.principalIdFieldName, var5);
        } catch (InvocationTargetException var6) {
            throw new org.crazycake.shiro.exception.PrincipalInstanceException(principalObject.getClass(), this.principalIdFieldName, var6);
        }
    }

    private Method getPrincipalIdGetter(Object principalObject) {
        Method pincipalIdGetter = null;
        String principalIdMethodName = this.getPrincipalIdMethodName();

        try {
            pincipalIdGetter = principalObject.getClass().getMethod(principalIdMethodName);
            return pincipalIdGetter;
        } catch (NoSuchMethodException var5) {
            throw new PrincipalInstanceException(principalObject.getClass(), this.principalIdFieldName);
        }
    }

    private String getPrincipalIdMethodName() {
        if (this.principalIdFieldName != null && !"".equals(this.principalIdFieldName)) {
            return "get" + this.principalIdFieldName.substring(0, 1).toUpperCase() + this.principalIdFieldName.substring(1);
        } else {
            throw new CacheManagerPrincipalIdNotAssignedException();
        }
    }

    @Override
    public void clear() throws CacheException {
        logger.debug("clear cache");
        Set<String> keys = null;
        try {
            keys = redisService.scan(this.keyPrefix + "*");
        } catch (Exception e) {
            logger.error("get keys error", e);
        }
        if (keys == null || keys.size() == 0) {
            return;
        }
        for (String key: keys) {
            redisService.del(key);
        }
    }

    @Override
    public int size() {
        Long longSize = 0L;
        try {
            longSize = new Long(redisService.scanSize(this.keyPrefix + "*"));
        } catch (Exception e) {
            logger.error("get keys error", e);
        }
        return longSize.intValue();
    }

    @Override
    public Set<K> keys() {
        Set<String> keys = null;
        try {
            keys = redisService.scan(this.keyPrefix + "*");
        } catch (Exception e) {
            logger.error("get keys error", e);
            return Collections.emptySet();
        }

        if (CollectionUtils.isEmpty(keys)) {
            return Collections.emptySet();
        }

        Set<K> convertedKeys = new HashSet<K>();
        for (String key:keys) {
            try {
                convertedKeys.add((K) key);
            } catch (Exception e) {
                logger.error("deserialize keys error", e);
            }
        }
        return convertedKeys;
    }

    @Override
    public Collection<V> values() {
        Set<String> keys = null;
        try {
            keys = redisService.scan(this.keyPrefix + "*");
        } catch (Exception e) {
            logger.error("get values error", e);
            return Collections.emptySet();
        }

        if (CollectionUtils.isEmpty(keys)) {
            return Collections.emptySet();
        }

        List<V> values = new ArrayList<V>(keys.size());
        for (String key : keys) {
            V value = null;
            try {
                value = (V) redisService.get(key);
            } catch (Exception e) {
                logger.error("deserialize values= error", e);
            }
            if (value != null) {
                values.add(value);
            }
        }
        return Collections.unmodifiableList(values);
    }

    public String getKeyPrefix() {
        return keyPrefix;
    }

    public void setKeyPrefix(String keyPrefix) {
        this.keyPrefix = keyPrefix;
    }

    public String getPrincipalIdFieldName() {
        return principalIdFieldName;
    }

    public void setPrincipalIdFieldName(String principalIdFieldName) {
        this.principalIdFieldName = principalIdFieldName;
    }
}
public class ShiroRedisCacheManager implements CacheManager {

    private final Logger logger = LoggerFactory.getLogger(ShiroRedisCacheManager.class);

    private final ConcurrentMap<String, Cache> caches = new ConcurrentHashMap<>();
    @Autowired
    private RedisService redisService;
    /**
     * expire time in seconds
     */
    private static final int DEFAULT_EXPIRE = 1800;
    private int expire = DEFAULT_EXPIRE;
    /**
     * The Redis key prefix for caches
     */
    public static final String DEFAULT_CACHE_KEY_PREFIX = "shiro:cache:";
    private String keyPrefix = DEFAULT_CACHE_KEY_PREFIX;
    public static final String DEFAULT_PRINCIPAL_ID_FIELD_NAME = "id";
    private String principalIdFieldName = DEFAULT_PRINCIPAL_ID_FIELD_NAME;

    @Override
    public <K, V> Cache<K, V> getCache(String name) throws CacheException {
        logger.debug("get cache, name={}",name);
        Cache cache = caches.get(name);

        if (cache == null) {
            cache = new ShiroRedisCache(redisService,keyPrefix + name + ":", expire, principalIdFieldName);
            caches.put(name, cache);
        }
        return cache;
    }

    public RedisService getRedisService() {
        return redisService;
    }

    public void setRedisService(RedisService redisService) {
        this.redisService = redisService;
    }

    public String getKeyPrefix() {
        return keyPrefix;
    }

    public void setKeyPrefix(String keyPrefix) {
        this.keyPrefix = keyPrefix;
    }

    public int getExpire() {
        return expire;
    }

    public void setExpire(int expire) {
        this.expire = expire;
    }

    public String getPrincipalIdFieldName() {
        return principalIdFieldName;
    }

    public void setPrincipalIdFieldName(String principalIdFieldName) {
        this.principalIdFieldName = principalIdFieldName;
    }
}

自定义异常类

public class PrincipalIdNullException extends RuntimeException {
    private static final long serialVersionUID = -3311055610511133072L;

    private static final String MESSAGE = "Principal Id shouldn't be null!";

    public PrincipalIdNullException(Class clazz, String idMethodName) {
        super(clazz + " id field: " +  idMethodName + ", value is null\n" + MESSAGE);
    }
}
public class PrincipalInstanceException extends RuntimeException {
    private static final long serialVersionUID = 5011741906229816259L;

    private static final String MESSAGE = "We need a field to identify this Cache Object in Redis. "
            + "So you need to defined an id field which you can get unique id to identify this principal. "
            + "For example, if you use UserInfo as Principal class, the id field maybe userId, userName, email, etc. "
            + "For example, getUserId(), getUserName(), getEmail(), etc.\n"
            + "Default value is authCacheKey or id, that means your principal object has a method called \"getAuthCacheKey()\" or \"getId()\"";

    public PrincipalInstanceException(Class clazz, String idMethodName) {
        super(clazz + " must has getter for field: " + idMethodName + "\n" + MESSAGE);
    }

    public PrincipalInstanceException(Class clazz, String idMethodName, Exception e) {
        super(clazz + " must has getter for field: " + idMethodName + "\n" + MESSAGE, e);
    }

}
public class CacheManagerPrincipalIdNotAssignedException extends RuntimeException {
    private static final String MESSAGE = "CacheManager didn't assign Principal Id field name!";

    public CacheManagerPrincipalIdNotAssignedException() {
        super("CacheManager didn't assign Principal Id field name!");
    }
}

5、配置共享

实现了session和cache的重写,但是此时还没有被spring shiro使用,所以需要通过配置来实现让容器使用自己实现的session和cache。

	@Bean
    public SecurityManager securityManager() {
        DefaultWebSecurityManager securityManager = new DefaultWebSecurityManager();
        // 自定义session会话管理
        securityManager.setSessionManager(sessionManager());
        // 自定义cache缓存管理
        securityManager.setCacheManager(cacheManager());
        // 自定义Realm验证
        securityManager.setRealm(myShiroRealm());
        return securityManager;
    }
    
	@Bean
    public CacheManager cacheManager() {
        ShiroRedisCacheManager shiroRedisCacheManager = new ShiroRedisCacheManager();
        return shiroRedisCacheManager;
    }

    @Bean
    public SessionManager sessionManager() {
        ShiroRedisSessionManager shiroRedisSessionManager = new ShiroRedisSessionManager();
        return shiroRedisSessionManager;
    }

6、小结

实现shiro安全框架中的session共享和cache共享主要是重写其存储实现,通过自定义实现基于redis的共享,当redis是集群时,也可以在springboot工程中使用lettuce客户端实现动态刷新集群拓扑,实现应用中认证、授权功能的高可用。使用shiro-redis插件实现的无法实现动态刷新集群拓扑,当redis集群中主节点宕掉后,应用中的认证、授权功能调用出现报错。具体的业务场景需要具体分析,多多深入解读shiro源码,以便更深入的应用。

猜你喜欢

转载自blog.csdn.net/leijie0322/article/details/131314631
今日推荐