|
|
|
@ -116,7 +116,7 @@ class ASR:
|
|
|
|
|
self.att_feats = [torch.zeros(self.audio_dim, 16, dtype=torch.float32, device=self.device)] * 4 # 4 zero padding...
|
|
|
|
|
|
|
|
|
|
# warm up steps needed: mid + right + window_size + attention_size
|
|
|
|
|
self.warm_up_steps = self.context_size + self.stride_right_size + 8 + 2 * 3
|
|
|
|
|
self.warm_up_steps = self.context_size + self.stride_right_size + self.stride_left_size #+ 8 + 2 * 3
|
|
|
|
|
|
|
|
|
|
self.listening = False
|
|
|
|
|
self.playing = False
|
|
|
|
@ -204,7 +204,6 @@ class ASR:
|
|
|
|
|
else:
|
|
|
|
|
self.frames.append(frame)
|
|
|
|
|
# put to output
|
|
|
|
|
#if self.play:
|
|
|
|
|
self.output_queue.put(frame)
|
|
|
|
|
# context not enough, do not run network.
|
|
|
|
|
if len(self.frames) < self.stride_left_size + self.context_size + self.stride_right_size:
|
|
|
|
@ -217,7 +216,9 @@ class ASR:
|
|
|
|
|
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
|
|
|
|
|
|
|
|
|
print(f'[INFO] frame_to_text... ')
|
|
|
|
|
#t = time.time()
|
|
|
|
|
logits, labels, text = self.frame_to_text(inputs)
|
|
|
|
|
#print(f'-------wav2vec time:{time.time()-t:.4f}s')
|
|
|
|
|
feats = logits # better lips-sync than labels
|
|
|
|
|
|
|
|
|
|
# save feats
|
|
|
|
@ -257,6 +258,7 @@ class ASR:
|
|
|
|
|
np.save(output_path, unfold_feats.cpu().numpy())
|
|
|
|
|
print(f"[INFO] saved logits to {output_path}")
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
def create_file_stream(self):
|
|
|
|
|
|
|
|
|
|
stream, sample_rate = sf.read(self.opt.asr_wav) # [T*sample_rate,] float64
|
|
|
|
@ -302,7 +304,7 @@ class ASR:
|
|
|
|
|
frames_per_buffer=self.chunk)
|
|
|
|
|
|
|
|
|
|
return audio, stream
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
def get_audio_frame(self):
|
|
|
|
|
|
|
|
|
@ -351,8 +353,8 @@ class ASR:
|
|
|
|
|
|
|
|
|
|
# print(frame.shape, inputs.input_values.shape, logits.shape)
|
|
|
|
|
|
|
|
|
|
predicted_ids = torch.argmax(logits, dim=-1)
|
|
|
|
|
transcription = self.processor.batch_decode(predicted_ids)[0].lower()
|
|
|
|
|
#predicted_ids = torch.argmax(logits, dim=-1)
|
|
|
|
|
#transcription = self.processor.batch_decode(predicted_ids)[0].lower()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# for esperanto
|
|
|
|
@ -363,7 +365,7 @@ class ASR:
|
|
|
|
|
# print(predicted_ids[0])
|
|
|
|
|
# print(transcription)
|
|
|
|
|
|
|
|
|
|
return logits[0], predicted_ids[0], transcription # [N,]
|
|
|
|
|
return logits[0], None,None #predicted_ids[0], transcription # [N,]
|
|
|
|
|
|
|
|
|
|
def create_bytes_stream(self,byte_stream):
|
|
|
|
|
#byte_stream=BytesIO(buffer)
|
|
|
|
@ -404,8 +406,8 @@ class ASR:
|
|
|
|
|
self.queue.put(stream[idx:idx+self.chunk])
|
|
|
|
|
streamlen -= self.chunk
|
|
|
|
|
idx += self.chunk
|
|
|
|
|
if streamlen>0:
|
|
|
|
|
self.queue.put(stream[idx:])
|
|
|
|
|
#if streamlen>0: #skip last frame(not 20ms)
|
|
|
|
|
# self.queue.put(stream[idx:])
|
|
|
|
|
self.input_stream.seek(0)
|
|
|
|
|
self.input_stream.truncate()
|
|
|
|
|
|
|
|
|
|