diff --git a/virtual-patient-common/pom.xml b/virtual-patient-common/pom.xml index 8727129f..419cb110 100644 --- a/virtual-patient-common/pom.xml +++ b/virtual-patient-common/pom.xml @@ -42,6 +42,11 @@ spring-ai-ollama-spring-boot-starter + + org.springframework.cloud + spring-cloud-starter-bootstrap + + org.springframework.boot diff --git a/virtual-patient-common/src/main/java/com/supervision/util/AiChatUtil.java b/virtual-patient-common/src/main/java/com/supervision/util/AiChatUtil.java index f578be45..5e99feaf 100644 --- a/virtual-patient-common/src/main/java/com/supervision/util/AiChatUtil.java +++ b/virtual-patient-common/src/main/java/com/supervision/util/AiChatUtil.java @@ -17,9 +17,9 @@ import java.util.concurrent.*; @Slf4j public class AiChatUtil { - public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(5, 0, "chat", new ThreadPoolExecutor.CallerRunsPolicy()); + private static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(5, 0, "chat", new ThreadPoolExecutor.CallerRunsPolicy()); - public static final OllamaChatClient chatClient = SpringBeanUtil.getBean(OllamaChatClient.class); + private static final OllamaChatClient chatClient = SpringBeanUtil.getBean(OllamaChatClient.class); /** * 单轮对话 @@ -27,7 +27,7 @@ public class AiChatUtil { * @param chat 对话的内容 * @return jsonObject */ - public Optional chat(String chat) { + public static Optional chat(String chat) { Prompt prompt = new Prompt(List.of(new UserMessage(chat))); Future submit = chatExecutor.submit(new ChatTask(chatClient, prompt)); try { @@ -44,7 +44,7 @@ public class AiChatUtil { * @param messageList 消息列表 * @return jsonObject */ - public Optional chat(List messageList) { + public static Optional chat(List messageList) { Prompt prompt = new Prompt(messageList); Future submit = chatExecutor.submit(new ChatTask(chatClient, prompt)); try { @@ -63,7 +63,7 @@ public class AiChatUtil { * @param 需要序列化的对象的泛型 * @return 对应对象类型, 不支持列表类型 */ - public Optional chat(List messageList, Class clazz) { + public static Optional chat(List messageList, Class clazz) { Prompt prompt = new Prompt(messageList); Future submit = chatExecutor.submit(new ChatTask(chatClient, prompt)); try { @@ -83,7 +83,7 @@ public class AiChatUtil { * @param 需要序列化的对象的泛型 * @return 对应对象类型, 不支持列表类型 */ - public Optional chat(String chat, Class clazz) { + public static Optional chat(String chat, Class clazz) { Prompt prompt = new Prompt(List.of(new UserMessage(chat))); Future submit = chatExecutor.submit(new ChatTask(chatClient, prompt)); try { diff --git a/virtual-patient-graph/pom.xml b/virtual-patient-graph/pom.xml index c5ec6ced..b3b9038a 100644 --- a/virtual-patient-graph/pom.xml +++ b/virtual-patient-graph/pom.xml @@ -32,11 +32,6 @@ - - org.springframework.cloud - spring-cloud-starter-bootstrap - - org.springframework.cloud diff --git a/virtual-patient-model/src/main/java/com/supervision/model/MedicalRec.java b/virtual-patient-model/src/main/java/com/supervision/model/MedicalRec.java index 9e87461a..5ddc1b29 100644 --- a/virtual-patient-model/src/main/java/com/supervision/model/MedicalRec.java +++ b/virtual-patient-model/src/main/java/com/supervision/model/MedicalRec.java @@ -128,6 +128,9 @@ public class MedicalRec extends Model implements Serializable { @Schema(description = "全面检查") private String fullCheck; + @Schema(description = "提交给大模型的病历") + private String medicalRecordAi; + /** * 创建人ID */ diff --git a/virtual-patient-web/src/main/java/com/supervision/controller/AskController.java b/virtual-patient-web/src/main/java/com/supervision/controller/AskController.java index dba3edc9..b7647730 100644 --- a/virtual-patient-web/src/main/java/com/supervision/controller/AskController.java +++ b/virtual-patient-web/src/main/java/com/supervision/controller/AskController.java @@ -2,6 +2,7 @@ package com.supervision.controller; import com.supervision.pojo.vo.TalkResultResVO; import com.supervision.pojo.vo.TalkVideoReqVO; +import com.supervision.pojo.vo.TalkVideoTtsResultResVO; import com.supervision.service.AskService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; @@ -32,7 +33,11 @@ public class AskController { return askService.talkByVideo(talkReqVO); } - + @Operation(summary = "使用无声视频+语音转文字的形式来做") + @PostMapping("talkByVideoAndTts") + public TalkVideoTtsResultResVO talkByVideoAndTts(@RequestBody TalkVideoReqVO talkReqVO) { + return askService.talkByVideoAndTts(talkReqVO); + } } diff --git a/virtual-patient-web/src/main/java/com/supervision/pojo/ai/AiTalkAnswerDTO.java b/virtual-patient-web/src/main/java/com/supervision/pojo/ai/AiTalkAnswerDTO.java new file mode 100644 index 00000000..52134580 --- /dev/null +++ b/virtual-patient-web/src/main/java/com/supervision/pojo/ai/AiTalkAnswerDTO.java @@ -0,0 +1,9 @@ +package com.supervision.pojo.ai; + +import lombok.Data; + +@Data +public class AiTalkAnswerDTO { + + private String answer; +} diff --git a/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java b/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java new file mode 100644 index 00000000..f13896ce --- /dev/null +++ b/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java @@ -0,0 +1,27 @@ +package com.supervision.pojo.vo; + + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +/** + * 无声视频+tts相应 + */ +@Data +public class TalkVideoTtsResultResVO { + + /** + * 后端返回给前端时使用,表示该是语音回复还是action动作,1语音回复,2体格检查 3辅助检查 + * 本方案固定为1 + */ + @Schema(description = "后端返回给前端时使用,表示该是语音回复还是action动作,1语音回复,2体格检查 3辅助检查") + private Integer type = 1; + + @Schema(description = "音频base64位编码") + private String voiceBase64; + + @Schema(description = "回复的消息内容,用于调试视频资料内容") + private String answerMessage; + + +} diff --git a/virtual-patient-web/src/main/java/com/supervision/service/AiService.java b/virtual-patient-web/src/main/java/com/supervision/service/AiService.java new file mode 100644 index 00000000..826b4591 --- /dev/null +++ b/virtual-patient-web/src/main/java/com/supervision/service/AiService.java @@ -0,0 +1,6 @@ +package com.supervision.service; + +public interface AiService { + + String talk(String question, String medicalRecord); +} diff --git a/virtual-patient-web/src/main/java/com/supervision/service/AskService.java b/virtual-patient-web/src/main/java/com/supervision/service/AskService.java index a7f5db14..87871080 100644 --- a/virtual-patient-web/src/main/java/com/supervision/service/AskService.java +++ b/virtual-patient-web/src/main/java/com/supervision/service/AskService.java @@ -2,6 +2,7 @@ package com.supervision.service; import com.supervision.pojo.vo.TalkResultResVO; import com.supervision.pojo.vo.TalkVideoReqVO; +import com.supervision.pojo.vo.TalkVideoTtsResultResVO; import org.springframework.web.multipart.MultipartFile; import java.io.IOException; @@ -12,6 +13,7 @@ public interface AskService { TalkResultResVO talkByVideo(TalkVideoReqVO talkReqVO) throws IOException; + TalkVideoTtsResultResVO talkByVideoAndTts(TalkVideoReqVO talkReqVO); } diff --git a/virtual-patient-web/src/main/java/com/supervision/service/impl/AiServiceImpl.java b/virtual-patient-web/src/main/java/com/supervision/service/impl/AiServiceImpl.java new file mode 100644 index 00000000..38943b2c --- /dev/null +++ b/virtual-patient-web/src/main/java/com/supervision/service/impl/AiServiceImpl.java @@ -0,0 +1,61 @@ +package com.supervision.service.impl; + +import cn.hutool.core.util.StrUtil; +import com.supervision.pojo.ai.AiTalkAnswerDTO; +import com.supervision.service.AiService; +import com.supervision.util.AiChatUtil; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.stereotype.Service; + +import java.util.*; + +@Slf4j +@Service +public class AiServiceImpl implements AiService { + + public static final String userPromptTemplate = """ + 预设病例的基本情况如下: + --- + {medicalRecord} + --- + 回复格式为json:{"answer":"扮演病人并根据病历回复的内容"} + """; + + private static final String systemPrompt = """ + 我们现在来进行一个角色扮演连续对话场景。 + 你现在扮演一个病人的角色,我来扮演医生,我们进行模拟问诊。 + 我会给你设定病人的病历,你根据给定病历内容进行回复。 + 如果病历中信息不能回问题,则回复'我不知道/没注意/不清楚'",所有问话都必须给回复! + 如果医生说'你好'等礼貌用语,你就回复'你好,医生'!!! + """; + + /** + * 根据病历进行回复 + * + * @param question 医生的问题 + * @return 大模型回答的问题 + */ + public String talk(String question, String medicalRecord) { + Map paramMap = new HashMap<>(); + paramMap.put("medicalRecord", medicalRecord); + List messageHistoryList = new ArrayList<>(); + messageHistoryList.add(new SystemMessage(systemPrompt)); + messageHistoryList.add(new UserMessage(StrUtil.format(userPromptTemplate, paramMap))); + messageHistoryList.add(new AssistantMessage("好的,已了解我要扮演病人的病历。已准备好对话了。")); + messageHistoryList.add(new UserMessage(question)); + Optional chat = AiChatUtil.chat(messageHistoryList, AiTalkAnswerDTO.class); + if (chat.isPresent()) { + AiTalkAnswerDTO aiTalkAnswerDTO = chat.get(); + if (StrUtil.isNotBlank(aiTalkAnswerDTO.getAnswer())) { + return aiTalkAnswerDTO.getAnswer(); + } + } + return "医生,我没有听懂你说的是什么"; + } + + +} diff --git a/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java b/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java index 06fda067..ec346212 100644 --- a/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java +++ b/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java @@ -18,9 +18,11 @@ import com.supervision.pojo.qaSimilarity.QaSimilarityQuestion; import com.supervision.pojo.qaSimilarity.QaSimilarityQuestionAnswer; import com.supervision.pojo.vo.TalkResultResVO; import com.supervision.pojo.vo.TalkVideoReqVO; +import com.supervision.pojo.vo.TalkVideoTtsResultResVO; import com.supervision.service.*; import com.supervision.util.AsrUtil; import com.supervision.util.MinioUtil; +import com.supervision.util.TtsUtil; import com.supervision.util.UserUtil; import com.supervision.vo.rasa.RasaTalkVo; import lombok.RequiredArgsConstructor; @@ -29,6 +31,7 @@ import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; import java.io.InputStream; +import java.util.Comparator; import java.util.List; import java.util.Optional; @@ -55,6 +58,10 @@ public class AskServiceImpl implements AskService { private final CommonDicService commonDicService; + private final AiService aiService; + + private final MedicalRecService medicalRecService; + @Override public String receiveVoiceFile(MultipartFile file) { if (file.getSize() <= 0) { @@ -221,6 +228,20 @@ public class AskServiceImpl implements AskService { } } + public QaSimilarityQuestionAnswer talkQaSimilarityWithScore(String question, String sessionId) { + log.info("开始调用talkQaSimilarity,问题:{}", question); + try { + GlobalResult> result = askQaSimilarityFeignClient.askQuestionSimilarityAnswer(new QaSimilarityQuestion(question)); + // 排序,降序,取最高的 + result.getData().sort(Comparator.comparing(QaSimilarityQuestionAnswer::getMatchScore).reversed()); + log.info("调用talkQaSimilarity结束,问题:{},返回结果:{}", question, JSONUtil.toJsonStr(result)); + return CollUtil.getFirst(result.getData()); + } catch (Exception e) { + log.error("调用talkQaSimilarity error ", e); + return null; + } + } + private AskPatientAnswer getMedicalRecErrorAnswer(String medicalRecId) { //Optional.ofNullable(medicalRecErrorAnswer).orElseGet(() ->new AskPatientAnswer()).getAnswer() Assert.notEmpty(medicalRecId, "病历id不能为空"); @@ -262,4 +283,59 @@ public class AskServiceImpl implements AskService { } } + + /** + * 使用无声视频+语音转文字的形式来做 + * + * @param talkReqVO 请求 + * @return 返回结果 + */ + @Override + public TalkVideoTtsResultResVO talkByVideoAndTts(TalkVideoReqVO talkReqVO) { + // 根据processId找到对应的病人 + Process process = Optional.ofNullable(processService.getById(talkReqVO.getProcessId())).orElseThrow(() -> new BusinessException("未找到诊疗进程")); + MedicalRec medicalRec = medicalRecService.getById(process.getMedicalRecId()); + QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = talkQaSimilarityWithScore(talkReqVO.getText(), UserUtil.getUser().getId()); + TalkVideoTtsResultResVO talkVideoTtsResultResVO = new TalkVideoTtsResultResVO(); + // 如果匹配度没有匹配到任何数据,则走大模型 + if (ObjectUtil.isEmpty(qaSimilarityQuestionAnswer)) { + String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi()); + talkVideoTtsResultResVO.setAnswerMessage(talk); + } else { + // 如果阈值过低,也走大模型 + if (qaSimilarityQuestionAnswer.getMatchScore() < 0.5) { + log.info("{}:匹配到的结果阈值过低,走大模型回答", qaSimilarityQuestionAnswer); + String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi()); + talkVideoTtsResultResVO.setAnswerMessage(talk); + } else { + // 如果查到的问题不在问题库中,走大模型回答 + AskTemplateQuestionLibrary library = askTemplateQuestionLibraryService.getById(qaSimilarityQuestionAnswer.getMatchQuestionCode()); + if (ObjectUtil.isEmpty(library)) { + log.info("{}:未从问题库中找到,走大模型回答", qaSimilarityQuestionAnswer); + String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi()); + talkVideoTtsResultResVO.setAnswerMessage(talk); + } else { + // 根据问题找这个病历配置的答案 + AskPatientAnswer askPatientAnswer = askPatientAnswerService.lambdaQuery().eq(AskPatientAnswer::getMedicalId, process.getMedicalRecId()) + .eq(AskPatientAnswer::getLibraryQuestionId, library.getId()).last("limit 1").one(); + // 如果找到了,就走病历配置的内容回答 + if (ObjectUtil.isNotEmpty(askPatientAnswer)) { + String resText = askPatientAnswer.getAnswer(); + log.info("{}:找到了病历配置的回答语句:{},回答内容:{},走病历回答", qaSimilarityQuestionAnswer.getMatchQuestionCode(), askPatientAnswer.getId(), resText); + talkVideoTtsResultResVO.setAnswerMessage(resText); + // 保存记录 + saveQaRecord(talkReqVO.getProcessId(), "patient", askPatientAnswer.getId(), talkReqVO.getText(), library, resText); + return talkVideoTtsResultResVO; + } else { + // 如果问题的答案没有配置,还是走大模型的回答 + log.info("{}:病历配置,从AskPatientAnswer中未找到回答结果,走大模型", qaSimilarityQuestionAnswer.getMatchQuestionCode()); + String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi()); + talkVideoTtsResultResVO.setAnswerMessage(talk); + } + } + } + } + talkVideoTtsResultResVO.setVoiceBase64(TtsUtil.ttsTransform(talkVideoTtsResultResVO.getAnswerMessage())); + return talkVideoTtsResultResVO; + } }