AI dialogue interaction scene uses WebSocket to establish real-time two-way communication of information between H5 client and server

WebSocket makes the data exchange between the client and the server easier, allowing the server to actively push data to the client . In the WebSocket API, the browser and the server only need to complete a handshake, and a persistent connection can be established between the two, and two-way data transmission can be performed.

1. Why do you need WebSocket?

Anyone new to WebSocket asks the same question: Why do we need another protocol when we already have HTTP? What good does it bring?

The answer is simple, because the HTTP protocol has a flaw: communication can only be initiated by the client.

For example, if we want to know today's weather, only the client sends a request to the server, and the server returns the query result. The HTTP protocol cannot allow the server to actively push information to the client.
insert image description here
The characteristics of this one-way request are destined to be very troublesome for the client to know if the server has continuous state changes. We can only use "polling": every once in a while, an inquiry is sent to see if the server has new information. The most typical scenario is the chat room.

Polling is inefficient and wastes resources (since the connection must be kept open, or the HTTP connection is always open). Therefore, engineers have been thinking, is there a better way. That's how WebSocket was invented.

2. Introduction

The WebSocket protocol was born in 2008 and became an international standard in 2011. All browsers already support it.

Its biggest feature is that the server can actively push information to the client, and the client can also actively send information to the server. It is a real two-way equal dialogue and belongs to a kind of server push technology .
insert image description here
Other features include:

(1) Based on the TCP protocol, the server-side implementation is relatively easy.

(2) It has good compatibility with HTTP protocol. The default ports are also 80 and 443, and the handshake phase uses the HTTP protocol, so it is not easy to shield during the handshake, and can pass through various HTTP proxy servers.

(3) The data format is relatively light, the performance overhead is small, and the communication is efficient.

(4) Text or binary data can be sent.

(5) There is no same-origin restriction, and the client can communicate with any server.

(6) The protocol identifier is ws (or wss if encrypted), and the server URL is the URL.

insert image description here

Implementation of the server

Rely on spring-boot-starter-websocketthe module to realize WebSocket real-time dialogue interaction.

CustomTextWebSocketHandler,expandedTextWebSocketHandler


import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.util.concurrent.CountDownLatch;

/**
 * 文本处理器
 *
 * @see org.springframework.web.socket.handler.TextWebSocketHandler
 */
@Slf4j
public class CustomTextWebSocketHandler extends TextWebSocketHandler {
    
    
    /**
     * 第三方身份,消息身份
     */
    private String thirdPartyId;
    /**
     * 回复消息内容
     */
    private String replyContent;
    private StringBuilder replyContentBuilder;
    /**
     * 完成信号
     */
    private final CountDownLatch doneSignal;

    public CustomTextWebSocketHandler(CountDownLatch doneSignal) {
    
    
        this.doneSignal = doneSignal;
    }

    public String getThirdPartyId() {
    
    
        return thirdPartyId;
    }

    public String getReplyContent() {
    
    
        return replyContent;
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
    
    
        log.info("connection established, session={}", session);
        replyContentBuilder = new StringBuilder(16);
//        super.afterConnectionEstablished(session);
    }

    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
    
    
        super.handleMessage(session, message);
    }

    /**
     * 消息已接收完毕("stop")
     */
    private static final String MESSAGE_DONE = "[DONE]";

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
    
    
//        super.handleTextMessage(session, message);
        String payload = message.getPayload();
        log.info("payload={}", payload);
        OpenAiReplyResponse replyResponse = Jsons.fromJson(payload, OpenAiReplyResponse.class);
        if (replyResponse != null && replyResponse.isSuccess()) {
    
    
            String msg = replyResponse.getMsg();
            if (Strings.isEmpty(msg)) {
    
    
                return;
            } else if (msg.startsWith("【超出最大单次回复字数】")) {
    
    
                // {"msg":"【超出最大单次回复字数】该提示由GPT官方返回,非我司限制,请缩减回复字数","code":1,
                // "extParam":"{\"chatId\":\"10056:8889007174\",\"requestId\":\"b6af5830a5a64fa8a4ca9451d7cb5f6f\",\"bizId\":\"\"}",
                // "id":"chatcmpl-7LThw6J9KmBUOcwK1SSOvdBP2vK9w"}
                return;
            } else if (msg.startsWith("发送内容包含敏感词")) {
    
    
                // {"msg":"发送内容包含敏感词,请修改后重试。不合规汇如下:炸弹","code":1,
                // "extParam":"{\"chatId\":\"10024:8889006970\",\"requestId\":\"828068d945c8415d8f32598ef6ef4ad6\",\"bizId\":\"430\"}",
                // "id":"4d4106c3-f7d4-4393-8cce-a32766d43f8b"}
                matchSensitiveWords = msg;
                // 请求完成
                doneSignal.countDown();
                return;
            } else if (MESSAGE_DONE.equals(msg)) {
    
    
                // 消息已接收完毕
                replyContent = replyContentBuilder.toString();
                thirdPartyId = replyResponse.getId();
                // 请求完成
                doneSignal.countDown();
                log.info("replyContent={}", replyContent);
                return;
            }
            replyContentBuilder.append(msg);
        }
    }

    @Override
    protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {
    
    
        super.handlePongMessage(session, message);
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
    
    
        replyContentBuilder = null;
        log.info("handle transport error, session={}", session, exception);
        doneSignal.countDown();
//        super.handleTransportError(session, exception);
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
    
    
        replyContentBuilder = null;
        log.info("connection closed, session={}, status={}", session, status);
        if (status == CloseStatus.NORMAL) {
    
    
            log.error("connection closed fail, session={}, status={}", session, status);
        }
        doneSignal.countDown();
//        super.afterConnectionClosed(session, status);
    }
}

OpenAiHandler


/**
 * OpenAI处理器
 */
public interface OpenAiHandler<Req, Rsp> {
    
    
    /**
     * 请求前置处理
     *
     * @param req 入参
     */
    default void beforeRequest(Req req) {
    
    
        //
    }

    /**
     * 响应后置处理
     *
     * @param req 入参
     * @param rsp 出参
     */
    default void afterResponse(Req req, Rsp rsp) {
    
    
        //
    }
}

OpenAiService


/**
 * OpenAI服务
 * <pre>
 * API reference introduction
 * https://platform.openai.com/docs/api-reference/introduction
 * </pre>
 */
public interface OpenAiService<Req, Rsp> extends OpenAiHandler<Req, Rsp> {
    
    
    /**
     * 补全指令
     *
     * @param req 入参
     * @return 出参
     */
    default Rsp completions(Req req) {
    
    
        beforeRequest(req);
        Rsp rsp = doCompletions(req);
        afterResponse(req, rsp);
        return rsp;
    }

    /**
     * 操作补全指令
     *
     * @param req 入参
     * @return 出参
     */
    Rsp doCompletions(Req req);
}

OpenAiServiceImpl


import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;

import javax.annotation.Nullable;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

/**
 * OpenAI服务实现
 */
@Slf4j
@Configuration(proxyBeanMethods = false)
@EnableConfigurationProperties(OpenAiProperties.class)
@Service("openAiService")
public class OpenAiServiceImpl implements OpenAiService<CompletionReq, CompletionRsp> {
    
    

    private final OpenAiProperties properties;
    /**
     * 套接字客户端
     */
    private final WebSocketClient webSocketClient;
    /**
     * 模型请求记录服务
     */
    private final ModelRequestRecordService modelRequestRecordService;

    private static final String THREAD_NAME_PREFIX = "gpt.openai";

    public OpenAiServiceImpl(
            OpenAiProperties properties,
            ModelRequestRecordService modelRequestRecordService
    ) {
    
    
        this.properties = properties;
        this.modelRequestRecordService = modelRequestRecordService;
        webSocketClient = WebSocketUtil.applyWebSocketClient(THREAD_NAME_PREFIX);
        log.info("create OpenAiServiceImpl instance");
    }

    @Override
    public void beforeRequest(CompletionReq req) {
    
    
        // 请求身份
        if (Strings.isEmpty(req.getRequestId())) {
    
    
            req.setRequestId(UuidUtil.getUuid());
        }
    }

    @Override
    public void afterResponse(CompletionReq req, CompletionRsp rsp) {
    
    
        if (rsp == null || Strings.isEmpty(rsp.getReplyContent())) {
    
    
            return;
        }
        // 三方敏感词检测
        String matchSensitiveWords = rsp.getMatchSensitiveWords();
        if (Strings.isNotEmpty(matchSensitiveWords)) {
    
    
            // 敏感词命中
            rsp.setMatchSensitiveWords(matchSensitiveWords);
            return;
        }
        // 阶段任务耗时统计
        StopWatch stopWatch = new StopWatch(req.getRequestId());
        try {
    
    
            // 敏感词检测
            stopWatch.start("checkSensitiveWord");
            String replyContent = rsp.getReplyContent();
//            ApiResult<String> apiResult = checkMsg(replyContent, false);
//            stopWatch.stop();
//            if (!apiResult.isSuccess() && Strings.isNotEmpty(apiResult.getData())) {
    
    
//                // 敏感词命中
//                rsp.setMatchSensitiveWords(apiResult.getData());
//                return;
//            }
            // 记录落库
            stopWatch.start("saveModelRequestRecord");
            ModelRequestRecord entity = applyModelRequestRecord(req, rsp);
            modelRequestRecordService.save(entity);
        } finally {
    
    
            if (stopWatch.isRunning()) {
    
    
                stopWatch.stop();
            }
            log.info("afterResponse execute time, {}", stopWatch);
        }
    }

    private static ModelRequestRecord applyModelRequestRecord(
            CompletionReq req, CompletionRsp rsp) {
    
    
        Long orgId = req.getOrgId();
        Long userId = req.getUserId();
        String chatId = applyChatId(orgId, userId);
        return new ModelRequestRecord()
                .setOrgId(orgId)
                .setUserId(userId)
                .setModelType(req.getModelType())
                .setRequestId(req.getRequestId())
                .setBizId(req.getBizId())
                .setChatId(chatId)
                .setThirdPartyId(rsp.getThirdPartyId())
                .setInputMessage(req.getMessage())
                .setReplyContent(rsp.getReplyContent());
    }

    private static String applyChatId(Long orgId, Long userId) {
    
    
        return orgId + ":" + userId;
    }

    private static String applySessionId(String appId, String chatId) {
    
    
        return appId + '_' + chatId;
    }

    private static final String URI_TEMPLATE = "wss://socket.******.com/websocket/{sessionId}";

    @Nullable
    @Override
    public CompletionRsp doCompletions(CompletionReq req) {
    
    
        // 阶段任务耗时统计
        StopWatch stopWatch = new StopWatch(req.getRequestId());
        stopWatch.start("doHandshake");

        // 闭锁,相当于一扇门(同步工具类)
        CountDownLatch doneSignal = new CountDownLatch(1);
        CustomTextWebSocketHandler webSocketHandler = new CustomTextWebSocketHandler(doneSignal);
        String chatId = applyChatId(req.getOrgId(), req.getUserId());
        String sessionId = applySessionId(properties.getAppId(), chatId);
        ListenableFuture<WebSocketSession> listenableFuture = webSocketClient
                .doHandshake(webSocketHandler, URI_TEMPLATE, sessionId);
        stopWatch.stop();
        stopWatch.start("getWebSocketSession");
        long connectionTimeout = properties.getConnectionTimeout().getSeconds();
        try (WebSocketSession webSocketSession = listenableFuture.get(connectionTimeout, TimeUnit.SECONDS)) {
    
    
            stopWatch.stop();
            stopWatch.start("sendMessage");
            OpenAiParam param = applyParam(chatId, req);
            webSocketSession.sendMessage(new TextMessage(Jsons.toJson(param)));
            long requestTimeout = properties.getRequestTimeout().getSeconds();
            // wait for all to finish
            boolean await = doneSignal.await(requestTimeout, TimeUnit.SECONDS);
            if (!await) {
    
    
                log.error("await doneSignal fail, req={}", req);
            }
            String replyContent = webSocketHandler.getReplyContent();
            String matchSensitiveWords = webSocketHandler.getMatchSensitiveWords();
            if (Strings.isEmpty(replyContent) && Strings.isEmpty(matchSensitiveWords)) {
    
    
                // 消息回复异常
                return null;
            }
            String delimiters = properties.getDelimiters();
            replyContent = StrUtil.replaceFirst(replyContent, delimiters, "");
            replyContent = StrUtil.replaceLast(replyContent, delimiters, "");
            String thirdPartyId = webSocketHandler.getThirdPartyId();
            return new CompletionRsp()
                    .setThirdPartyId(thirdPartyId)
                    .setReplyContent(replyContent)
                    .setMatchSensitiveWords(matchSensitiveWords);
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
    
    
            log.error("get WebSocketSession fail, req={}", req, e);
        } catch (IOException e) {
    
    
            log.error("sendMessage fail, req={}", req, e);
        } finally {
    
    
            if (stopWatch.isRunning()) {
    
    
                stopWatch.stop();
            }
            log.info("doCompletions execute time, {}", stopWatch);
        }
        return null;
    }

    private static final int MIN_TOKENS = 11;

    /**
     * 限制单次最大回复单词数(tokens)
     */
    private static int applyMaxTokens(int reqMaxTokens, int maxTokensConfig) {
    
    
        if (reqMaxTokens < MIN_TOKENS || maxTokensConfig < reqMaxTokens) {
    
    
            return maxTokensConfig;
        }
        return reqMaxTokens;
    }

    private OpenAiParam applyParam(String chatId, CompletionReq req) {
    
    
        OpenAiDataExtParam extParam = new OpenAiDataExtParam()
                .setChatId(chatId)
                .setRequestId(req.getRequestId())
                .setBizId(req.getBizId());
        // 提示
        String prompt = req.getPrompt();
        // 分隔符
        String delimiters = properties.getDelimiters();
        String message = prompt + delimiters + req.getMessage() + delimiters;
        int maxTokens = applyMaxTokens(req.getMaxTokens(), properties.getMaxTokens());
        OpenAiData data = new OpenAiData()
                .setMsg(message)
                .setContext(properties.getContext())
                .setLimitTokens(maxTokens)
                .setExtParam(extParam);
        String sign = OpenAiUtil.applySign(message, properties.getSecret());
        return new OpenAiParam()
                .setData(data)
                .setSign(sign);
    }
}

WebSocketUtil


import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;

/**
 * WebSocket辅助方法
 */
public final class WebSocketUtil {
    
    
    /**
     * 创建一个新的WebSocket客户端
     */
    public static WebSocketClient applyWebSocketClient(String threadNamePrefix) {
    
    
        StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
        int cpuNum = Runtime.getRuntime().availableProcessors();
        ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
        taskExecutor.setCorePoolSize(cpuNum);
        taskExecutor.setMaxPoolSize(200);
        taskExecutor.setDaemon(true);
        if (StringUtils.hasText(threadNamePrefix)) {
    
    
            taskExecutor.setThreadNamePrefix(threadNamePrefix);
        } else {
    
    
            taskExecutor.setThreadNamePrefix("gpt.web.socket");
        }
        taskExecutor.initialize();
        webSocketClient.setTaskExecutor(taskExecutor);
        return webSocketClient;
    }
}

OpenAiUtil


import org.springframework.util.DigestUtils;

import java.nio.charset.StandardCharsets;

/**
 * OpenAi辅助方法
 */
public final class OpenAiUtil {
    
    
    /**
     * 对消息内容进行md5加密
     *
     * @param message 消息内容
     * @param secret 加签密钥
     * @return 十六进制加密后的消息内容
     */
    public static String applySign(String message, String secret) {
    
    
        String data = message + secret;
        byte[] dataBytes = data.getBytes(StandardCharsets.UTF_8);
        return DigestUtils.md5DigestAsHex(dataBytes);
    }
}

References

Guess you like

Origin blog.csdn.net/shupili141005/article/details/130811504