diff --git a/pom.xml b/pom.xml index 62a2a3f..5edcd75 100644 --- a/pom.xml +++ b/pom.xml @@ -64,6 +64,11 @@ pinyin4j 2.5.0 + + org.springframework.boot + spring-boot-starter-webflux + + diff --git a/src/main/java/com/supervision/SpeechDemoServiceApplication.java b/src/main/java/com/supervision/SpeechDemoServiceApplication.java index c02a086..c929887 100644 --- a/src/main/java/com/supervision/SpeechDemoServiceApplication.java +++ b/src/main/java/com/supervision/SpeechDemoServiceApplication.java @@ -2,6 +2,9 @@ package com.supervision; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.reactive.function.client.WebClient; @SpringBootApplication public class SpeechDemoServiceApplication { @@ -9,5 +12,14 @@ public class SpeechDemoServiceApplication { public static void main(String[] args) { SpringApplication.run(SpeechDemoServiceApplication.class, args); } + @Bean + public RestTemplate restTemplate() { + return new RestTemplate(); + } + + @Bean + public WebClient webClient() { + return WebClient.create(); + } } diff --git a/src/main/java/com/supervision/contoller/ChatController.java b/src/main/java/com/supervision/contoller/ChatController.java index 62731cc..42ddb47 100644 --- a/src/main/java/com/supervision/contoller/ChatController.java +++ b/src/main/java/com/supervision/contoller/ChatController.java @@ -1,18 +1,28 @@ package com.supervision.contoller; +import cn.hutool.core.io.resource.ResourceUtil; +import cn.hutool.core.util.StrUtil; import com.supervision.dto.R; import com.supervision.dto.robot.RobotTalkDTO; import com.supervision.model.RobotTalkReq; +import com.supervision.model.dify.StreamResponse; +import com.supervision.service.DifyService; import com.supervision.service.IChatService; import jakarta.servlet.http.HttpServletResponse; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.bind.annotation.*; import org.springframework.web.multipart.MultipartFile; +import reactor.core.publisher.Flux; import java.io.IOException; +import java.time.Duration; import java.util.ArrayList; import java.util.List; +import java.util.Map; + @Slf4j @RestController @RequestMapping("/chat") @@ -20,6 +30,7 @@ import java.util.List; public class ChatController { private final IChatService chatService; + private final DifyService difyService; @PostMapping("/talk") public R talk(@RequestParam("file") MultipartFile multipartFile, @ModelAttribute RobotTalkReq robotTalkReq) { @@ -27,6 +38,11 @@ public class ChatController { return R.ok(talk); } + @PostMapping("/asr") + public R asr(@RequestParam("file") MultipartFile file) throws IOException { + return R.ok(chatService.asr(file)); + } + @GetMapping("/talkList") public List talkList(String sessionId) { return new ArrayList(); @@ -37,4 +53,16 @@ public class ChatController { @RequestParam("audioId")String audioId) throws IOException { chatService.getAudio(response,audioId); } + + @GetMapping(value="/stream",produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux>> test2(@RequestParam("query") String query) { + return chatService.streamingMessage(query); + } + + @GetMapping(value = "/webflux",produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public Flux chatWebFlux() { + String string = ResourceUtil.readUtf8Str("classpath:static/test.txt"); + String[] stringArray = StrUtil.split(string, 1); + return Flux.just(stringArray).delayElements(Duration.ofMillis(50)); + } } diff --git a/src/main/java/com/supervision/model/dify/DifyChatReqVO.java b/src/main/java/com/supervision/model/dify/DifyChatReqVO.java index d61d541..9d35bf1 100644 --- a/src/main/java/com/supervision/model/dify/DifyChatReqVO.java +++ b/src/main/java/com/supervision/model/dify/DifyChatReqVO.java @@ -2,12 +2,12 @@ package com.supervision.model.dify; import lombok.Data; -import static com.supervision.common.constant.DifyConstants.CHAT_RESPONSE_MODE_BLOCKING; +import static com.supervision.common.constant.DifyConstants.CHAT_RESPONSE_MODE_STREAMING; @Data public class DifyChatReqVO { private String user; - private String response_mode = CHAT_RESPONSE_MODE_BLOCKING; + private String response_mode = CHAT_RESPONSE_MODE_STREAMING; private DIFYChatReqInputVO inputs = new DIFYChatReqInputVO(); private String query; private String conversation_id; diff --git a/src/main/java/com/supervision/model/dify/OutputsData.java b/src/main/java/com/supervision/model/dify/OutputsData.java new file mode 100644 index 0000000..6d4e304 --- /dev/null +++ b/src/main/java/com/supervision/model/dify/OutputsData.java @@ -0,0 +1,10 @@ +package com.supervision.model.dify; + +import lombok.Data; + +import java.io.Serializable; + +@Data +public class OutputsData implements Serializable { + private String answer; +} \ No newline at end of file diff --git a/src/main/java/com/supervision/model/dify/StreamResponse.java b/src/main/java/com/supervision/model/dify/StreamResponse.java new file mode 100644 index 0000000..a8e8726 --- /dev/null +++ b/src/main/java/com/supervision/model/dify/StreamResponse.java @@ -0,0 +1,44 @@ +package com.supervision.model.dify; + +import lombok.Data; + +import java.io.Serializable; + +@Data +public class StreamResponse implements Serializable { + + /** + * 不同模式下的事件类型. + */ + private String event; + + /** + * agent_thought id. + */ + private String id; + + /** + * 任务ID. + */ + private String taskId; + + /** + * 消息唯一ID. + */ + private String messageId; + + /** + * LLM 返回文本块内容. + */ + private String answer; + + /** + * 创建时间戳. + */ + private Long createdAt; + + /** + * 会话 ID. + */ + private String conversationId; +} \ No newline at end of file diff --git a/src/main/java/com/supervision/model/dify/StreamResponseData.java b/src/main/java/com/supervision/model/dify/StreamResponseData.java new file mode 100644 index 0000000..a067c03 --- /dev/null +++ b/src/main/java/com/supervision/model/dify/StreamResponseData.java @@ -0,0 +1,15 @@ +package com.supervision.model.dify; + +import lombok.Data; + +import java.io.Serializable; + +@Data +public class StreamResponseData implements Serializable { + private String id; + private String workflow_id; + private String status; + private Long created_at; + private Long finished_at; + private OutputsData outputs; +} \ No newline at end of file diff --git a/src/main/java/com/supervision/service/DifyService.java b/src/main/java/com/supervision/service/DifyService.java new file mode 100644 index 0000000..a0958cf --- /dev/null +++ b/src/main/java/com/supervision/service/DifyService.java @@ -0,0 +1,88 @@ +package com.supervision.service; + +import com.alibaba.fastjson.JSON; +import com.supervision.dto.paddlespeech.res.TtsResultDTO; +import com.supervision.dto.robot.AnswerInfo; +import com.supervision.model.dify.DIFYChatReqInputVO; +import com.supervision.model.dify.DifyChatReqVO; +import com.supervision.model.dify.StreamResponse; +import com.supervision.util.TtsUtil; +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.stereotype.Service; +import org.springframework.web.client.RestTemplate; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; + +import java.util.HashMap; +import java.util.Map; +import java.util.UUID; + +@Slf4j +@Service +@RequiredArgsConstructor +public class DifyService { + @Value("${dify.url}") + private String difyUrl; + @Value("${dify.app-auth}") + private String difyAppAuth; + private final WebClient webClient; + Map voiceCache = new HashMap<>(); + + /** + * 流式调用dify. + * + * @param query 查询文本 + * @return Flux 响应流 + */ + public Flux>> streamingMessage(String query) { + DifyChatReqVO difyChatReqVO = new DifyChatReqVO(); + difyChatReqVO.setUser("admin"); + DIFYChatReqInputVO inputs = new DIFYChatReqInputVO(); + difyChatReqVO.setQuery(query); +// difyChatReqVO.setQuery("尽可能详细的介绍一下勐赫小镇的医疗服务"); + difyChatReqVO.setInputs(inputs); + StringBuilder sentence = new StringBuilder(); + + return webClient.post() + .uri(difyUrl) + .headers(httpHeaders -> { + httpHeaders.setContentType(MediaType.APPLICATION_JSON); + httpHeaders.setBearerAuth(difyAppAuth); + }) + .bodyValue(JSON.toJSONString(difyChatReqVO)) + .retrieve() + .bodyToFlux(StreamResponse.class) + .map(response -> { + Map map = new HashMap<>(); + map.put("event", response.getEvent()); + if (response.getEvent().equals("message") && response.getAnswer() != null) { + TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(response.getAnswer()); + String voiceBaseId = UUID.randomUUID().toString(); + voiceCache.put(voiceBaseId, ttsResultDTO.getAudio()); + map.put("audioId", voiceBaseId); + } + return ServerSentEvent.builder(map).build(); + +// if (response.getEvent().equals("message") && response.getAnswer() != null) { +// //遍历answer中的每一个字符,判断是否为标点符号,如果是,说明是句子的结尾,将标点符号前的文本拼接到sentence中,并打印,然后清空sentence,如果标点符号后还有文本,将文本拼接到sentence中 +// for (char ch : response.getAnswer().toCharArray()) { +// sentence.append(ch); +// if (ch == '。' || ch == '!' || ch == '?' || ch == ',' || ch == '、' || ch == '‘' || ch == '’' || ch == '“' || ch == '”') { // Check for punctuation marks +// log.info(sentence.toString()); +// sentence.setLength(0); // Clear the sentence +// } +// } +// } +// if (response.getEvent().equals("message_end") && !sentence.isEmpty()) { +// log.info(sentence.toString()); +// } +// Map map; +// map = Map.of("event", response.getEvent(), "audioId", sentence.toString()); +// return ServerSentEvent.builder(map).build(); + }); + } +} diff --git a/src/main/java/com/supervision/service/IChatService.java b/src/main/java/com/supervision/service/IChatService.java index ed0e329..3e16e8e 100644 --- a/src/main/java/com/supervision/service/IChatService.java +++ b/src/main/java/com/supervision/service/IChatService.java @@ -3,11 +3,18 @@ package com.supervision.service; import com.supervision.dto.robot.RobotTalkDTO; import com.supervision.model.RobotTalkReq; import jakarta.servlet.http.HttpServletResponse; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.web.multipart.MultipartFile; +import reactor.core.publisher.Flux; import java.io.IOException; +import java.util.Map; public interface IChatService { + Flux>> streamingMessage(String query); + + String asr(MultipartFile file) throws IOException; + RobotTalkDTO talk(MultipartFile file, RobotTalkReq robotTalkReq); void getAudio(HttpServletResponse response, String audioId) throws IOException; diff --git a/src/main/java/com/supervision/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/service/impl/ChatServiceImpl.java index ddbd951..8bd1946 100644 --- a/src/main/java/com/supervision/service/impl/ChatServiceImpl.java +++ b/src/main/java/com/supervision/service/impl/ChatServiceImpl.java @@ -1,6 +1,7 @@ package com.supervision.service.impl; import cn.hutool.core.codec.Base64; +import com.alibaba.fastjson.JSON; import com.supervision.dto.dify.ChatResDTO; import com.supervision.dto.paddlespeech.res.TtsResultDTO; import com.supervision.dto.robot.AnswerInfo; @@ -9,21 +10,27 @@ import com.supervision.dto.robot.RobotTalkDTO; import com.supervision.model.RobotTalkReq; import com.supervision.model.dify.DIFYChatReqInputVO; import com.supervision.model.dify.DifyChatReqVO; +import com.supervision.model.dify.StreamResponse; import com.supervision.service.IChatService; import com.supervision.util.AsrUtil; import com.supervision.util.DifyApiUtil; import com.supervision.util.TtsUtil; -import jakarta.annotation.Resource; import jakarta.servlet.http.HttpServletResponse; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import net.sourceforge.pinyin4j.PinyinHelper; import net.sourceforge.pinyin4j.format.HanyuPinyinCaseType; import net.sourceforge.pinyin4j.format.HanyuPinyinOutputFormat; import net.sourceforge.pinyin4j.format.HanyuPinyinToneType; import net.sourceforge.pinyin4j.format.exception.BadHanyuPinyinOutputFormatCombination; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; import org.springframework.stereotype.Service; import org.springframework.util.StopWatch; import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; import java.io.IOException; import java.util.HashMap; @@ -34,14 +41,70 @@ import java.util.regex.Pattern; @Slf4j @Service +@RequiredArgsConstructor public class ChatServiceImpl implements IChatService { - @Resource - private DifyApiUtil difyApiUtil; - + @Value("${dify.url}") + private String difyUrl; + @Value("${dify.app-auth}") + private String difyAppAuth; + private final WebClient webClient; + private final DifyApiUtil difyApiUtil; Map voiceCache = new HashMap<>(); + @Override + public Flux>> streamingMessage(String query) { + DifyChatReqVO difyChatReqVO = new DifyChatReqVO(); + difyChatReqVO.setUser("admin"); + DIFYChatReqInputVO inputs = new DIFYChatReqInputVO(); + difyChatReqVO.setQuery(query); +// difyChatReqVO.setQuery("尽可能详细的介绍一下勐赫小镇的医疗服务"); + difyChatReqVO.setInputs(inputs); + StringBuilder sentence = new StringBuilder(); + log.info("query:{}", query); + return webClient.post() + .uri(difyUrl) + .headers(httpHeaders -> { + httpHeaders.setContentType(MediaType.APPLICATION_JSON); + httpHeaders.setBearerAuth(difyAppAuth); + }) + .bodyValue(JSON.toJSONString(difyChatReqVO)) + .retrieve() + .bodyToFlux(StreamResponse.class) + .map(response -> { + Map map = new HashMap<>(); + map.put("event", response.getEvent()); + if (response.getEvent().equals("message") && response.getAnswer() != null) { + String voiceBaseId = UUID.randomUUID().toString(); + //遍历answer中的每一个字符,判断是否为标点符号,如果是,说明是句子的结尾,将标点符号前的文本拼接到sentence中,并打印,然后清空sentence,如果标点符号后还有文本,将文本拼接到sentence中 + for (char ch : response.getAnswer().toCharArray()) { + sentence.append(ch); + if (ch == '。' || ch == '!' || ch == '?' || ch == ',' || ch == '、' || ch == '‘' || ch == '’' || ch == '“' || ch == '”') { // Check for punctuation marks + log.info(sentence.toString()); + TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(sentence.toString()); + voiceCache.put(voiceBaseId, ttsResultDTO.getAudio()); + map.put("audioId", voiceBaseId); + sentence.setLength(0); // Clear the sentence + } + } + if (response.getEvent().equals("message_end") && !sentence.isEmpty()) { + log.info(sentence.toString()); + TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(sentence.toString()); + voiceCache.put(voiceBaseId, ttsResultDTO.getAudio()); + map.put("audioId", voiceBaseId); + } + return ServerSentEvent.builder(map).build(); + } + return ServerSentEvent.builder(map).build(); + }); + } + + @Override + public String asr(MultipartFile file) throws IOException { + return replaceTown(AsrUtil.asrTransformByBytes(file.getBytes())); + } + @Override public RobotTalkDTO talk(MultipartFile file, RobotTalkReq robotTalkReq) { log.info("robotTalkReq:{}", robotTalkReq); diff --git a/src/main/java/com/supervision/util/DifyApiUtil.java b/src/main/java/com/supervision/util/DifyApiUtil.java index b489ba7..20f8141 100644 --- a/src/main/java/com/supervision/util/DifyApiUtil.java +++ b/src/main/java/com/supervision/util/DifyApiUtil.java @@ -1,8 +1,14 @@ package com.supervision.util; import cn.hutool.json.JSONUtil; +import com.alibaba.fastjson.JSON; import com.supervision.dto.dify.ChatResDTO; +import com.supervision.model.dify.DIFYChatReqInputVO; import com.supervision.model.dify.DifyChatReqVO; +import com.supervision.model.dify.StreamResponse; +import io.micrometer.common.util.StringUtils; +import jakarta.annotation.Resource; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.hc.client5.http.ClientProtocolException; import org.apache.hc.client5.http.classic.methods.HttpPost; @@ -14,18 +20,54 @@ import org.apache.hc.core5.http.HttpStatus; import org.apache.hc.core5.http.io.entity.EntityUtils; import org.apache.hc.core5.http.io.entity.StringEntity; import org.springframework.beans.factory.annotation.Value; +import org.springframework.http.MediaType; import org.springframework.stereotype.Component; +import org.springframework.stereotype.Service; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; import java.nio.charset.StandardCharsets; @Component @Slf4j +@Service public class DifyApiUtil { @Value("${dify.url}") private String difyUrl; @Value("${dify.app-auth}") private String difyAppAuth; + @Resource + private WebClient webClient; + /** + * 流式调用dify. + * + * @param query 查询文本 + * @return Flux 响应流 + */ + public Flux streamingMessage(String query) { + DifyChatReqVO difyChatReqVO = new DifyChatReqVO(); + difyChatReqVO.setUser("admin"); + DIFYChatReqInputVO inputs = new DIFYChatReqInputVO(); + difyChatReqVO.setQuery("你好"); + difyChatReqVO.setInputs(inputs); + + return webClient.post() + .uri(difyUrl) + .headers(httpHeaders -> { + httpHeaders.setContentType(MediaType.APPLICATION_JSON); + httpHeaders.setBearerAuth(difyAppAuth); + }) + .bodyValue(JSON.toJSONString(difyChatReqVO)) + .retrieve() + .bodyToFlux(StreamResponse.class); + } + + private boolean shouldInclude(StreamResponse streamResponse) { + // 示例:只要message节点的数据和message_end节点的数据 + return streamResponse.getEvent().equals("message") + || streamResponse.getEvent().equals("message_end"); + } public ChatResDTO chat(DifyChatReqVO difyChatReqVO) { ChatResDTO execute; diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 871404c..2abf7ee 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -7,7 +7,7 @@ server: context-path: /speech-demo-service dify: url: http://192.168.10.138/v1/chat-messages - app-auth: Bearer app-79ABpRQWayX0bK9C2m7vecXe + app-auth: app-79ABpRQWayX0bK9C2m7vecXe paddle-speech: tts: http://192.168.10.96:8090/paddlespeech/tts asr: http://192.168.10.96:8090/paddlespeech/asr diff --git a/src/main/resources/static/test.txt b/src/main/resources/static/test.txt new file mode 100644 index 0000000..66b1b21 --- /dev/null +++ b/src/main/resources/static/test.txt @@ -0,0 +1,3 @@ +SpringBoot+WebFlux通过流式响应实现类似ChatGPT的打字机效果 +突然间想用Java实现一下像ChatGPT一样的打字机输出效果,但是网上搜了相关教程感觉都不够满意。 +这里贴一下自己的实现,为中文互联网做一点小小的贡献 \ No newline at end of file diff --git a/src/test/java/com/supervision/SpeechDemoServiceApplicationTests.java b/src/test/java/com/supervision/SpeechDemoServiceApplicationTests.java index c424d90..c0fad56 100644 --- a/src/test/java/com/supervision/SpeechDemoServiceApplicationTests.java +++ b/src/test/java/com/supervision/SpeechDemoServiceApplicationTests.java @@ -23,7 +23,7 @@ class SpeechDemoServiceApplicationTests { @Test void testTtsTransform() { - TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform("你好,我是小爱同学我是小爱同学我是小爱同学我是小爱同学我是小爱同学我是小爱同学我是小爱同学我是小爱同学我是小爱同学我是小爱同学我是小爱同学我是小爱同学我是小爱同学"); + TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform("欢迎来电,我是您的康养顾问小苏。关于您的问题,我们已经为您查询到了相关信息。"); System.out.println(JSONUtil.toJsonStr(ttsResultDTO)); // https://www.toolfk.com/zh-cn/tools/base64-to-audio.html base64转音频