From dfcd836ba5cf51db5657900e7409081d57674be3 Mon Sep 17 00:00:00 2001 From: liu Date: Thu, 6 Jun 2024 13:23:59 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=BE=97=E5=88=86=E9=A1=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../domain/QaSimilarityQuestionAnswer.java | 22 ++++++++ .../com/supervision/util/SimilarityUtil.java | 56 +++++++++++++++++++ .../QaSimilarityQuestionAnswer.java | 13 ----- .../pojo/vo/TalkVideoTtsResultResVO.java | 2 + .../service/impl/AskServiceImpl.java | 34 ++--------- 5 files changed, 85 insertions(+), 42 deletions(-) create mode 100644 virtual-patient-common/src/main/java/com/supervision/domain/QaSimilarityQuestionAnswer.java create mode 100644 virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java delete mode 100644 virtual-patient-web/src/main/java/com/supervision/pojo/qaSimilarity/QaSimilarityQuestionAnswer.java diff --git a/virtual-patient-common/src/main/java/com/supervision/domain/QaSimilarityQuestionAnswer.java b/virtual-patient-common/src/main/java/com/supervision/domain/QaSimilarityQuestionAnswer.java new file mode 100644 index 00000000..30f45d53 --- /dev/null +++ b/virtual-patient-common/src/main/java/com/supervision/domain/QaSimilarityQuestionAnswer.java @@ -0,0 +1,22 @@ +package com.supervision.domain; + +import lombok.Data; + +@Data +public class QaSimilarityQuestionAnswer { + + /** + * 匹配到的问题 + */ + private String matchQuestion; + + /** + * 匹配到的问题ID + */ + private String matchQuestionCode; + + /** + * cosine余弦得分,一般0.5以上匹配度就比较高了 + */ + private Double matchScore; +} diff --git a/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java b/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java new file mode 100644 index 00000000..bc930e8a --- /dev/null +++ b/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java @@ -0,0 +1,56 @@ +package com.supervision.util; + +import com.supervision.domain.QaSimilarityQuestionAnswer; +import lombok.extern.slf4j.Slf4j; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.RedisVectorStore; +import org.springframework.ai.vectorstore.SearchRequest; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Optional; + +@Slf4j +public class SimilarityUtil { + + private static final RedisVectorStore redisVectorStore = SpringBeanUtil.getBean(RedisVectorStore.class); + + /** + * 相似度比较,只找最高的 + * + * @param question 问题 + * @return 最高的TOP + */ + public static Optional talkRedisVectorWithScoreByFirst(String question) { + List qaSimilarityQuestionAnswers = talkRedisVectorWithScore(question); + return qaSimilarityQuestionAnswers.stream().findFirst(); + } + + /** + * 相似度比较,找出所有的 + * + * @param question 问题 + * @return 相似度最高的top10 + */ + public static List talkRedisVectorWithScore(String question) { + log.info("开始调用talkQaSimilarity,问题:{}", question); + try { + // 走Redis向量库进行比较,找出最高的top10 + List documents = redisVectorStore.similaritySearch(SearchRequest + .query(question) + .withTopK(10)); + return documents.stream().map(document -> { + QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = new QaSimilarityQuestionAnswer(); + qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent()); + qaSimilarityQuestionAnswer.setMatchQuestionCode(String.valueOf(document.getMetadata().get("standardQuestionId"))); + qaSimilarityQuestionAnswer.setMatchScore(Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score")))); + return qaSimilarityQuestionAnswer; + // 排序,降序,取最高的 + }).sorted(Comparator.comparing(QaSimilarityQuestionAnswer::getMatchScore).reversed()).toList(); + } catch (Exception e) { + log.error("调用talkQaSimilarity error ", e); + return new ArrayList<>(); + } + } +} diff --git a/virtual-patient-web/src/main/java/com/supervision/pojo/qaSimilarity/QaSimilarityQuestionAnswer.java b/virtual-patient-web/src/main/java/com/supervision/pojo/qaSimilarity/QaSimilarityQuestionAnswer.java deleted file mode 100644 index 2d018a7b..00000000 --- a/virtual-patient-web/src/main/java/com/supervision/pojo/qaSimilarity/QaSimilarityQuestionAnswer.java +++ /dev/null @@ -1,13 +0,0 @@ -package com.supervision.pojo.qaSimilarity; - -import lombok.Data; - -@Data -public class QaSimilarityQuestionAnswer { - - private String matchQuestion; - - private String matchQuestionCode; - - private Double matchScore; -} diff --git a/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java b/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java index f13896ce..6f333f93 100644 --- a/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java +++ b/virtual-patient-web/src/main/java/com/supervision/pojo/vo/TalkVideoTtsResultResVO.java @@ -1,6 +1,7 @@ package com.supervision.pojo.vo; +import com.fasterxml.jackson.annotation.JsonProperty; import io.swagger.v3.oas.annotations.media.Schema; import lombok.Data; @@ -17,6 +18,7 @@ public class TalkVideoTtsResultResVO { @Schema(description = "后端返回给前端时使用,表示该是语音回复还是action动作,1语音回复,2体格检查 3辅助检查") private Integer type = 1; + @JsonProperty(index = 100) @Schema(description = "音频base64位编码") private String voiceBase64; diff --git a/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java b/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java index 5efc6796..2ce52d8a 100644 --- a/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java +++ b/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java @@ -2,27 +2,24 @@ package com.supervision.service.impl; import cn.hutool.core.util.ObjectUtil; import cn.hutool.core.util.StrUtil; -import cn.hutool.json.JSONUtil; +import com.supervision.domain.QaSimilarityQuestionAnswer; import com.supervision.exception.BusinessException; import com.supervision.model.Process; import com.supervision.model.*; -import com.supervision.pojo.qaSimilarity.QaSimilarityQuestionAnswer; import com.supervision.pojo.vo.TalkVideoReqVO; import com.supervision.pojo.vo.TalkVideoTtsResultResVO; import com.supervision.service.*; import com.supervision.util.AsrUtil; +import com.supervision.util.SimilarityUtil; import com.supervision.util.TtsUtil; import com.supervision.util.UserUtil; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; -import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.web.multipart.MultipartFile; -import java.util.Comparator; -import java.util.List; import java.util.Optional; @Slf4j @@ -79,28 +76,6 @@ public class AskServiceImpl implements AskService { record.insert(); } - public QaSimilarityQuestionAnswer talkRedisVectorWithScore(String question) { - log.info("开始调用talkQaSimilarity,问题:{}", question); - try { - // 走Redis进行比较 - List documents = redisVectorStore.similaritySearch(question); - Optional first = documents.stream().map(document -> { - QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = new QaSimilarityQuestionAnswer(); - qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent()); - qaSimilarityQuestionAnswer.setMatchQuestionCode(String.valueOf(document.getMetadata().get("standardQuestionId"))); - qaSimilarityQuestionAnswer.setMatchScore(Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score")))); - return qaSimilarityQuestionAnswer; - }).max(Comparator.comparing(QaSimilarityQuestionAnswer::getMatchScore)); - // 排序,降序,取最高的 - log.info("调用talkQaSimilarity结束,问题:{},返回结果:{}", question, JSONUtil.toJsonStr(first.orElse(null))); - return first.orElse(null); - } catch (Exception e) { - log.error("调用talkQaSimilarity error ", e); - return null; - } - } - - /** * 使用无声视频+语音转文字的形式来做 * @@ -112,14 +87,15 @@ public class AskServiceImpl implements AskService { // 根据processId找到对应的病人 Process process = Optional.ofNullable(processService.getById(talkReqVO.getProcessId())).orElseThrow(() -> new BusinessException("未找到诊疗进程")); MedicalRec medicalRec = medicalRecService.getById(process.getMedicalRecId()); - QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = talkRedisVectorWithScore(talkReqVO.getText()); + Optional qaSimilarityQuestionAnswerOptional = SimilarityUtil.talkRedisVectorWithScoreByFirst(talkReqVO.getText()); TalkVideoTtsResultResVO talkVideoTtsResultResVO = new TalkVideoTtsResultResVO(); // 如果匹配度没有匹配到任何数据,则走大模型 - if (ObjectUtil.isEmpty(qaSimilarityQuestionAnswer)) { + if (qaSimilarityQuestionAnswerOptional.isEmpty()) { String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi()); talkVideoTtsResultResVO.setAnswerMessage(talk); saveAiRecord(process.getId(), talkReqVO.getText(), talkVideoTtsResultResVO.getAnswerMessage()); } else { + QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = qaSimilarityQuestionAnswerOptional.get(); // 如果阈值过低,也走大模型 double thresholdValue = Double.parseDouble(threshold); if (qaSimilarityQuestionAnswer.getMatchScore() < thresholdValue) {