You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
9.5 KiB
Java

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package com.supervision.service.impl;
import cn.hutool.core.codec.Base64;
import com.alibaba.fastjson.JSON;
import com.supervision.dto.dify.ChatResDTO;
import com.supervision.dto.paddlespeech.res.TtsResultDTO;
import com.supervision.dto.robot.AnswerInfo;
import com.supervision.dto.robot.AskInfo;
import com.supervision.dto.robot.RobotTalkDTO;
import com.supervision.model.RobotTalkReq;
import com.supervision.model.dify.DIFYChatReqInputVO;
import com.supervision.model.dify.DifyChatReqVO;
import com.supervision.model.dify.StreamResponse;
import com.supervision.service.IChatService;
import com.supervision.util.AsrUtil;
import com.supervision.util.DifyApiUtil;
import com.supervision.util.TtsUtil;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import net.sourceforge.pinyin4j.PinyinHelper;
import net.sourceforge.pinyin4j.format.HanyuPinyinCaseType;
import net.sourceforge.pinyin4j.format.HanyuPinyinOutputFormat;
import net.sourceforge.pinyin4j.format.HanyuPinyinToneType;
import net.sourceforge.pinyin4j.format.exception.BadHanyuPinyinOutputFormatCombination;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.MediaType;
import org.springframework.http.codec.ServerSentEvent;
import org.springframework.stereotype.Service;
import org.springframework.util.StopWatch;
import org.springframework.web.multipart.MultipartFile;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.UUID;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@Slf4j
@Service
@RequiredArgsConstructor
public class ChatServiceImpl implements IChatService {
@Value("${dify.url}")
private String difyUrl;
@Value("${dify.app-auth}")
private String difyAppAuth;
private final WebClient webClient;
private final DifyApiUtil difyApiUtil;
Map<String, String> voiceCache = new HashMap<>();
@Override
public Flux<ServerSentEvent<Map<String, String>>> streamingMessage(String query) {
DifyChatReqVO difyChatReqVO = new DifyChatReqVO();
difyChatReqVO.setUser("admin");
DIFYChatReqInputVO inputs = new DIFYChatReqInputVO();
difyChatReqVO.setQuery(query);
// difyChatReqVO.setQuery("尽可能详细的介绍一下勐赫小镇的医疗服务");
difyChatReqVO.setInputs(inputs);
StringBuilder sentence = new StringBuilder();
log.info("query:{}", query);
return webClient.post()
.uri(difyUrl)
.headers(httpHeaders -> {
httpHeaders.setContentType(MediaType.APPLICATION_JSON);
httpHeaders.setBearerAuth(difyAppAuth);
})
.bodyValue(JSON.toJSONString(difyChatReqVO))
.retrieve()
.bodyToFlux(StreamResponse.class)
.map(response -> {
Map<String, String> map = new HashMap<>();
map.put("event", response.getEvent());
if (response.getEvent().equals("message") && response.getAnswer() != null) {
//遍历answer中的每一个字符判断是否为标点符号如果是说明是句子的结尾将标点符号前的文本拼接到sentence中并打印然后清空sentence如果标点符号后还有文本将文本拼接到sentence中
for (char ch : response.getAnswer().toCharArray()) {
sentence.append(ch);
if (ch == '。' || ch == '' || ch == '' || ch == '' || ch == '、' || ch == '' || ch == '' || ch == '“' || ch == '”') { // Check for punctuation marks
log.info(sentence.toString());
TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(sentence.toString());
String voiceBaseId = UUID.randomUUID().toString();
voiceCache.put(voiceBaseId, ttsResultDTO.getAudio());
map.put("audioId", voiceBaseId);
sentence.setLength(0); // Clear the sentence
return ServerSentEvent.builder(map).build();
}
}
if (response.getEvent().equals("message_end") && !sentence.isEmpty()) {
log.info(sentence.toString());
TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(sentence.toString());
String voiceBaseId = UUID.randomUUID().toString();
voiceCache.put(voiceBaseId, ttsResultDTO.getAudio());
map.put("audioId", voiceBaseId);
return ServerSentEvent.builder(map).build();
}
}
return ServerSentEvent.builder(map).build();
});
}
@Override
public String asr(MultipartFile file) throws IOException {
return replaceTown(AsrUtil.asrTransformByBytes(file.getBytes()));
}
@Override
public RobotTalkDTO talk(MultipartFile file, RobotTalkReq robotTalkReq) {
log.info("robotTalkReq:{}", robotTalkReq);
RobotTalkDTO.RobotTalkDTOBuilder builder = RobotTalkDTO.builder();
try {
byte[] bytes = file.getBytes();
StopWatch stopWatch = new StopWatch();
DifyChatReqVO difyChatReqVO = new DifyChatReqVO();
difyChatReqVO.setUser("admin");
DIFYChatReqInputVO inputs = new DIFYChatReqInputVO();
stopWatch.start("stt");
stopWatch.stop();
difyChatReqVO.setQuery(replaceTown(AsrUtil.asrTransformByBytes(bytes)));
difyChatReqVO.setConversation_id(robotTalkReq.getSessionId());
stopWatch.start("dify");
ChatResDTO chatResDTO = difyApiUtil.chat(difyChatReqVO);
stopWatch.stop();
log.info("response:{}", chatResDTO.getAnswer());
builder.askInfo(AskInfo.builder().contentType(2).message(inputs.getQuery()).audioLength(100L).askId(chatResDTO.getMessage_id()).build());
voiceCache.put(chatResDTO.getMessage_id(), Base64.encode(bytes));
stopWatch.start("tts");
TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(chatResDTO.getAnswer());
stopWatch.stop();
String voiceBaseId = UUID.randomUUID().toString();
builder.answerInfo(AnswerInfo.builder().contentType(2).message(chatResDTO.getAnswer()).voiceBaseId(voiceBaseId).voiceBase64(ttsResultDTO.getAudio()).build());
builder.sessionId(chatResDTO.getConversation_id());
voiceCache.put(voiceBaseId, ttsResultDTO.getAudio());
log.info("耗时:{}", stopWatch.prettyPrint());
} catch (IOException e) {
throw new RuntimeException(e);
}
return builder.build();
}
@Override
public void getAudio(HttpServletResponse response, String audioId) throws IOException {
log.info("audioId:{}", audioId);
Base64.decodeToStream(voiceCache.get(audioId), response.getOutputStream(), false);
}
/**
* 检索字符串中的“小镇”并判断“小镇”前的两个汉字的拼音是否为“menghe”
* 如果为true则替换这四个汉字为“勐赫小镇”返回更新后的字符串。
*
* @param text 输入字符串
* @return 更新后的字符串
*/
public static String replaceTown(String text) {
// 正则模式:匹配两个汉字紧跟“小镇”
Pattern pattern = Pattern.compile("([\u4e00-\u9fa5]{2})小镇");
Matcher matcher = pattern.matcher(text);
StringBuffer sb = new StringBuffer();
// 配置拼音输出格式:小写、无声调
HanyuPinyinOutputFormat format = new HanyuPinyinOutputFormat();
format.setCaseType(HanyuPinyinCaseType.LOWERCASE);
format.setToneType(HanyuPinyinToneType.WITHOUT_TONE);
while (matcher.find()) {
String twoChars = matcher.group(1);
StringBuilder pinyinStr = new StringBuilder();
boolean valid = true;
// 遍历两个汉字,转换为拼音并拼接
for (int i = 0; i < twoChars.length(); i++) {
char ch = twoChars.charAt(i);
try {
String[] pinyinArray = PinyinHelper.toHanyuPinyinStringArray(ch, format);
if (pinyinArray != null && pinyinArray.length > 0) {
pinyinStr.append(pinyinArray[0]);
} else {
valid = false;
break;
}
} catch (BadHanyuPinyinOutputFormatCombination e) {
valid = false;
break;
}
}
// 如果转换后的拼音为 "menghe",则替换为 "勐赫小镇"
if (valid && "menghe".contentEquals(pinyinStr)) {
matcher.appendReplacement(sb, "勐赫小镇");
} else {
matcher.appendReplacement(sb, matcher.group(0));
}
}
matcher.appendTail(sb);
return sb.toString();
}
// 示例测试
public static void main(String[] args) {
String sampleText = "欢迎来到孟河小镇,体验独特的小镇风情;另外,还有梦和小镇等待你探访。";
System.out.println(replaceTown(sampleText));
}
}