优化得分项

pull/1/head
liu 11 months ago
parent d7d7e2e5d3
commit dfcd836ba5

@ -0,0 +1,22 @@
package com.supervision.domain;
import lombok.Data;
@Data
public class QaSimilarityQuestionAnswer {
/**
*
*/
private String matchQuestion;
/**
* ID
*/
private String matchQuestionCode;
/**
* cosine,0.5
*/
private Double matchScore;
}

@ -0,0 +1,56 @@
package com.supervision.util;
import com.supervision.domain.QaSimilarityQuestionAnswer;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.ai.vectorstore.SearchRequest;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
@Slf4j
public class SimilarityUtil {
private static final RedisVectorStore redisVectorStore = SpringBeanUtil.getBean(RedisVectorStore.class);
/**
* ,
*
* @param question
* @return TOP
*/
public static Optional<QaSimilarityQuestionAnswer> talkRedisVectorWithScoreByFirst(String question) {
List<QaSimilarityQuestionAnswer> qaSimilarityQuestionAnswers = talkRedisVectorWithScore(question);
return qaSimilarityQuestionAnswers.stream().findFirst();
}
/**
* ,
*
* @param question
* @return top10
*/
public static List<QaSimilarityQuestionAnswer> talkRedisVectorWithScore(String question) {
log.info("开始调用talkQaSimilarity,问题:{}", question);
try {
// 走Redis向量库进行比较,找出最高的top10
List<Document> documents = redisVectorStore.similaritySearch(SearchRequest
.query(question)
.withTopK(10));
return documents.stream().map(document -> {
QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = new QaSimilarityQuestionAnswer();
qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent());
qaSimilarityQuestionAnswer.setMatchQuestionCode(String.valueOf(document.getMetadata().get("standardQuestionId")));
qaSimilarityQuestionAnswer.setMatchScore(Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score"))));
return qaSimilarityQuestionAnswer;
// 排序,降序,取最高的
}).sorted(Comparator.comparing(QaSimilarityQuestionAnswer::getMatchScore).reversed()).toList();
} catch (Exception e) {
log.error("调用talkQaSimilarity error ", e);
return new ArrayList<>();
}
}
}

@ -1,13 +0,0 @@
package com.supervision.pojo.qaSimilarity;
import lombok.Data;
@Data
public class QaSimilarityQuestionAnswer {
private String matchQuestion;
private String matchQuestionCode;
private Double matchScore;
}

@ -1,6 +1,7 @@
package com.supervision.pojo.vo;
import com.fasterxml.jackson.annotation.JsonProperty;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@ -17,6 +18,7 @@ public class TalkVideoTtsResultResVO {
@Schema(description = "后端返回给前端时使用,表示该是语音回复还是action动作,1语音回复,2体格检查 3辅助检查")
private Integer type = 1;
@JsonProperty(index = 100)
@Schema(description = "音频base64位编码")
private String voiceBase64;

@ -2,27 +2,24 @@ package com.supervision.service.impl;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.domain.QaSimilarityQuestionAnswer;
import com.supervision.exception.BusinessException;
import com.supervision.model.Process;
import com.supervision.model.*;
import com.supervision.pojo.qaSimilarity.QaSimilarityQuestionAnswer;
import com.supervision.pojo.vo.TalkVideoReqVO;
import com.supervision.pojo.vo.TalkVideoTtsResultResVO;
import com.supervision.service.*;
import com.supervision.util.AsrUtil;
import com.supervision.util.SimilarityUtil;
import com.supervision.util.TtsUtil;
import com.supervision.util.UserUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
@Slf4j
@ -79,28 +76,6 @@ public class AskServiceImpl implements AskService {
record.insert();
}
public QaSimilarityQuestionAnswer talkRedisVectorWithScore(String question) {
log.info("开始调用talkQaSimilarity,问题:{}", question);
try {
// 走Redis进行比较
List<Document> documents = redisVectorStore.similaritySearch(question);
Optional<QaSimilarityQuestionAnswer> first = documents.stream().map(document -> {
QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = new QaSimilarityQuestionAnswer();
qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent());
qaSimilarityQuestionAnswer.setMatchQuestionCode(String.valueOf(document.getMetadata().get("standardQuestionId")));
qaSimilarityQuestionAnswer.setMatchScore(Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score"))));
return qaSimilarityQuestionAnswer;
}).max(Comparator.comparing(QaSimilarityQuestionAnswer::getMatchScore));
// 排序,降序,取最高的
log.info("调用talkQaSimilarity结束,问题:{},返回结果:{}", question, JSONUtil.toJsonStr(first.orElse(null)));
return first.orElse(null);
} catch (Exception e) {
log.error("调用talkQaSimilarity error ", e);
return null;
}
}
/**
* 使+
*
@ -112,14 +87,15 @@ public class AskServiceImpl implements AskService {
// 根据processId找到对应的病人
Process process = Optional.ofNullable(processService.getById(talkReqVO.getProcessId())).orElseThrow(() -> new BusinessException("未找到诊疗进程"));
MedicalRec medicalRec = medicalRecService.getById(process.getMedicalRecId());
QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = talkRedisVectorWithScore(talkReqVO.getText());
Optional<QaSimilarityQuestionAnswer> qaSimilarityQuestionAnswerOptional = SimilarityUtil.talkRedisVectorWithScoreByFirst(talkReqVO.getText());
TalkVideoTtsResultResVO talkVideoTtsResultResVO = new TalkVideoTtsResultResVO();
// 如果匹配度没有匹配到任何数据,则走大模型
if (ObjectUtil.isEmpty(qaSimilarityQuestionAnswer)) {
if (qaSimilarityQuestionAnswerOptional.isEmpty()) {
String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi());
talkVideoTtsResultResVO.setAnswerMessage(talk);
saveAiRecord(process.getId(), talkReqVO.getText(), talkVideoTtsResultResVO.getAnswerMessage());
} else {
QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = qaSimilarityQuestionAnswerOptional.get();
// 如果阈值过低,也走大模型
double thresholdValue = Double.parseDouble(threshold);
if (qaSimilarityQuestionAnswer.getMatchScore() < thresholdValue) {

Loading…
Cancel
Save