You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
virtual-patient/virtual-patient-common/src/main/java/com/supervision/util/SimilarityUtil.java

92 lines
4.0 KiB
Java

package com.supervision.util;
import cn.hutool.core.collection.CollUtil;
import com.supervision.domain.QaSimilarityQuestionAnswer;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.FilterExpressionBuilder;
import java.util.*;
@Slf4j
public class SimilarityUtil {
private static final RedisVectorStore redisVectorStore = SpringBeanUtil.getBean(RedisVectorStore.class);
/**
* 相似度比较,只找最高的
*
* @param question 问题
* @return 最高的TOP
*/
public static Optional<QaSimilarityQuestionAnswer> talkRedisVectorWithScoreByFirst(String question) {
List<QaSimilarityQuestionAnswer> qaSimilarityQuestionAnswers = talkRedisVectorWithScore(question);
return qaSimilarityQuestionAnswers.stream().findFirst();
}
/**
* 相似度比较,找到病历里面配置答案的问题
*
* @param question 问题
* @param libraryQuestionIdList 配置答案的问题ID
* @return 匹配到的结果列表
*/
public static List<QaSimilarityQuestionAnswer> talkRedisVectorWithScore(String question, Collection<String> libraryQuestionIdList) {
log.info("开始调用talkQaSimilarity,问题:{}", question);
try {
SearchRequest searchRequest = SearchRequest
.query(question)
.withTopK(10);
// 添加条件,只匹配对应标准问的问题
if (CollUtil.isNotEmpty(libraryQuestionIdList)) {
FilterExpressionBuilder b = new FilterExpressionBuilder();
searchRequest.withFilterExpression(b.in("libraryQuestionId", new ArrayList<>(libraryQuestionIdList)).build());
}
// 走Redis向量库进行比较,找出最高的top10
List<Document> documents = redisVectorStore.similaritySearch(searchRequest);
return documents.stream().map(document -> {
QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = new QaSimilarityQuestionAnswer();
qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent());
qaSimilarityQuestionAnswer.setQuestionId(String.valueOf(document.getMetadata().get("questionId")));
qaSimilarityQuestionAnswer.setDictId(String.valueOf(document.getMetadata().get("dictId")));
qaSimilarityQuestionAnswer.setLibraryQuestionId(String.valueOf(document.getMetadata().get("libraryQuestionId")));
// 计算相似度
double score = computeScore(document);
qaSimilarityQuestionAnswer.setMatchScore(score);
return qaSimilarityQuestionAnswer;
// 排序,降序,取最高的
}).sorted(Comparator.comparing(QaSimilarityQuestionAnswer::getMatchScore).reversed()).toList();
} catch (Exception e) {
log.error("调用talkQaSimilarity error ", e);
return new ArrayList<>();
}
}
/**
* 计算相似度得分
*
* @param document 文档
* @return 得分
*/
public static Double computeScore(Document document) {
// 1- 可以使数据进行排序,相似度越高,数值越大(redis相似度给的数据是越小相似度越高)
// -0.25目的是使数据趋近于中间,相似度不要太大(太大也不好调整),以使我们数据和张总之前提供的方法相似度差异稍小一点,但是不能小于0,如果小于0,取一个较大的值
return Math.max(0, 1 - Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score"))) - 0.25);
}
/**
* 相似度比较,找出所有的
*
* @param question 问题
* @return 相似度最高的top10
*/
public static List<QaSimilarityQuestionAnswer> talkRedisVectorWithScore(String question) {
return talkRedisVectorWithScore(question, null);
}
}