From 567fbc106bfe2ded15836909cb2df973ed1d889b Mon Sep 17 00:00:00 2001 From: lipku Date: Sat, 5 Jul 2025 13:08:12 +0800 Subject: [PATCH] add doubao tts --- README.md | 1 + basereal.py | 6 +- requirements.txt | 1 + ttsreal.py | 141 +++++++++++++++++++++++++++++++++++++++++++++-- 4 files changed, 142 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 5230d94..a3e5f5d 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ - 2025.3.16 支持mac gpu推理,感谢[@GcsSloop](https://github.com/GcsSloop) - 2025.5.1 精简运行参数,ernerf模型移至git分支ernerf-rtmp - 2025.6.7 添加虚拟摄像头输出 +- 2025.7.5 添加豆包语音合成, 感谢[@ELK-milu](https://github.com/ELK-milu) ## Features 1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human diff --git a/basereal.py b/basereal.py index aacfa42..3f8b245 100644 --- a/basereal.py +++ b/basereal.py @@ -38,7 +38,7 @@ from av import AudioFrame, VideoFrame import av from fractions import Fraction -from ttsreal import EdgeTTS,SovitsTTS,XTTS,CosyVoiceTTS,FishTTS,TencentTTS +from ttsreal import EdgeTTS,SovitsTTS,XTTS,CosyVoiceTTS,FishTTS,TencentTTS,DoubaoTTS from logger import logger from tqdm import tqdm @@ -86,6 +86,8 @@ class BaseReal: self.tts = FishTTS(opt,self) elif opt.tts == "tencent": self.tts = TencentTTS(opt,self) + elif opt.tts == "doubao": + self.tts = DoubaoTTS(opt,self) self.speaking = False @@ -363,6 +365,7 @@ class BaseReal: else: combine_frame = current_frame + cv2.putText(combine_frame, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1) if self.opt.transport=='virtualcam': if vircam==None: height, width,_= combine_frame.shape @@ -370,7 +373,6 @@ class BaseReal: vircam.send(combine_frame) else: #webrtc image = combine_frame - image[0,:] &= 0xFE new_frame = VideoFrame.from_ndarray(image, format="bgr24") asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) self.record_video_data(combine_frame) diff --git a/requirements.txt b/requirements.txt index 9b5d366..b89493d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,3 +41,4 @@ accelerate librosa openai +websockets==12.0 diff --git a/ttsreal.py b/ttsreal.py index 280440a..186ff59 100644 --- a/ttsreal.py +++ b/ttsreal.py @@ -36,6 +36,8 @@ import requests import queue from queue import Queue from io import BytesIO +import copy,websockets,gzip + from threading import Thread, Event from enum import Enum @@ -233,11 +235,11 @@ class SovitsTTS(BaseTTS): text,textevent = msg self.stream_tts( self.gpt_sovits( - text, - self.opt.REF_FILE, - self.opt.REF_TEXT, - "zh", #en args.language, - self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, + text=text, + reffile=self.opt.REF_FILE, + reftext=self.opt.REF_TEXT, + language="zh", #en args.language, + server_url=self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, ), msg ) @@ -516,6 +518,135 @@ class TencentTTS(BaseTTS): ########################################################################################### + +class DoubaoTTS(BaseTTS): + def __init__(self, opt, parent): + super().__init__(opt, parent) + # 从配置中读取火山引擎参数 + self.appid = os.getenv("DOUBAO_APPID") + self.token = os.getenv("DOUBAO_TOKEN") + _cluster = 'volcano_tts' + _host = "openspeech.bytedance.com" + self.api_url = f"wss://{_host}/api/v1/tts/ws_binary" + + self.request_json = { + "app": { + "appid": self.appid, + "token": "access_token", + "cluster": _cluster + }, + "user": { + "uid": "xxx" + }, + "audio": { + "voice_type": "xxx", + "encoding": "pcm", + "rate": 16000, + "speed_ratio": 1.0, + "volume_ratio": 1.0, + "pitch_ratio": 1.0, + }, + "request": { + "reqid": "xxx", + "text": "字节跳动语音合成。", + "text_type": "plain", + "operation": "xxx" + } + } + + async def doubao_voice(self, text): # -> Iterator[bytes]: + start = time.perf_counter() + voice_type = self.opt.REF_FILE + + try: + # 创建请求对象 + default_header = bytearray(b'\x11\x10\x11\x00') + submit_request_json = copy.deepcopy(self.request_json) + submit_request_json["user"]["uid"] = self.parent.sessionid + submit_request_json["audio"]["voice_type"] = voice_type + submit_request_json["request"]["text"] = text + submit_request_json["request"]["reqid"] = str(uuid.uuid4()) + submit_request_json["request"]["operation"] = "submit" + payload_bytes = str.encode(json.dumps(submit_request_json)) + payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line + full_client_request = bytearray(default_header) + full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + full_client_request.extend(payload_bytes) # payload + + header = {"Authorization": f"Bearer; {self.token}"} + first = True + async with websockets.connect(self.api_url, extra_headers=header, ping_interval=None) as ws: + await ws.send(full_client_request) + while True: + res = await ws.recv() + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + payload = res[header_size*4:] + + if message_type == 0xb: # audio-only server response + if message_type_specific_flags == 0: # no sequence number as ACK + #print(" Payload size: 0") + continue + else: + if first: + end = time.perf_counter() + logger.info(f"doubao tts Time to first chunk: {end-start}s") + first = False + sequence_number = int.from_bytes(payload[:4], "big", signed=True) + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload = payload[8:] + yield payload + if sequence_number < 0: + break + else: + break + except Exception as e: + logger.exception('doubao') + # # 检查响应状态码 + # if response.status_code == 200: + # # 处理响应数据 + # audio_data = base64.b64decode(response.json().get('data')) + # yield audio_data + # else: + # logger.error(f"请求失败,状态码: {response.status_code}") + # return + + def txt_to_audio(self, msg): + text, textevent = msg + asyncio.new_event_loop().run_until_complete( + self.stream_tts( + self.doubao_voice(text), + msg + ) + ) + + async def stream_tts(self, audio_stream, msg): + text, textevent = msg + first = True + last_stream = np.array([],dtype=np.float32) + async for chunk in audio_stream: + if chunk is not None and len(chunk) > 0: + stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 + stream = np.concatenate((last_stream,stream)) + #stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) + # byte_stream=BytesIO(buffer) + # stream = self.__create_bytes_stream(byte_stream) + streamlen = stream.shape[0] + idx = 0 + while streamlen >= self.chunk: + eventpoint = None + if first: + eventpoint = {'status': 'start', 'text': text, 'msgenvent': textevent} + first = False + self.parent.put_audio_frame(stream[idx:idx + self.chunk], eventpoint) + streamlen -= self.chunk + idx += self.chunk + last_stream = stream[idx:] #get the remain stream + eventpoint = {'status': 'end', 'text': text, 'msgenvent': textevent} + self.parent.put_audio_frame(np.zeros(self.chunk, np.float32), eventpoint) + +########################################################################################### class XTTS(BaseTTS): def __init__(self, opt, parent): super().__init__(opt,parent)