177 lines
8.8 KiB
Java

package com.supervision.service.impl;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.json.JSONUtil;
import com.supervision.exception.BusinessException;
import com.supervision.model.Process;
import com.supervision.model.*;
import com.supervision.pojo.qaSimilarity.QaSimilarityQuestionAnswer;
import com.supervision.pojo.vo.TalkVideoReqVO;
import com.supervision.pojo.vo.TalkVideoTtsResultResVO;
import com.supervision.service.*;
import com.supervision.util.AsrUtil;
import com.supervision.util.TtsUtil;
import com.supervision.util.UserUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.RedisVectorStore;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
@Slf4j
@Service
@RequiredArgsConstructor
public class AskServiceImpl implements AskService {
private final ProcessService processService;
private final AskTemplateQuestionLibraryService askTemplateQuestionLibraryService;
private final AskPatientAnswerService askPatientAnswerService;
private final AiService aiService;
private final MedicalRecService medicalRecService;
private final DiagnosisAiRecordService diagnosisAiRecordService;
private final RedisVectorStore redisVectorStore;
@Value("${threshold:0.7}")
private String threshold;
@Override
public String receiveVoiceFile(MultipartFile file) {
if (file.getSize() <= 0) {
throw new BusinessException("语音内容为空");
}
// 获取音频对应的文字
String text = null;
try {
text = AsrUtil.asrTransformByBytes(file.getBytes());
} catch (Exception e) {
throw new BusinessException("获取语音失败");
}
if (StrUtil.isEmpty(text)) {
throw new BusinessException("语音内容为空");
}
return text;
}
private void saveQaRecord(String processId, String answerType, String answerId, String question, AskTemplateQuestionLibrary library, String resText) {
DiagnosisQaRecord record = new DiagnosisQaRecord();
record.setProcessId(processId);
record.setAnswerType(answerType);
record.setAnswerId(answerId);
if (ObjectUtil.isNotEmpty(library)) {
record.setQuestionLibraryId(library.getId());
}
record.setQuestion(question);
record.setAnswer(resText);
record.setCreateUserId(UserUtil.getUser().getId());
record.insert();
}
public QaSimilarityQuestionAnswer talkRedisVectorWithScore(String question) {
log.info("开始调用talkQaSimilarity,问题:{}", question);
try {
// 走Redis进行比较
List<Document> documents = redisVectorStore.similaritySearch(question);
Optional<QaSimilarityQuestionAnswer> first = documents.stream().map(document -> {
QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = new QaSimilarityQuestionAnswer();
qaSimilarityQuestionAnswer.setMatchQuestion(document.getContent());
qaSimilarityQuestionAnswer.setMatchQuestionCode(String.valueOf(document.getMetadata().get("standardQuestionId")));
qaSimilarityQuestionAnswer.setMatchScore(Double.parseDouble(String.valueOf(document.getMetadata().get("vector_score"))));
return qaSimilarityQuestionAnswer;
}).max(Comparator.comparing(QaSimilarityQuestionAnswer::getMatchScore));
// 排序,降序,取最高的
log.info("调用talkQaSimilarity结束,问题:{},返回结果:{}", question, JSONUtil.toJsonStr(first.orElse(null)));
return first.orElse(null);
} catch (Exception e) {
log.error("调用talkQaSimilarity error ", e);
return null;
}
}
/**
* 使+
*
* @param talkReqVO
* @return
*/
@Override
public TalkVideoTtsResultResVO talkByVideoAndTts(TalkVideoReqVO talkReqVO) {
// 根据processId找到对应的病人
Process process = Optional.ofNullable(processService.getById(talkReqVO.getProcessId())).orElseThrow(() -> new BusinessException("未找到诊疗进程"));
MedicalRec medicalRec = medicalRecService.getById(process.getMedicalRecId());
QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = talkRedisVectorWithScore(talkReqVO.getText());
TalkVideoTtsResultResVO talkVideoTtsResultResVO = new TalkVideoTtsResultResVO();
// 如果匹配度没有匹配到任何数据,则走大模型
if (ObjectUtil.isEmpty(qaSimilarityQuestionAnswer)) {
String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi());
talkVideoTtsResultResVO.setAnswerMessage(talk);
saveAiRecord(process.getId(), talkReqVO.getText(), talkVideoTtsResultResVO.getAnswerMessage());
} else {
// 如果阈值过低,也走大模型
double thresholdValue = Double.parseDouble(threshold);
if (qaSimilarityQuestionAnswer.getMatchScore() < thresholdValue) {
log.info("{}:匹配到的结果阈值过低,走大模型回答", qaSimilarityQuestionAnswer);
String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi());
talkVideoTtsResultResVO.setAnswerMessage(talk);
saveAiRecord(process.getId(), talkReqVO.getText(), talkVideoTtsResultResVO.getAnswerMessage());
} else {
// 如果查到的问题不在问题库中,走大模型回答
AskTemplateQuestionLibrary library = askTemplateQuestionLibraryService.getById(qaSimilarityQuestionAnswer.getMatchQuestionCode());
if (ObjectUtil.isEmpty(library)) {
log.info("{}:未从问题库中找到,走大模型回答", qaSimilarityQuestionAnswer);
String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi());
talkVideoTtsResultResVO.setAnswerMessage(talk);
saveAiRecord(process.getId(), talkReqVO.getText(), talkVideoTtsResultResVO.getAnswerMessage());
} else {
// 根据问题找这个病历配置的答案
AskPatientAnswer askPatientAnswer = askPatientAnswerService.lambdaQuery().eq(AskPatientAnswer::getMedicalId, process.getMedicalRecId())
.eq(AskPatientAnswer::getLibraryQuestionId, library.getId()).last("limit 1").one();
// 如果找到了,就走病历配置的内容回答
if (ObjectUtil.isNotEmpty(askPatientAnswer)) {
String resText = askPatientAnswer.getAnswer();
log.info("{}:找到了病历配置的回答语句:{},回答内容:{},走病历回答", qaSimilarityQuestionAnswer.getMatchQuestionCode(), askPatientAnswer.getId(), resText);
talkVideoTtsResultResVO.setAnswerMessage(resText);
// 保存记录到问答记录表
saveQaRecord(talkReqVO.getProcessId(), "patient", askPatientAnswer.getId(), talkReqVO.getText(), library, resText);
} else {
// 如果问题的答案没有配置,还是走大模型的回答
log.info("{}:病历配置,从AskPatientAnswer中未找到回答结果,走大模型", qaSimilarityQuestionAnswer.getMatchQuestionCode());
String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi());
talkVideoTtsResultResVO.setAnswerMessage(talk);
saveAiRecord(process.getId(), talkReqVO.getText(), talkVideoTtsResultResVO.getAnswerMessage());
}
}
}
}
talkVideoTtsResultResVO.setVoiceBase64(TtsUtil.ttsTransform(talkVideoTtsResultResVO.getAnswerMessage()));
return talkVideoTtsResultResVO;
}
/**
* AI,便AI
*/
private void saveAiRecord(String processId, String question, String answer) {
DiagnosisAiRecord diagnosisAiRecord = new DiagnosisAiRecord();
diagnosisAiRecord.setProcessId(processId);
diagnosisAiRecord.setQuestion(question);
diagnosisAiRecord.setAnswer(answer);
diagnosisAiRecord.setCreateUserId(UserUtil.getUser().getId());
diagnosisAiRecord.setUpdateUserId(UserUtil.getUser().getId());
diagnosisAiRecordService.save(diagnosisAiRecord);
}
}