概要
在使用chatGPT等大模型时,并不像传统API一样,一次性返回所有回答,而是通过字节流一段一段返回的,下文将介绍如何实现这样的接口
整体流程
java版本
1、新建sseService,建立sse连接
public SseEmitter connect(String sessionId) {
SseEmitter sseEmitter = new SseEmitter(0L);
// 注册回调
sseEmitter.onCompletion(completionCallBack(sessionId));
sseEmitter.onError(errorCallBack(sessionId)); //错误处理
sseEmitter.onTimeout(timeoutCallBack(sessionId)); //超时处理
//RedisService.increment(ONLINE_SESSION_COUNT);
sseEmitterMap.put(getCacheKey(sessionId),sseEmitter);
log.info("创建新的sse连接,当前会话:{}", sessionId);
return sseEmitter;
}
/
private Runnable timeoutCallBack(String userId) {
return () -> {
log.info("连接超时:{}", userId);
removeUser(getCacheKey(userId));
};
}
private Consumer<Throwable> errorCallBack(String userId) {
return throwable -> {
log.info("连接异常:{}", userId);
removeUser(getCacheKey(userId));
};
}
2、发字节流消息
public void sendMessage(String cacheKey, String message) {
if (sseEmitterMap.containsKey(cacheKey)) {
// if (SpringUtils.getBean(RedisCache.class).hasKey(cacheKey)) {
try {
// SseEmitter sseEmitter = SpringUtils.getBean(RedisCache.class).getCacheObject(cacheKey);
SseEmitter sseEmitter = sseEmitterMap.get(cacheKey);
sseEmitter.send(message,MediaType.TEXT_EVENT_STREAM);
log.info("用户[{}]推送成功:{}", cacheKey, message);
} catch (IOException e) {
log.error("用户[{}]推送异常:{}", cacheKey, e.getMessage());
removeUser(cacheKey);
}
}
}
3、开发接口
@RestController
@RequestMapping("/sse")
public class SseEmitterController {
/**
* 创建用户连接并返回 SseEmitter
*
* @param sessionId 用户ID
* @return SseEmitter
*/
@Autowired
private SseEmitterServer sseEmitterServer;
@Autowired
private ChatGpt chatGpt;
private final Logger log = LoggerFactory.getLogger(SseEmitterServer.class);
@GetMapping("/send")
public Object send(@LoginUser Long userId, @RequestParam("issue") String issue) {
SseEmitter sseEmitter = null;
Long guestId = null;
if (null == userId) {
guestId = RedisService.genId();
sseEmitter = sseEmitterServer.connect("" + guestId);
} else {
sseEmitter = sseEmitterServer.connect(""+userId);
}
chatGpt.sendIssue(userId, issue,guestId);
return sseEmitter;
}
@GetMapping("/close")
public Object close(@LoginUser Long userId) {
sseEmitterServer.closeSseConnect(""+userId);
return ResponseUtil.ok();
}
}
4、sseService完整代码
public class SseEmitterServer {
private final Logger log = LoggerFactory.getLogger(SseEmitterServer.class);
private final String KEY_PREFIX = "SseEmitter_";
private final String ONLINE_SESSION_COUNT = "OnlineSessionCount";
private Map<String, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();
public SseEmitter connect(String sessionId) {
SseEmitter sseEmitter = new SseEmitter(0L);
// 注册回调
sseEmitter.onCompletion(completionCallBack(sessionId));
sseEmitter.onError(errorCallBack(sessionId));
sseEmitter.onTimeout(timeoutCallBack(sessionId));
RedisService.increment(ONLINE_SESSION_COUNT);
sseEmitterMap.put(getCacheKey(sessionId),sseEmitter);
log.info("创建新的sse连接,当前会话:{}", sessionId);
return sseEmitter;
}
public void closeSseConnect(String userId) {
String cacheKey = getCacheKey(userId);
SseEmitter sseEmitter = sseEmitterMap.get(cacheKey);
if (sseEmitter != null) {
sseEmitter.complete();
removeUser(cacheKey);
}
}
/**
* 给指定用户发送信息 -- 单播
*/
public void sendMsg(String userId, String message) {
sendMessage(getCacheKey(userId),message);
}
/**
* 给指定用户发送信息
*/
public void sendMessage(String cacheKey, String message) {
if (sseEmitterMap.containsKey(cacheKey)) {
// if (SpringUtils.getBean(RedisCache.class).hasKey(cacheKey)) {
try {
// SseEmitter sseEmitter = SpringUtils.getBean(RedisCache.class).getCacheObject(cacheKey);
SseEmitter sseEmitter = sseEmitterMap.get(cacheKey);
sseEmitter.send(message,MediaType.TEXT_EVENT_STREAM);
log.info("用户[{}]推送成功:{}", cacheKey, message);
} catch (IOException e) {
log.error("用户[{}]推送异常:{}", cacheKey, e.getMessage());
removeUser(cacheKey);
}
}
}
/**
* 移除用户连接
*/
public void removeUser(String cacheKey) {
// SpringUtils.getBean(RedisCache.class).deleteObject(cacheKey);
sseEmitterMap.remove(cacheKey);
// 数量-1
RedisService.decrement(ONLINE_SESSION_COUNT);
log.info("移除用户:{}", cacheKey);
}
/**
* 获取当前连接数量
*/
public int getUserCount() {
Object o = RedisService.get(ONLINE_SESSION_COUNT);
Integer count = o instanceof Integer ? ((Integer) o) : 0;
return count;
}
private Runnable completionCallBack(String userId) {
return () -> {
log.info("结束连接:{}", userId);
removeUser(getCacheKey(userId));
};
}
private Runnable timeoutCallBack(String userId) {
return () -> {
log.info("连接超时:{}", userId);
removeUser(getCacheKey(userId));
};
}
private Consumer<Throwable> errorCallBack(String userId) {
return throwable -> {
log.info("连接异常:{}", userId);
removeUser(getCacheKey(userId));
};
}
/**
* 设置cache key
*
* @param configKey 参数键
* @return 缓存键key
*/
public String getCacheKey(String configKey){
return KEY_PREFIX + configKey;
}
}
5、最终使用,接入gatGPT实现问答
我这里使用了代理访问openAi的api,这里可以根据自己的情况调整。
public void sendPostSse(String url, Map<String, String> header, String issue, Long userId,Long userResourceId) throws Exception {
OutputStreamWriter out = null;
BufferedReader in = null;
StringBuilder result = new StringBuilder();
try {
URL realUrl = new URL(url);
HttpURLConnection conn = (HttpURLConnection) realUrl.openConnection();
// 发送POST请求必须设置如下两行
if (header != null && !header.isEmpty()) {
for (Map.Entry<String, String> entry : header.entrySet()) {
conn.setRequestProperty(entry.getKey(), entry.getValue());
}
}
conn.setDoOutput(true);
conn.setDoInput(true);
// POST方法
conn.setRequestMethod("POST");
// 设置通用的请求属性
conn.setRequestProperty("accept", "*/*");
conn.setRequestProperty("connection", "Keep-Alive");
conn.setRequestProperty("user-agent",
"Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1;SV1)");
conn.setRequestProperty("Content-Type", "application/json");
conn.connect();
// 获取URLConnection对象对应的输出流
out = new OutputStreamWriter(conn.getOutputStream(), "UTF-8");
// 发送请求参数
Map<String, Object> issueReq = new HashMap<>();
QChatLogEnt qChatLogEnt = QChatLogEnt.chatLogEnt;
Date now = new Date();
BooleanExpression expression = qChatLogEnt.userId.eq(userId);
//查询最近的聊天记录,放入请求中,实现上下文对应
ChatLogEnt ent = chatLogSvc.findOneByDsl(expression, Sort.by("id").descending());
String msg = issue;
if (null != ent) {
List<Object> history = new ArrayList<>();
Map item = new HashMap();
item.put("role", "user");
item.put("content", ent.getUserAsk());
history.add(item);
Map assistant = new HashMap();
assistant.put("role", "assistant");
assistant.put("content", ent.getUserAnswer());
history.add(assistant);
issueReq.put("history", history);
}
issueReq.put("issue", msg);
String uid = "" + userId;
issueReq.put("userId", uid);
out.write(ConvertUtil.o2s(issueReq));
// flush输出流的缓冲
out.flush();
// 定义BufferedReader输入流来读取URL的响应
in = new BufferedReader(
new InputStreamReader(conn.getInputStream(), "UTF-8"));
String line;
while ((line = in.readLine()) != null) {
sseEmitterServer.sendMsg(uid, line + "\n");
result.append(line);
}
sseEmitterServer.sendMsg(uid, "complete:end\n");
chatLogSvc.add(userId, issue, result.toString()); //将聊天内容插入数据库
} catch (IOException e) {
e.printStackTrace();
throw e;
}
//使用finally块来关闭输出流、输入流
finally {
try {
sseEmitterServer.closeSseConnect(""+userId);
if (out != null) {
out.close();
}
if (in != null) {
in.close();
}
} catch (IOException ex) {
ex.printStackTrace();
}
}
}
前端js接入
前端使用了event-source-polyfill库实现接入
安装前端库 npm install event-source-polyfill
发送如接收
saveMsg(tomsg) {
this.show = false
if (this.lists && this.lists.length && this.lists[this.lists.length-1].process) {
this.$toast('AI正在回答上一个问题');
return
}
this.lists.push({
id: this.userid,
face: this.userimg,
word: tomsg
});
const es = new EventSourcePolyfill('/sse/send?issue=' + encodeURIComponent(tomsg),{
headers: {
'Content-Type': 'application/json',
'Token':getToken()
},
data: {
"issue":tomsg}
});
this.es = es
let receive = {
id: 1529,
face: '你的头像',
process:true,
word:""
// touserdata.words[Math.floor(Math.random() * touserdata.words.length)]
// .info
}
this.process = true
this.lists.push(receive );
this.scrollToBottom()
es.onopen = (event) => {
console.log("连接成功", event);
};
es.onmessage = (event) =>{
if (event.data.indexOf("complete:") == 0 || event.data.indexOf("error:") == 0) {
this.closeSse(es)
receive.process = false
this.process = false
if (event.data.indexOf("登录失效") != -1) {
this.$toast('请重新登录');
return
}
if (event.data.indexOf("complete:") == 0) {
return
}
}
let msg = event.data
if (/^[A-Za-z]+$/.test(msg)) {
msg = " " + msg+" "
}
let word = receive.word + msg ;
//这里做了返回数据排版布局,不需要可以删除
if (/[;:。!?\.]\d{1,3}\./.test(word)) {
word = word.replace(/([;:。!?\.])(\d{1,3}\.)/,"$1<br/><br/> $2");
}
receive.word = word
this.scrollToBottom() //聊天窗滚动到底部
};
es.onerror = (error) => {
this.closeSse(es)
receive.word += "网络异常:" + error.status;
receive.process = false
this.process = false
console.log("错误", error);
};
}
python接口
本接口以chatGPT接口为例实现
import os
import openai
from django.http import StreamingHttpResponse
from .utils import jsonSuccess,jsonError
from django.views.decorators.csrf import csrf_exempt
import json
openai.api_key = os.environ.get("openai_api_key")
#openai.proxy = "http://代理地址:端口"
suffix = "The end of the story."
@csrf_exempt
def req(request):
postBody = request.body
json_result = json.loads(postBody)
# 提问
issue = json_result['issue']
userId = json_result['userId']
historyMsg = json_result.get("history",[])
history = ""
if len(historyMsg) > 1:
history = "Human:" + historyMsg[0].content
history += historyMsg[1].content
# 访问OpenAI接口
issue = history + "\nHuman:" + issue
try:
def event_stream():
response = openai.Completion.create(
# model='gpt-3.5-turbo',
#engine='curie', #gpt4
engine='text-davinci-003', #gpt4
prompt=issue,
temperature=0.9, # 随机性,0-0.9 越大差异越大
max_tokens=2000, # 返回最大字数
top_p=0.7, # 返回最大字数的概率
frequency_penalty=0.0, # 控制字符的重复度
presence_penalty=0.6, # 控制主题的重复度
stream=True,
# suffix="=====结束!"
user=userId
)
for chunk in response:
text = chunk["choices"][0]["text"]
yield text + '\n\n'
return StreamingHttpResponse(event_stream(), content_type='text/event-stream')
except Exception as e:
print(e)
return jsonError(e)
def reqChat(request):
forwarded_addresses = request.META.get('HTTP_X_FORWARDED_FOR')
if forwarded_addresses:
client_addr = forwarded_addresses.split(',')[0]
else:
client_addr = request.META.get('REMOTE_ADDR')
print('ip:' + client_addr)
postBody = request.body
json_result = json.loads(postBody)
# 提问
issue = json_result['issue']
print('issue:' + issue)
userId = json_result['userId']
historyMsg = json_result.get("history",[])
import os
import openai
from django.http import StreamingHttpResponse
from .utils import jsonSuccess,jsonError
from django.views.decorators.csrf import csrf_exempt
import json
openai.api_key = os.environ.get("openai_api_key")
openai.proxy = "http://172.16.247.94:13060"
suffix = "The end of the story."
@csrf_exempt
def req(request):
postBody = request.body
json_result = json.loads(postBody)
# 提问
issue = json_result['issue']
userId = json_result['userId']
historyMsg = json_result.get("history",[])
history = ""
if len(historyMsg) > 1:
history = "Human:" + historyMsg[0].content
history += historyMsg[1].content
# 访问OpenAI接口
issue = history + "\nHuman:" + issue
try:
def event_stream():
response = openai.Completion.create(
# model='gpt-3.5-turbo',
#engine='curie', #gpt4
engine='text-davinci-003', #gpt4
prompt=issue,
temperature=0.9, # 随机性,0-0.9 越大差异越大
max_tokens=2000, # 返回最大字数
top_p=0.7, # 返回最大字数的概率
frequency_penalty=0.0, # 控制字符的重复度
presence_penalty=0.6, # 控制主题的重复度
stream=True,
# suffix="=====结束!"
user=userId
)
for chunk in response:
text = chunk["choices"][0]["text"]
yield text + '\n\n'
return StreamingHttpResponse(event_stream(), content_type='text/event-stream')
except Exception as e:
print(e)
return jsonError(e)
def reqChat(request):
forwarded_addresses = request.META.get('HTTP_X_FORWARDED_FOR')
if forwarded_addresses:
client_addr = forwarded_addresses.split(',')[0]
else:
client_addr = request.META.get('REMOTE_ADDR')
print('ip:' + client_addr)
postBody = request.body
json_result = json.loads(postBody)
# 提问
issue = json_result['issue']
print('issue:' + issue)
userId = json_result['userId']
historyMsg = json_result.get("history",[])
historyMsg.append({
"role": "user", "content": issue})
# 访问OpenAI接口
try:
def event_stream():
response = openai.ChatCompletion.create(
#["gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314", "gpt-3.5-turbo", "gpt-3.5-turbo-0301"]:
#model='gpt-4-0314',
model='gpt-3.5-turbo',
# engine='text-davinci-003', #gpt4
messages=historyMsg,
temperature=0.9, # 随机性,0-0.9 越大差异越大
max_tokens=2000, # 返回最大字数
top_p=0.7, # 返回最大字数的概率
frequency_penalty=0.0, # 控制字符的重复度
presence_penalty=0.6, # 控制主题的重复度
request_timeout=60,
stream=True,
# suffix="=====结束!"
user=userId
)
for chunk in response:
chunk_message = chunk['choices'][0]['delta']
if hasattr(chunk_message,'content'):
yield chunk_message['content'] + "\n"
return StreamingHttpResponse(event_stream(), content_type='text/event-stream')
except openai.error.Timeout as e:
print("Error: " + str(e))
except Exception as e:
print("Error: " + str(e))
def createimg(request):
postBody = request.body
json_result = json.loads(postBody)
# 提问
issue = json_result['issue']
userId = json_result['userId']
# 访问OpenAI接口
try:
image = openai.Image.create(
prompt=issue,
n=1,
size="512x512",
response_format="url"
)
# img_b64 = base64.b64encode(image.data)
# img_res = img_b64.decode('utf-8')
return jsonSuccess( image['data'][0]['url'])
except Exception as e:
return jsonError(e)
小结
由于完整的代码穿插的业务流程,比如用户权限、使用次数限制、ui等等,这里就不方便提供,仅供学习参考!