集成大模型对话功能

pull/1/head
liu
parent 94a989288c
commit 043c78097a

@ -42,6 +42,11 @@
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-starter-bootstrap</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>

@ -17,9 +17,9 @@ import java.util.concurrent.*;
@Slf4j
public class AiChatUtil {
public static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(5, 0, "chat", new ThreadPoolExecutor.CallerRunsPolicy());
private static final ExecutorService chatExecutor = ThreadUtil.newFixedExecutor(5, 0, "chat", new ThreadPoolExecutor.CallerRunsPolicy());
public static final OllamaChatClient chatClient = SpringBeanUtil.getBean(OllamaChatClient.class);
private static final OllamaChatClient chatClient = SpringBeanUtil.getBean(OllamaChatClient.class);
/**
*
@ -27,7 +27,7 @@ public class AiChatUtil {
* @param chat
* @return jsonObject
*/
public Optional<JSONObject> chat(String chat) {
public static Optional<JSONObject> chat(String chat) {
Prompt prompt = new Prompt(List.of(new UserMessage(chat)));
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try {
@ -44,7 +44,7 @@ public class AiChatUtil {
* @param messageList
* @return jsonObject
*/
public Optional<JSONObject> chat(List<Message> messageList) {
public static Optional<JSONObject> chat(List<Message> messageList) {
Prompt prompt = new Prompt(messageList);
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try {
@ -63,7 +63,7 @@ public class AiChatUtil {
* @param <T>
* @return ,
*/
public <T> Optional<T> chat(List<Message> messageList, Class<T> clazz) {
public static <T> Optional<T> chat(List<Message> messageList, Class<T> clazz) {
Prompt prompt = new Prompt(messageList);
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try {
@ -83,7 +83,7 @@ public class AiChatUtil {
* @param <T>
* @return ,
*/
public <T> Optional<T> chat(String chat, Class<T> clazz) {
public static <T> Optional<T> chat(String chat, Class<T> clazz) {
Prompt prompt = new Prompt(List.of(new UserMessage(chat)));
Future<String> submit = chatExecutor.submit(new ChatTask(chatClient, prompt));
try {

@ -32,11 +32,6 @@
</exclusions>
</dependency>
<dependency>
<groupId>org.springframework.cloud</groupId>
<artifactId>spring-cloud-starter-bootstrap</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.cloud</groupId>

@ -128,6 +128,9 @@ public class MedicalRec extends Model<MedicalRec> implements Serializable {
@Schema(description = "全面检查")
private String fullCheck;
@Schema(description = "提交给大模型的病历")
private String medicalRecordAi;
/**
* ID
*/

@ -2,6 +2,7 @@ package com.supervision.controller;
import com.supervision.pojo.vo.TalkResultResVO;
import com.supervision.pojo.vo.TalkVideoReqVO;
import com.supervision.pojo.vo.TalkVideoTtsResultResVO;
import com.supervision.service.AskService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
@ -32,7 +33,11 @@ public class AskController {
return askService.talkByVideo(talkReqVO);
}
@Operation(summary = "使用无声视频+语音转文字的形式来做")
@PostMapping("talkByVideoAndTts")
public TalkVideoTtsResultResVO talkByVideoAndTts(@RequestBody TalkVideoReqVO talkReqVO) {
return askService.talkByVideoAndTts(talkReqVO);
}
}

@ -0,0 +1,9 @@
package com.supervision.pojo.ai;
import lombok.Data;
@Data
public class AiTalkAnswerDTO {
private String answer;
}

@ -0,0 +1,27 @@
package com.supervision.pojo.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
/**
* +tts
*/
@Data
public class TalkVideoTtsResultResVO {
/**
* 使,action,1,2 3
* 1
*/
@Schema(description = "后端返回给前端时使用,表示该是语音回复还是action动作,1语音回复,2体格检查 3辅助检查")
private Integer type = 1;
@Schema(description = "音频base64位编码")
private String voiceBase64;
@Schema(description = "回复的消息内容,用于调试视频资料内容")
private String answerMessage;
}

@ -0,0 +1,6 @@
package com.supervision.service;
public interface AiService {
String talk(String question, String medicalRecord);
}

@ -2,6 +2,7 @@ package com.supervision.service;
import com.supervision.pojo.vo.TalkResultResVO;
import com.supervision.pojo.vo.TalkVideoReqVO;
import com.supervision.pojo.vo.TalkVideoTtsResultResVO;
import org.springframework.web.multipart.MultipartFile;
import java.io.IOException;
@ -12,6 +13,7 @@ public interface AskService {
TalkResultResVO talkByVideo(TalkVideoReqVO talkReqVO) throws IOException;
TalkVideoTtsResultResVO talkByVideoAndTts(TalkVideoReqVO talkReqVO);
}

@ -0,0 +1,61 @@
package com.supervision.service.impl;
import cn.hutool.core.util.StrUtil;
import com.supervision.pojo.ai.AiTalkAnswerDTO;
import com.supervision.service.AiService;
import com.supervision.util.AiChatUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.stereotype.Service;
import java.util.*;
@Slf4j
@Service
public class AiServiceImpl implements AiService {
public static final String userPromptTemplate = """
---
{medicalRecord}
---
json:{"answer":"扮演病人并根据病历回复的内容"}
""";
private static final String systemPrompt = """
,
,
'//'",!
'',','!!!
""";
/**
*
*
* @param question
* @return
*/
public String talk(String question, String medicalRecord) {
Map<String, String> paramMap = new HashMap<>();
paramMap.put("medicalRecord", medicalRecord);
List<Message> messageHistoryList = new ArrayList<>();
messageHistoryList.add(new SystemMessage(systemPrompt));
messageHistoryList.add(new UserMessage(StrUtil.format(userPromptTemplate, paramMap)));
messageHistoryList.add(new AssistantMessage("好的,已了解我要扮演病人的病历。已准备好对话了。"));
messageHistoryList.add(new UserMessage(question));
Optional<AiTalkAnswerDTO> chat = AiChatUtil.chat(messageHistoryList, AiTalkAnswerDTO.class);
if (chat.isPresent()) {
AiTalkAnswerDTO aiTalkAnswerDTO = chat.get();
if (StrUtil.isNotBlank(aiTalkAnswerDTO.getAnswer())) {
return aiTalkAnswerDTO.getAnswer();
}
}
return "医生,我没有听懂你说的是什么";
}
}

@ -18,9 +18,11 @@ import com.supervision.pojo.qaSimilarity.QaSimilarityQuestion;
import com.supervision.pojo.qaSimilarity.QaSimilarityQuestionAnswer;
import com.supervision.pojo.vo.TalkResultResVO;
import com.supervision.pojo.vo.TalkVideoReqVO;
import com.supervision.pojo.vo.TalkVideoTtsResultResVO;
import com.supervision.service.*;
import com.supervision.util.AsrUtil;
import com.supervision.util.MinioUtil;
import com.supervision.util.TtsUtil;
import com.supervision.util.UserUtil;
import com.supervision.vo.rasa.RasaTalkVo;
import lombok.RequiredArgsConstructor;
@ -29,6 +31,7 @@ import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.io.InputStream;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
@ -55,6 +58,10 @@ public class AskServiceImpl implements AskService {
private final CommonDicService commonDicService;
private final AiService aiService;
private final MedicalRecService medicalRecService;
@Override
public String receiveVoiceFile(MultipartFile file) {
if (file.getSize() <= 0) {
@ -221,6 +228,20 @@ public class AskServiceImpl implements AskService {
}
}
public QaSimilarityQuestionAnswer talkQaSimilarityWithScore(String question, String sessionId) {
log.info("开始调用talkQaSimilarity,问题:{}", question);
try {
GlobalResult<List<QaSimilarityQuestionAnswer>> result = askQaSimilarityFeignClient.askQuestionSimilarityAnswer(new QaSimilarityQuestion(question));
// 排序,降序,取最高的
result.getData().sort(Comparator.comparing(QaSimilarityQuestionAnswer::getMatchScore).reversed());
log.info("调用talkQaSimilarity结束,问题:{},返回结果:{}", question, JSONUtil.toJsonStr(result));
return CollUtil.getFirst(result.getData());
} catch (Exception e) {
log.error("调用talkQaSimilarity error ", e);
return null;
}
}
private AskPatientAnswer getMedicalRecErrorAnswer(String medicalRecId) {
//Optional.ofNullable(medicalRecErrorAnswer).orElseGet(() ->new AskPatientAnswer()).getAnswer()
Assert.notEmpty(medicalRecId, "病历id不能为空");
@ -262,4 +283,59 @@ public class AskServiceImpl implements AskService {
}
}
/**
* 使+
*
* @param talkReqVO
* @return
*/
@Override
public TalkVideoTtsResultResVO talkByVideoAndTts(TalkVideoReqVO talkReqVO) {
// 根据processId找到对应的病人
Process process = Optional.ofNullable(processService.getById(talkReqVO.getProcessId())).orElseThrow(() -> new BusinessException("未找到诊疗进程"));
MedicalRec medicalRec = medicalRecService.getById(process.getMedicalRecId());
QaSimilarityQuestionAnswer qaSimilarityQuestionAnswer = talkQaSimilarityWithScore(talkReqVO.getText(), UserUtil.getUser().getId());
TalkVideoTtsResultResVO talkVideoTtsResultResVO = new TalkVideoTtsResultResVO();
// 如果匹配度没有匹配到任何数据,则走大模型
if (ObjectUtil.isEmpty(qaSimilarityQuestionAnswer)) {
String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi());
talkVideoTtsResultResVO.setAnswerMessage(talk);
} else {
// 如果阈值过低,也走大模型
if (qaSimilarityQuestionAnswer.getMatchScore() < 0.5) {
log.info("{}:匹配到的结果阈值过低,走大模型回答", qaSimilarityQuestionAnswer);
String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi());
talkVideoTtsResultResVO.setAnswerMessage(talk);
} else {
// 如果查到的问题不在问题库中,走大模型回答
AskTemplateQuestionLibrary library = askTemplateQuestionLibraryService.getById(qaSimilarityQuestionAnswer.getMatchQuestionCode());
if (ObjectUtil.isEmpty(library)) {
log.info("{}:未从问题库中找到,走大模型回答", qaSimilarityQuestionAnswer);
String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi());
talkVideoTtsResultResVO.setAnswerMessage(talk);
} else {
// 根据问题找这个病历配置的答案
AskPatientAnswer askPatientAnswer = askPatientAnswerService.lambdaQuery().eq(AskPatientAnswer::getMedicalId, process.getMedicalRecId())
.eq(AskPatientAnswer::getLibraryQuestionId, library.getId()).last("limit 1").one();
// 如果找到了,就走病历配置的内容回答
if (ObjectUtil.isNotEmpty(askPatientAnswer)) {
String resText = askPatientAnswer.getAnswer();
log.info("{}:找到了病历配置的回答语句:{},回答内容:{},走病历回答", qaSimilarityQuestionAnswer.getMatchQuestionCode(), askPatientAnswer.getId(), resText);
talkVideoTtsResultResVO.setAnswerMessage(resText);
// 保存记录
saveQaRecord(talkReqVO.getProcessId(), "patient", askPatientAnswer.getId(), talkReqVO.getText(), library, resText);
return talkVideoTtsResultResVO;
} else {
// 如果问题的答案没有配置,还是走大模型的回答
log.info("{}:病历配置,从AskPatientAnswer中未找到回答结果,走大模型", qaSimilarityQuestionAnswer.getMatchQuestionCode());
String talk = aiService.talk(talkReqVO.getText(), medicalRec.getMedicalRecordAi());
talkVideoTtsResultResVO.setAnswerMessage(talk);
}
}
}
}
talkVideoTtsResultResVO.setVoiceBase64(TtsUtil.ttsTransform(talkVideoTtsResultResVO.getAnswerMessage()));
return talkVideoTtsResultResVO;
}
}

Loading…
Cancel
Save