开发像chatGPT一样的流式接口,打造你的AI客服,包含java、python接口和前端实现

概要

在使用chatGPT等大模型时,并不像传统API一样,一次性返回所有回答,而是通过字节流一段一段返回的,下文将介绍如何实现这样的接口
在这里插入图片描述

整体流程

java版本

1、新建sseService,建立sse连接

    public  SseEmitter connect(String sessionId) {
    
    
        SseEmitter sseEmitter = new SseEmitter(0L);
        // 注册回调
        sseEmitter.onCompletion(completionCallBack(sessionId));
        sseEmitter.onError(errorCallBack(sessionId)); //错误处理
        sseEmitter.onTimeout(timeoutCallBack(sessionId)); //超时处理
        //RedisService.increment(ONLINE_SESSION_COUNT);
        sseEmitterMap.put(getCacheKey(sessionId),sseEmitter);

        log.info("创建新的sse连接,当前会话:{}", sessionId);
        return sseEmitter;
    }
/
    private  Runnable timeoutCallBack(String userId) {
    
    
        return () -> {
    
    
            log.info("连接超时:{}", userId);
            removeUser(getCacheKey(userId));
        };
    }

    private  Consumer<Throwable> errorCallBack(String userId) {
    
    
        return throwable -> {
    
    
            log.info("连接异常:{}", userId);
            removeUser(getCacheKey(userId));
        };
    }

2、发字节流消息

    public  void sendMessage(String cacheKey, String message) {
    
    
        if (sseEmitterMap.containsKey(cacheKey)) {
    
    
            // if (SpringUtils.getBean(RedisCache.class).hasKey(cacheKey)) {
    
    
            try {
    
    
                // SseEmitter sseEmitter = SpringUtils.getBean(RedisCache.class).getCacheObject(cacheKey);
                SseEmitter sseEmitter = sseEmitterMap.get(cacheKey);

                sseEmitter.send(message,MediaType.TEXT_EVENT_STREAM);
                log.info("用户[{}]推送成功:{}", cacheKey, message);

            } catch (IOException e) {
    
    
                log.error("用户[{}]推送异常:{}", cacheKey, e.getMessage());
                removeUser(cacheKey);
            }
        }
    }

3、开发接口

@RestController
@RequestMapping("/sse")
public class SseEmitterController {
    
    
    /**
     * 创建用户连接并返回 SseEmitter
     *
     * @param sessionId 用户ID
     * @return SseEmitter
     */

    @Autowired
    private SseEmitterServer sseEmitterServer;
    @Autowired
    private ChatGpt chatGpt;
    private  final Logger log = LoggerFactory.getLogger(SseEmitterServer.class);
    @GetMapping("/send")
    public Object send(@LoginUser Long userId, @RequestParam("issue") String issue) {
    
    
        SseEmitter sseEmitter = null;
        Long guestId = null;
        if (null == userId) {
    
    
            guestId = RedisService.genId();
            sseEmitter = sseEmitterServer.connect("" + guestId);
        } else {
    
    
            sseEmitter =   sseEmitterServer.connect(""+userId);
        }
        chatGpt.sendIssue(userId, issue,guestId);
        return sseEmitter;
    }
    @GetMapping("/close")
    public Object close(@LoginUser Long userId) {
    
    
        sseEmitterServer.closeSseConnect(""+userId);
        return ResponseUtil.ok();
    }
}

4、sseService完整代码

public class SseEmitterServer {
    
    

    private  final Logger log = LoggerFactory.getLogger(SseEmitterServer.class);

    private  final String KEY_PREFIX = "SseEmitter_";
    private  final String ONLINE_SESSION_COUNT = "OnlineSessionCount";


    private  Map<String, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();

    public  SseEmitter connect(String sessionId) {
    
    
        SseEmitter sseEmitter = new SseEmitter(0L);
        // 注册回调
        sseEmitter.onCompletion(completionCallBack(sessionId));
        sseEmitter.onError(errorCallBack(sessionId));
        sseEmitter.onTimeout(timeoutCallBack(sessionId));
        RedisService.increment(ONLINE_SESSION_COUNT);
        sseEmitterMap.put(getCacheKey(sessionId),sseEmitter);

        log.info("创建新的sse连接,当前会话:{}", sessionId);
        return sseEmitter;
    }
    public void closeSseConnect(String userId) {
    
    
        String cacheKey = getCacheKey(userId);
        SseEmitter sseEmitter = sseEmitterMap.get(cacheKey);
        if (sseEmitter != null) {
    
    
            sseEmitter.complete();
            removeUser(cacheKey);
        }
    }
    /**
     * 给指定用户发送信息  -- 单播
     */
    public  void sendMsg(String userId, String message) {
    
    
        sendMessage(getCacheKey(userId),message);
    }

    /**
     * 给指定用户发送信息
     */
    public  void sendMessage(String cacheKey, String message) {
    
    
        if (sseEmitterMap.containsKey(cacheKey)) {
    
    
            // if (SpringUtils.getBean(RedisCache.class).hasKey(cacheKey)) {
    
    
            try {
    
    
                // SseEmitter sseEmitter = SpringUtils.getBean(RedisCache.class).getCacheObject(cacheKey);
                SseEmitter sseEmitter = sseEmitterMap.get(cacheKey);

                sseEmitter.send(message,MediaType.TEXT_EVENT_STREAM);
                log.info("用户[{}]推送成功:{}", cacheKey, message);

            } catch (IOException e) {
    
    
                log.error("用户[{}]推送异常:{}", cacheKey, e.getMessage());
                removeUser(cacheKey);
            }
        }
    }



    /**
     * 移除用户连接
     */
    public  void removeUser(String cacheKey) {
    
    
        // SpringUtils.getBean(RedisCache.class).deleteObject(cacheKey);
        sseEmitterMap.remove(cacheKey);
        // 数量-1
        RedisService.decrement(ONLINE_SESSION_COUNT);

        log.info("移除用户:{}", cacheKey);
    }


    /**
     * 获取当前连接数量
     */
    public  int getUserCount() {
    
    
        Object o = RedisService.get(ONLINE_SESSION_COUNT);
        Integer count =  o instanceof Integer ? ((Integer) o) : 0;
        return count;
    }

    private  Runnable completionCallBack(String userId) {
    
    
        return () -> {
    
    
            log.info("结束连接:{}", userId);
            removeUser(getCacheKey(userId));
        };
    }

    private  Runnable timeoutCallBack(String userId) {
    
    
        return () -> {
    
    
            log.info("连接超时:{}", userId);
            removeUser(getCacheKey(userId));
        };
    }

    private  Consumer<Throwable> errorCallBack(String userId) {
    
    
        return throwable -> {
    
    
            log.info("连接异常:{}", userId);
            removeUser(getCacheKey(userId));
        };
    }

    /**
     * 设置cache key
     *
     * @param configKey 参数键
     * @return 缓存键key
     */
    public  String getCacheKey(String configKey){
    
    
        return KEY_PREFIX + configKey;
    }

}

5、最终使用,接入gatGPT实现问答
我这里使用了代理访问openAi的api,这里可以根据自己的情况调整。

public  void  sendPostSse(String url, Map<String, String> header, String issue, Long userId,Long userResourceId) throws Exception {
    
    
        OutputStreamWriter out = null;
        BufferedReader in = null;
        StringBuilder result = new StringBuilder();

        try {
    
    

            URL realUrl = new URL(url);
            HttpURLConnection conn = (HttpURLConnection) realUrl.openConnection();
            // 发送POST请求必须设置如下两行
            if (header != null && !header.isEmpty()) {
    
    
                for (Map.Entry<String, String> entry : header.entrySet()) {
    
    
                    conn.setRequestProperty(entry.getKey(), entry.getValue());
                }
            }
            conn.setDoOutput(true);
            conn.setDoInput(true);
            // POST方法
            conn.setRequestMethod("POST");

            // 设置通用的请求属性
            conn.setRequestProperty("accept", "*/*");
            conn.setRequestProperty("connection", "Keep-Alive");
            conn.setRequestProperty("user-agent",
                    "Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1;SV1)");
            conn.setRequestProperty("Content-Type", "application/json");
            conn.connect();
            // 获取URLConnection对象对应的输出流
            out = new OutputStreamWriter(conn.getOutputStream(), "UTF-8");
            // 发送请求参数


            Map<String, Object> issueReq = new HashMap<>();
            QChatLogEnt qChatLogEnt = QChatLogEnt.chatLogEnt;
            Date now = new Date();
            BooleanExpression expression = qChatLogEnt.userId.eq(userId);
            //查询最近的聊天记录,放入请求中,实现上下文对应
            ChatLogEnt ent = chatLogSvc.findOneByDsl(expression, Sort.by("id").descending());
            String msg = issue;
            if (null != ent) {
    
    
                List<Object> history = new ArrayList<>();
                Map item = new HashMap();
                item.put("role", "user");
                item.put("content", ent.getUserAsk());
                history.add(item);
                Map assistant = new HashMap();
                assistant.put("role", "assistant");
                assistant.put("content", ent.getUserAnswer());
                history.add(assistant);
                issueReq.put("history", history);
            }
            issueReq.put("issue", msg);
            String uid = "" + userId;
            issueReq.put("userId", uid);
            out.write(ConvertUtil.o2s(issueReq));
            // flush输出流的缓冲
            out.flush();
            // 定义BufferedReader输入流来读取URL的响应
            in = new BufferedReader(
                    new InputStreamReader(conn.getInputStream(), "UTF-8"));
            String line;
            while ((line = in.readLine()) != null) {
    
    
                sseEmitterServer.sendMsg(uid, line + "\n");
                result.append(line);
            }
            sseEmitterServer.sendMsg(uid, "complete:end\n");

            chatLogSvc.add(userId, issue, result.toString()); //将聊天内容插入数据库
        } catch (IOException e) {
    
    
            e.printStackTrace();
            throw e;
        }
        //使用finally块来关闭输出流、输入流
        finally {
    
    
            try {
    
    
                sseEmitterServer.closeSseConnect(""+userId);
                if (out != null) {
    
    
                    out.close();
                }
                if (in != null) {
    
    
                    in.close();
                }
            } catch (IOException ex) {
    
    
                ex.printStackTrace();
            }
        }
    }

前端js接入

前端使用了event-source-polyfill库实现接入
安装前端库 npm install event-source-polyfill
发送如接收

saveMsg(tomsg) {
    
    
                this.show = false
                if (this.lists && this.lists.length && this.lists[this.lists.length-1].process) {
    
    
                    this.$toast('AI正在回答上一个问题');
                    return
                }
                this.lists.push({
    
    
                    id: this.userid,
                    face: this.userimg,
                    word: tomsg
                });
                const es = new EventSourcePolyfill('/sse/send?issue=' + encodeURIComponent(tomsg),{
    
    
                    headers: {
    
    
                        'Content-Type': 'application/json',
                        'Token':getToken()
                    },
                    data: {
    
    "issue":tomsg}
                });
                this.es = es
                let receive = {
    
    
                    id: 1529,
                    face: '你的头像',
                    process:true,
                    word:""
                    // touserdata.words[Math.floor(Math.random() * touserdata.words.length)]
                    //     .info
                }
                this.process = true
                this.lists.push(receive );
                this.scrollToBottom()
                es.onopen =  (event) => {
    
    
                    console.log("连接成功", event);
                };
                es.onmessage = (event) =>{
    
    
                    if (event.data.indexOf("complete:") == 0 || event.data.indexOf("error:") == 0) {
    
    
                        this.closeSse(es)
                        receive.process = false
                        this.process = false
                        if (event.data.indexOf("登录失效") != -1) {
    
    
                            this.$toast('请重新登录');
                            return
                        }
                        if (event.data.indexOf("complete:") == 0) {
    
    
                            return
                        }
                    }
                    let msg = event.data
                    if (/^[A-Za-z]+$/.test(msg)) {
    
    
                        msg = " " + msg+" "
                    }
                    let word = receive.word + msg ;
                    //这里做了返回数据排版布局,不需要可以删除
                    if (/[;:。!?\.]\d{1,3}\./.test(word)) {
    
    
                        word = word.replace(/([;:。!?\.])(\d{1,3}\.)/,"$1<br/><br/>&nbsp;&nbsp;&nbsp;&nbsp;$2");
                    }

                    receive.word = word
                    this.scrollToBottom() //聊天窗滚动到底部
                };
                es.onerror = (error) => {
    
    
                    this.closeSse(es)
                    receive.word += "网络异常:" +  error.status;
                    receive.process = false
                    this.process = false
                    console.log("错误", error);
                };
            }

python接口

本接口以chatGPT接口为例实现

import os
import openai
from django.http import StreamingHttpResponse
from .utils import jsonSuccess,jsonError
from django.views.decorators.csrf import csrf_exempt
import json
openai.api_key = os.environ.get("openai_api_key")
#openai.proxy = "http://代理地址:端口"
suffix = "The end of the story."
@csrf_exempt
def req(request):
    postBody = request.body
    json_result = json.loads(postBody)
    # 提问
    issue = json_result['issue']
    userId = json_result['userId']
    historyMsg = json_result.get("history",[])
    history = ""
    if len(historyMsg) > 1:
        history = "Human:" +  historyMsg[0].content
        history += historyMsg[1].content

    # 访问OpenAI接口
    issue = history + "\nHuman:" + issue
    try:
        def event_stream():
            response = openai.Completion.create(
                # model='gpt-3.5-turbo',
                #engine='curie', #gpt4
                engine='text-davinci-003', #gpt4
                prompt=issue,
                temperature=0.9,  # 随机性,0-0.9 越大差异越大
                max_tokens=2000,  # 返回最大字数
                top_p=0.7,  # 返回最大字数的概率
                frequency_penalty=0.0,  # 控制字符的重复度
                presence_penalty=0.6,  # 控制主题的重复度
                stream=True,
                # suffix="=====结束!"
                user=userId
            )
            for chunk in response:
                text = chunk["choices"][0]["text"]
                yield text + '\n\n'
        return StreamingHttpResponse(event_stream(), content_type='text/event-stream')
    except Exception  as e:
        print(e)
        return jsonError(e)
def  reqChat(request):
    forwarded_addresses = request.META.get('HTTP_X_FORWARDED_FOR')
    if forwarded_addresses:
        client_addr = forwarded_addresses.split(',')[0]
    else:
        client_addr = request.META.get('REMOTE_ADDR')
    print('ip:' + client_addr)
    postBody = request.body
    json_result = json.loads(postBody)
    # 提问
    issue = json_result['issue']
    print('issue:' + issue)
    userId = json_result['userId']
    historyMsg = json_result.get("history",[])
import os
import openai
from django.http import StreamingHttpResponse
from .utils import jsonSuccess,jsonError
from django.views.decorators.csrf import csrf_exempt
import json
openai.api_key = os.environ.get("openai_api_key")
openai.proxy = "http://172.16.247.94:13060"
suffix = "The end of the story."
@csrf_exempt
def req(request):
    postBody = request.body
    json_result = json.loads(postBody)
    # 提问
    issue = json_result['issue']
    userId = json_result['userId']
    historyMsg = json_result.get("history",[])
    history = ""
    if len(historyMsg) > 1:
        history = "Human:" +  historyMsg[0].content
        history += historyMsg[1].content

    # 访问OpenAI接口
    issue = history + "\nHuman:" + issue
    try:
        def event_stream():
            response = openai.Completion.create(
                # model='gpt-3.5-turbo',
                #engine='curie', #gpt4
                engine='text-davinci-003', #gpt4
                prompt=issue,
                temperature=0.9,  # 随机性,0-0.9 越大差异越大
                max_tokens=2000,  # 返回最大字数
                top_p=0.7,  # 返回最大字数的概率
                frequency_penalty=0.0,  # 控制字符的重复度
                presence_penalty=0.6,  # 控制主题的重复度
                stream=True,
                # suffix="=====结束!"
                user=userId
            )
            for chunk in response:
                text = chunk["choices"][0]["text"]
                yield text + '\n\n'
        return StreamingHttpResponse(event_stream(), content_type='text/event-stream')
    except Exception  as e:
        print(e)
        return jsonError(e)
def  reqChat(request):
    forwarded_addresses = request.META.get('HTTP_X_FORWARDED_FOR')
    if forwarded_addresses:
        client_addr = forwarded_addresses.split(',')[0]
    else:
        client_addr = request.META.get('REMOTE_ADDR')
    print('ip:' + client_addr)
    postBody = request.body
    json_result = json.loads(postBody)
    # 提问
    issue = json_result['issue']
    print('issue:' + issue)
    userId = json_result['userId']
    historyMsg = json_result.get("history",[])
    historyMsg.append({
    
    "role": "user", "content": issue})
    # 访问OpenAI接口
    try:
        def event_stream():
            response =   openai.ChatCompletion.create(
                #["gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301"]:
                #model='gpt-4-0314',
                model='gpt-3.5-turbo',
                # engine='text-davinci-003', #gpt4
                messages=historyMsg,
                temperature=0.9,  # 随机性,0-0.9 越大差异越大
                max_tokens=2000,  # 返回最大字数
                top_p=0.7,  # 返回最大字数的概率
                frequency_penalty=0.0,  # 控制字符的重复度
                presence_penalty=0.6,  # 控制主题的重复度
                request_timeout=60,
                stream=True,
                # suffix="=====结束!"
                user=userId
            )
            for chunk in response:
                chunk_message = chunk['choices'][0]['delta']
                if hasattr(chunk_message,'content'):
                    yield chunk_message['content'] + "\n"
        return StreamingHttpResponse(event_stream(), content_type='text/event-stream')
    except openai.error.Timeout as e:
        print("Error: " + str(e))
    except Exception as e:
        print("Error: " + str(e))
def createimg(request):
    postBody = request.body
    json_result = json.loads(postBody)
    # 提问
    issue = json_result['issue']
    userId = json_result['userId']
    # 访问OpenAI接口
    try:
        image = openai.Image.create(
            prompt=issue,
            n=1,
            size="512x512",
            response_format="url"
        )
        # img_b64 = base64.b64encode(image.data)
        # img_res = img_b64.decode('utf-8')
        return  jsonSuccess( image['data'][0]['url'])
    except Exception  as e:
        return jsonError(e)

小结

由于完整的代码穿插的业务流程,比如用户权限、使用次数限制、ui等等,这里就不方便提供,仅供学习参考!