需要引入注解@EnableWebSocket先写一个config继承WebSocketConfigurer接口用来注册websocket的请求,通过registry.addHandler(socketHander, "/dds","/dds/test");将为dds或者的/dds/test请求交由socketHander来处理,
直接使用registry.addHandler(new socketHander(),"/")可处理所有的websocket请求,当为"/"时代表所有在此处new一个 socketHander和利用spring注入一个新的socketHander对象是一样的,在socketHander写websocket处理逻辑。
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {
@Autowired
public SocketHandler socketHander;
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(socketHander, "/dds","/dds/test"); // 可以注册多个 registry.addHandler(XXX,"");
}
}
@Component
public class SocketHandler extends AbstractWebSocketHandler {
//利用ConcurrentHashMap来管理所有的链接key为websocket的WebSocketSession socketTask管理任务
private static ConcurrentHashMap<WebSocketSession, SocketTask> webSockets = new ConcurrentHashMap<WebSocketSession, SocketTask>();
//握手鉴权
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
SocketTask task = new SocketTask(session);
webSockets.put(session, task);
task.onOpen();
System.out.println("count"+webSockets.toString());
}
//客户端关闭链接之后
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
SocketTask task = webSockets.remove(session);
if(task != null) {
task.onClose();
}
}
//客户端发送过来的二进制数据(音频等)
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
SocketTask task = webSockets.get(session);
if(task != null) {
task.handleBinaryMessage(message);
}
}
//客户端发送过来的文本数据
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
SocketTask task = webSockets.get(session);
if(task != null) {
task.send("你好客户端");
}
session.sendMessage(new TextMessage("你发送了一个文本针"));
}
public class SocketTask {
static List<JSONObject> openResp;
static List<JSONObject> dataResp;
private WebSocketSession session;
static {
init();
}
private static void init() {
openResp = Arrays.asList(
fromInput(0, "success"),
fromInput(10114, "time out"),
fromInput(10105, "illegal access")
);
dataResp = Arrays.asList(
fromOutPut(0, "ok"),
fromOutPut(10105, "illegal access"),
fromOutPut(10106, "illegal paramater"),
fromOutPut(10109, "illegal data"),
fromOutPut(10110, "no license")
);
}
private static JSONObject fromInput(int code, String info) {
JSONObject obj = new JSONObject();
if(code == 0) {
obj.put("action", "started");
obj.put("data", "");
} else {
obj.put("action", "error");
}
obj.put("desc", info);
obj.put("code", code);
obj.put("sid", "1");
return obj;
}
private static JSONObject fromOutPut(int code, String info) {
JSONObject obj = new JSONObject();
if(code == 0) {
obj.put("action", "result");
obj.put("desc", "success");
JSONPath.set(obj, "$.data.is_last", Boolean.TRUE);
JSONPath.set(obj, "$.data.is_finish", Boolean.TRUE);
JSONPath.set(obj, "$.data.text", info);
} else {
obj.put("action", "error");
obj.put("desc", info);
}
obj.put("code", code);
return obj;
}
public SocketTask(WebSocketSession session) {
this.session = session;
}
public void clear() {
session = null;
}
public void onOpen() {
int code = new Random().nextInt(openResp.size());
JSONObject obj = openResp.get(code);
send(obj.toJSONString());
}
public void onClose() {
clear();
}
public void handleBinaryMessage(BinaryMessage message) {
System.out.println("aaaaaaaaaaaa");
byte[] bytes = message.getPayload().array();
if(Arrays.equals(bytes, "--end--".getBytes())) {
int code = new Random().nextInt(dataResp.size());
JSONObject obj = dataResp.get(code);
send(obj.toJSONString());
close();
} else {
JSONObject obj = fromOutPut(0, "");
JSONPath.set(obj, "$.data.is_last", Boolean.FALSE);
JSONPath.set(obj, "$.data.is_finish", Boolean.FALSE);
send(obj.toJSONString());
}
}
public void send(String message) {
try {
session.sendMessage(new TextMessage(message));
} catch (IOException e) {
e.printStackTrace();
}
}
public void close() {
try {
session.close();
} catch (IOException e) {
e.printStackTrace();
}
}
用httpClient作为客户端发送websocket请求:
@Bean(destroyMethod = "close")
public AsyncHttpClient asyncHttpClient(AsyncHttpClientConfig config) {
return new DefaultAsyncHttpClient(config);
}
@Bean
public AsyncHttpClientConfig clientConfig() {
AsyncHttpClientConfig config = new DefaultAsyncHttpClientConfig.Builder()
.setWebSocketMaxFrameSize(properties.getMaxSize())
.build();
return config;
}
@Autowired
private AsyncHttpClient httpClient;
String url="ws//localhost:8090/dds"
httpClient.prepareGet(url.toString()).execute(new WebSocketUpgradeHandler.Builder()
.addWebSocketListener(innerListener).build());
private class InnerWebSocketListener implements WebSocketListener {
ArrayList<PendingTask> pendingTasks = new ArrayList<>();
VoiceHeader originalHeader;
@Override
public void onOpen(WebSocket websocket) {
logger.debug("InnerWebSocketListener->onOpen websocket opened");
synchronized (asrLock) {//为了线程安全启用同步访问
SbcServiceImpl.this.websocket = websocket;
SbcServiceImpl.this.asrStatus = WebSocketConnectStatus.CONNECTED;
//危险危险, 里面不能有耗时程序
//目前只有异步发送代码
//将来增加是千万不要放入耗时的代码(阻塞或长计算)
pendingTasks.forEach(runner -> runner.run(null));
pendingTasks.clear();
}
}
//调用者已经枷锁
public void addTask(PendingTask task) {
pendingTasks.add(task);
}
//调用者已经枷锁
public void clearTasks() {
pendingTasks.clear();
}
@Override
public void onClose(WebSocket websocket, int code, String reason) {
logger.debug("asr websocket closed");
synchronized (asrLock) {//为了线程安全启用同步访问
SbcServiceImpl.this.websocket = null;
SbcServiceImpl.this.asrStatus = WebSocketConnectStatus.DISCONNECTED;
pendingTasks.clear();
}
}
@Override
public void onError(Throwable t) {
logger.debug("InnerWebSocketListener->onError websocket err:{} ", t.getMessage());
synchronized (asrLock) {//为了线程安全启用同步访问
SbcServiceImpl.this.websocket = null;
SbcServiceImpl.this.asrStatus = WebSocketConnectStatus.DISCONNECTED;
pendingTasks.forEach(runner -> runner.run(t));
pendingTasks.clear();
}
}
@Override//接受asr返回的数据
public void onTextFrame(String payload, boolean finalFragment, int rsv) {
logger.info("InnerWebSocketListener->onTextFrame websocket text:{}", payload);
onAsrMessage(payload, originalHeader);
}
private void onAsrMessage(String text, VoiceHeader originalHeader) {
JSONObject json = JSONObject.parseObject(text);
//中间结果
AsrTextResponse textResult = new AsrTextResponse();
AsrTextPayload payload = new AsrTextPayload();
VoiceHeader header = originalHeader;
header.setTopic("asr.speech.text");
textResult.setHeader(header);
String rec = (String) JSONPath.eval(json, "$.result.rec");
if (rec == null) {
rec = "";
}
asrText += rec;
asrText = asrText.replaceAll("[ \\p{P}]", "");
payload.setText(asrText.trim());
textResult.setPayload(payload);
if (json.getInteger("eof") == 1) {
long asrEndTime = System.currentTimeMillis();
logger.debug("InnerWebSocketListener->onAsrMessage 语音识别最终结果:{}, 耗时:{}", asrText, asrEndTime - asrStartTime);
closeSocket();
//最终结果
payload.setType("FINAL");
//也向innerService 报告一次中间识别结果
//向innerService 报告最终结果
innerService.asrOutput(textResult);
} else {
payload.setType("INTERMEDIATE");
if ("".equals(asrText)) {
payload.setText((String) JSONPath.eval(json, "$.result.var"));
}
logger.debug("InnerWebSocketListener->onAsrMessage 中间识别结果:{}", asrText);
//向innerService 报告一次中间识别结果
innerService.asrOutput(textResult);
}
}
@Override
public void onBinaryFrame(byte[] payload, boolean finalFragment, int rsv) {
//不处理服务回来的二进制消息
logger.debug("don't support binary {}", payload);
}
}