|
|
|
@ -4,29 +4,19 @@ import torch
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
|
|
|
|
|
|
|
|
|
|
#import pyaudio
|
|
|
|
|
import soundfile as sf
|
|
|
|
|
import resampy
|
|
|
|
|
|
|
|
|
|
import queue
|
|
|
|
|
from queue import Queue
|
|
|
|
|
#from collections import deque
|
|
|
|
|
from threading import Thread, Event
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
|
|
|
|
|
class ASR:
|
|
|
|
|
def __init__(self, opt):
|
|
|
|
|
|
|
|
|
|
self.opt = opt
|
|
|
|
|
from baseasr import BaseASR
|
|
|
|
|
|
|
|
|
|
self.play = opt.asr_play #false
|
|
|
|
|
class ASR(BaseASR):
|
|
|
|
|
def __init__(self, opt):
|
|
|
|
|
super().__init__(opt)
|
|
|
|
|
|
|
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
self.fps = opt.fps # 20 ms per frame
|
|
|
|
|
self.sample_rate = 16000
|
|
|
|
|
self.chunk = self.sample_rate // self.fps # 320 samples per chunk (20ms * 16000 / 1000)
|
|
|
|
|
self.mode = 'live' if opt.asr_wav == '' else 'file'
|
|
|
|
|
|
|
|
|
|
if 'esperanto' in self.opt.asr_model:
|
|
|
|
|
self.audio_dim = 44
|
|
|
|
|
elif 'deepspeech' in self.opt.asr_model:
|
|
|
|
@ -41,30 +31,11 @@ class ASR:
|
|
|
|
|
self.context_size = opt.m
|
|
|
|
|
self.stride_left_size = opt.l
|
|
|
|
|
self.stride_right_size = opt.r
|
|
|
|
|
self.text = '[START]\n'
|
|
|
|
|
self.terminated = False
|
|
|
|
|
self.frames = []
|
|
|
|
|
self.inwarm = False
|
|
|
|
|
|
|
|
|
|
# pad left frames
|
|
|
|
|
if self.stride_left_size > 0:
|
|
|
|
|
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.exit_event = Event()
|
|
|
|
|
#self.audio_instance = pyaudio.PyAudio() #not need
|
|
|
|
|
|
|
|
|
|
# create input stream
|
|
|
|
|
self.queue = Queue()
|
|
|
|
|
self.output_queue = Queue()
|
|
|
|
|
# start a background process to read frames
|
|
|
|
|
#self.input_stream = self.audio_instance.open(format=pyaudio.paInt16, channels=1, rate=self.sample_rate, input=True, output=False, frames_per_buffer=self.chunk)
|
|
|
|
|
#self.queue = Queue()
|
|
|
|
|
#self.process_read_frame = Thread(target=_read_frame, args=(self.input_stream, self.exit_event, self.queue, self.chunk))
|
|
|
|
|
|
|
|
|
|
# current location of audio
|
|
|
|
|
self.idx = 0
|
|
|
|
|
|
|
|
|
|
# create wav2vec model
|
|
|
|
|
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
|
|
|
|
if 'hubert' in self.opt.asr_model:
|
|
|
|
@ -74,10 +45,6 @@ class ASR:
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
|
|
|
|
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
|
|
|
|
|
|
|
|
|
# prepare to save logits
|
|
|
|
|
if self.opt.asr_save_feats:
|
|
|
|
|
self.all_feats = []
|
|
|
|
|
|
|
|
|
|
# the extracted features
|
|
|
|
|
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]
|
|
|
|
|
self.feat_buffer_size = 4
|
|
|
|
@ -93,8 +60,16 @@ class ASR:
|
|
|
|
|
# warm up steps needed: mid + right + window_size + attention_size
|
|
|
|
|
self.warm_up_steps = self.context_size + self.stride_left_size + self.stride_right_size #+ self.stride_left_size #+ 8 + 2 * 3
|
|
|
|
|
|
|
|
|
|
self.listening = False
|
|
|
|
|
self.playing = False
|
|
|
|
|
def get_audio_frame(self):
|
|
|
|
|
try:
|
|
|
|
|
frame = self.queue.get(block=False)
|
|
|
|
|
type = 0
|
|
|
|
|
#print(f'[INFO] get frame {frame.shape}')
|
|
|
|
|
except queue.Empty:
|
|
|
|
|
frame = np.zeros(self.chunk, dtype=np.float32)
|
|
|
|
|
type = 1
|
|
|
|
|
|
|
|
|
|
return frame,type
|
|
|
|
|
|
|
|
|
|
def get_next_feat(self): #get audio embedding to nerf
|
|
|
|
|
# return a [1/8, 16] window, for the next input to nerf side.
|
|
|
|
@ -136,17 +111,8 @@ class ASR:
|
|
|
|
|
|
|
|
|
|
def run_step(self):
|
|
|
|
|
|
|
|
|
|
if self.terminated:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# get a frame of audio
|
|
|
|
|
frame,type = self.__get_audio_frame()
|
|
|
|
|
|
|
|
|
|
# the last frame
|
|
|
|
|
if frame is None:
|
|
|
|
|
# terminate, but always run the network for the left frames
|
|
|
|
|
self.terminated = True
|
|
|
|
|
else:
|
|
|
|
|
frame,type = self.get_audio_frame()
|
|
|
|
|
self.frames.append(frame)
|
|
|
|
|
# put to output
|
|
|
|
|
self.output_queue.put((frame,type))
|
|
|
|
@ -157,7 +123,6 @@ class ASR:
|
|
|
|
|
inputs = np.concatenate(self.frames) # [N * chunk]
|
|
|
|
|
|
|
|
|
|
# discard the old part to save memory
|
|
|
|
|
if not self.terminated:
|
|
|
|
|
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
|
|
|
|
|
|
|
|
|
#print(f'[INFO] frame_to_text... ')
|
|
|
|
@ -166,10 +131,6 @@ class ASR:
|
|
|
|
|
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
|
|
|
|
|
feats = logits # better lips-sync than labels
|
|
|
|
|
|
|
|
|
|
# save feats
|
|
|
|
|
if self.opt.asr_save_feats:
|
|
|
|
|
self.all_feats.append(feats)
|
|
|
|
|
|
|
|
|
|
# record the feats efficiently.. (no concat, constant memory)
|
|
|
|
|
start = self.feat_buffer_idx * self.context_size
|
|
|
|
|
end = start + feats.shape[0]
|
|
|
|
@ -203,24 +164,6 @@ class ASR:
|
|
|
|
|
# np.save(output_path, unfold_feats.cpu().numpy())
|
|
|
|
|
# print(f"[INFO] saved logits to {output_path}")
|
|
|
|
|
|
|
|
|
|
def put_audio_frame(self,audio_chunk): #16khz 20ms pcm
|
|
|
|
|
self.queue.put(audio_chunk)
|
|
|
|
|
|
|
|
|
|
def __get_audio_frame(self):
|
|
|
|
|
if self.inwarm: # warm up
|
|
|
|
|
return np.zeros(self.chunk, dtype=np.float32),1
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
frame = self.queue.get(block=False)
|
|
|
|
|
type = 0
|
|
|
|
|
print(f'[INFO] get frame {frame.shape}')
|
|
|
|
|
except queue.Empty:
|
|
|
|
|
frame = np.zeros(self.chunk, dtype=np.float32)
|
|
|
|
|
type = 1
|
|
|
|
|
|
|
|
|
|
self.idx = self.idx + self.chunk
|
|
|
|
|
|
|
|
|
|
return frame,type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __frame_to_text(self, frame):
|
|
|
|
@ -241,8 +184,8 @@ class ASR:
|
|
|
|
|
right = min(logits.shape[1], logits.shape[1] - self.stride_right_size + 1) # +1 to make sure output is the same length as input.
|
|
|
|
|
|
|
|
|
|
# do not cut right if terminated.
|
|
|
|
|
if self.terminated:
|
|
|
|
|
right = logits.shape[1]
|
|
|
|
|
# if self.terminated:
|
|
|
|
|
# right = logits.shape[1]
|
|
|
|
|
|
|
|
|
|
logits = logits[:, left:right]
|
|
|
|
|
|
|
|
|
@ -263,9 +206,22 @@ class ASR:
|
|
|
|
|
return logits[0], None,None #predicted_ids[0], transcription # [N,]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_audio_out(self): #get origin audio pcm to nerf
|
|
|
|
|
return self.output_queue.get()
|
|
|
|
|
def warm_up(self):
|
|
|
|
|
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
|
|
|
|
t = time.time()
|
|
|
|
|
#for _ in range(self.stride_left_size):
|
|
|
|
|
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
|
|
|
|
|
for _ in range(self.warm_up_steps):
|
|
|
|
|
self.run_step()
|
|
|
|
|
#if torch.cuda.is_available():
|
|
|
|
|
# torch.cuda.synchronize()
|
|
|
|
|
t = time.time() - t
|
|
|
|
|
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
|
|
|
|
|
|
|
|
|
#self.clear_queue()
|
|
|
|
|
|
|
|
|
|
#####not used function#####################################
|
|
|
|
|
'''
|
|
|
|
|
def __init_queue(self):
|
|
|
|
|
self.frames = []
|
|
|
|
|
self.queue.queue.clear()
|
|
|
|
@ -290,26 +246,6 @@ class ASR:
|
|
|
|
|
if self.play:
|
|
|
|
|
self.output_queue.queue.clear()
|
|
|
|
|
|
|
|
|
|
def warm_up(self):
|
|
|
|
|
|
|
|
|
|
#self.listen()
|
|
|
|
|
|
|
|
|
|
self.inwarm = True
|
|
|
|
|
print(f'[INFO] warm up ASR live model, expected latency = {self.warm_up_steps / self.fps:.6f}s')
|
|
|
|
|
t = time.time()
|
|
|
|
|
#for _ in range(self.stride_left_size):
|
|
|
|
|
# self.frames.append(np.zeros(self.chunk, dtype=np.float32))
|
|
|
|
|
for _ in range(self.warm_up_steps):
|
|
|
|
|
self.run_step()
|
|
|
|
|
#if torch.cuda.is_available():
|
|
|
|
|
# torch.cuda.synchronize()
|
|
|
|
|
t = time.time() - t
|
|
|
|
|
print(f'[INFO] warm-up done, actual latency = {t:.6f}s')
|
|
|
|
|
self.inwarm = False
|
|
|
|
|
|
|
|
|
|
#self.clear_queue()
|
|
|
|
|
|
|
|
|
|
#####not used function#####################################
|
|
|
|
|
def listen(self):
|
|
|
|
|
# start
|
|
|
|
|
if self.mode == 'live' and not self.listening:
|
|
|
|
@ -405,3 +341,4 @@ if __name__ == '__main__':
|
|
|
|
|
|
|
|
|
|
with ASR(opt) as asr:
|
|
|
|
|
asr.run()
|
|
|
|
|
'''
|