添加ultralight数字人支持
parent
bcc21c3c8a
commit
047f40e302
@ -0,0 +1,34 @@
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from baseasr import BaseASR
|
||||
|
||||
|
||||
class LightASR(BaseASR):
|
||||
def __init__(self, opt, parent, audio_processor):
|
||||
super().__init__(opt, parent)
|
||||
self.audio_processor = audio_processor
|
||||
self.stride_left_size = 32
|
||||
self.stride_right_size = 32
|
||||
|
||||
|
||||
def run_step(self):
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(self.batch_size * 2):
|
||||
audio_frame, type_ = self.get_audio_frame()
|
||||
self.frames.append(audio_frame)
|
||||
self.output_queue.put((audio_frame, type_))
|
||||
|
||||
if len(self.frames) <= self.stride_left_size + self.stride_right_size:
|
||||
return
|
||||
|
||||
inputs = np.concatenate(self.frames) # [N * chunk]
|
||||
|
||||
mel = self.audio_processor.get_hubert_from_16k_speech(inputs)
|
||||
mel_chunks=self.audio_processor.feature2chunks(feature_array=mel,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2)
|
||||
|
||||
self.feat_queue.put(mel_chunks)
|
||||
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
|
||||
print(f"Processing audio costs {(time.time() - start_time) * 1000}ms")
|
||||
|
@ -0,0 +1,96 @@
|
||||
from transformers import Wav2Vec2Processor, HubertModel
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Audio2Feature():
|
||||
def __init__(self):
|
||||
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
self.processor = Wav2Vec2Processor.from_pretrained('./models/hubert-large-ls960-ft')
|
||||
self.model = HubertModel.from_pretrained('./models/hubert-large-ls960-ft').to(self.device)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_hubert_from_16k_speech(self, speech):
|
||||
if speech.ndim == 2:
|
||||
speech = speech[:, 0] # [T, 2] ==> [T,]
|
||||
input_values_all = self.processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T]
|
||||
input_values_all = input_values_all.to(self.device)
|
||||
|
||||
kernel = 400
|
||||
stride = 320
|
||||
clip_length = stride * 1000
|
||||
num_iter = input_values_all.shape[1] // clip_length
|
||||
expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride
|
||||
res_lst = []
|
||||
for i in range(num_iter):
|
||||
if i == 0:
|
||||
start_idx = 0
|
||||
end_idx = clip_length - stride + kernel
|
||||
else:
|
||||
start_idx = clip_length * i
|
||||
end_idx = start_idx + (clip_length - stride + kernel)
|
||||
input_values = input_values_all[:, start_idx: end_idx]
|
||||
hidden_states = self.model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
|
||||
res_lst.append(hidden_states[0])
|
||||
if num_iter > 0:
|
||||
input_values = input_values_all[:, clip_length * num_iter:]
|
||||
else:
|
||||
input_values = input_values_all
|
||||
if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it
|
||||
hidden_states = self.model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024]
|
||||
res_lst.append(hidden_states[0])
|
||||
ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024]
|
||||
assert abs(ret.shape[0] - expected_T) <= 1
|
||||
if ret.shape[0] < expected_T:
|
||||
ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0]))
|
||||
else:
|
||||
ret = ret[:expected_T]
|
||||
return ret
|
||||
|
||||
def get_sliced_feature(self,
|
||||
feature_array,
|
||||
vid_idx,
|
||||
audio_feat_length=[8,8],
|
||||
fps=25):
|
||||
"""
|
||||
Get sliced features based on a given index
|
||||
:param feature_array:
|
||||
:param start_idx: the start index of the feature
|
||||
:param audio_feat_length:
|
||||
:return:
|
||||
"""
|
||||
length = len(feature_array)
|
||||
selected_feature = []
|
||||
selected_idx = []
|
||||
|
||||
center_idx = int(vid_idx*50/fps)
|
||||
left_idx = center_idx-audio_feat_length[0]*2
|
||||
right_idx = center_idx + (audio_feat_length[1])*2
|
||||
|
||||
for idx in range(left_idx,right_idx):
|
||||
idx = max(0, idx)
|
||||
idx = min(length-1, idx)
|
||||
x = feature_array[idx]
|
||||
selected_feature.append(x)
|
||||
selected_idx.append(idx)
|
||||
|
||||
selected_feature = np.concatenate(selected_feature, axis=0)
|
||||
selected_feature = selected_feature.reshape(-1, 1024)
|
||||
return selected_feature,selected_idx
|
||||
|
||||
def feature2chunks(self,feature_array,fps,batch_size,audio_feat_length = [8,8],start=0):
|
||||
whisper_chunks = []
|
||||
whisper_idx_multiplier = 50./fps
|
||||
i = 0
|
||||
#print(f"video in {fps} FPS, audio idx in 50FPS")
|
||||
for _ in range(batch_size):
|
||||
# start_idx = int(i * whisper_idx_multiplier)
|
||||
# if start_idx>=len(feature_array):
|
||||
# break
|
||||
selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i+start,audio_feat_length=audio_feat_length,fps=fps)
|
||||
#print(f"i:{i},selected_idx {selected_idx}")
|
||||
whisper_chunks.append(selected_feature)
|
||||
i += 1
|
||||
|
||||
return whisper_chunks
|
@ -0,0 +1,283 @@
|
||||
import time
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
def __init__(self, inp, oup, stride, use_res_connect, expand_ratio=6):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
self.use_res_connect = use_res_connect
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(inp * expand_ratio),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(inp * expand_ratio,
|
||||
inp * expand_ratio,
|
||||
3,
|
||||
stride,
|
||||
1,
|
||||
groups=inp * expand_ratio,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(inp * expand_ratio),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False),
|
||||
nn.BatchNorm2d(oup),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
if self.use_res_connect:
|
||||
return x + self.conv(x)
|
||||
else:
|
||||
return self.conv(x)
|
||||
|
||||
class DoubleConvDW(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, stride=2):
|
||||
|
||||
super(DoubleConvDW, self).__init__()
|
||||
self.double_conv = nn.Sequential(
|
||||
InvertedResidual(in_channels, out_channels, stride=stride, use_res_connect=False, expand_ratio=2),
|
||||
InvertedResidual(out_channels, out_channels, stride=1, use_res_connect=True, expand_ratio=2)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.double_conv(x)
|
||||
|
||||
class InConvDw(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(InConvDw, self).__init__()
|
||||
self.inconv = nn.Sequential(
|
||||
InvertedResidual(in_channels, out_channels, stride=1, use_res_connect=False, expand_ratio=2)
|
||||
)
|
||||
def forward(self, x):
|
||||
return self.inconv(x)
|
||||
|
||||
class Down(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
|
||||
super(Down, self).__init__()
|
||||
self.maxpool_conv = nn.Sequential(
|
||||
DoubleConvDW(in_channels, out_channels, stride=2)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.maxpool_conv(x)
|
||||
|
||||
class Up(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(Up, self).__init__()
|
||||
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
||||
self.conv = DoubleConvDW(in_channels, out_channels, stride=1)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
|
||||
x1 = self.up(x1)
|
||||
diffY = x2.shape[2] - x1.shape[2]
|
||||
diffX = x2.shape[3] - x1.shape[3]
|
||||
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
||||
x = torch.cat([x1, x2], axis=1)
|
||||
|
||||
return self.conv(x)
|
||||
|
||||
class OutConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(OutConv, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class AudioConvWenet(nn.Module):
|
||||
def __init__(self):
|
||||
super(AudioConvWenet, self).__init__()
|
||||
# ch = [16, 32, 64, 128, 256] # if you want to run this model on a mobile device, use this.
|
||||
ch = [32, 64, 128, 256, 512]
|
||||
self.conv1 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2)
|
||||
self.conv2 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2)
|
||||
|
||||
self.conv3 = nn.Conv2d(ch[3], ch[3], kernel_size=3, padding=1, stride=(1,2))
|
||||
self.bn3 = nn.BatchNorm2d(ch[3])
|
||||
|
||||
self.conv4 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2)
|
||||
|
||||
self.conv5 = nn.Conv2d(ch[3], ch[4], kernel_size=3, padding=3, stride=2)
|
||||
self.bn5 = nn.BatchNorm2d(ch[4])
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.conv6 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2)
|
||||
self.conv7 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
x = self.relu(self.bn3(self.conv3(x)))
|
||||
|
||||
x = self.conv4(x)
|
||||
|
||||
x = self.relu(self.bn5(self.conv5(x)))
|
||||
|
||||
x = self.conv6(x)
|
||||
x = self.conv7(x)
|
||||
|
||||
return x
|
||||
|
||||
class AudioConvHubert(nn.Module):
|
||||
def __init__(self):
|
||||
super(AudioConvHubert, self).__init__()
|
||||
# ch = [16, 32, 64, 128, 256] # if you want to run this model on a mobile device, use this.
|
||||
ch = [32, 64, 128, 256, 512]
|
||||
self.conv1 = InvertedResidual(32, ch[1], stride=1, use_res_connect=False, expand_ratio=2)
|
||||
self.conv2 = InvertedResidual(ch[1], ch[2], stride=1, use_res_connect=False, expand_ratio=2)
|
||||
|
||||
self.conv3 = nn.Conv2d(ch[2], ch[3], kernel_size=3, padding=1, stride=(2,2))
|
||||
self.bn3 = nn.BatchNorm2d(ch[3])
|
||||
|
||||
self.conv4 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2)
|
||||
|
||||
self.conv5 = nn.Conv2d(ch[3], ch[4], kernel_size=3, padding=3, stride=2)
|
||||
self.bn5 = nn.BatchNorm2d(ch[4])
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self.conv6 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2)
|
||||
self.conv7 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
x = self.relu(self.bn3(self.conv3(x)))
|
||||
|
||||
x = self.conv4(x)
|
||||
|
||||
x = self.relu(self.bn5(self.conv5(x)))
|
||||
|
||||
x = self.conv6(x)
|
||||
x = self.conv7(x)
|
||||
|
||||
return x
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self,n_channels=6, mode='hubert'):
|
||||
super(Model, self).__init__()
|
||||
self.n_channels = n_channels #BGR
|
||||
# ch = [16, 32, 64, 128, 256] # if you want to run this model on a mobile device, use this.
|
||||
ch = [32, 64, 128, 256, 512]
|
||||
|
||||
if mode=='hubert':
|
||||
self.audio_model = AudioConvHubert()
|
||||
if mode=='wenet':
|
||||
self.audio_model = AudioConvWenet()
|
||||
|
||||
self.fuse_conv = nn.Sequential(
|
||||
DoubleConvDW(ch[4]*2, ch[4], stride=1),
|
||||
DoubleConvDW(ch[4], ch[3], stride=1)
|
||||
)
|
||||
|
||||
self.inc = InConvDw(n_channels, ch[0])
|
||||
self.down1 = Down(ch[0], ch[1])
|
||||
self.down2 = Down(ch[1], ch[2])
|
||||
self.down3 = Down(ch[2], ch[3])
|
||||
self.down4 = Down(ch[3], ch[4])
|
||||
|
||||
self.up1 = Up(ch[4], ch[3]//2)
|
||||
self.up2 = Up(ch[3], ch[2]//2)
|
||||
self.up3 = Up(ch[2], ch[1]//2)
|
||||
self.up4 = Up(ch[1], ch[0])
|
||||
|
||||
self.outc = OutConv(ch[0], 3)
|
||||
|
||||
def forward(self, x, audio_feat):
|
||||
|
||||
x1 = self.inc(x)
|
||||
x2 = self.down1(x1)
|
||||
x3 = self.down2(x2)
|
||||
x4 = self.down3(x3)
|
||||
x5 = self.down4(x4)
|
||||
|
||||
audio_feat = self.audio_model(audio_feat)
|
||||
x5 = torch.cat([x5, audio_feat], axis=1)
|
||||
x5 = self.fuse_conv(x5)
|
||||
x = self.up1(x5, x4)
|
||||
x = self.up2(x, x3)
|
||||
x = self.up3(x, x2)
|
||||
x = self.up4(x, x1)
|
||||
out = self.outc(x)
|
||||
out = F.sigmoid(out)
|
||||
return out
|
||||
|
||||
if __name__ == '__main__':
|
||||
import time
|
||||
import copy
|
||||
import onnx
|
||||
import numpy as np
|
||||
onnx_path = "./unet.onnx"
|
||||
|
||||
from thop import profile, clever_format
|
||||
|
||||
def reparameterize_model(model: torch.nn.Module) -> torch.nn.Module:
|
||||
""" Method returns a model where a multi-branched structure
|
||||
used in training is re-parameterized into a single branch
|
||||
for inference.
|
||||
:param model: MobileOne model in train mode.
|
||||
:return: MobileOne model in inference mode.
|
||||
"""
|
||||
# Avoid editing original graph
|
||||
model = copy.deepcopy(model)
|
||||
for module in model.modules():
|
||||
if hasattr(module, 'reparameterize'):
|
||||
module.reparameterize()
|
||||
return model
|
||||
device = torch.device("cuda")
|
||||
def check_onnx(torch_out, torch_in, audio):
|
||||
onnx_model = onnx.load(onnx_path)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
import onnxruntime
|
||||
providers = ["CUDAExecutionProvider"]
|
||||
ort_session = onnxruntime.InferenceSession(onnx_path, providers=providers)
|
||||
print(ort_session.get_providers())
|
||||
ort_inputs = {ort_session.get_inputs()[0].name: torch_in.cpu().numpy(), ort_session.get_inputs()[1].name: audio.cpu().numpy()}
|
||||
ort_outs = ort_session.run(None, ort_inputs)
|
||||
np.testing.assert_allclose(torch_out[0].cpu().numpy(), ort_outs[0][0], rtol=1e-03, atol=1e-05)
|
||||
print("Exported model has been tested with ONNXRuntime, and the result looks good!")
|
||||
|
||||
net = Model(6).eval().to(device)
|
||||
img = torch.zeros([1, 6, 160, 160]).to(device)
|
||||
audio = torch.zeros([1, 16, 32, 32]).to(device)
|
||||
# net = reparameterize_model(net)
|
||||
flops, params = profile(net, (img,audio))
|
||||
macs, params = clever_format([flops, params], "%3f")
|
||||
print(macs, params)
|
||||
# dynamic_axes= {'input':[2, 3], 'output':[2, 3]}
|
||||
|
||||
input_dict = {"input": img, "audio": audio}
|
||||
|
||||
with torch.no_grad():
|
||||
torch_out = net(img, audio)
|
||||
print(torch_out.shape)
|
||||
torch.onnx.export(net, (img, audio), onnx_path, input_names=['input', "audio"],
|
||||
output_names=['output'],
|
||||
# dynamic_axes=dynamic_axes,
|
||||
# example_outputs=torch_out,
|
||||
opset_version=11,
|
||||
export_params=True)
|
||||
check_onnx(torch_out, img, audio)
|
||||
|
||||
# img = torch.zeros([1, 6, 160, 160]).to(device)
|
||||
# audio = torch.zeros([1, 16, 32, 32]).to(device)
|
||||
# with torch.no_grad():
|
||||
# for i in range(100000):
|
||||
# t1 = time.time()
|
||||
# out = net(img, audio)
|
||||
# t2 = time.time()
|
||||
# # print(out.shape)
|
||||
# print('time cost::', t2-t1)
|
||||
# torch.save(net.state_dict(), '1.pth')
|
Loading…
Reference in New Issue