diff --git a/README.md b/README.md index 7f64ec4..69cb1f6 100644 --- a/README.md +++ b/README.md @@ -133,9 +133,9 @@ srs和nginx的运行同2.1和2.3 在Tesla T4显卡上测试整体fps为18左右,如果去掉音视频编码推流,帧率在20左右。用4090显卡可以达到40多帧/秒。 优化:新开一个线程运行音视频编码推流 2. 延时 -整体延时5s多 -(1)tts延时2s左右,目前用的edgetts,需要将每句话转完后一次性输入,可以优化tts改成流式输入 -(2)wav2vec延时1s多,需要缓存50帧音频做计算,可以通过-m设置context_size来减少延时 +整体延时3s左右 +(1)tts延时1.7s左右,目前用的edgetts,需要将每句话转完后一次性输入,可以优化tts改成流式输入 +(2)wav2vec延时0.4s,需要缓存18帧音频做计算 (3)srs转发延时,设置srs服务器减少缓冲延时。具体配置可看 https://ossrs.net/lts/zh-cn/docs/v5/doc/low-latency, 配置了一个低延时版本 ```python docker run --rm -it -p 1935:1935 -p 1985:1985 -p 8080:8080 registry.cn-hangzhou.aliyuncs.com/lipku/srs:v1.1 diff --git a/app.py b/app.py index 0503d8c..68cdcf2 100644 --- a/app.py +++ b/app.py @@ -37,7 +37,11 @@ async def main(voicename: str, text: str, render): communicate = edge_tts.Communicate(text, voicename) #with open(OUTPUT_FILE, "wb") as file: + first = True async for chunk in communicate.stream(): + if first: + #render.before_push_audio() + first = False if chunk["type"] == "audio": render.push_audio(chunk["data"]) #file.write(chunk["data"]) @@ -258,7 +262,7 @@ if __name__ == '__main__': parser.add_argument('--fps', type=int, default=50) # sliding window left-middle-right length (unit: 20ms) parser.add_argument('-l', type=int, default=10) - parser.add_argument('-m', type=int, default=50) + parser.add_argument('-m', type=int, default=8) parser.add_argument('-r', type=int, default=10) parser.add_argument('--fullbody', action='store_true', help="fullbody human") diff --git a/asrreal.py b/asrreal.py index 3020021..c8b1e5f 100644 --- a/asrreal.py +++ b/asrreal.py @@ -122,58 +122,34 @@ class ASR: self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding... # warm up steps needed: mid + right + window_size + attention_size - self.warm_up_steps = self.context_size + self.stride_right_size + self.stride_left_size #+ 8 + 2 * 3 + self.warm_up_steps = self.context_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3 self.listening = False self.playing = False - def listen(self): - # start - if self.mode == 'live' and not self.listening: - print(f'[INFO] starting read frame thread...') - self.process_read_frame.start() - self.listening = True - - if self.play and not self.playing: - print(f'[INFO] starting play frame thread...') - self.process_play_frame.start() - self.playing = True - - def stop(self): - - self.exit_event.set() - - if self.play: - self.output_stream.stop_stream() - self.output_stream.close() - if self.playing: - self.process_play_frame.join() - self.playing = False - - if self.mode == 'live': - #self.input_stream.stop_stream() todo - self.input_stream.close() - if self.listening: - self.process_read_frame.join() - self.listening = False - + def get_next_feat(self): #get audio embedding to nerf + # return a [1/8, 16] window, for the next input to nerf side. + if self.opt.att>0: + while len(self.att_feats) < 8: + # [------f+++t-----] + if self.front < self.tail: + feat = self.feat_queue[self.front:self.tail] + # [++t-----------f+] + else: + feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0) - def __enter__(self): - return self + self.front = (self.front + 2) % self.feat_queue.shape[0] + self.tail = (self.tail + 2) % self.feat_queue.shape[0] - def __exit__(self, exc_type, exc_value, traceback): - - self.stop() + # print(self.front, self.tail, feat.shape) - if self.mode == 'live': - # live mode: also print the result text. - self.text += '\n[END]' - print(self.text) + self.att_feats.append(feat.permute(1, 0)) + + att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] - def get_next_feat(self): - # return a [1/8, 16] window, for the next input to nerf side. - - while len(self.att_feats) < 8: + # discard old + self.att_feats = self.att_feats[1:] + else: # [------f+++t-----] if self.front < self.tail: feat = self.feat_queue[self.front:self.tail] @@ -184,14 +160,8 @@ class ASR: self.front = (self.front + 2) % self.feat_queue.shape[0] self.tail = (self.tail + 2) % self.feat_queue.shape[0] - # print(self.front, self.tail, feat.shape) - - self.att_feats.append(feat.permute(1, 0)) - - att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16] + att_feat = feat.permute(1, 0).unsqueeze(0) - # discard old - self.att_feats = self.att_feats[1:] return att_feat @@ -201,7 +171,7 @@ class ASR: return # get a frame of audio - frame = self.get_audio_frame() + frame = self.__get_audio_frame() # the last frame if frame is None: @@ -223,7 +193,7 @@ class ASR: print(f'[INFO] frame_to_text... ') #t = time.time() - logits, labels, text = self.frame_to_text(inputs) + logits, labels, text = self.__frame_to_text(inputs) #print(f'-------wav2vec time:{time.time()-t:.4f}s') feats = logits # better lips-sync than labels @@ -264,68 +234,17 @@ class ASR: np.save(output_path, unfold_feats.cpu().numpy()) print(f"[INFO] saved logits to {output_path}") - ''' - def create_file_stream(self): - - stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64 - stream = stream.astype(np.float32) - - if stream.ndim > 1: - print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') - stream = stream[:, 0] - - if sample_rate != self.sample_rate: - print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') - stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) - - print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}') - - return stream - - - def create_pyaudio_stream(self): - - import pyaudio - - print(f'[INFO] creating live audio stream ...') - - audio = pyaudio.PyAudio() - - # get devices - info = audio.get_host_api_info_by_index(0) - n_devices = info.get('deviceCount') - - for i in range(0, n_devices): - if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: - name = audio.get_device_info_by_host_api_device_index(0, i).get('name') - print(f'[INFO] choose audio device {name}, id {i}') - break - - # get stream - stream = audio.open(input_device_index=i, - format=pyaudio.paInt16, - channels=1, - rate=self.sample_rate, - input=True, - frames_per_buffer=self.chunk) - - return audio, stream - ''' - - def get_audio_frame(self): - + def __get_audio_frame(self): if self.inwarm: # warm up return np.zeros(self.chunk, dtype=np.float32) if self.mode == 'file': - if self.idx < self.file_stream.shape[0]: frame = self.file_stream[self.idx: self.idx + self.chunk] self.idx = self.idx + self.chunk return frame else: - return None - + return None else: try: frame = self.queue.get(block=False) @@ -338,7 +257,7 @@ class ASR: return frame - def frame_to_text(self, frame): + def __frame_to_text(self, frame): # frame: [N * 320], N = (context_size + 2 * stride_size) inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True) @@ -377,7 +296,7 @@ class ASR: return logits[0], None,None #predicted_ids[0], transcription # [N,] - def create_bytes_stream(self,byte_stream): + def __create_bytes_stream(self,byte_stream): #byte_stream=BytesIO(buffer) stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64 print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}') @@ -393,14 +312,14 @@ class ASR: return stream - def push_audio(self,buffer): + def push_audio(self,buffer): #push audio pcm from tts print(f'[INFO] push_audio {len(buffer)}') if self.opt.tts == "xtts": if len(buffer)>0: stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767 stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate) #byte_stream=BytesIO(buffer) - #stream = self.create_bytes_stream(byte_stream) + #stream = self.__create_bytes_stream(byte_stream) streamlen = stream.shape[0] idx=0 while streamlen >= self.chunk: @@ -413,7 +332,7 @@ class ASR: self.input_stream.write(buffer) if len(buffer)<=0: self.input_stream.seek(0) - stream = self.create_bytes_stream(self.input_stream) + stream = self.__create_bytes_stream(self.input_stream) streamlen = stream.shape[0] idx=0 while streamlen >= self.chunk: @@ -425,8 +344,21 @@ class ASR: self.input_stream.seek(0) self.input_stream.truncate() - def get_audio_out(self): + def get_audio_out(self): #get origin audio pcm to nerf return self.output_queue.get() + + def __init_queue(self): + self.frames = [] + self.queue.queue.clear() + self.output_queue.queue.clear() + self.front = self.feat_buffer_size * self.context_size - 8 # fake padding + self.tail = 8 + # attention window... + self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 + + def before_push_audio(self): + self.__init_queue() + self.warm_up() def run(self): @@ -450,19 +382,110 @@ class ASR: self.inwarm = True print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s') t = time.time() + for _ in range(self.stride_left_size): + self.frames.append(np.zeros(self.chunk, dtype=np.float32)) for _ in range(self.warm_up_steps): self.run_step() - if torch.cuda.is_available(): - torch.cuda.synchronize() + #if torch.cuda.is_available(): + # torch.cuda.synchronize() t = time.time() - t print(f'[INFO] warm-up done, actual latency = {t:.6f}s') self.inwarm = False #self.clear_queue() - + ''' + def create_file_stream(self): + + stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64 + stream = stream.astype(np.float32) + + if stream.ndim > 1: + print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.') + stream = stream[:, 0] + + if sample_rate != self.sample_rate: + print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.') + stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate) + + print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}') + + return stream + + + def create_pyaudio_stream(self): + + import pyaudio + + print(f'[INFO] creating live audio stream ...') + + audio = pyaudio.PyAudio() + + # get devices + info = audio.get_host_api_info_by_index(0) + n_devices = info.get('deviceCount') + + for i in range(0, n_devices): + if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0: + name = audio.get_device_info_by_host_api_device_index(0, i).get('name') + print(f'[INFO] choose audio device {name}, id {i}') + break + + # get stream + stream = audio.open(input_device_index=i, + format=pyaudio.paInt16, + channels=1, + rate=self.sample_rate, + input=True, + frames_per_buffer=self.chunk) + + return audio, stream + ''' + #####not used function##################################### + def listen(self): + # start + if self.mode == 'live' and not self.listening: + print(f'[INFO] starting read frame thread...') + self.process_read_frame.start() + self.listening = True + + if self.play and not self.playing: + print(f'[INFO] starting play frame thread...') + self.process_play_frame.start() + self.playing = True + + def stop(self): + + self.exit_event.set() + + if self.play: + self.output_stream.stop_stream() + self.output_stream.close() + if self.playing: + self.process_play_frame.join() + self.playing = False + + if self.mode == 'live': + #self.input_stream.stop_stream() todo + self.input_stream.close() + if self.listening: + self.process_read_frame.join() + self.listening = False + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + + self.stop() + + if self.mode == 'live': + # live mode: also print the result text. + self.text += '\n[END]' + print(self.text) + ######################################################### + if __name__ == '__main__': import argparse diff --git a/nerfreal.py b/nerfreal.py index bcf774d..426f08f 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -108,6 +108,9 @@ class NeRFReal: def push_audio(self,chunk): self.asr.push_audio(chunk) + + def before_push_audio(self): + self.asr.before_push_audio() def prepare_buffer(self, outputs): if self.mode == 'image':