|
|
@ -2,7 +2,7 @@ import time
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from transformers import AutoModelForCTC, AutoProcessor
|
|
|
|
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
|
|
|
|
|
|
|
|
|
|
|
|
#import pyaudio
|
|
|
|
#import pyaudio
|
|
|
|
import soundfile as sf
|
|
|
|
import soundfile as sf
|
|
|
@ -52,6 +52,8 @@ class ASR:
|
|
|
|
self.audio_dim = 44
|
|
|
|
self.audio_dim = 44
|
|
|
|
elif 'deepspeech' in self.opt.asr_model:
|
|
|
|
elif 'deepspeech' in self.opt.asr_model:
|
|
|
|
self.audio_dim = 29
|
|
|
|
self.audio_dim = 29
|
|
|
|
|
|
|
|
elif 'hubert' in self.opt.asr_model:
|
|
|
|
|
|
|
|
self.audio_dim = 1024
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.audio_dim = 32
|
|
|
|
self.audio_dim = 32
|
|
|
|
|
|
|
|
|
|
|
@ -96,6 +98,10 @@ class ASR:
|
|
|
|
|
|
|
|
|
|
|
|
# create wav2vec model
|
|
|
|
# create wav2vec model
|
|
|
|
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
|
|
|
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
|
|
|
|
|
|
|
|
if 'hubert' in self.opt.asr_model:
|
|
|
|
|
|
|
|
self.processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
|
|
|
|
|
|
|
|
self.model = HubertModel.from_pretrained(opt.asr_model).to(self.device)
|
|
|
|
|
|
|
|
else:
|
|
|
|
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
|
|
|
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
|
|
|
|
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
|
|
|
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
|
|
|
|
|
|
|
|
|
|
|
@ -339,7 +345,11 @@ class ASR:
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|
result = self.model(inputs.input_values.to(self.device))
|
|
|
|
result = self.model(inputs.input_values.to(self.device))
|
|
|
|
|
|
|
|
if 'hubert' in self.opt.asr_model:
|
|
|
|
|
|
|
|
logits = result.last_hidden_state # [B=1, T=pts//320, hid=1024]
|
|
|
|
|
|
|
|
else:
|
|
|
|
logits = result.logits # [1, N - 1, 32]
|
|
|
|
logits = result.logits # [1, N - 1, 32]
|
|
|
|
|
|
|
|
#print('logits.shape:',logits.shape)
|
|
|
|
|
|
|
|
|
|
|
|
# cut off stride
|
|
|
|
# cut off stride
|
|
|
|
left = max(0, self.stride_left_size)
|
|
|
|
left = max(0, self.stride_left_size)
|
|
|
|