Use Sa-Token to solve WebSocket handshake authentication

foreword

Compared with the single communication method of Http, WebSocket can actively push messages from the server to the browser. This feature can help us complete some specific services such as order message push and IM real-time chat.

However, WebSocket itself does not provide direct support for "identity authentication", and the default for client connections is "all comers are welcome", so we have to do it ourselves for authentication and authorization.

Sa-Token is a java authority authentication framework, which mainly solves a series of authority-related issues such as login authentication, authority authentication, single sign-on, OAuth2, and micro-service gateway authentication.
GitHub open source address: https://github.com/dromara/sa-token

Let's introduce how to integrate Sa-Token authentication in WebSocket to ensure the security of the connection.

Two ways to integrate

We will introduce the two most common ways of integrating WebSocket in turn:

  • Java native version: javax.websocket.Session
  • Spring package version: WebSocketSession

Not much nonsense, just start:

Method 1: Java native version javax.websocket.Session

1. The first is to introduce pom.xml dependencies
<!-- SpringBoot依赖 -->
<dependency>
	<groupId>org.springframework.boot</groupId>
	<artifactId>spring-boot-starter-web</artifactId>
</dependency>

<!-- WebScoket 依赖 -->
<dependency>  
	<groupId>org.springframework.boot</groupId>  
	<artifactId>spring-boot-starter-websocket</artifactId>  
</dependency>

<!-- Sa-Token 权限认证, 在线文档:http://sa-token.dev33.cn/ -->
<dependency>
	<groupId>cn.dev33</groupId>
	<artifactId>sa-token-spring-boot-starter</artifactId>
	<version>1.29.0</version>
</dependency>
2. Login interface, used to obtain session token
/**
 * 登录测试 
 */
@RestController
@RequestMapping("/acc/")
public class LoginController {
    
    

	// 测试登录  ---- http://localhost:8081/acc/doLogin?name=zhang&pwd=123456
	@RequestMapping("doLogin")
	public SaResult doLogin(String name, String pwd) {
    
    
		// 此处仅作模拟示例,真实项目需要从数据库中查询数据进行比对 
		if("zhang".equals(name) && "123456".equals(pwd)) {
    
    
			StpUtil.login(10001);
			return SaResult.ok("登录成功").set("token", StpUtil.getTokenValue());
		}
		return SaResult.error("登录失败");
	}

	// ... 
	
}
3. WebSocket connection processing
@Component
@ServerEndpoint("/ws-connect/{satoken}")
public class WebSocketConnect {
    
    

    /**
     * 固定前缀 
     */
    private static final String USER_ID = "user_id_";
	
	 /** 
	  * 存放Session集合,方便推送消息 (javax.websocket.Session)  
	  */
    private static ConcurrentHashMap<String, Session> sessionMap = new ConcurrentHashMap<>();
    
	// 监听:连接成功
	@OnOpen
	public void onOpen(Session session, @PathParam("satoken") String satoken) throws IOException {
    
    
		
		// 根据 token 获取对应的 userId 
		Object loginId = StpUtil.getLoginIdByToken(satoken);
		if(loginId == null) {
    
    
			session.close();
			throw new SaTokenException("连接失败,无效Token:" + satoken);
		}
		
		// put到集合,方便后续操作 
		long userId = SaFoxUtil.getValueByType(loginId, long.class);
		sessionMap.put(USER_ID + userId, session);
		
		// 给个提示 
		String tips = "Web-Socket 连接成功,sid=" + session.getId() + ",userId=" + userId;
		System.out.println(tips);
		sendMessage(session, tips);
	}

	// 监听: 连接关闭
	@OnClose
	public void onClose(Session session) {
    
    
		System.out.println("连接关闭,sid=" + session.getId());
		for (String key : sessionMap.keySet()) {
    
    
			if(sessionMap.get(key).getId().equals(session.getId())) {
    
    
				sessionMap.remove(key);
			}
		}
	}
	
	// 监听:收到客户端发送的消息 
	@OnMessage
	public void onMessage(Session session, String message) {
    
    
		System.out.println("sid为:" + session.getId() + ",发来:" + message);
	}
	
	// 监听:发生异常 
	@OnError
	public void onError(Session session, Throwable error) {
    
    
		System.out.println("sid为:" + session.getId() + ",发生错误");
		error.printStackTrace();
	}
	
	// ---------
	
	// 向指定客户端推送消息 
	public static void sendMessage(Session session, String message) {
    
    
		try {
    
    
			System.out.println("向sid为:" + session.getId() + ",发送:" + message);
			session.getBasicRemote().sendText(message);
		} catch (IOException e) {
    
    
			throw new RuntimeException(e);
		}
	}
	
	// 向指定用户推送消息 
	public static void sendMessage(long userId, String message) {
    
    
		Session session = sessionMap.get(USER_ID + userId);
		if(session != null) {
    
    
			sendMessage(session, message);
		}
	}
	
}
4. WebSocket configuration
/**
 * 开启WebSocket支持
 */
@Configuration  
public class WebSocketConfig {
    
     
	
	@Bean  
	public ServerEndpointExporter serverEndpointExporter() {
    
      
		return new ServerEndpointExporter();  
	}
	
} 
5. Startup class
@SpringBootApplication
public class SaTokenWebSocketApplication {
    
    

	public static void main(String[] args) {
    
    
		SpringApplication.run(SaTokenWebSocketApplication.class, args); 
	}
	
}

After the construction is complete, start the project

6. Test

1. First, we access the login interface and get the session token

http://localhost:8081/acc/doLogin?name=zhang&pwd=123456

as the picture shows:

token

2. Then we randomly find a WebSocket online test page to connect
, for example: https://www.bejson.com/httputil/websocket/

Connection address:

ws://localhost:8081/ws-connect/302ee2f8-60aa-42aa-8ecb-eeae5ba57015

as the picture shows:

test connection

3. What happens if we enter a wrong token?

Connection failed

As you can see, the connection will be disconnected immediately!

Method 2: Spring package version: WebSocketSession

1. Same as above: the first is to introduce pom.xml dependencies
<!-- SpringBoot依赖 -->
<dependency>
	<groupId>org.springframework.boot</groupId>
	<artifactId>spring-boot-starter-web</artifactId>
</dependency>

<!-- WebScoket 依赖 -->
<dependency>  
	<groupId>org.springframework.boot</groupId>  
	<artifactId>spring-boot-starter-websocket</artifactId>  
</dependency>

<!-- Sa-Token 权限认证, 在线文档:http://sa-token.dev33.cn/ -->
<dependency>
	<groupId>cn.dev33</groupId>
	<artifactId>sa-token-spring-boot-starter</artifactId>
	<version>1.29.0</version>
</dependency>
2. Login interface, used to obtain session token
/**
 * 登录测试 
 */
@RestController
@RequestMapping("/acc/")
public class LoginController {
    
    

	// 测试登录  ---- http://localhost:8081/acc/doLogin?name=zhang&pwd=123456
	@RequestMapping("doLogin")
	public SaResult doLogin(String name, String pwd) {
    
    
		// 此处仅作模拟示例,真实项目需要从数据库中查询数据进行比对 
		if("zhang".equals(name) && "123456".equals(pwd)) {
    
    
			StpUtil.login(10001);
			return SaResult.ok("登录成功").set("token", StpUtil.getTokenValue());
		}
		return SaResult.error("登录失败");
	}

	// ... 
	
}
3. WebSocket connection processing
/**
 * 处理 WebSocket 连接 
 */
public class MyWebSocketHandler extends TextWebSocketHandler {
    
    

    /**
     * 固定前缀 
     */
    private static final String USER_ID = "user_id_";
    
    /**
     * 存放Session集合,方便推送消息
     */
    private static ConcurrentHashMap<String, WebSocketSession> webSocketSessionMaps = new ConcurrentHashMap<>();

    // 监听:连接开启 
    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
    
    

    	// put到集合,方便后续操作 
        String userId = session.getAttributes().get("userId").toString();
        webSocketSessionMaps.put(USER_ID + userId, session);
        

		// 给个提示 
		String tips = "Web-Socket 连接成功,sid=" + session.getId() + ",userId=" + userId;
		System.out.println(tips);
		sendMessage(session, tips);
    }
    
    // 监听:连接关闭 
    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
    
    
    	// 从集合移除 
        String userId = session.getAttributes().get("userId").toString();
        webSocketSessionMaps.remove(USER_ID + userId);
        
        // 给个提示 
        String tips = "Web-Socket 连接关闭,sid=" + session.getId() + ",userId=" + userId;
    	System.out.println(tips);
    }

    // 收到消息 
    @Override
    public void handleTextMessage(WebSocketSession session, TextMessage message) throws IOException {
    
    
    	System.out.println("sid为:" + session.getId() + ",发来:" + message);
    }

    // ----------- 
    
    // 向指定客户端推送消息 
 	public static void sendMessage(WebSocketSession session, String message) {
    
    
 		try {
    
    
 			System.out.println("向sid为:" + session.getId() + ",发送:" + message);
 			session.sendMessage(new TextMessage(message));
 		} catch (IOException e) {
    
    
 			throw new RuntimeException(e);
 		}
 	}
 	
 	// 向指定用户推送消息 
 	public static void sendMessage(long userId, String message) {
    
    
 		WebSocketSession session = webSocketSessionMaps.get(USER_ID + userId);
		if(session != null) {
    
    
			sendMessage(session, message);
		}
 	}
    
}
4. WebSocket pre-interceptor
/**
 * WebSocket 握手的前置拦截器 
 */
public class WebSocketInterceptor implements HandshakeInterceptor {
    
    

	// 握手之前触发 (return true 才会握手成功 )
	@Override
	public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler handler,
			Map<String, Object> attr) {
    
    
		
		System.out.println("---- 握手之前触发 " + StpUtil.getTokenValue());
		
		// 未登录情况下拒绝握手 
		if(StpUtil.isLogin() == false) {
    
    
			System.out.println("---- 未授权客户端,连接失败");
			return false;
		}
		
		// 标记 userId,握手成功 
		attr.put("userId", StpUtil.getLoginIdAsLong());
		return true;
	}

	// 握手之后触发 
	@Override
	public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
			Exception exception) {
    
    
		System.out.println("---- 握手之后触发 ");
	}
	
}
5. WebSocket configuration
/**
 * WebSocket 相关配置 
 */
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
    
    
	
	// 注册 WebSocket 处理器 
    @Override
    public void registerWebSocketHandlers(WebSocketHandlerRegistry webSocketHandlerRegistry) {
    
    
        webSocketHandlerRegistry
        		// WebSocket 连接处理器 
                .addHandler(new MyWebSocketHandler(), "/ws-connect")
                // WebSocket 拦截器 
                .addInterceptors(new WebSocketInterceptor())
                // 允许跨域 
                .setAllowedOrigins("*");
    }

}
6. Startup class
/**
 * Sa-Token 整合 WebSocket 鉴权示例 
 */
@SpringBootApplication
public class SaTokenWebSocketSpringApplication {
    
    

	public static void main(String[] args) {
    
    
		SpringApplication.run(SaTokenWebSocketSpringApplication.class, args); 
	}
	
}

Start the project and start testing

7. Test

1. First access the login interface and get the session token

http://localhost:8081/acc/doLogin?name=zhang&pwd=123456

as the picture shows:

token

2. Then open the WebSocket online test page to connect
, for example: https://www.bejson.com/httputil/websocket/

Connection address:

ws://localhost:8081/ws-connect?satoken=fe6e7dbd-38b8-4de2-ae05-cda7e36bf2f7

as the picture shows:

test connection

Note: The url is used to pass the Token here because it is more convenient on the third-party test page. In the real project, you can choose one of the three methods of Cookie, Header parameter, and url parameter to pass the session token, and the effect is the same.

3. If you enter a wrong Token

Connection failed

Connection failed!

example address

The above code has been uploaded to git, example address:
code cloud: sa-token-demo-websocket

References

Guess you like

Origin blog.csdn.net/shengzhang_/article/details/122916798