优化字段

pull/1/head
liu 11 months ago
parent 82429a733f
commit 0b17a4462a

@ -1,15 +1,14 @@
package com.supervision.util;
import cn.hutool.core.collection.CollUtil;
import com.supervision.domain.QaSimilarityQuestionAnswer;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.*;
@Slf4j
public class SimilarityUtil {
@ -28,18 +27,26 @@ public class SimilarityUtil {
}
/**
* ,
* ,
*
* @param question
* @return top10
* @param question
* @param libraryQuestionIdList ID
* @return
*/
public static List<QaSimilarityQuestionAnswer> talkRedisVectorWithScore(String question) {
public static List<QaSimilarityQuestionAnswer> talkRedisVectorWithScore(String question, Collection<String> libraryQuestionIdList) {
log.info("开始调用talkQaSimilarity,问题:{}", question);
try {
// 走Redis向量库进行比较,找出最高的top10
List<Document> documents = redisVectorStore.similaritySearch(SearchRequest
SearchRequest searchRequest = SearchRequest
.query(question)
.withTopK(10));
.withTopK(10);
// 添加条件,只匹配对应标准问的问题
if (CollUtil.isNotEmpty(libraryQuestionIdList)) {
FilterExpressionBuilder b = new FilterExpressionBuilder();
searchRequest.withFilterExpression(b.in("libraryQuestionId", libraryQuestionIdList).build());
}
// 走Redis向量库进行比较,找出最高的top10
List<Document> documents = redisVectorStore.similaritySearch(searchRequest);
return documents.stream().map(document -> {
QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = new QaSimilarityQuestionAnswer();
qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent());
@ -59,5 +66,15 @@ public class SimilarityUtil {
}
}
/**
* ,
*
* @param question
* @return top10
*/
public static List<QaSimilarityQuestionAnswer> talkRedisVectorWithScore(String question) {
return talkRedisVectorWithScore(question, null);
}
}

@ -25,7 +25,9 @@ import org.springframework.web.multipart.MultipartFile;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
@Slf4j
@Service
@ -92,8 +94,23 @@ public class AskServiceImpl implements AskService {
// 根据processId找到对应的病人
Process process = Optional.ofNullable(processService.getById(processId)).orElseThrow(() -> new BusinessException("未找到诊疗进程"));
MedicalRec medicalRec = medicalRecService.getById(process.getMedicalRecId());
// 找到这个病历里面配置的所有配置了答案的问题
List<AskPatientAnswer> medicalAnswerList = askPatientAnswerService.lambdaQuery().eq(AskPatientAnswer::getMedicalId, medicalRec.getId()).select(AskPatientAnswer::getLibraryQuestionId).list();
if (CollUtil.isEmpty(medicalAnswerList)) {
// 记录流转信息
circulationList.add(AskCirculationDetail.builder().failInfo("病历里面没有配置任何答案,直接略过相似度匹配,都大模型").build());
// 如果没有配置答案,就走大模型
String answer = aiService.talk(question, medicalRec.getMedicalRecordAi());
// 记录大模型的流转记录
buildAiCirculationDetail(circulationList, answer, medicalRec);
// 保存消息到记录表
saveQaRecord(process.getId(), medicalRec, 2, question, null, answer, circulationList, stopWatch);
return answer;
}
Set<String> libraryQuestionIdSet = medicalAnswerList.stream().map(AskPatientAnswer::getLibraryQuestionId).collect(Collectors.toSet());
// 进行相似度匹配
List<QaSimilarityQuestionAnswer> similarityAnswerList = SimilarityUtil.talkRedisVectorWithScore(question);
List<QaSimilarityQuestionAnswer> similarityAnswerList = SimilarityUtil.talkRedisVectorWithScore(question, libraryQuestionIdSet);
Optional<QaSimilarityQuestionAnswer> first = similarityAnswerList.stream().findFirst();
// 如果匹配度没有匹配到任何数据
if (first.isEmpty()) {

Loading…
Cancel
Save