diff --git a/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java b/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java index 741b89d6..490cf53f 100644 --- a/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java +++ b/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java @@ -1,15 +1,14 @@ package com.supervision.util; +import cn.hutool.core.collection.CollUtil; import com.supervision.domain.QaSimilarityQuestionAnswer; import lombok.extern.slf4j.Slf4j; import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.RedisVectorStore; import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder; -import java.util.ArrayList; -import java.util.Comparator; -import java.util.List; -import java.util.Optional; +import java.util.*; @Slf4j public class SimilarityUtil { @@ -28,18 +27,26 @@ public class SimilarityUtil { } /** - * 相似度比较,找出所有的 + * 相似度比较,找到病历里面配置答案的问题 * - * @param question 问题 - * @return 相似度最高的top10 + * @param question 问题 + * @param libraryQuestionIdList 配置答案的问题ID + * @return 匹配到的结果列表 */ - public static List talkRedisVectorWithScore(String question) { + public static List talkRedisVectorWithScore(String question, Collection libraryQuestionIdList) { log.info("开始调用talkQaSimilarity,问题:{}", question); try { - // 走Redis向量库进行比较,找出最高的top10 - List documents = redisVectorStore.similaritySearch(SearchRequest + + SearchRequest searchRequest = SearchRequest .query(question) - .withTopK(10)); + .withTopK(10); + // 添加条件,只匹配对应标准问的问题 + if (CollUtil.isNotEmpty(libraryQuestionIdList)) { + FilterExpressionBuilder b = new FilterExpressionBuilder(); + searchRequest.withFilterExpression(b.in("libraryQuestionId", libraryQuestionIdList).build()); + } + // 走Redis向量库进行比较,找出最高的top10 + List documents = redisVectorStore.similaritySearch(searchRequest); return documents.stream().map(document -> { QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = new QaSimilarityQuestionAnswer(); qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent()); @@ -59,5 +66,15 @@ public class SimilarityUtil { } } + /** + * 相似度比较,找出所有的 + * + * @param question 问题 + * @return 相似度最高的top10 + */ + public static List talkRedisVectorWithScore(String question) { + return talkRedisVectorWithScore(question, null); + } + } diff --git a/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java b/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java index 07b74174..046f9212 100644 --- a/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java +++ b/virtual-patient-web/src/main/java/com/supervision/service/impl/AskServiceImpl.java @@ -25,7 +25,9 @@ import org.springframework.web.multipart.MultipartFile; import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; @Slf4j @Service @@ -92,8 +94,23 @@ public class AskServiceImpl implements AskService { // 根据processId找到对应的病人 Process process = Optional.ofNullable(processService.getById(processId)).orElseThrow(() -> new BusinessException("未找到诊疗进程")); MedicalRec medicalRec = medicalRecService.getById(process.getMedicalRecId()); + // 找到这个病历里面配置的所有配置了答案的问题 + List medicalAnswerList = askPatientAnswerService.lambdaQuery().eq(AskPatientAnswer::getMedicalId, medicalRec.getId()).select(AskPatientAnswer::getLibraryQuestionId).list(); + if (CollUtil.isEmpty(medicalAnswerList)) { + // 记录流转信息 + circulationList.add(AskCirculationDetail.builder().failInfo("病历里面没有配置任何答案,直接略过相似度匹配,都大模型").build()); + // 如果没有配置答案,就走大模型 + String answer = aiService.talk(question, medicalRec.getMedicalRecordAi()); + // 记录大模型的流转记录 + buildAiCirculationDetail(circulationList, answer, medicalRec); + // 保存消息到记录表 + saveQaRecord(process.getId(), medicalRec, 2, question, null, answer, circulationList, stopWatch); + return answer; + } + Set libraryQuestionIdSet = medicalAnswerList.stream().map(AskPatientAnswer::getLibraryQuestionId).collect(Collectors.toSet()); + // 进行相似度匹配 - List similarityAnswerList = SimilarityUtil.talkRedisVectorWithScore(question); + List similarityAnswerList = SimilarityUtil.talkRedisVectorWithScore(question, libraryQuestionIdSet); Optional first = similarityAnswerList.stream().findFirst(); // 如果匹配度没有匹配到任何数据 if (first.isEmpty()) {