La escena de interacción de diálogo de IA utiliza WebSocket para establecer una comunicación de información bidireccional en tiempo real entre el cliente H5 y el servidor

WebSocket facilita el intercambio de datos entre el cliente y el servidor, lo que permite que el servidor envíe datos al cliente de forma activa . En la API de WebSocket, el navegador y el servidor solo necesitan completar un protocolo de enlace, y se puede establecer una conexión persistente entre los dos y se puede realizar una transmisión de datos bidireccional.

1. ¿Por qué necesita WebSocket?

Cualquier persona nueva en WebSocket hace la misma pregunta: ¿Por qué necesitamos otro protocolo cuando ya tenemos HTTP? ¿Qué bien trae?

La respuesta es simple, porque el protocolo HTTP tiene una falla: solo el cliente puede iniciar la comunicación.

Por ejemplo, si queremos saber el clima de hoy, solo el cliente envía una solicitud al servidor y el servidor devuelve el resultado de la consulta. El protocolo HTTP no puede permitir que el servidor envíe información al cliente de forma activa.
inserte la descripción de la imagen aquí
Las características de esta solicitud unidireccional están destinadas a ser muy problemáticas para que el cliente sepa si el servidor tiene cambios de estado continuos. Solo podemos usar "sondeo": de vez en cuando, se envía una consulta para ver si el servidor tiene nueva información. El escenario más típico es la sala de chat.

El sondeo es ineficiente y desperdicia recursos (ya que la conexión debe mantenerse abierta o la conexión HTTP siempre está abierta). Por lo tanto, los ingenieros han estado pensando, ¿hay una mejor manera? Así fue como se inventó WebSocket.

2. Introducción

El protocolo WebSocket nació en 2008 y se convirtió en un estándar internacional en 2011. Todos los navegadores ya lo soportan.

Su característica más importante es que el servidor puede enviar información de forma activa al cliente, y el cliente también puede enviar información de forma activa al servidor. Es un diálogo real de dos vías y pertenece a un tipo de tecnología de envío de servidor .
inserte la descripción de la imagen aquí
Otras características incluyen:

(1) Basado en el protocolo TCP, la implementación del lado del servidor es relativamente fácil.

(2) Tiene buena compatibilidad con el protocolo HTTP. Los puertos predeterminados también son 80 y 443, y la fase de negociación utiliza el protocolo HTTP, por lo que no es fácil protegerse durante la negociación y puede pasar a través de varios servidores proxy HTTP.

(3) El formato de datos es relativamente ligero, la sobrecarga de rendimiento es pequeña y la comunicación es eficiente.

(4) Se pueden enviar texto o datos binarios.

(5) No hay restricción del mismo origen y el cliente puede comunicarse con cualquier servidor.

(6) El identificador de protocolo es ws (o wss si está encriptado) y la URL del servidor es la URL.

inserte la descripción de la imagen aquí

implementacion del servidor

Confíe en spring-boot-starter-websocketel módulo para realizar la interacción de diálogo en tiempo real de WebSocket.

CustomTextWebSocketHandler,expandidoTextWebSocketHandler


import cn.hutool.json.JSONUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PongMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

import java.util.concurrent.CountDownLatch;

/**
 * 文本处理器
 *
 * @see org.springframework.web.socket.handler.TextWebSocketHandler
 */
@Slf4j
public class CustomTextWebSocketHandler extends TextWebSocketHandler {
    
    
    /**
     * 第三方身份,消息身份
     */
    private String thirdPartyId;
    /**
     * 回复消息内容
     */
    private String replyContent;
    private StringBuilder replyContentBuilder;
    /**
     * 完成信号
     */
    private final CountDownLatch doneSignal;

    public CustomTextWebSocketHandler(CountDownLatch doneSignal) {
    
    
        this.doneSignal = doneSignal;
    }

    public String getThirdPartyId() {
    
    
        return thirdPartyId;
    }

    public String getReplyContent() {
    
    
        return replyContent;
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
    
    
        log.info("connection established, session={}", session);
        replyContentBuilder = new StringBuilder(16);
//        super.afterConnectionEstablished(session);
    }

    @Override
    public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
    
    
        super.handleMessage(session, message);
    }

    /**
     * 消息已接收完毕("stop")
     */
    private static final String MESSAGE_DONE = "[DONE]";

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
    
    
//        super.handleTextMessage(session, message);
        String payload = message.getPayload();
        log.info("payload={}", payload);
        OpenAiReplyResponse replyResponse = Jsons.fromJson(payload, OpenAiReplyResponse.class);
        if (replyResponse != null && replyResponse.isSuccess()) {
    
    
            String msg = replyResponse.getMsg();
            if (Strings.isEmpty(msg)) {
    
    
                return;
            } else if (msg.startsWith("【超出最大单次回复字数】")) {
    
    
                // {"msg":"【超出最大单次回复字数】该提示由GPT官方返回,非我司限制,请缩减回复字数","code":1,
                // "extParam":"{\"chatId\":\"10056:8889007174\",\"requestId\":\"b6af5830a5a64fa8a4ca9451d7cb5f6f\",\"bizId\":\"\"}",
                // "id":"chatcmpl-7LThw6J9KmBUOcwK1SSOvdBP2vK9w"}
                return;
            } else if (msg.startsWith("发送内容包含敏感词")) {
    
    
                // {"msg":"发送内容包含敏感词,请修改后重试。不合规汇如下:炸弹","code":1,
                // "extParam":"{\"chatId\":\"10024:8889006970\",\"requestId\":\"828068d945c8415d8f32598ef6ef4ad6\",\"bizId\":\"430\"}",
                // "id":"4d4106c3-f7d4-4393-8cce-a32766d43f8b"}
                matchSensitiveWords = msg;
                // 请求完成
                doneSignal.countDown();
                return;
            } else if (MESSAGE_DONE.equals(msg)) {
    
    
                // 消息已接收完毕
                replyContent = replyContentBuilder.toString();
                thirdPartyId = replyResponse.getId();
                // 请求完成
                doneSignal.countDown();
                log.info("replyContent={}", replyContent);
                return;
            }
            replyContentBuilder.append(msg);
        }
    }

    @Override
    protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception {
    
    
        super.handlePongMessage(session, message);
    }

    @Override
    public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
    
    
        replyContentBuilder = null;
        log.info("handle transport error, session={}", session, exception);
        doneSignal.countDown();
//        super.handleTransportError(session, exception);
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
    
    
        replyContentBuilder = null;
        log.info("connection closed, session={}, status={}", session, status);
        if (status == CloseStatus.NORMAL) {
    
    
            log.error("connection closed fail, session={}, status={}", session, status);
        }
        doneSignal.countDown();
//        super.afterConnectionClosed(session, status);
    }
}

OpenAiHandler


/**
 * OpenAI处理器
 */
public interface OpenAiHandler<Req, Rsp> {
    
    
    /**
     * 请求前置处理
     *
     * @param req 入参
     */
    default void beforeRequest(Req req) {
    
    
        //
    }

    /**
     * 响应后置处理
     *
     * @param req 入参
     * @param rsp 出参
     */
    default void afterResponse(Req req, Rsp rsp) {
    
    
        //
    }
}

OpenAiService


/**
 * OpenAI服务
 * <pre>
 * API reference introduction
 * https://platform.openai.com/docs/api-reference/introduction
 * </pre>
 */
public interface OpenAiService<Req, Rsp> extends OpenAiHandler<Req, Rsp> {
    
    
    /**
     * 补全指令
     *
     * @param req 入参
     * @return 出参
     */
    default Rsp completions(Req req) {
    
    
        beforeRequest(req);
        Rsp rsp = doCompletions(req);
        afterResponse(req, rsp);
        return rsp;
    }

    /**
     * 操作补全指令
     *
     * @param req 入参
     * @return 出参
     */
    Rsp doCompletions(Req req);
}

OpenAiServiceImpl


import cn.hutool.core.util.StrUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.client.WebSocketClient;

import javax.annotation.Nullable;
import java.io.IOException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

/**
 * OpenAI服务实现
 */
@Slf4j
@Configuration(proxyBeanMethods = false)
@EnableConfigurationProperties(OpenAiProperties.class)
@Service("openAiService")
public class OpenAiServiceImpl implements OpenAiService<CompletionReq, CompletionRsp> {
    
    

    private final OpenAiProperties properties;
    /**
     * 套接字客户端
     */
    private final WebSocketClient webSocketClient;
    /**
     * 模型请求记录服务
     */
    private final ModelRequestRecordService modelRequestRecordService;

    private static final String THREAD_NAME_PREFIX = "gpt.openai";

    public OpenAiServiceImpl(
            OpenAiProperties properties,
            ModelRequestRecordService modelRequestRecordService
    ) {
    
    
        this.properties = properties;
        this.modelRequestRecordService = modelRequestRecordService;
        webSocketClient = WebSocketUtil.applyWebSocketClient(THREAD_NAME_PREFIX);
        log.info("create OpenAiServiceImpl instance");
    }

    @Override
    public void beforeRequest(CompletionReq req) {
    
    
        // 请求身份
        if (Strings.isEmpty(req.getRequestId())) {
    
    
            req.setRequestId(UuidUtil.getUuid());
        }
    }

    @Override
    public void afterResponse(CompletionReq req, CompletionRsp rsp) {
    
    
        if (rsp == null || Strings.isEmpty(rsp.getReplyContent())) {
    
    
            return;
        }
        // 三方敏感词检测
        String matchSensitiveWords = rsp.getMatchSensitiveWords();
        if (Strings.isNotEmpty(matchSensitiveWords)) {
    
    
            // 敏感词命中
            rsp.setMatchSensitiveWords(matchSensitiveWords);
            return;
        }
        // 阶段任务耗时统计
        StopWatch stopWatch = new StopWatch(req.getRequestId());
        try {
    
    
            // 敏感词检测
            stopWatch.start("checkSensitiveWord");
            String replyContent = rsp.getReplyContent();
//            ApiResult<String> apiResult = checkMsg(replyContent, false);
//            stopWatch.stop();
//            if (!apiResult.isSuccess() && Strings.isNotEmpty(apiResult.getData())) {
    
    
//                // 敏感词命中
//                rsp.setMatchSensitiveWords(apiResult.getData());
//                return;
//            }
            // 记录落库
            stopWatch.start("saveModelRequestRecord");
            ModelRequestRecord entity = applyModelRequestRecord(req, rsp);
            modelRequestRecordService.save(entity);
        } finally {
    
    
            if (stopWatch.isRunning()) {
    
    
                stopWatch.stop();
            }
            log.info("afterResponse execute time, {}", stopWatch);
        }
    }

    private static ModelRequestRecord applyModelRequestRecord(
            CompletionReq req, CompletionRsp rsp) {
    
    
        Long orgId = req.getOrgId();
        Long userId = req.getUserId();
        String chatId = applyChatId(orgId, userId);
        return new ModelRequestRecord()
                .setOrgId(orgId)
                .setUserId(userId)
                .setModelType(req.getModelType())
                .setRequestId(req.getRequestId())
                .setBizId(req.getBizId())
                .setChatId(chatId)
                .setThirdPartyId(rsp.getThirdPartyId())
                .setInputMessage(req.getMessage())
                .setReplyContent(rsp.getReplyContent());
    }

    private static String applyChatId(Long orgId, Long userId) {
    
    
        return orgId + ":" + userId;
    }

    private static String applySessionId(String appId, String chatId) {
    
    
        return appId + '_' + chatId;
    }

    private static final String URI_TEMPLATE = "wss://socket.******.com/websocket/{sessionId}";

    @Nullable
    @Override
    public CompletionRsp doCompletions(CompletionReq req) {
    
    
        // 阶段任务耗时统计
        StopWatch stopWatch = new StopWatch(req.getRequestId());
        stopWatch.start("doHandshake");

        // 闭锁,相当于一扇门(同步工具类)
        CountDownLatch doneSignal = new CountDownLatch(1);
        CustomTextWebSocketHandler webSocketHandler = new CustomTextWebSocketHandler(doneSignal);
        String chatId = applyChatId(req.getOrgId(), req.getUserId());
        String sessionId = applySessionId(properties.getAppId(), chatId);
        ListenableFuture<WebSocketSession> listenableFuture = webSocketClient
                .doHandshake(webSocketHandler, URI_TEMPLATE, sessionId);
        stopWatch.stop();
        stopWatch.start("getWebSocketSession");
        long connectionTimeout = properties.getConnectionTimeout().getSeconds();
        try (WebSocketSession webSocketSession = listenableFuture.get(connectionTimeout, TimeUnit.SECONDS)) {
    
    
            stopWatch.stop();
            stopWatch.start("sendMessage");
            OpenAiParam param = applyParam(chatId, req);
            webSocketSession.sendMessage(new TextMessage(Jsons.toJson(param)));
            long requestTimeout = properties.getRequestTimeout().getSeconds();
            // wait for all to finish
            boolean await = doneSignal.await(requestTimeout, TimeUnit.SECONDS);
            if (!await) {
    
    
                log.error("await doneSignal fail, req={}", req);
            }
            String replyContent = webSocketHandler.getReplyContent();
            String matchSensitiveWords = webSocketHandler.getMatchSensitiveWords();
            if (Strings.isEmpty(replyContent) && Strings.isEmpty(matchSensitiveWords)) {
    
    
                // 消息回复异常
                return null;
            }
            String delimiters = properties.getDelimiters();
            replyContent = StrUtil.replaceFirst(replyContent, delimiters, "");
            replyContent = StrUtil.replaceLast(replyContent, delimiters, "");
            String thirdPartyId = webSocketHandler.getThirdPartyId();
            return new CompletionRsp()
                    .setThirdPartyId(thirdPartyId)
                    .setReplyContent(replyContent)
                    .setMatchSensitiveWords(matchSensitiveWords);
        } catch (InterruptedException | ExecutionException | TimeoutException e) {
    
    
            log.error("get WebSocketSession fail, req={}", req, e);
        } catch (IOException e) {
    
    
            log.error("sendMessage fail, req={}", req, e);
        } finally {
    
    
            if (stopWatch.isRunning()) {
    
    
                stopWatch.stop();
            }
            log.info("doCompletions execute time, {}", stopWatch);
        }
        return null;
    }

    private static final int MIN_TOKENS = 11;

    /**
     * 限制单次最大回复单词数(tokens)
     */
    private static int applyMaxTokens(int reqMaxTokens, int maxTokensConfig) {
    
    
        if (reqMaxTokens < MIN_TOKENS || maxTokensConfig < reqMaxTokens) {
    
    
            return maxTokensConfig;
        }
        return reqMaxTokens;
    }

    private OpenAiParam applyParam(String chatId, CompletionReq req) {
    
    
        OpenAiDataExtParam extParam = new OpenAiDataExtParam()
                .setChatId(chatId)
                .setRequestId(req.getRequestId())
                .setBizId(req.getBizId());
        // 提示
        String prompt = req.getPrompt();
        // 分隔符
        String delimiters = properties.getDelimiters();
        String message = prompt + delimiters + req.getMessage() + delimiters;
        int maxTokens = applyMaxTokens(req.getMaxTokens(), properties.getMaxTokens());
        OpenAiData data = new OpenAiData()
                .setMsg(message)
                .setContext(properties.getContext())
                .setLimitTokens(maxTokens)
                .setExtParam(extParam);
        String sign = OpenAiUtil.applySign(message, properties.getSecret());
        return new OpenAiParam()
                .setData(data)
                .setSign(sign);
    }
}

WebSocketUtil


import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.StringUtils;
import org.springframework.web.socket.client.WebSocketClient;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;

/**
 * WebSocket辅助方法
 */
public final class WebSocketUtil {
    
    
    /**
     * 创建一个新的WebSocket客户端
     */
    public static WebSocketClient applyWebSocketClient(String threadNamePrefix) {
    
    
        StandardWebSocketClient webSocketClient = new StandardWebSocketClient();
        int cpuNum = Runtime.getRuntime().availableProcessors();
        ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
        taskExecutor.setCorePoolSize(cpuNum);
        taskExecutor.setMaxPoolSize(200);
        taskExecutor.setDaemon(true);
        if (StringUtils.hasText(threadNamePrefix)) {
    
    
            taskExecutor.setThreadNamePrefix(threadNamePrefix);
        } else {
    
    
            taskExecutor.setThreadNamePrefix("gpt.web.socket");
        }
        taskExecutor.initialize();
        webSocketClient.setTaskExecutor(taskExecutor);
        return webSocketClient;
    }
}

OpenAiUtil


import org.springframework.util.DigestUtils;

import java.nio.charset.StandardCharsets;

/**
 * OpenAi辅助方法
 */
public final class OpenAiUtil {
    
    
    /**
     * 对消息内容进行md5加密
     *
     * @param message 消息内容
     * @param secret 加签密钥
     * @return 十六进制加密后的消息内容
     */
    public static String applySign(String message, String secret) {
    
    
        String data = message + secret;
        byte[] dataBytes = data.getBytes(StandardCharsets.UTF_8);
        return DigestUtils.md5DigestAsHex(dataBytes);
    }
}

Referencias

Supongo que te gusta

Origin blog.csdn.net/shupili141005/article/details/130811504
Recomendado
Clasificación