From b2848b8ccde2386ca0fc7247e424b496f9b15024 Mon Sep 17 00:00:00 2001 From: liu Date: Thu, 26 Oct 2023 14:26:56 +0800 Subject: [PATCH] =?UTF-8?q?=E9=97=AE=E8=AF=8Asocket=E6=B5=81=E7=A8=8B?= =?UTF-8?q?=E6=8F=90=E4=BA=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../supervision/controller/AskController.java | 4 +- .../com/supervision/paddlespeech/TtsUtil.java | 8 +-- .../supervision/pojo/vo/ReplyVoiceResVO.java | 11 ---- .../com/supervision/service/AskService.java | 7 +- .../service/impl/AskServiceImpl.java | 66 +++++++++++++++---- .../service/impl/RasaServiceImpl.java | 4 +- .../websocket/dto/SocketMessageDTO.java | 6 +- .../handler/AskWebSocketHandler.java | 4 +- 8 files changed, 67 insertions(+), 43 deletions(-) delete mode 100644 virtual-patient-web/src/main/java/com/supervision/pojo/vo/ReplyVoiceResVO.java diff --git a/virtual-patient-web/src/main/java/com/supervision/controller/AskController.java b/virtual-patient-web/src/main/java/com/supervision/controller/AskController.java index f07ccff1..56ff0f1c 100644 --- a/virtual-patient-web/src/main/java/com/supervision/controller/AskController.java +++ b/virtual-patient-web/src/main/java/com/supervision/controller/AskController.java @@ -1,11 +1,9 @@ package com.supervision.controller; -import com.supervision.pojo.vo.ReplyVoiceResVO; import com.supervision.service.AskService; import com.supervision.websocket.cache.WebSocketUserCache; import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; -import lombok.RequiredArgsConstructor; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; @@ -39,7 +37,7 @@ public class AskController { @ApiOperation("回复语音及文字消息") @GetMapping("replyVoice") - public ReplyVoiceResVO replyVoice() { + public String replyVoice() { return askService.replyVoice(); } diff --git a/virtual-patient-web/src/main/java/com/supervision/paddlespeech/TtsUtil.java b/virtual-patient-web/src/main/java/com/supervision/paddlespeech/TtsUtil.java index 9d307f5f..e43c6ab8 100644 --- a/virtual-patient-web/src/main/java/com/supervision/paddlespeech/TtsUtil.java +++ b/virtual-patient-web/src/main/java/com/supervision/paddlespeech/TtsUtil.java @@ -9,7 +9,6 @@ import com.supervision.exception.BusinessException; import com.supervision.paddlespeech.dto.req.TtsReqDTO; import com.supervision.paddlespeech.dto.res.PaddleSpeechResDTO; import com.supervision.paddlespeech.dto.res.TtsResultDTO; -import com.supervision.pojo.vo.ReplyVoiceResVO; import com.supervision.util.SpringBeanUtil; import org.springframework.core.env.Environment; @@ -19,7 +18,7 @@ public class TtsUtil { private static final ObjectMapper objectMapper = new ObjectMapper(); - public static ReplyVoiceResVO ttsTransform(String str) { + public static String ttsTransform(String str) { // 构建 String post = HttpUtil.post(TTS_URL, JSONUtil.toJsonStr(new TtsReqDTO(str))); try { @@ -28,10 +27,7 @@ public class TtsUtil { if (!response.getSuccess() || ObjectUtil.isEmpty(response.getResult())) { throw new BusinessException("文字转换语音失败"); } - ReplyVoiceResVO resVO = new ReplyVoiceResVO(); - resVO.setVoice(response.getResult().getAudio()); - resVO.setText(str); - return resVO; + return response.getResult().getAudio(); } catch (Exception e) { throw new BusinessException("语音转换文字失败"); } diff --git a/virtual-patient-web/src/main/java/com/supervision/pojo/vo/ReplyVoiceResVO.java b/virtual-patient-web/src/main/java/com/supervision/pojo/vo/ReplyVoiceResVO.java deleted file mode 100644 index 3459e2bf..00000000 --- a/virtual-patient-web/src/main/java/com/supervision/pojo/vo/ReplyVoiceResVO.java +++ /dev/null @@ -1,11 +0,0 @@ -package com.supervision.pojo.vo; - -import lombok.Data; - -@Data -public class ReplyVoiceResVO { - - private String voice; - - private String text; -} diff --git a/virtual-patient-web/src/main/java/com/supervision/service/AskService.java b/virtual-patient-web/src/main/java/com/supervision/service/AskService.java index c84f30ac..68653521 100644 --- a/virtual-patient-web/src/main/java/com/supervision/service/AskService.java +++ b/virtual-patient-web/src/main/java/com/supervision/service/AskService.java @@ -1,9 +1,6 @@ package com.supervision.service; -import com.fasterxml.jackson.core.JsonProcessingException; -import com.supervision.pojo.vo.ReplyVoiceResVO; import com.supervision.websocket.dto.SocketMessageDTO; -import org.springframework.web.bind.annotation.RequestParam; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; @@ -11,11 +8,11 @@ import java.util.List; public interface AskService { - void handlerMessageBySocket(SocketMessageDTO socketMessageDTO); + void handlerMessageBySocket(SocketMessageDTO socketMessageDTO) throws IOException; String receiveVoiceFile(MultipartFile file) throws IOException; - ReplyVoiceResVO replyVoice(); + String replyVoice(); List conversation(String question, String sessionId); diff --git a/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java b/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java index 49b9a328..6488a518 100644 --- a/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java +++ b/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java @@ -1,25 +1,25 @@ package com.supervision.service.impl; import cn.hutool.core.util.StrUtil; -import cn.hutool.http.HttpUtil; import cn.hutool.json.JSONUtil; import com.supervision.exception.BusinessException; +import com.supervision.model.User; import com.supervision.paddlespeech.AsrUtil; import com.supervision.paddlespeech.TtsUtil; import com.supervision.rasa.RasaUtil; -import com.supervision.rasa.dto.RasaReqDTO; -import com.supervision.rasa.dto.RasaResDTO; -import com.supervision.pojo.vo.ReplyVoiceResVO; import com.supervision.service.AskService; +import com.supervision.util.UserUtil; +import com.supervision.websocket.cache.WebSocketUserCache; +import com.supervision.websocket.dto.ActionDTO; import com.supervision.websocket.dto.SocketMessageDTO; import lombok.RequiredArgsConstructor; -import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.socket.TextMessage; +import org.springframework.web.socket.WebSocketSession; import java.io.IOException; import java.util.List; -import java.util.stream.Collectors; @Service @RequiredArgsConstructor @@ -27,7 +27,7 @@ public class AskServiceImpl implements AskService { @Override - public void handlerMessageBySocket(SocketMessageDTO socketMessageDTO) { + public void handlerMessageBySocket(SocketMessageDTO socketMessageDTO) throws IOException { // 首先获取消息的类型 String text; if (0 == socketMessageDTO.getMessageType()) { @@ -41,8 +41,52 @@ public class AskServiceImpl implements AskService { throw new BusinessException("语音消息不能为空"); } // 进行rasa对话 - - + List rasaResultList = RasaUtil.talkRasa(text, socketMessageDTO.getSocketId()); + WebSocketSession session = WebSocketUserCache.getSession(socketMessageDTO.getSocketId()); + for (String rasaResult : rasaResultList) { + if (StrUtil.isNotBlank(rasaResult)){ + // 这里校验,rasa回复的结果是不是action + // 这里设置的模板,对于action的动作全部是用---进行标记,详情看生成rasa的yml的代码:RasaServiceImpl.generateDomain + // ---ancillary---xxx + // ---tool---xxx + if (rasaResult.startsWith("---")){ + // ["","ancillary","xxx"] + List actionList = StrUtil.split(rasaResult, "---"); + if (actionList.size() > 2){ + ActionDTO actionDTO = new ActionDTO(); + actionDTO.setActionType(actionList.get(1)); + actionDTO.setActionId(actionList.get(2)); + // 在这里给socket回复,设置为动作 + SocketMessageDTO res = new SocketMessageDTO(); + res.setSocketId(socketMessageDTO.getSocketId()); + res.setUserId(UserUtil.getUser().getId()); + res.setAction(actionDTO); + res.setType(2); + session.sendMessage(new TextMessage(JSONUtil.toJsonStr(res))); + return; + } + }else { + // 走到这里,说明是文字,这个时候文字转语音 + String replyVoiceResVO = TtsUtil.ttsTransform(rasaResult); + // 在这里给socket回复 + SocketMessageDTO res = new SocketMessageDTO(); + res.setSocketId(socketMessageDTO.getSocketId()); + res.setUserId(UserUtil.getUser().getId()); + res.setTextMessage(rasaResult); + res.setVoiceMessage(replyVoiceResVO); + res.setType(1); + session.sendMessage(new TextMessage(JSONUtil.toJsonStr(res))); + return; + } + } + } + // 兜底,如果走到了这里,就直接返回未识别 + SocketMessageDTO res = new SocketMessageDTO(); + res.setSocketId(socketMessageDTO.getSocketId()); + res.setUserId(UserUtil.getUser().getId()); + res.setTextMessage("医生,我们有听懂您说的什么"); + res.setType(1); + session.sendMessage(new TextMessage(JSONUtil.toJsonStr(res))); } @Override @@ -55,7 +99,7 @@ public class AskServiceImpl implements AskService { @Override - public ReplyVoiceResVO replyVoice() { + public String replyVoice() { String text = "测试:这是文字转语音的测试,测试是否OK"; return TtsUtil.ttsTransform(text); } @@ -63,6 +107,6 @@ public class AskServiceImpl implements AskService { @Override public List conversation(String question, String sessionId) { - return RasaUtil.talkRasa(question,sessionId); + return RasaUtil.talkRasa(question, sessionId); } } diff --git a/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java b/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java index 423a47c8..def758b4 100644 --- a/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java +++ b/virtual-patient-web/src/main/java/com/supervision/service/impl/RasaServiceImpl.java @@ -168,12 +168,10 @@ public class RasaServiceImpl implements RasaService { itemCodeIdMap.put(itemIntent, ancillary); } - // 生成后生成yml对象 -// createYmlFile(nluYml, "nlu.yml"); - // 加载模板配置 NluYmlTemplate nluYmlTemplate = new NluYmlTemplate(); nluYmlTemplate.setNlu(nluList); + // 生成后生成yml文件 createYmlFile(NluYmlTemplate.class, "nlu.ftl", nluYmlTemplate, "nlu.yml", ymalFileMap); } diff --git a/virtual-patient-web/src/main/java/com/supervision/websocket/dto/SocketMessageDTO.java b/virtual-patient-web/src/main/java/com/supervision/websocket/dto/SocketMessageDTO.java index 16c1fd74..4146d79e 100644 --- a/virtual-patient-web/src/main/java/com/supervision/websocket/dto/SocketMessageDTO.java +++ b/virtual-patient-web/src/main/java/com/supervision/websocket/dto/SocketMessageDTO.java @@ -24,14 +24,14 @@ public class SocketMessageDTO { private ActionDTO action; /** - * 后端返回给前端时使用,表示该是消息还是action动作,0消息,1动作 + * 后端返回给前端时使用,表示该是消息还是action动作,1消息,2动作 */ - @ApiModelProperty("后端返回给前端时使用,表示该是消息还是action动作,0消息,1动作") + @ApiModelProperty("后端返回给前端时使用,表示该是消息还是action动作,1消息,2动作") private Integer type; /** * 表示是消息还是action动作,0语音,1文字 */ - @ApiModelProperty("前端到后端使用,表示是语音还是文字,0语音,1文字") + @ApiModelProperty("前端到后端使用,表示是语音还是文字,1语音,2文字") private Integer messageType; } diff --git a/virtual-patient-web/src/main/java/com/supervision/websocket/handler/AskWebSocketHandler.java b/virtual-patient-web/src/main/java/com/supervision/websocket/handler/AskWebSocketHandler.java index 32f93691..f05b4729 100644 --- a/virtual-patient-web/src/main/java/com/supervision/websocket/handler/AskWebSocketHandler.java +++ b/virtual-patient-web/src/main/java/com/supervision/websocket/handler/AskWebSocketHandler.java @@ -12,6 +12,8 @@ import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler; +import java.io.IOException; + @Slf4j public class AskWebSocketHandler extends TextWebSocketHandler { @@ -28,7 +30,7 @@ public class AskWebSocketHandler extends TextWebSocketHandler { } @Override - protected void handleTextMessage(WebSocketSession session, TextMessage message) { + protected void handleTextMessage(WebSocketSession session, TextMessage message) throws IOException { // 处理接收到的消息 log.info("收到消息:{}", message.toString()); // 这里反序列化消息,将消息形成固定的格式