添加对话列表

main
xueqingkun 2 months ago
parent 2e170833de
commit 139c11cb1c

@ -19,6 +19,7 @@ import com.supervision.service.IChatService;
import com.supervision.util.AsrUtil;
import com.supervision.util.DifyApiUtil;
import com.supervision.util.TtsUtil;
import com.supervision.util.WavUtil;
import jakarta.servlet.http.HttpServletResponse;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
@ -103,9 +104,9 @@ public class ChatServiceImpl implements IChatService {
log.info(sentence.toString());
TtsResultDTO ttsResultDTO = TtsUtil.ttsTransform(sentence.toString());
String voiceBaseId = UUID.randomUUID().toString();
audioList.add(voiceBaseId);
audioCache.put(voiceBaseId, ttsResultDTO.getAudio());
map.put("audioId", voiceBaseId);
audioList.add(voiceBaseId);
sentence.setLength(0);
return ServerSentEvent.builder(map).build();
}
@ -121,9 +122,12 @@ public class ChatServiceImpl implements IChatService {
audioCache.put(voiceBaseId, ttsResultDTO.getAudio());
map.put("audioId", voiceBaseId);
}
String fullAnswer = audioList.stream().map(audioId -> map.get("audioId")).filter(Objects::nonNull).collect(Collectors.joining());
String uuid = cn.hutool.core.lang.UUID.randomUUID().toString();
List<String> collect = audioList.stream().map(audioCache::get).filter(Objects::nonNull).collect(Collectors.toList());
String fullAnswer = WavUtil.mergeWavFilesToBase64(collect);
audioCache.put(uuid,fullAnswer);
builder.answerInfo(AnswerInfo.builder().contentType(2).message(fullAskString.toString()).voiceBaseId(uuid).build());
builder.sessionId(response.getConversation_id());
builder.askInfo(AskInfo.builder().contentType(2).message(query).audioLength(100L).build());

@ -0,0 +1,134 @@
package com.supervision.util;
import java.io.ByteArrayOutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
public class WavUtil {
public static String mergeWavFilesToBase64(List<String> base64WavList) {
if (base64WavList == null || base64WavList.isEmpty()) {
throw new IllegalArgumentException("No WAV files provided");
}
List<byte[]> allAudioData = new ArrayList<>();
// 处理第一个文件提取fmt和数据
byte[] firstWav = Base64.getDecoder().decode(base64WavList.get(0));
WavFormat format = parseWav(firstWav);
byte[] fmtChunk = format.fmtChunk;
allAudioData.add(format.audioData);
// 处理剩余文件
for (int i = 1; i < base64WavList.size(); i++) {
byte[] wavBytes = Base64.getDecoder().decode(base64WavList.get(i));
WavFormat currentFormat = parseWav(wavBytes);
if (!currentFormat.equals(format)) {
throw new IllegalArgumentException("WAV format mismatch");
}
allAudioData.add(currentFormat.audioData);
}
// 合并音频数据
int totalDataSize = allAudioData.stream().mapToInt(a -> a.length).sum();
ByteArrayOutputStream mergedData = new ByteArrayOutputStream();
allAudioData.forEach(data -> mergedData.write(data, 0, data.length));
byte[] mergedDataBytes = mergedData.toByteArray();
// 构建新WAV文件
ByteArrayOutputStream output = new ByteArrayOutputStream();
// RIFF头
output.writeBytes("RIFF".getBytes(StandardCharsets.US_ASCII));
// RIFF大小4 (WAVE) + fmt块长度 + 8 (data头) + 数据长度
int riffSize = 4 + fmtChunk.length + 8 + totalDataSize;
writeLittleEndianInt(output, riffSize);
output.writeBytes("WAVE".getBytes(StandardCharsets.US_ASCII));
// fmt块
output.writeBytes(fmtChunk);
// data块
output.writeBytes("data".getBytes(StandardCharsets.US_ASCII));
writeLittleEndianInt(output, totalDataSize);
output.writeBytes(mergedDataBytes);
// Base64编码
return Base64.getEncoder().encodeToString(output.toByteArray());
}
private static WavFormat parseWav(byte[] wavBytes) {
int pos = 12; // 跳过RIFF头
byte[] fmtChunk = null;
byte[] audioData = null;
WavFormat format = new WavFormat();
while (pos < wavBytes.length - 8) {
String chunkId = new String(wavBytes, pos, 4, StandardCharsets.US_ASCII);
int chunkSize = readLittleEndianInt(wavBytes, pos + 4);
if (chunkId.equals("fmt ")) {
fmtChunk = Arrays.copyOfRange(wavBytes, pos, pos + 8 + chunkSize);
format.audioFormat = readLittleEndianShort(wavBytes, pos + 8);
format.channels = readLittleEndianShort(wavBytes, pos + 10);
format.sampleRate = readLittleEndianInt(wavBytes, pos + 12);
format.byteRate = readLittleEndianInt(wavBytes, pos + 16);
format.blockAlign = readLittleEndianShort(wavBytes, pos + 20);
format.bitsPerSample = readLittleEndianShort(wavBytes, pos + 22);
format.fmtChunk = fmtChunk;
pos += 8 + chunkSize;
} else if (chunkId.equals("data")) {
audioData = Arrays.copyOfRange(wavBytes, pos + 8, pos + 8 + chunkSize);
format.audioData = audioData;
pos += 8 + chunkSize;
} else {
pos += 8 + chunkSize;
}
// 处理填充字节(如果块大小为奇数)
if (chunkSize % 2 != 0) {
pos++;
}
}
if (fmtChunk == null || audioData == null) {
throw new IllegalArgumentException("Invalid WAV file");
}
return format;
}
private static int readLittleEndianInt(byte[] bytes, int offset) {
return (bytes[offset] & 0xFF) | ((bytes[offset + 1] & 0xFF) << 8)
| ((bytes[offset + 2] & 0xFF) << 16) | ((bytes[offset + 3] & 0xFF) << 24);
}
private static short readLittleEndianShort(byte[] bytes, int offset) {
return (short) ((bytes[offset] & 0xFF) | ((bytes[offset + 1] & 0xFF) << 8));
}
private static void writeLittleEndianInt(ByteArrayOutputStream out, int value) {
out.write((value) & 0xFF);
out.write((value >> 8) & 0xFF);
out.write((value >> 16) & 0xFF);
out.write((value >> 24) & 0xFF);
}
static class WavFormat {
byte[] fmtChunk;
byte[] audioData;
int audioFormat;
int channels;
int sampleRate;
int byteRate;
int blockAlign;
int bitsPerSample;
boolean equals(WavFormat other) {
return audioFormat == other.audioFormat &&
channels == other.channels &&
sampleRate == other.sampleRate &&
byteRate == other.byteRate &&
blockAlign == other.blockAlign &&
bitsPerSample == other.bitsPerSample;
}
}
}
Loading…
Cancel
Save