From 2e170833decb25a7250c4ecbb8ad7b853d41a2db Mon Sep 17 00:00:00 2001 From: xueqingkun Date: Tue, 4 Mar 2025 17:14:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AF=B9=E8=AF=9D=E5=88=97?= =?UTF-8?q?=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../supervision/contoller/ChatController.java | 14 ++-- .../model/dify/StreamResponse.java | 2 +- .../com/supervision/service/IChatService.java | 5 +- .../service/impl/ChatServiceImpl.java | 75 ++++++++++++++++--- 4 files changed, 76 insertions(+), 20 deletions(-) diff --git a/src/main/java/com/supervision/contoller/ChatController.java b/src/main/java/com/supervision/contoller/ChatController.java index 7a41377..56cb91e 100644 --- a/src/main/java/com/supervision/contoller/ChatController.java +++ b/src/main/java/com/supervision/contoller/ChatController.java @@ -1,11 +1,8 @@ package com.supervision.contoller; -import cn.hutool.core.io.resource.ResourceUtil; -import cn.hutool.core.util.StrUtil; import com.supervision.dto.R; import com.supervision.dto.robot.RobotTalkDTO; import com.supervision.model.RobotTalkReq; -import com.supervision.model.dify.StreamResponse; import com.supervision.service.DifyService; import com.supervision.service.IChatService; import jakarta.servlet.http.HttpServletResponse; @@ -18,8 +15,6 @@ import org.springframework.web.multipart.MultipartFile; import reactor.core.publisher.Flux; import java.io.IOException; -import java.time.Duration; -import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -44,8 +39,9 @@ public class ChatController { } @GetMapping("/talkList") - public List talkList(String sessionId) { - return new ArrayList(); + public R> talkList(@RequestParam("sessionId") String sessionId) { + List commonDialogVos = chatService.talkList(sessionId); + return R.ok(commonDialogVos); } @GetMapping("/getAudio") @@ -55,7 +51,7 @@ public class ChatController { } @GetMapping(value="/stream",produces = MediaType.TEXT_EVENT_STREAM_VALUE) - public Flux>> stream(@RequestParam("query") String query) { - return chatService.streamingMessage(query); + public Flux>> stream(@RequestParam("query") String query,@RequestParam("sessionId") String sessionId) { + return chatService.streamingMessage(query,sessionId); } } diff --git a/src/main/java/com/supervision/model/dify/StreamResponse.java b/src/main/java/com/supervision/model/dify/StreamResponse.java index a8e8726..b57ec8d 100644 --- a/src/main/java/com/supervision/model/dify/StreamResponse.java +++ b/src/main/java/com/supervision/model/dify/StreamResponse.java @@ -40,5 +40,5 @@ public class StreamResponse implements Serializable { /** * 会话 ID. */ - private String conversationId; + private String conversation_id; } \ No newline at end of file diff --git a/src/main/java/com/supervision/service/IChatService.java b/src/main/java/com/supervision/service/IChatService.java index 3e16e8e..4fc22d6 100644 --- a/src/main/java/com/supervision/service/IChatService.java +++ b/src/main/java/com/supervision/service/IChatService.java @@ -8,14 +8,17 @@ import org.springframework.web.multipart.MultipartFile; import reactor.core.publisher.Flux; import java.io.IOException; +import java.util.List; import java.util.Map; public interface IChatService { - Flux>> streamingMessage(String query); + Flux>> streamingMessage(String query,String sessionId); String asr(MultipartFile file) throws IOException; RobotTalkDTO talk(MultipartFile file, RobotTalkReq robotTalkReq); void getAudio(HttpServletResponse response, String audioId) throws IOException; + + List talkList(String sessionId); } diff --git a/src/main/java/com/supervision/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/service/impl/ChatServiceImpl.java index 3c1e650..6e1ae66 100644 --- a/src/main/java/com/supervision/service/impl/ChatServiceImpl.java +++ b/src/main/java/com/supervision/service/impl/ChatServiceImpl.java @@ -1,6 +1,10 @@ package com.supervision.service.impl; import cn.hutool.core.codec.Base64; +import cn.hutool.core.collection.CollUtil; +import cn.hutool.core.lang.Assert; +import cn.hutool.core.util.StrUtil; +import cn.hutool.crypto.digest.MD5; import com.alibaba.fastjson.JSON; import com.supervision.dto.dify.ChatResDTO; import com.supervision.dto.paddlespeech.res.TtsResultDTO; @@ -33,12 +37,14 @@ import org.springframework.web.reactive.function.client.WebClient; import reactor.core.publisher.Flux; import java.io.IOException; +import java.util.*; import java.util.HashMap; import java.util.Map; import java.util.Objects; import java.util.UUID; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; @Slf4j @Service @@ -51,18 +57,31 @@ public class ChatServiceImpl implements IChatService { private String difyAppAuth; private final WebClient webClient; private final DifyApiUtil difyApiUtil; - Map voiceCache = new HashMap<>(); + // 音频缓存 key: audioId, value: audioBase64 + private final Map audioCache = new HashMap<>(); + + // 对话缓存 key: sessionId, value: 对话列表 + private final Map> dialogCache = new HashMap<>(); + @Override - public Flux>> streamingMessage(String query) { + public Flux>> streamingMessage(String query,String sessionId) { DifyChatReqVO difyChatReqVO = new DifyChatReqVO(); difyChatReqVO.setUser("admin"); DIFYChatReqInputVO inputs = new DIFYChatReqInputVO(); difyChatReqVO.setQuery(query); + difyChatReqVO.setConversation_id(sessionId); difyChatReqVO.setInputs(inputs); + + RobotTalkDTO.RobotTalkDTOBuilder builder = RobotTalkDTO.builder(); + + + StringBuilder fullAskString = new StringBuilder(); + List audioList = new ArrayList<>(); StringBuilder sentence = new StringBuilder(); log.info("query:{}", query); + return webClient.post() .uri(difyUrl) .headers(httpHeaders -> { @@ -79,25 +98,36 @@ public class ChatServiceImpl implements IChatService { //遍历answer中的每一个字符,判断是否为标点符号,如果是,说明是句子的结尾,将标点符号前的文本拼接到sentence中,并打印,然后清空sentence,如果标点符号后还有文本,将文本拼接到sentence中 for (char ch : response.getAnswer().toCharArray()) { sentence.append(ch); + fullAskString.append(ch); if (ch == '。' || ch == '!' || ch == '?' || ch == ',' || ch == '、' || ch == '‘' || ch == '’' || ch == '“' || ch == '”') { // Check for punctuation marks log.info(sentence.toString()); TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(sentence.toString()); String voiceBaseId = UUID.randomUUID().toString(); - voiceCache.put(voiceBaseId, ttsResultDTO.getAudio()); + audioCache.put(voiceBaseId, ttsResultDTO.getAudio()); map.put("audioId", voiceBaseId); + audioList.add(voiceBaseId); sentence.setLength(0); return ServerSentEvent.builder(map).build(); } } } if (response.getEvent().equals("message_end")) { + map.put("sessionId", response.getConversation_id()); if (!sentence.isEmpty()) { log.info(sentence.toString()); TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(sentence.toString()); String voiceBaseId = UUID.randomUUID().toString(); - voiceCache.put(voiceBaseId, ttsResultDTO.getAudio()); + audioList.add(voiceBaseId); + audioCache.put(voiceBaseId, ttsResultDTO.getAudio()); map.put("audioId", voiceBaseId); } + String fullAnswer = audioList.stream().map(audioId -> map.get("audioId")).filter(Objects::nonNull).collect(Collectors.joining()); + String uuid = cn.hutool.core.lang.UUID.randomUUID().toString(); + audioCache.put(uuid,fullAnswer); + builder.answerInfo(AnswerInfo.builder().contentType(2).message(fullAskString.toString()).voiceBaseId(uuid).build()); + builder.sessionId(response.getConversation_id()); + builder.askInfo(AskInfo.builder().contentType(2).message(query).audioLength(100L).build()); + this.appendDialogCache(response.getConversation_id(), builder.build()); return ServerSentEvent.builder(map).build(); } return null; @@ -106,6 +136,9 @@ public class ChatServiceImpl implements IChatService { @Override public String asr(MultipartFile file) throws IOException { + String text = replaceTown(AsrUtil.asrTransformByBytes(file.getBytes())); + String md5 = MD5.create().digestHex(text); + audioCache.put(md5, Base64.encode(file.getBytes())); return replaceTown(AsrUtil.asrTransformByBytes(file.getBytes())); } @@ -129,14 +162,14 @@ public class ChatServiceImpl implements IChatService { stopWatch.stop(); log.info("response:{}", chatResDTO.getAnswer()); builder.askInfo(AskInfo.builder().contentType(2).message(inputs.getQuery()).audioLength(100L).askId(chatResDTO.getMessage_id()).build()); - voiceCache.put(chatResDTO.getMessage_id(), Base64.encode(bytes)); + audioCache.put(chatResDTO.getMessage_id(), Base64.encode(bytes)); stopWatch.start("tts"); TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(chatResDTO.getAnswer()); stopWatch.stop(); String voiceBaseId = UUID.randomUUID().toString(); builder.answerInfo(AnswerInfo.builder().contentType(2).message(chatResDTO.getAnswer()).voiceBaseId(voiceBaseId).voiceBase64(ttsResultDTO.getAudio()).build()); builder.sessionId(chatResDTO.getConversation_id()); - voiceCache.put(voiceBaseId, ttsResultDTO.getAudio()); + audioCache.put(voiceBaseId, ttsResultDTO.getAudio()); log.info("耗时:{}", stopWatch.prettyPrint()); } catch (IOException e) { throw new RuntimeException(e); @@ -149,7 +182,31 @@ public class ChatServiceImpl implements IChatService { public void getAudio(HttpServletResponse response, String audioId) throws IOException { log.info("audioId:{}", audioId); - Base64.decodeToStream(voiceCache.get(audioId), response.getOutputStream(), false); + if (StrUtil.isEmpty(audioId) && StrUtil.equals(audioId, "undefined")) { + return; + } + Base64.decodeToStream(audioCache.get(audioId), response.getOutputStream(), false); + } + + @Override + public List talkList(String sessionId) { + if (StrUtil.isNotEmpty(sessionId)) { + return getDialogCache(sessionId); + } + return CollUtil.newArrayList(); + } + + private void appendDialogCache(String sessionId, RobotTalkDTO dialog) { + Assert.notEmpty(sessionId, "sessionId不能为空"); + List dialogList = dialogCache.getOrDefault(sessionId,new ArrayList<>()); + dialogList.add(dialog); + dialogCache.put(sessionId, dialogList); + } + + + private List getDialogCache(String sessionId) { + Assert.notEmpty(sessionId, "sessionId不能为空"); + return dialogCache.getOrDefault(sessionId,new ArrayList<>()); } /** @@ -159,7 +216,7 @@ public class ChatServiceImpl implements IChatService { * @param text 输入字符串 * @return 更新后的字符串 */ - public static String replaceTown(String text) { + public String replaceTown(String text) { // 正则模式:匹配两个汉字紧跟“小镇” Pattern pattern = Pattern.compile("([\u4e00-\u9fa5]{2})小镇"); Matcher matcher = pattern.matcher(text); @@ -204,7 +261,7 @@ public class ChatServiceImpl implements IChatService { // 示例测试 public static void main(String[] args) { String sampleText = "欢迎来到孟河小镇,体验独特的小镇风情;另外,还有梦和小镇等待你探访。"; - System.out.println(replaceTown(sampleText)); + System.out.println(new ChatServiceImpl(null,null).replaceTown(sampleText)); } }