diff --git a/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java b/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java index 41107161..ef6541ca 100644 --- a/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java +++ b/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java @@ -43,8 +43,9 @@ public class SimilarityUtil { return documents.stream().map(document -> { QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = new QaSimilarityQuestionAnswer(); qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent()); + qaSimilarityQuestionAnswer.setMatchQuestionId(String.valueOf(document.getMetadata().get("matchQuestionId"))); qaSimilarityQuestionAnswer.setDictId(String.valueOf(document.getMetadata().get("dictId"))); - qaSimilarityQuestionAnswer.setLibraryQuestionId(String.valueOf(document.getMetadata().get("standardQuestionId"))); + qaSimilarityQuestionAnswer.setLibraryQuestionId(String.valueOf(document.getMetadata().get("libraryQuestionId"))); // 1- 可以使数据进行排序,相似度越高,数值越大(redis相似度给的数据是越小相似度越高) // -0.25目的是使数据趋近于中间,相似度不要太大(太大也不好调整),以使我们数据和张总之前提供的方法相似度差异稍小一点,但是不能小于0,如果小于0,取一个较大的值 double score = Math.max(0, 1 - Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score"))) - 0.25); diff --git a/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java b/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java index 6f333f93..27f6a292 100644 --- a/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java +++ b/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java @@ -25,5 +25,4 @@ public class TalkVideoTtsResultResVO { @Schema(description = "回复的消息内容,用于调试视频资料内容") private String answerMessage; - } diff --git a/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java b/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java index de902739..0e9402cf 100644 --- a/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java +++ b/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java @@ -1,5 +1,6 @@ package com.supervision.service.impl; +import cn.hutool.core.collection.CollUtil; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.StrUtil; import cn.hutool.json.JSONUtil; @@ -18,6 +19,7 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; +import org.springframework.util.StopWatch; import org.springframework.web.multipart.MultipartFile; import java.util.ArrayList; @@ -73,6 +75,7 @@ public class AskServiceImpl implements AskService { */ @Override public TalkVideoTtsResultResVO talkByVideoAndTts(TalkVideoReqVO talkReqVO) { + String answer = talkByVideoAndTts(talkReqVO.getProcessId(), talkReqVO.getText()); TalkVideoTtsResultResVO talkVideoTtsResultResVO = new TalkVideoTtsResultResVO(); talkVideoTtsResultResVO.setVoiceBase64(TtsUtil.ttsTransform(answer)); @@ -81,6 +84,8 @@ public class AskServiceImpl implements AskService { } private String talkByVideoAndTts(String processId, String question) { + StopWatch stopWatch = new StopWatch(); + stopWatch.start(); // 流转记录表 List circulationList = new ArrayList<>(); @@ -99,7 +104,7 @@ public class AskServiceImpl implements AskService { // 记录大模型的流转记录 buildAiCirculationDetail(circulationList, answer, medicalRec); // 保存消息到记录表 - saveQaRecord(process.getId(), medicalRec, 2, question, null, answer, circulationList); + saveQaRecord(process.getId(), medicalRec, 2, question, null, answer, circulationList, stopWatch); return answer; } QaSimilarityQuestionAnswer similarityResult = first.get(); @@ -115,7 +120,7 @@ public class AskServiceImpl implements AskService { String answer = aiService.talk(question, medicalRec.getMedicalRecordAi()); // 记录流转记录 buildAiCirculationDetail(circulationList, answer, medicalRec); - saveQaRecord(process.getId(), medicalRec, 2, question, null, answer, circulationList); + saveQaRecord(process.getId(), medicalRec, 2, question, null, answer, circulationList, stopWatch); return answer; } // 根据对应的标准问题,从标准问题表中找到标准问题 @@ -130,7 +135,7 @@ public class AskServiceImpl implements AskService { String answer = aiService.talk(question, medicalRec.getMedicalRecordAi()); // 记录流转记录 buildAiCirculationDetail(circulationList, answer, medicalRec); - saveQaRecord(process.getId(), medicalRec, 2, question, null, answer, circulationList); + saveQaRecord(process.getId(), medicalRec, 2, question, null, answer, circulationList, stopWatch); return answer; } // 根据问题找这个病历配置的答案 @@ -146,7 +151,7 @@ public class AskServiceImpl implements AskService { String answer = aiService.talk(question, medicalRec.getMedicalRecordAi()); // 记录流转记录 buildAiCirculationDetail(circulationList, answer, medicalRec); - saveQaRecord(process.getId(), medicalRec, 2, question, null, answer, circulationList); + saveQaRecord(process.getId(), medicalRec, 2, question, null, answer, circulationList, stopWatch); return answer; } // 如果找到了,就走病历配置的内容回答 @@ -165,7 +170,7 @@ public class AskServiceImpl implements AskService { .matchQuestion(similarityResult.getMatchQuestion()) .successType(1) .build()); - saveQaRecord(process.getId(), medicalRec, 1, question, similarityResult.getLibraryQuestionId(), patientAnswer, circulationList); + saveQaRecord(process.getId(), medicalRec, 1, question, similarityResult.getLibraryQuestionId(), patientAnswer, circulationList, stopWatch); return patientAnswer; } @@ -178,7 +183,7 @@ public class AskServiceImpl implements AskService { } - private void saveQaRecord(String processId, MedicalRec medicalRec, Integer matchType, String question, String libraryId, String answer, List circulationList) { + private void saveQaRecord(String processId, MedicalRec medicalRec, Integer matchType, String question, String libraryId, String answer, List circulationList, StopWatch stopWatch) { DiagnosisQaRecord record = new DiagnosisQaRecord(); record.setProcessId(processId); record.setMatchType(matchType); @@ -196,6 +201,11 @@ public class AskServiceImpl implements AskService { e.setQuestion(question); } ); + stopWatch.stop(); + // 记录本次执行耗时 + if (CollUtil.isNotEmpty(circulationList)) { + circulationList.get(circulationList.size() - 1).setRemark(stopWatch.toString()); + } askCirculationDetailService.saveBatch(circulationList); } }