|
|
|
@ -122,58 +122,34 @@ class ASR:
|
|
|
|
|
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
|
|
|
|
|
|
|
|
|
|
# warm up steps needed: mid + right + window_size + attention_size
|
|
|
|
|
self.warm_up_steps = self.context_size + self.stride_right_size + self.stride_left_size #+ 8 + 2 * 3
|
|
|
|
|
self.warm_up_steps = self.context_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
|
|
|
|
|
|
|
|
|
|
self.listening = False
|
|
|
|
|
self.playing = False
|
|
|
|
|
|
|
|
|
|
def listen(self):
|
|
|
|
|
# start
|
|
|
|
|
if self.mode == 'live' and not self.listening:
|
|
|
|
|
print(f'[INFO] starting read frame thread...')
|
|
|
|
|
self.process_read_frame.start()
|
|
|
|
|
self.listening = True
|
|
|
|
|
|
|
|
|
|
if self.play and not self.playing:
|
|
|
|
|
print(f'[INFO] starting play frame thread...')
|
|
|
|
|
self.process_play_frame.start()
|
|
|
|
|
self.playing = True
|
|
|
|
|
|
|
|
|
|
def stop(self):
|
|
|
|
|
|
|
|
|
|
self.exit_event.set()
|
|
|
|
|
|
|
|
|
|
if self.play:
|
|
|
|
|
self.output_stream.stop_stream()
|
|
|
|
|
self.output_stream.close()
|
|
|
|
|
if self.playing:
|
|
|
|
|
self.process_play_frame.join()
|
|
|
|
|
self.playing = False
|
|
|
|
|
|
|
|
|
|
if self.mode == 'live':
|
|
|
|
|
#self.input_stream.stop_stream() todo
|
|
|
|
|
self.input_stream.close()
|
|
|
|
|
if self.listening:
|
|
|
|
|
self.process_read_frame.join()
|
|
|
|
|
self.listening = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
return self
|
|
|
|
|
def get_next_feat(self): #get audio embedding to nerf
|
|
|
|
|
# return a [1/8, 16] window, for the next input to nerf side.
|
|
|
|
|
if self.opt.att>0:
|
|
|
|
|
while len(self.att_feats) < 8:
|
|
|
|
|
# [------f+++t-----]
|
|
|
|
|
if self.front < self.tail:
|
|
|
|
|
feat = self.feat_queue[self.front:self.tail]
|
|
|
|
|
# [++t-----------f+]
|
|
|
|
|
else:
|
|
|
|
|
feat = torch.cat([self.feat_queue[self.front:], self.feat_queue[:self.tail]], dim=0)
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
|
self.front = (self.front + 2) % self.feat_queue.shape[0]
|
|
|
|
|
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
|
|
|
|
|
|
|
|
|
|
self.stop()
|
|
|
|
|
# print(self.front, self.tail, feat.shape)
|
|
|
|
|
|
|
|
|
|
if self.mode == 'live':
|
|
|
|
|
# live mode: also print the result text.
|
|
|
|
|
self.text += '\n[END]'
|
|
|
|
|
print(self.text)
|
|
|
|
|
self.att_feats.append(feat.permute(1, 0))
|
|
|
|
|
|
|
|
|
|
def get_next_feat(self):
|
|
|
|
|
# return a [1/8, 16] window, for the next input to nerf side.
|
|
|
|
|
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
|
|
|
|
|
|
|
|
|
|
while len(self.att_feats) < 8:
|
|
|
|
|
# discard old
|
|
|
|
|
self.att_feats = self.att_feats[1:]
|
|
|
|
|
else:
|
|
|
|
|
# [------f+++t-----]
|
|
|
|
|
if self.front < self.tail:
|
|
|
|
|
feat = self.feat_queue[self.front:self.tail]
|
|
|
|
@ -184,14 +160,8 @@ class ASR:
|
|
|
|
|
self.front = (self.front + 2) % self.feat_queue.shape[0]
|
|
|
|
|
self.tail = (self.tail + 2) % self.feat_queue.shape[0]
|
|
|
|
|
|
|
|
|
|
# print(self.front, self.tail, feat.shape)
|
|
|
|
|
|
|
|
|
|
self.att_feats.append(feat.permute(1, 0))
|
|
|
|
|
|
|
|
|
|
att_feat = torch.stack(self.att_feats, dim=0) # [8, 44, 16]
|
|
|
|
|
att_feat = feat.permute(1, 0).unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
# discard old
|
|
|
|
|
self.att_feats = self.att_feats[1:]
|
|
|
|
|
|
|
|
|
|
return att_feat
|
|
|
|
|
|
|
|
|
@ -201,7 +171,7 @@ class ASR:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# get a frame of audio
|
|
|
|
|
frame = self.get_audio_frame()
|
|
|
|
|
frame = self.__get_audio_frame()
|
|
|
|
|
|
|
|
|
|
# the last frame
|
|
|
|
|
if frame is None:
|
|
|
|
@ -223,7 +193,7 @@ class ASR:
|
|
|
|
|
|
|
|
|
|
print(f'[INFO] frame_to_text... ')
|
|
|
|
|
#t = time.time()
|
|
|
|
|
logits, labels, text = self.frame_to_text(inputs)
|
|
|
|
|
logits, labels, text = self.__frame_to_text(inputs)
|
|
|
|
|
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
|
|
|
|
|
feats = logits # better lips-sync than labels
|
|
|
|
|
|
|
|
|
@ -264,68 +234,17 @@ class ASR:
|
|
|
|
|
np.save(output_path, unfold_feats.cpu().numpy())
|
|
|
|
|
print(f"[INFO] saved logits to {output_path}")
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
def create_file_stream(self):
|
|
|
|
|
|
|
|
|
|
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
|
|
|
|
|
stream = stream.astype(np.float32)
|
|
|
|
|
|
|
|
|
|
if stream.ndim > 1:
|
|
|
|
|
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
|
|
|
|
stream = stream[:, 0]
|
|
|
|
|
|
|
|
|
|
if sample_rate != self.sample_rate:
|
|
|
|
|
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
|
|
|
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
|
|
|
|
|
|
|
|
|
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
|
|
|
|
|
|
|
|
|
|
return stream
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_pyaudio_stream(self):
|
|
|
|
|
|
|
|
|
|
import pyaudio
|
|
|
|
|
|
|
|
|
|
print(f'[INFO] creating live audio stream ...')
|
|
|
|
|
|
|
|
|
|
audio = pyaudio.PyAudio()
|
|
|
|
|
|
|
|
|
|
# get devices
|
|
|
|
|
info = audio.get_host_api_info_by_index(0)
|
|
|
|
|
n_devices = info.get('deviceCount')
|
|
|
|
|
|
|
|
|
|
for i in range(0, n_devices):
|
|
|
|
|
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
|
|
|
|
|
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
|
|
|
|
|
print(f'[INFO] choose audio device {name}, id {i}')
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# get stream
|
|
|
|
|
stream = audio.open(input_device_index=i,
|
|
|
|
|
format=pyaudio.paInt16,
|
|
|
|
|
channels=1,
|
|
|
|
|
rate=self.sample_rate,
|
|
|
|
|
input=True,
|
|
|
|
|
frames_per_buffer=self.chunk)
|
|
|
|
|
|
|
|
|
|
return audio, stream
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def get_audio_frame(self):
|
|
|
|
|
|
|
|
|
|
def __get_audio_frame(self):
|
|
|
|
|
if self.inwarm: # warm up
|
|
|
|
|
return np.zeros(self.chunk, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
if self.mode == 'file':
|
|
|
|
|
|
|
|
|
|
if self.idx < self.file_stream.shape[0]:
|
|
|
|
|
frame = self.file_stream[self.idx: self.idx + self.chunk]
|
|
|
|
|
self.idx = self.idx + self.chunk
|
|
|
|
|
return frame
|
|
|
|
|
else:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
try:
|
|
|
|
|
frame = self.queue.get(block=False)
|
|
|
|
@ -338,7 +257,7 @@ class ASR:
|
|
|
|
|
return frame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def frame_to_text(self, frame):
|
|
|
|
|
def __frame_to_text(self, frame):
|
|
|
|
|
# frame: [N * 320], N = (context_size + 2 * stride_size)
|
|
|
|
|
|
|
|
|
|
inputs = self.processor(frame, sampling_rate=self.sample_rate, return_tensors="pt", padding=True)
|
|
|
|
@ -377,7 +296,7 @@ class ASR:
|
|
|
|
|
|
|
|
|
|
return logits[0], None,None #predicted_ids[0], transcription # [N,]
|
|
|
|
|
|
|
|
|
|
def create_bytes_stream(self,byte_stream):
|
|
|
|
|
def __create_bytes_stream(self,byte_stream):
|
|
|
|
|
#byte_stream=BytesIO(buffer)
|
|
|
|
|
stream, sample_rate = sf.read(byte_stream) # [T*sample_rate,] float64
|
|
|
|
|
print(f'[INFO]tts audio stream {sample_rate}: {stream.shape}')
|
|
|
|
@ -393,14 +312,14 @@ class ASR:
|
|
|
|
|
|
|
|
|
|
return stream
|
|
|
|
|
|
|
|
|
|
def push_audio(self,buffer):
|
|
|
|
|
def push_audio(self,buffer): #push audio pcm from tts
|
|
|
|
|
print(f'[INFO] push_audio {len(buffer)}')
|
|
|
|
|
if self.opt.tts == "xtts":
|
|
|
|
|
if len(buffer)>0:
|
|
|
|
|
stream = np.frombuffer(buffer, dtype=np.int16).astype(np.float32) / 32767
|
|
|
|
|
stream = resampy.resample(x=stream, sr_orig=24000, sr_new=self.sample_rate)
|
|
|
|
|
#byte_stream=BytesIO(buffer)
|
|
|
|
|
#stream = self.create_bytes_stream(byte_stream)
|
|
|
|
|
#stream = self.__create_bytes_stream(byte_stream)
|
|
|
|
|
streamlen = stream.shape[0]
|
|
|
|
|
idx=0
|
|
|
|
|
while streamlen >= self.chunk:
|
|
|
|
@ -413,7 +332,7 @@ class ASR:
|
|
|
|
|
self.input_stream.write(buffer)
|
|
|
|
|
if len(buffer)<=0:
|
|
|
|
|
self.input_stream.seek(0)
|
|
|
|
|
stream = self.create_bytes_stream(self.input_stream)
|
|
|
|
|
stream = self.__create_bytes_stream(self.input_stream)
|
|
|
|
|
streamlen = stream.shape[0]
|
|
|
|
|
idx=0
|
|
|
|
|
while streamlen >= self.chunk:
|
|
|
|
@ -425,9 +344,22 @@ class ASR:
|
|
|
|
|
self.input_stream.seek(0)
|
|
|
|
|
self.input_stream.truncate()
|
|
|
|
|
|
|
|
|
|
def get_audio_out(self):
|
|
|
|
|
def get_audio_out(self): #get origin audio pcm to nerf
|
|
|
|
|
return self.output_queue.get()
|
|
|
|
|
|
|
|
|
|
def __init_queue(self):
|
|
|
|
|
self.frames = []
|
|
|
|
|
self.queue.queue.clear()
|
|
|
|
|
self.output_queue.queue.clear()
|
|
|
|
|
self.front = self.feat_buffer_size * self.context_size - 8 # fake padding
|
|
|
|
|
self.tail = 8
|
|
|
|
|
# attention window...
|
|
|
|
|
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4
|
|
|
|
|
|
|
|
|
|
def before_push_audio(self):
|
|
|
|
|
self.__init_queue()
|
|
|
|
|
self.warm_up()
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
|
|
|
|
|
self.listen()
|
|
|
|
@ -450,18 +382,109 @@ class ASR:
|
|
|
|
|
self.inwarm = True
|
|
|
|
|
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
|
|
|
|
t = time.time()
|
|
|
|
|
for _ in range(self.stride_left_size):
|
|
|
|
|
self.frames.append(np.zeros(self.chunk, dtype=np.float32))
|
|
|
|
|
for _ in range(self.warm_up_steps):
|
|
|
|
|
self.run_step()
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
#if torch.cuda.is_available():
|
|
|
|
|
# torch.cuda.synchronize()
|
|
|
|
|
t = time.time() - t
|
|
|
|
|
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
|
|
|
|
self.inwarm = False
|
|
|
|
|
|
|
|
|
|
#self.clear_queue()
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
def create_file_stream(self):
|
|
|
|
|
|
|
|
|
|
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
|
|
|
|
|
stream = stream.astype(np.float32)
|
|
|
|
|
|
|
|
|
|
if stream.ndim > 1:
|
|
|
|
|
print(f'[WARN] audio has {stream.shape[1]} channels, only use the first.')
|
|
|
|
|
stream = stream[:, 0]
|
|
|
|
|
|
|
|
|
|
if sample_rate != self.sample_rate:
|
|
|
|
|
print(f'[WARN] audio sample rate is {sample_rate}, resampling into {self.sample_rate}.')
|
|
|
|
|
stream = resampy.resample(x=stream, sr_orig=sample_rate, sr_new=self.sample_rate)
|
|
|
|
|
|
|
|
|
|
print(f'[INFO] loaded audio stream {self.opt.asr_wav}: {stream.shape}')
|
|
|
|
|
|
|
|
|
|
return stream
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_pyaudio_stream(self):
|
|
|
|
|
|
|
|
|
|
import pyaudio
|
|
|
|
|
|
|
|
|
|
print(f'[INFO] creating live audio stream ...')
|
|
|
|
|
|
|
|
|
|
audio = pyaudio.PyAudio()
|
|
|
|
|
|
|
|
|
|
# get devices
|
|
|
|
|
info = audio.get_host_api_info_by_index(0)
|
|
|
|
|
n_devices = info.get('deviceCount')
|
|
|
|
|
|
|
|
|
|
for i in range(0, n_devices):
|
|
|
|
|
if (audio.get_device_info_by_host_api_device_index(0, i).get('maxInputChannels')) > 0:
|
|
|
|
|
name = audio.get_device_info_by_host_api_device_index(0, i).get('name')
|
|
|
|
|
print(f'[INFO] choose audio device {name}, id {i}')
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
# get stream
|
|
|
|
|
stream = audio.open(input_device_index=i,
|
|
|
|
|
format=pyaudio.paInt16,
|
|
|
|
|
channels=1,
|
|
|
|
|
rate=self.sample_rate,
|
|
|
|
|
input=True,
|
|
|
|
|
frames_per_buffer=self.chunk)
|
|
|
|
|
|
|
|
|
|
return audio, stream
|
|
|
|
|
'''
|
|
|
|
|
#####not used function#####################################
|
|
|
|
|
def listen(self):
|
|
|
|
|
# start
|
|
|
|
|
if self.mode == 'live' and not self.listening:
|
|
|
|
|
print(f'[INFO] starting read frame thread...')
|
|
|
|
|
self.process_read_frame.start()
|
|
|
|
|
self.listening = True
|
|
|
|
|
|
|
|
|
|
if self.play and not self.playing:
|
|
|
|
|
print(f'[INFO] starting play frame thread...')
|
|
|
|
|
self.process_play_frame.start()
|
|
|
|
|
self.playing = True
|
|
|
|
|
|
|
|
|
|
def stop(self):
|
|
|
|
|
|
|
|
|
|
self.exit_event.set()
|
|
|
|
|
|
|
|
|
|
if self.play:
|
|
|
|
|
self.output_stream.stop_stream()
|
|
|
|
|
self.output_stream.close()
|
|
|
|
|
if self.playing:
|
|
|
|
|
self.process_play_frame.join()
|
|
|
|
|
self.playing = False
|
|
|
|
|
|
|
|
|
|
if self.mode == 'live':
|
|
|
|
|
#self.input_stream.stop_stream() todo
|
|
|
|
|
self.input_stream.close()
|
|
|
|
|
if self.listening:
|
|
|
|
|
self.process_read_frame.join()
|
|
|
|
|
self.listening = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
|
|
|
|
|
|
self.stop()
|
|
|
|
|
|
|
|
|
|
if self.mode == 'live':
|
|
|
|
|
# live mode: also print the result text.
|
|
|
|
|
self.text += '\n[END]'
|
|
|
|
|
print(self.text)
|
|
|
|
|
#########################################################
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
import argparse
|
|
|
|
|