diff --git a/src/main/java/com/supervision/service/impl/ChatServiceImpl.java b/src/main/java/com/supervision/service/impl/ChatServiceImpl.java index 6e1ae66..4877c75 100644 --- a/src/main/java/com/supervision/service/impl/ChatServiceImpl.java +++ b/src/main/java/com/supervision/service/impl/ChatServiceImpl.java @@ -19,6 +19,7 @@ import com.supervision.service.IChatService; import com.supervision.util.AsrUtil; import com.supervision.util.DifyApiUtil; import com.supervision.util.TtsUtil; +import com.supervision.util.WavUtil; import jakarta.servlet.http.HttpServletResponse; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; @@ -103,9 +104,9 @@ public class ChatServiceImpl implements IChatService { log.info(sentence.toString()); TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(sentence.toString()); String voiceBaseId = UUID.randomUUID().toString(); + audioList.add(voiceBaseId); audioCache.put(voiceBaseId, ttsResultDTO.getAudio()); map.put("audioId", voiceBaseId); - audioList.add(voiceBaseId); sentence.setLength(0); return ServerSentEvent.builder(map).build(); } @@ -121,9 +122,12 @@ public class ChatServiceImpl implements IChatService { audioCache.put(voiceBaseId, ttsResultDTO.getAudio()); map.put("audioId", voiceBaseId); } - String fullAnswer = audioList.stream().map(audioId -> map.get("audioId")).filter(Objects::nonNull).collect(Collectors.joining()); + String uuid = cn.hutool.core.lang.UUID.randomUUID().toString(); + List collect = audioList.stream().map(audioCache::get).filter(Objects::nonNull).collect(Collectors.toList()); + String fullAnswer = WavUtil.mergeWavFilesToBase64(collect); audioCache.put(uuid,fullAnswer); + builder.answerInfo(AnswerInfo.builder().contentType(2).message(fullAskString.toString()).voiceBaseId(uuid).build()); builder.sessionId(response.getConversation_id()); builder.askInfo(AskInfo.builder().contentType(2).message(query).audioLength(100L).build()); diff --git a/src/main/java/com/supervision/util/WavUtil.java b/src/main/java/com/supervision/util/WavUtil.java new file mode 100644 index 0000000..8fb4f0f --- /dev/null +++ b/src/main/java/com/supervision/util/WavUtil.java @@ -0,0 +1,134 @@ +package com.supervision.util; +import java.io.ByteArrayOutputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Base64; +import java.util.List; + +public class WavUtil { + + public static String mergeWavFilesToBase64(List base64WavList) { + if (base64WavList == null || base64WavList.isEmpty()) { + throw new IllegalArgumentException("No WAV files provided"); + } + + List allAudioData = new ArrayList<>(); + + // 处理第一个文件,提取fmt和数据 + byte[] firstWav = Base64.getDecoder().decode(base64WavList.get(0)); + WavFormat format = parseWav(firstWav); + byte[] fmtChunk = format.fmtChunk; + allAudioData.add(format.audioData); + + // 处理剩余文件 + for (int i = 1; i < base64WavList.size(); i++) { + byte[] wavBytes = Base64.getDecoder().decode(base64WavList.get(i)); + WavFormat currentFormat = parseWav(wavBytes); + if (!currentFormat.equals(format)) { + throw new IllegalArgumentException("WAV format mismatch"); + } + allAudioData.add(currentFormat.audioData); + } + + // 合并音频数据 + int totalDataSize = allAudioData.stream().mapToInt(a -> a.length).sum(); + ByteArrayOutputStream mergedData = new ByteArrayOutputStream(); + allAudioData.forEach(data -> mergedData.write(data, 0, data.length)); + byte[] mergedDataBytes = mergedData.toByteArray(); + + // 构建新WAV文件 + ByteArrayOutputStream output = new ByteArrayOutputStream(); + // RIFF头 + output.writeBytes("RIFF".getBytes(StandardCharsets.US_ASCII)); + // RIFF大小:4 (WAVE) + fmt块长度 + 8 (data头) + 数据长度 + int riffSize = 4 + fmtChunk.length + 8 + totalDataSize; + writeLittleEndianInt(output, riffSize); + output.writeBytes("WAVE".getBytes(StandardCharsets.US_ASCII)); + // fmt块 + output.writeBytes(fmtChunk); + // data块 + output.writeBytes("data".getBytes(StandardCharsets.US_ASCII)); + writeLittleEndianInt(output, totalDataSize); + output.writeBytes(mergedDataBytes); + + // Base64编码 + return Base64.getEncoder().encodeToString(output.toByteArray()); + } + + private static WavFormat parseWav(byte[] wavBytes) { + int pos = 12; // 跳过RIFF头 + byte[] fmtChunk = null; + byte[] audioData = null; + WavFormat format = new WavFormat(); + + while (pos < wavBytes.length - 8) { + String chunkId = new String(wavBytes, pos, 4, StandardCharsets.US_ASCII); + int chunkSize = readLittleEndianInt(wavBytes, pos + 4); + + if (chunkId.equals("fmt ")) { + fmtChunk = Arrays.copyOfRange(wavBytes, pos, pos + 8 + chunkSize); + format.audioFormat = readLittleEndianShort(wavBytes, pos + 8); + format.channels = readLittleEndianShort(wavBytes, pos + 10); + format.sampleRate = readLittleEndianInt(wavBytes, pos + 12); + format.byteRate = readLittleEndianInt(wavBytes, pos + 16); + format.blockAlign = readLittleEndianShort(wavBytes, pos + 20); + format.bitsPerSample = readLittleEndianShort(wavBytes, pos + 22); + format.fmtChunk = fmtChunk; + pos += 8 + chunkSize; + } else if (chunkId.equals("data")) { + audioData = Arrays.copyOfRange(wavBytes, pos + 8, pos + 8 + chunkSize); + format.audioData = audioData; + pos += 8 + chunkSize; + } else { + pos += 8 + chunkSize; + } + + // 处理填充字节(如果块大小为奇数) + if (chunkSize % 2 != 0) { + pos++; + } + } + + if (fmtChunk == null || audioData == null) { + throw new IllegalArgumentException("Invalid WAV file"); + } + return format; + } + + private static int readLittleEndianInt(byte[] bytes, int offset) { + return (bytes[offset] & 0xFF) | ((bytes[offset + 1] & 0xFF) << 8) + | ((bytes[offset + 2] & 0xFF) << 16) | ((bytes[offset + 3] & 0xFF) << 24); + } + + private static short readLittleEndianShort(byte[] bytes, int offset) { + return (short) ((bytes[offset] & 0xFF) | ((bytes[offset + 1] & 0xFF) << 8)); + } + + private static void writeLittleEndianInt(ByteArrayOutputStream out, int value) { + out.write((value) & 0xFF); + out.write((value >> 8) & 0xFF); + out.write((value >> 16) & 0xFF); + out.write((value >> 24) & 0xFF); + } + + static class WavFormat { + byte[] fmtChunk; + byte[] audioData; + int audioFormat; + int channels; + int sampleRate; + int byteRate; + int blockAlign; + int bitsPerSample; + + boolean equals(WavFormat other) { + return audioFormat == other.audioFormat && + channels == other.channels && + sampleRate == other.sampleRate && + byteRate == other.byteRate && + blockAlign == other.blockAlign && + bitsPerSample == other.bitsPerSample; + } + } +}