From 3d9d16a2fb354b64f5557d8894e71d5acfaf204f Mon Sep 17 00:00:00 2001 From: lipku Date: Sun, 2 Feb 2025 16:06:12 +0800 Subject: [PATCH] add eventpoint sync with audio --- baseasr.py | 20 ++++++------- basereal.py | 11 +++++--- hubertasr.py | 4 +-- lightreal.py | 11 ++++---- lipasr.py | 4 +-- lipreal.py | 11 ++++---- museasr.py | 4 +-- musereal.py | 11 ++++---- nerfasr.py | 32 ++++++++++----------- nerfreal.py | 10 +++---- ttsreal.py | 79 +++++++++++++++++++++++++++++++++++++--------------- webrtc.py | 6 +++- 12 files changed, 124 insertions(+), 79 deletions(-) diff --git a/baseasr.py b/baseasr.py index d105cee..8b353fd 100644 --- a/baseasr.py +++ b/baseasr.py @@ -47,12 +47,13 @@ class BaseASR: def flush_talk(self): self.queue.queue.clear() - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.queue.put(audio_chunk) + def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm + self.queue.put((audio_chunk,eventpoint)) + #return frame:audio pcm; type: 0-normal speak, 1-silence; eventpoint:custom event sync with audio def get_audio_frame(self): try: - frame = self.queue.get(block=True,timeout=0.01) + frame,eventpoint = self.queue.get(block=True,timeout=0.01) type = 0 #print(f'[INFO] get frame {frame.shape}') except queue.Empty: @@ -62,20 +63,19 @@ class BaseASR: else: frame = np.zeros(self.chunk, dtype=np.float32) type = 1 + eventpoint = None - return frame,type + return frame,type,eventpoint - def is_audio_frame_empty(self)->bool: - return self.queue.empty() - - def get_audio_out(self): #get origin audio pcm to nerf + #return frame:audio pcm; type: 0-normal speak, 1-silence; eventpoint:custom event sync with audio + def get_audio_out(self): return self.output_queue.get() def warm_up(self): for _ in range(self.stride_left_size + self.stride_right_size): - audio_frame,type=self.get_audio_frame() + audio_frame,type,eventpoint=self.get_audio_frame() self.frames.append(audio_frame) - self.output_queue.put((audio_frame,type)) + self.output_queue.put((audio_frame,type,eventpoint)) for _ in range(self.stride_left_size): self.output_queue.get() diff --git a/basereal.py b/basereal.py index 1a23657..5136aba 100644 --- a/basereal.py +++ b/basereal.py @@ -77,11 +77,11 @@ class BaseReal: self.custom_opt = {} self.__loadcustom() - def put_msg_txt(self,msg): - self.tts.put_msg_txt(msg) + def put_msg_txt(self,msg,eventpoint=None): + self.tts.put_msg_txt(msg,eventpoint) - def put_audio_frame(self,audio_chunk): #16khz 20ms pcm - self.asr.put_audio_frame(audio_chunk) + def put_audio_frame(self,audio_chunk,eventpoint=None): #16khz 20ms pcm + self.asr.put_audio_frame(audio_chunk,eventpoint) def put_audio_file(self,filebyte): input_stream = BytesIO(filebyte) @@ -134,6 +134,9 @@ class BaseReal: for key in self.custom_index: self.custom_index[key]=0 + def notify(self,eventpoint): + print("notify:",eventpoint) + def start_recording(self): """开始录制视频""" if self.recording: diff --git a/hubertasr.py b/hubertasr.py index 6e37e37..8fb4b22 100644 --- a/hubertasr.py +++ b/hubertasr.py @@ -18,9 +18,9 @@ class HubertASR(BaseASR): start_time = time.time() for _ in range(self.batch_size * 2): - audio_frame, type_ = self.get_audio_frame() + audio_frame, type,eventpoint = self.get_audio_frame() self.frames.append(audio_frame) - self.output_queue.put((audio_frame, type_)) + self.output_queue.put((audio_frame, type,eventpoint)) if len(self.frames) <= self.stride_left_size + self.stride_right_size: return diff --git a/lightreal.py b/lightreal.py index 9a3ec03..b79d5a3 100644 --- a/lightreal.py +++ b/lightreal.py @@ -163,8 +163,8 @@ def inference(quit_event, batch_size, face_list_cycle, audio_feat_queue, audio_o is_all_silence=True audio_frames = [] for _ in range(batch_size*2): - frame,type_ = audio_out_queue.get() - audio_frames.append((frame,type_)) + frame,type_,eventpoint = audio_out_queue.get() + audio_frames.append((frame,type_,eventpoint)) if type_==0: is_all_silence=False if is_all_silence: @@ -288,19 +288,20 @@ class LightReal(BaseReal): #print('blending time:',time.perf_counter()-t) new_frame = VideoFrame.from_ndarray(combine_frame, format="bgr24") - asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) self.record_video_data(combine_frame) for audio_frame in audio_frames: - frame,type_ = audio_frame + frame,type_,eventpoint = audio_frame frame = (frame * 32767).astype(np.int16) new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) new_frame.planes[0].update(frame.tobytes()) new_frame.sample_rate=16000 # if audio_track._queue.qsize()>10: # time.sleep(0.1) - asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) self.record_audio_data(frame) + #self.notify(eventpoint) print('lightreal process_frames thread stop') def render(self,quit_event,loop=None,audio_track=None,video_track=None): diff --git a/lipasr.py b/lipasr.py index 0868c97..b28785d 100644 --- a/lipasr.py +++ b/lipasr.py @@ -32,10 +32,10 @@ class LipASR(BaseASR): ############################################## extract audio feature ############################################## # get a frame of audio for _ in range(self.batch_size*2): - frame,type = self.get_audio_frame() + frame,type,eventpoint = self.get_audio_frame() self.frames.append(frame) # put to output - self.output_queue.put((frame,type)) + self.output_queue.put((frame,type,eventpoint)) # context not enough, do not run network. if len(self.frames) <= self.stride_left_size + self.stride_right_size: return diff --git a/lipreal.py b/lipreal.py index 847f99b..b569182 100644 --- a/lipreal.py +++ b/lipreal.py @@ -134,8 +134,8 @@ def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_q is_all_silence=True audio_frames = [] for _ in range(batch_size*2): - frame,type = audio_out_queue.get() - audio_frames.append((frame,type)) + frame,type,eventpoint = audio_out_queue.get() + audio_frames.append((frame,type,eventpoint)) if type==0: is_all_silence=False @@ -242,19 +242,20 @@ class LipReal(BaseReal): image = combine_frame #(outputs['image'] * 255).astype(np.uint8) new_frame = VideoFrame.from_ndarray(image, format="bgr24") - asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) self.record_video_data(image) for audio_frame in audio_frames: - frame,type = audio_frame + frame,type,eventpoint = audio_frame frame = (frame * 32767).astype(np.int16) new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) new_frame.planes[0].update(frame.tobytes()) new_frame.sample_rate=16000 # if audio_track._queue.qsize()>10: # time.sleep(0.1) - asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) self.record_audio_data(frame) + #self.notify(eventpoint) print('lipreal process_frames thread stop') def render(self,quit_event,loop=None,audio_track=None,video_track=None): diff --git a/museasr.py b/museasr.py index b8be556..b08f3f8 100644 --- a/museasr.py +++ b/museasr.py @@ -33,9 +33,9 @@ class MuseASR(BaseASR): ############################################## extract audio feature ############################################## start_time = time.time() for _ in range(self.batch_size*2): - audio_frame,type=self.get_audio_frame() + audio_frame,type,eventpoint = self.get_audio_frame() self.frames.append(audio_frame) - self.output_queue.put((audio_frame,type)) + self.output_queue.put((audio_frame,type,eventpoint)) if len(self.frames) <= self.stride_left_size + self.stride_right_size: return diff --git a/musereal.py b/musereal.py index ba6009b..ed88ec5 100644 --- a/musereal.py +++ b/musereal.py @@ -150,8 +150,8 @@ def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,a is_all_silence=True audio_frames = [] for _ in range(batch_size*2): - frame,type = audio_out_queue.get() - audio_frames.append((frame,type)) + frame,type,eventpoint = audio_out_queue.get() + audio_frames.append((frame,type,eventpoint)) if type==0: is_all_silence=False if is_all_silence: @@ -301,20 +301,21 @@ class MuseReal(BaseReal): image = combine_frame #(outputs['image'] * 255).astype(np.uint8) new_frame = VideoFrame.from_ndarray(image, format="bgr24") - asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) self.record_video_data(image) #self.recordq_video.put(new_frame) for audio_frame in audio_frames: - frame,type = audio_frame + frame,type,eventpoint = audio_frame frame = (frame * 32767).astype(np.int16) new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) new_frame.planes[0].update(frame.tobytes()) new_frame.sample_rate=16000 # if audio_track._queue.qsize()>10: # time.sleep(0.1) - asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) self.record_audio_data(frame) + #self.notify(eventpoint) #self.recordq_audio.put(new_frame) print('musereal process_frames thread stop') diff --git a/nerfasr.py b/nerfasr.py index 131fca1..d869ed2 100644 --- a/nerfasr.py +++ b/nerfasr.py @@ -76,20 +76,20 @@ class NerfASR(BaseASR): # warm up steps needed: mid + right + window_size + attention_size self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3 - def get_audio_frame(self): - try: - frame = self.queue.get(block=False) - type = 0 - #print(f'[INFO] get frame {frame.shape}') - except queue.Empty: - if self.parent and self.parent.curr_state>1: #播放自定义音频 - frame = self.parent.get_audio_stream(self.parent.curr_state) - type = self.parent.curr_state - else: - frame = np.zeros(self.chunk, dtype=np.float32) - type = 1 - - return frame,type + # def get_audio_frame(self): + # try: + # frame = self.queue.get(block=False) + # type = 0 + # #print(f'[INFO] get frame {frame.shape}') + # except queue.Empty: + # if self.parent and self.parent.curr_state>1: #播放自定义音频 + # frame = self.parent.get_audio_stream(self.parent.curr_state) + # type = self.parent.curr_state + # else: + # frame = np.zeros(self.chunk, dtype=np.float32) + # type = 1 + + # return frame,type def get_next_feat(self): #get audio embedding to nerf # return a [1/8, 16] window, for the next input to nerf side. @@ -132,10 +132,10 @@ class NerfASR(BaseASR): def run_step(self): # get a frame of audio - frame,type = self.get_audio_frame() + frame,type,eventpoint = self.get_audio_frame() self.frames.append(frame) # put to output - self.output_queue.put((frame,type)) + self.output_queue.put((frame,type,eventpoint)) # context not enough, do not run network. if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size: return diff --git a/nerfreal.py b/nerfreal.py index b7252ba..f75da52 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -235,7 +235,7 @@ class NeRFReal(BaseReal): audiotype2 = 0 #send audio for i in range(2): - frame,type = self.asr.get_audio_out() + frame,type,eventpoint = self.asr.get_audio_out() if i==0: audiotype1 = type else: @@ -248,7 +248,7 @@ class NeRFReal(BaseReal): new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) new_frame.planes[0].update(frame.tobytes()) new_frame.sample_rate=16000 - asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(audio_track._queue.put((new_frame,eventpoint)), loop) # if self.opt.transport=='rtmp': # for _ in range(2): @@ -285,7 +285,7 @@ class NeRFReal(BaseReal): self.streamer.stream_frame(image) else: new_frame = VideoFrame.from_ndarray(image, format="rgb24") - asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) else: #推理视频+贴回 outputs = self.trainer.test_gui_with_data(data, self.W, self.H) #print('-------ernerf time: ',time.time()-t) @@ -296,7 +296,7 @@ class NeRFReal(BaseReal): self.streamer.stream_frame(image) else: new_frame = VideoFrame.from_ndarray(image, format="rgb24") - asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) else: #fullbody human #print("frame index:",data['index']) #image_fullbody = cv2.imread(os.path.join(self.opt.fullbody_img, str(data['index'][0])+'.jpg')) @@ -310,7 +310,7 @@ class NeRFReal(BaseReal): self.streamer.stream_frame(image_fullbody) else: new_frame = VideoFrame.from_ndarray(image_fullbody, format="rgb24") - asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) + asyncio.run_coroutine_threadsafe(video_track._queue.put((new_frame,None)), loop) #self.pipe.stdin.write(image.tostring()) #ender.record() diff --git a/ttsreal.py b/ttsreal.py index a885394..ed3a1da 100644 --- a/ttsreal.py +++ b/ttsreal.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. ############################################################################### - +import os import time import numpy as np import soundfile as sf @@ -53,9 +53,9 @@ class BaseTTS: self.msgqueue.queue.clear() self.state = State.PAUSE - def put_msg_txt(self,msg): + def put_msg_txt(self,msg,eventpoint=None): if len(msg)>0: - self.msgqueue.put(msg) + self.msgqueue.put((msg,eventpoint)) def render(self,quit_event): process_thread = Thread(target=self.process_tts, args=(quit_event,)) @@ -79,7 +79,7 @@ class BaseTTS: class EdgeTTS(BaseTTS): def txt_to_audio(self,msg): voicename = "zh-CN-YunxiaNeural" - text = msg + text,textevent = msg t = time.time() asyncio.new_event_loop().run_until_complete(self.__main(voicename,text)) print(f'-------edge tts time:{time.time()-t:.4f}s') @@ -92,8 +92,13 @@ class EdgeTTS(BaseTTS): streamlen = stream.shape[0] idx=0 while streamlen >= self.chunk and self.state==State.RUNNING: - self.parent.put_audio_frame(stream[idx:idx+self.chunk]) + eventpoint=None streamlen -= self.chunk + if idx==0: + eventpoint={'status':'start','text':text,'msgenvent':textevent} + elif streamlen0: #skip last frame(not 20ms) # self.queue.put(stream[idx:]) @@ -137,14 +142,16 @@ class EdgeTTS(BaseTTS): ########################################################################################### class VoitsTTS(BaseTTS): def txt_to_audio(self,msg): + text,textevent = msg self.stream_tts( self.gpt_sovits( - msg, + text, self.opt.REF_FILE, self.opt.REF_TEXT, "zh", #en args.language, self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, - ) + ), + msg ) def gpt_sovits(self, text, reffile, reftext,language, server_url) -> Iterator[bytes]: @@ -207,7 +214,9 @@ class VoitsTTS(BaseTTS): return stream - def stream_tts(self,audio_stream): + def stream_tts(self,audio_stream,msg): + text,textevent = msg + first = True for chunk in audio_stream: if chunk is not None and len(chunk)>0: #stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 @@ -217,21 +226,29 @@ class VoitsTTS(BaseTTS): streamlen = stream.shape[0] idx=0 while streamlen >= self.chunk: - self.parent.put_audio_frame(stream[idx:idx+self.chunk]) + eventpoint=None + if first: + eventpoint={'status':'start','text':text,'msgenvent':textevent} + first = False + self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) streamlen -= self.chunk - idx += self.chunk + idx += self.chunk + eventpoint={'status':'end','text':text,'msgenvent':textevent} + self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) ########################################################################################### class CosyVoiceTTS(BaseTTS): - def txt_to_audio(self,msg): + def txt_to_audio(self,msg): + text,textevent = msg self.stream_tts( self.cosy_voice( - msg, + text, self.opt.REF_FILE, self.opt.REF_TEXT, "zh", #en args.language, self.opt.TTS_SERVER, #"http://127.0.0.1:5000", #args.server_url, - ) + ), + msg ) def cosy_voice(self, text, reffile, reftext,language, server_url) -> Iterator[bytes]: @@ -263,7 +280,9 @@ class CosyVoiceTTS(BaseTTS): except Exception as e: print(e) - def stream_tts(self,audio_stream): + def stream_tts(self,audio_stream,msg): + text,textevent = msg + first = True for chunk in audio_stream: if chunk is not None and len(chunk)>0: stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 @@ -273,9 +292,15 @@ class CosyVoiceTTS(BaseTTS): streamlen = stream.shape[0] idx=0 while streamlen >= self.chunk: - self.parent.put_audio_frame(stream[idx:idx+self.chunk]) + eventpoint=None + if first: + eventpoint={'status':'start','text':text,'msgenvent':textevent} + first = False + self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) streamlen -= self.chunk - idx += self.chunk + idx += self.chunk + eventpoint={'status':'end','text':text,'msgenvent':textevent} + self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) ########################################################################################### class XTTS(BaseTTS): @@ -283,15 +308,17 @@ class XTTS(BaseTTS): super().__init__(opt,parent) self.speaker = self.get_speaker(opt.REF_FILE, opt.TTS_SERVER) - def txt_to_audio(self,msg): + def txt_to_audio(self,msg): + text,textevent = msg self.stream_tts( self.xtts( - msg, + text, self.speaker, "zh-cn", #en args.language, self.opt.TTS_SERVER, #"http://localhost:9000", #args.server_url, "20" #args.stream_chunk_size - ) + ), + msg ) def get_speaker(self,ref_audio,server_url): @@ -329,7 +356,9 @@ class XTTS(BaseTTS): except Exception as e: print(e) - def stream_tts(self,audio_stream): + def stream_tts(self,audio_stream,msg): + text,textevent = msg + first = True for chunk in audio_stream: if chunk is not None and len(chunk)>0: stream = np.frombuffer(chunk, dtype=np.int16).astype(np.float32) / 32767 @@ -339,6 +368,12 @@ class XTTS(BaseTTS): streamlen = stream.shape[0] idx=0 while streamlen >= self.chunk: - self.parent.put_audio_frame(stream[idx:idx+self.chunk]) + eventpoint=None + if first: + eventpoint={'status':'start','text':text,'msgenvent':textevent} + first = False + self.parent.put_audio_frame(stream[idx:idx+self.chunk],eventpoint) streamlen -= self.chunk - idx += self.chunk \ No newline at end of file + idx += self.chunk + eventpoint={'status':'end','text':text,'msgenvent':textevent} + self.parent.put_audio_frame(np.zeros(self.chunk,np.float32),eventpoint) \ No newline at end of file diff --git a/webrtc.py b/webrtc.py index 4692048..81779e8 100644 --- a/webrtc.py +++ b/webrtc.py @@ -122,10 +122,12 @@ class PlayerStreamTrack(MediaStreamTrack): # frame = await self._queue.get() # else: # frame = await self._queue.get() - frame = await self._queue.get() + frame,eventpoint = await self._queue.get() pts, time_base = await self.next_timestamp() frame.pts = pts frame.time_base = time_base + if eventpoint: + self._player.notify(eventpoint) if frame is None: self.stop() raise Exception @@ -172,6 +174,8 @@ class HumanPlayer: self.__container = nerfreal + def notify(self,eventpoint): + self.__container.notify(eventpoint) @property def audio(self) -> MediaStreamTrack: