如题,第一次用websocket,做了个这玩意,只做了上下文的聊天,没做流式。
中间还有个低级报错但卡了好久,具体可以看【错误记录】websocket连接失败,但后端毫无反应,还有【错误记录】ruoyi-vue@Autowired注入自定义mapper时为null解决
,感兴趣可前往观看。
实际上我后端用的是ruoyi-vue,前端用的ruoyi-app,但不重要。因为功能就是基于websocket和文心一言千帆大模型的接口,完全可以独立出来。
每个新建的账号会送一张20元的代金券,期限一个月内。而聊天服务接口单价约1分/千token,总之用来练手肯定够用了。
文档中心-ERNIE-Bot-turbo
百度文心一言接入教程
若依插件-集成websocket实现简单通信
大致这样。
2023.10.13更新:昨天和朋友聊了一下,发现他的想法和我的不同——根本不用实体类去保存解析复杂的json,直接保存消息内容。有一说一,在这个小demo这里,确实可以更快更简单的实现,因为这个demo最耗时的就是看又臭又长的参数,然后写请求体和返回值的实体类,至少请求体实体类是可以不写的。
下面进入正题。
有三个角色,大模型 ←→ 后端 ←→ 前端。
大模型:接受后端发过来的消息,返回响应消息
后端:接受前端发过来的消息,封装发给大模型;接收大模型返回的消息,回给后端;发送的消息和返回的消息都要保存到数据库
前端:发送消息,接受后端返回的响应消息,实时回显在聊天页面。
显然,websocket用在前后端之间进行交互,后端类似一个中间人,前端是一个用户,大模型是ai服务。
1.1 注册到spring
@Configuration public class WebSocketConfig { @Bean public ServerEndpointExporter serverEndpointExporter() { return new ServerEndpointExporter(); } }1.2 实现一个WebSocket的服务(别看这么长,其实参考了若依插件-集成websocket实现简单通信,但没涉及信号量之类所以没什么用,除了onMessage外,其他如onOpen打印一条消息就行了,更多如WebSocketUsers可以去链接那下载)
@CrossOrigin @Component @ServerEndpoint("/websocket/message") public class WebSocketServer { private ChatRecordMapper chatRecordMapper = SpringUtils.getBean(ChatRecordMapper.class); /** * WebSocketServer 日志控制器 */ private static final Logger LOGGER = LoggerFactory.getLogger(WebSocketServer.class); /** * 默认最多允许同时在线人数100 */ public static int socketMaxOnlineCount = 100; private static Semaphore socketSemaphore = new Semaphore(socketMaxOnlineCount); /** * 连接建立成功调用的方法 */ @OnOpen public void onOpen(Session session) throws Exception { boolean semaphoreFlag = false; // 尝试获取信号量 semaphoreFlag = SemaphoreUtils.tryAcquire(socketSemaphore); if (!semaphoreFlag) { // 未获取到信号量 LOGGER.error("\n 当前在线人数超过限制数- {}", socketMaxOnlineCount); WebSocketUsers.sendMessageToUserByText(session, "当前在线人数超过限制数:" + socketMaxOnlineCount); session.close(); } else { // 添加用户 WebSocketUsers.put(session.getId(), session); LOGGER.info("\n 建立连接 - {}", session); LOGGER.info("\n 当前人数 - {}", WebSocketUsers.getUsers().size()); WebSocketUsers.sendMessageToUserByText(session, "连接成功"); } } /** * 连接关闭时处理 */ @OnClose public void onClose(Session session) { LOGGER.info("\n 关闭连接 - {}", session); // 移除用户 WebSocketUsers.remove(session.getId()); // 获取到信号量则需释放 SemaphoreUtils.release(socketSemaphore); } /** * 抛出异常时处理 */ @OnError public void onError(Session session, Throwable exception) throws Exception { if (session.isOpen()) { // 关闭连接 session.close(); } String sessionId = session.getId(); LOGGER.info("\n 连接异常 - {}", sessionId); LOGGER.info("\n 异常信息 - {}", exception); // 移出用户 WebSocketUsers.remove(sessionId); // 获取到信号量则需释放 SemaphoreUtils.release(socketSemaphore); } /** * 服务器接收到客户端消息时调用的方法 */ @OnMessage public void onMessage(String message, Session session) { // 首先,接收到一条消息 LOGGER.info("\n 收到消息 - {}", message); // 1. 调用大模型API,把上下文和这次问题传入,得到回复 BigModelService bigModelService = new BigModelService(); TurboResponse response = bigModelService.callModelAPI(session.getId(),message); if (response == null) { WebSocketUsers.sendMessageToUserByText(session, "抱歉,似乎出了点问题,请联系管理员"); return; } WebSocketUsers.sendMessageToUserByText(session, response.getResult()); } }
2.1 先写实体类,包括BaiduChatMessage(最基本的聊天消息)、ErnieBotTurboParam(ErnieBot-Turbo的请求参数,包括了List@Data
@SuperBuilder
@NoArgsConstructor
@AllArgsConstructor
public class BaiduChatMessage implements Serializable {
private String role;
private String content;
}
@Data
@SuperBuilder
public class ErnieBotTurboParam implements Serializable {
/**
* 聊天上下文信息。说明:
* (1)messages成员不能为空,1个成员表示单轮对话,多个成员表示多轮对话
* (2)最后一个message为当前请求的信息,前面的message为历史对话信息
* (3)必须为奇数个成员,成员中message的role必须依次为user、assistant
* (4)最后一个message的content长度(即此轮对话的问题)不能超过2000个字符;如果messages中content总长度大于2000字符,系统会依次遗忘最早的历史会话,直到content的总长度不超过2000个字符
*/
protected List
@Data
public class TurboResponse implements Serializable {
private String id;
private String object;
private Integer created;
private String sentence_id;
private Boolean is_end;
private Boolean is_truncated;
private String result;
private Boolean need_clear_history;
private Usage usage;
@Data
public static class Usage implements Serializable {
private Integer prompt_tokens;
private Integer completion_tokens;
private Integer total_tokens;
}
}
2.2 请求接口实现(注释很详细就不多说了)public class BigModelService {
private ChatRecordMapper chatRecordMapper = SpringUtils.getBean(ChatRecordMapper.class);
private static final Logger LOGGER = LoggerFactory.getLogger(BigModelService.class);
private static final OkHttpClient HTTP_CLIENT = new OkHttpClient().newBuilder().build();
public static final String API_KEY = "你的apikey";
public static final String SECRET_KEY = "你的secretkey";
static String getAccessToken() throws IOException {
MediaType mediaType = MediaType.parse("application/x-www-form-urlencoded");
RequestBody body = RequestBody.create(mediaType, "grant_type=client_credentials&client_id=" + API_KEY
+ "&client_secret=" + SECRET_KEY);
Request request = new Request.Builder()
.url("https://aip.baidubce.com/oauth/2.0/token")
.method("POST", body)
.addHeader("Content-Type", "application/x-www-form-urlencoded")
.build();
Response response = HTTP_CLIENT.newCall(request).execute();
// 解析返回的access_token
JSONObject jsonObject = JSONObject.parseObject(response.body().string());
return jsonObject.getString("access_token");
}
public TurboResponse callModelAPI(String sessionId, String message) {
// 1. 构建请求体
// 1.1 调用大模型API,要从数据库去查询上下文
ChatRecord cr = chatRecordMapper.selectChatRecordBySessionId(sessionId);
String records = cr == null ? "{}" : cr.getRecords();
// 1.2 把message加进请求体
// 1.2.1 解析上下文,获取聊天记录,把新的message封装加入到聊天记录中
ErnieBotTurboParam param = JSONObject.parseObject(records, ErnieBotTurboParam.class);
List
3.1 数据库里建个表
CREATE TABLE `chat_record` ( `record_id` varchar(20) NOT NULL COMMENT '记录id', `session_id` varchar(10) DEFAULT NULL COMMENT '所属用户', `records` json DEFAULT NULL COMMENT '聊天记录', `create_time` datetime DEFAULT NULL COMMENT '创建时间(判断过期)', PRIMARY KEY (`record_id`) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_ai_ci COMMENT='聊天记录表';3.2 对应实体类
@Data @NoArgsConstructor @AllArgsConstructor public class ChatRecord { private String recordId; private String sessionId; private String records; private LocalDateTime createTime; }3.3 再写个mapper就行了
@Mapper public interface ChatRecordMapper { @Insert("INSERT INTO chat_record (record_id, session_id, records, create_time) " + "VALUES (#{recordId}, #{sessionId}, #{records}, #{createTime})") void insertChatRecord(ChatRecord chatRecord); @Select("SELECT COUNT(*) FROM chat_record WHERE session_id = #{sessionId}") int selectRecordCountBySessionId(String sessionId); @Results({ // id是用来给@ResultMap注解引用的,到时候在xml中可以直接使用@ResultMap(value = "chatRecord") @Result(property = "recordId", column = "record_id"), @Result(property = "sessionId", column = "session_id"), @Result(property = "records", column = "records"), @Result(property = "createTime", column = "create_time") }) @Select("SELECT * FROM chat_record WHERE session_id = #{sessionId}") ChatRecord selectChatRecordBySessionId(String sessionId); @Update("UPDATE chat_record SET records = #{records} WHERE session_id = #{sessionId}") void updateChatRecord(ChatRecord chatRecord); }
4.1 聊天页面写一个(这里前端是uniapp,样式用到了些colorUI)
4.2 js里写一个websocket(见上4.1的connect()){{ message.content }} {{ message.time }}
以上就大功告成了,这玩意还有很多缺漏和细节没做,像现在还是根据会话id去做,没有匹配用户id,15min清除聊天记录,但前端那没清……不过能跑能动就行,本来就是一个小任务,也懒得继续花时间调整。
记录一下,有问题可以交流
上一篇:判断当前shell版本