diff --git a/app.py b/app.py index 85130cf..f665c7c 100644 --- a/app.py +++ b/app.py @@ -127,6 +127,9 @@ def build_nerfreal(sessionid): elif opt.model == 'ernerf': from nerfreal import NeRFReal nerfreal = NeRFReal(opt,model,avatar) + elif opt.model == 'ultralight': + from lightreal import LightReal + nerfreal = LightReal(opt,model,avatar) return nerfreal #@app.route('/offer', methods=['POST']) @@ -480,6 +483,12 @@ if __name__ == '__main__': # opt.sessionid=k # nerfreal = LipReal(opt,model) # nerfreals.append(nerfreal) + elif opt.model == 'ultralight': + from lightreal import LightReal,load_model,load_avatar,warm_up + print(opt) + model = load_model(opt) + avatar = load_avatar(opt.avatar_id) + warm_up(opt.batch_size,model,160) if opt.transport=='rtmp': thread_quit = Event() @@ -539,4 +548,4 @@ if __name__ == '__main__': # server = pywsgi.WSGIServer(('0.0.0.0', 8000), app, handler_class=WebSocketHandler) # server.serve_forever() - \ No newline at end of file + diff --git a/lightasr.py b/lightasr.py new file mode 100644 index 0000000..08020f3 --- /dev/null +++ b/lightasr.py @@ -0,0 +1,34 @@ +import time +import torch +import numpy as np +from baseasr import BaseASR + + +class LightASR(BaseASR): + def __init__(self, opt, parent, audio_processor): + super().__init__(opt, parent) + self.audio_processor = audio_processor + self.stride_left_size = 32 + self.stride_right_size = 32 + + + def run_step(self): + start_time = time.time() + + for _ in range(self.batch_size * 2): + audio_frame, type_ = self.get_audio_frame() + self.frames.append(audio_frame) + self.output_queue.put((audio_frame, type_)) + + if len(self.frames) <= self.stride_left_size + self.stride_right_size: + return + + inputs = np.concatenate(self.frames) # [N * chunk] + + mel = self.audio_processor.get_hubert_from_16k_speech(inputs) + mel_chunks=self.audio_processor.feature2chunks(feature_array=mel,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2) + + self.feat_queue.put(mel_chunks) + self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] + print(f"Processing audio costs {(time.time() - start_time) * 1000}ms") + diff --git a/lightreal.py b/lightreal.py new file mode 100644 index 0000000..61e02c4 --- /dev/null +++ b/lightreal.py @@ -0,0 +1,350 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + +import math +import torch +import numpy as np + +#from .utils import * +import os +import time +import cv2 +import glob +import pickle +import copy + +import queue +from queue import Queue +from threading import Thread, Event +import torch.multiprocessing as mp + + +from lightasr import LightASR +import asyncio +from av import AudioFrame, VideoFrame +from basereal import BaseReal + +#from imgcache import ImgCache + +from tqdm import tqdm + +#new +import os +import cv2 +import torch +import numpy as np +import torch.nn as nn +from torch import optim +from tqdm import tqdm +from transformers import Wav2Vec2Processor, HubertModel +from torch.utils.data import DataLoader +from ultralight.unet import Model +from ultralight.audio2feature import Audio2Feature + + + +device = 'cuda' if torch.cuda.is_available() else 'cpu' +print('Using {} for inference.'.format(device)) + + + +def load_model(opt): + audio_processor = Audio2Feature() + model = Model(6, 'hubert').to(device) # 假设Model是你自定义的类 + model.load_state_dict(torch.load('./models/ultralight.pth')) + model.eval() + + return model,audio_processor + +def load_avatar(avatar_id): + avatar_path = f"./data/avatars/{avatar_id}" + full_imgs_path = f"{avatar_path}/full_body_img" + land_marks_path = f"{avatar_path}/landmarks" + + input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')) + input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + frame_list_cycle = read_imgs(input_img_list) + #self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000) + land_marks_list = glob.glob(os.path.join(land_marks_path, '*.lms')) + land_marks_list = sorted(land_marks_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + lms_list_cycle = read_lms(land_marks_list) + lms_list_cycle = np.array(lms_list_cycle, dtype=np.int32) + return frame_list_cycle,lms_list_cycle + + +@torch.no_grad() +def warm_up(batch_size,model,modelres): + # ?~D?~C??~G??~U? + print('warmup model...') + model1, audio_processor = model + img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device) + mel_batch = torch.ones(batch_size, 32, 32, 32).to(device) + model1(img_batch, mel_batch) + +def read_imgs(img_list): + frames = [] + print('reading images...') + for img_path in tqdm(img_list): + frame = cv2.imread(img_path) + frames.append(frame) + return frames + +def get_audio_features(features, index): + left = index - 8 + right = index + 8 + pad_left = 0 + pad_right = 0 + if left < 0: + pad_left = -left + left = 0 + if right > features.shape[0]: + pad_right = right - features.shape[0] + right = features.shape[0] + auds = torch.from_numpy(features[left:right]) + if pad_left > 0: + auds = torch.cat([torch.zeros_like(auds[:pad_left]), auds], dim=0) + if pad_right > 0: + auds = torch.cat([auds, torch.zeros_like(auds[:pad_right])], dim=0) # [8, 16] + return auds + + +def read_lms(lms_list): + land_marks = [] + print('reading lms...') + for lms_path in tqdm(lms_list): + file_landmarks = [] # Store landmarks for this file + with open(lms_path, "r") as f: + lines = f.read().splitlines() + for line in lines: + arr = list(filter(None, line.split(" "))) + if arr: + arr = np.array(arr, dtype=np.float32) + file_landmarks.append(arr) + land_marks.append(file_landmarks) # Add the file's landmarks to the overall list + return land_marks + +def __mirror_index(size, index): + #size = len(self.coord_list_cycle) + turn = index // size + res = index % size + if turn % 2 == 0: + return res + else: + return size - res - 1 + + +def inference(quit_event, batch_size, frame_list_cycle, lms_list_cycle, audio_feat_queue, audio_out_queue, res_frame_queue, model): + length = len(lms_list_cycle) + index = 0 + count = 0 + counttime = 0 + print('start inference') + + while not quit_event.is_set(): + starttime=time.perf_counter() + try: + mel_batch = audio_feat_queue.get(block=True, timeout=1) + except queue.Empty: + continue + is_all_silence=True + audio_frames = [] + for _ in range(batch_size*2): + frame,type_ = audio_out_queue.get() + audio_frames.append((frame,type_)) + if type_==0: + is_all_silence=False + if is_all_silence: + for i in range(batch_size): + res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) + index = index + 1 + else: + t = time.perf_counter() + img_batch = [] + + for i in range(batch_size): + idx = __mirror_index(length, index + i) + face = frame_list_cycle[idx] + lms = lms_list_cycle[idx] + xmin, ymin = lms[1][0], lms[52][1] + xmax = lms[31][0] + width = xmax - xmin + ymax = ymin + width + crop_img = face[ymin:ymax, xmin:xmax] +# h, w = crop_img.shape[:2] + crop_img = cv2.resize(crop_img, (168, 168), cv2.INTER_AREA) + crop_img_ori = crop_img.copy() + img_real_ex = crop_img[4:164, 4:164].copy() + img_real_ex_ori = img_real_ex.copy() + img_masked = cv2.rectangle(img_real_ex_ori,(5,5,150,145),(0,0,0),-1) + + img_masked = img_masked.transpose(2,0,1).astype(np.float32) + img_real_ex = img_real_ex.transpose(2,0,1).astype(np.float32) + + img_real_ex_T = torch.from_numpy(img_real_ex / 255.0) + img_masked_T = torch.from_numpy(img_masked / 255.0) + img_concat_T = torch.cat([img_real_ex_T, img_masked_T], axis=0)[None] + img_batch.append(img_concat_T) + + reshaped_mel_batch = [arr.reshape(32, 32, 32) for arr in mel_batch] + mel_batch = torch.stack([torch.from_numpy(arr) for arr in reshaped_mel_batch]) + img_batch = torch.stack(img_batch).squeeze(1) + + + with torch.no_grad(): + pred = model(img_batch.cuda(),mel_batch.cuda()) + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + + counttime += (time.perf_counter() - t) + count += batch_size + if count >= 100: + print(f"------actual avg infer fps:{count / counttime:.4f}") + count = 0 + counttime = 0 + for i,res_frame in enumerate(pred): + #self.__pushmedia(res_frame,loop,audio_track,video_track) + res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) + index = index + 1 + +# for i, pred_frame in enumerate(pred): +# pred_frame_uint8 = np.array(pred_frame, dtype=np.uint8) +# res_frame_queue.put((pred_frame_uint8, __mirror_index(length, index), audio_frames[i * 2:i * 2 + 2])) +# index = (index + 1) % length + + #print('total batch time:', time.perf_counter() - starttime) + + print('lightreal inference processor stop') + + +class LightReal(BaseReal): + @torch.no_grad() + def __init__(self, opt, model, avatar): + super().__init__(opt) + #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. + self.W = opt.W + self.H = opt.H + + self.fps = opt.fps # 20 ms per frame + + self.batch_size = opt.batch_size + self.idx = 0 + self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue + #self.__loadavatar() + self.model,audio_processor = model + self.frame_list_cycle,self.lms_list_cycle = avatar + + self.asr = LightASR(opt,self,audio_processor) + self.asr.warm_up() + #self.__warm_up() + + self.render_event = mp.Event() + + def __del__(self): + print(f'lightreal({self.sessionid}) delete') + + + def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): + + while not quit_event.is_set(): + try: + res_frame,idx,audio_frames = self.res_frame_queue.get(block=True, timeout=1) + except queue.Empty: + continue + if audio_frames[0][1]!=0 and audio_frames[1][1]!=0: #全为静音数据,只需要取fullimg + self.speaking = False + audiotype = audio_frames[0][1] + if self.custom_index.get(audiotype) is not None: #有自定义视频 + mirindex = self.mirror_index(len(self.custom_img_cycle[audiotype]),self.custom_index[audiotype]) + combine_frame = self.custom_img_cycle[audiotype][mirindex] + self.custom_index[audiotype] += 1 + # if not self.custom_opt[audiotype].loop and self.custom_index[audiotype]>=len(self.custom_img_cycle[audiotype]): + # self.curr_state = 1 #当前视频不循环播放,切换到静音状态 + else: + combine_frame = self.frame_list_cycle[idx] + #combine_frame = self.imagecache.get_img(idx) + else: + self.speaking = True + lms = self.lms_list_cycle[idx] + combine_frame = copy.deepcopy(self.frame_list_cycle[idx]) + xmin = lms[1][0] + ymin = lms[52][1] + + xmax = lms[31][0] + width = xmax - xmin + ymax = ymin + width + crop_img = combine_frame[ymin:ymax, xmin:xmax] + h, w = crop_img.shape[:2] + crop_img_ori = cv2.resize(crop_img, (168, 168), cv2.INTER_AREA).copy() + #combine_frame = copy.deepcopy(self.imagecache.get_img(idx)) + res_frame = np.array(res_frame, dtype=np.uint8) + crop_img_ori[4:164, 4:164] = res_frame + crop_img_ori = cv2.resize(crop_img_ori, (w, h)) + combine_frame[ymin:ymax, xmin:xmax] = crop_img_ori + #print('blending time:',time.perf_counter()-t) + + new_frame = VideoFrame.from_ndarray(combine_frame, format="bgr24") + asyncio.run_coroutine_threadsafe(video_track._queue.put(new_frame), loop) + self.record_video_data(combine_frame) + + for audio_frame in audio_frames: + frame,type_ = audio_frame + frame = (frame * 32767).astype(np.int16) + new_frame = AudioFrame(format='s16', layout='mono', samples=frame.shape[0]) + new_frame.planes[0].update(frame.tobytes()) + new_frame.sample_rate=16000 + # if audio_track._queue.qsize()>10: + # time.sleep(0.1) + asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) + self.record_audio_data(frame) + print('lightreal process_frames thread stop') + + def render(self,quit_event,loop=None,audio_track=None,video_track=None): + #if self.opt.asr: + # self.asr.warm_up() + + self.tts.render(quit_event) + self.init_customindex() + process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track)) + process_thread.start() + Thread(target=inference, args=(quit_event,self.batch_size,self.frame_list_cycle,self.lms_list_cycle,self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, + self.model,)).start() #mp.Process + + + #self.render_event.set() #start infer process render + count=0 + totaltime=0 + _starttime=time.perf_counter() + #_totalframe=0 + while not quit_event.is_set(): + # update texture every frame + # audio stream thread... + t = time.perf_counter() + self.asr.run_step() + + # if video_track._queue.qsize()>=2*self.opt.batch_size: + # print('sleep qsize=',video_track._queue.qsize()) + # time.sleep(0.04*video_track._queue.qsize()*0.8) + if video_track._queue.qsize()>=5: + print('sleep qsize=',video_track._queue.qsize()) + time.sleep(0.04*video_track._queue.qsize()*0.8) + + # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms + # if delay > 0: + # time.sleep(delay) + #self.render_event.clear() #end infer process render + print('lightreal thread stop') + + diff --git a/ultralight/audio2feature.py b/ultralight/audio2feature.py new file mode 100644 index 0000000..dae7b08 --- /dev/null +++ b/ultralight/audio2feature.py @@ -0,0 +1,96 @@ +from transformers import Wav2Vec2Processor, HubertModel +import torch +import numpy as np + + +class Audio2Feature(): + def __init__(self): + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.processor = Wav2Vec2Processor.from_pretrained('./models/hubert-large-ls960-ft') + self.model = HubertModel.from_pretrained('./models/hubert-large-ls960-ft').to(self.device) + + + @torch.no_grad() + def get_hubert_from_16k_speech(self, speech): + if speech.ndim == 2: + speech = speech[:, 0] # [T, 2] ==> [T,] + input_values_all = self.processor(speech, return_tensors="pt", sampling_rate=16000).input_values # [1, T] + input_values_all = input_values_all.to(self.device) + + kernel = 400 + stride = 320 + clip_length = stride * 1000 + num_iter = input_values_all.shape[1] // clip_length + expected_T = (input_values_all.shape[1] - (kernel-stride)) // stride + res_lst = [] + for i in range(num_iter): + if i == 0: + start_idx = 0 + end_idx = clip_length - stride + kernel + else: + start_idx = clip_length * i + end_idx = start_idx + (clip_length - stride + kernel) + input_values = input_values_all[:, start_idx: end_idx] + hidden_states = self.model.forward(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] + res_lst.append(hidden_states[0]) + if num_iter > 0: + input_values = input_values_all[:, clip_length * num_iter:] + else: + input_values = input_values_all + if input_values.shape[1] >= kernel: # if the last batch is shorter than kernel_size, skip it + hidden_states = self.model(input_values).last_hidden_state # [B=1, T=pts//320, hid=1024] + res_lst.append(hidden_states[0]) + ret = torch.cat(res_lst, dim=0).cpu() # [T, 1024] + assert abs(ret.shape[0] - expected_T) <= 1 + if ret.shape[0] < expected_T: + ret = torch.nn.functional.pad(ret, (0,0,0,expected_T-ret.shape[0])) + else: + ret = ret[:expected_T] + return ret + + def get_sliced_feature(self, + feature_array, + vid_idx, + audio_feat_length=[8,8], + fps=25): + """ + Get sliced features based on a given index + :param feature_array: + :param start_idx: the start index of the feature + :param audio_feat_length: + :return: + """ + length = len(feature_array) + selected_feature = [] + selected_idx = [] + + center_idx = int(vid_idx*50/fps) + left_idx = center_idx-audio_feat_length[0]*2 + right_idx = center_idx + (audio_feat_length[1])*2 + + for idx in range(left_idx,right_idx): + idx = max(0, idx) + idx = min(length-1, idx) + x = feature_array[idx] + selected_feature.append(x) + selected_idx.append(idx) + + selected_feature = np.concatenate(selected_feature, axis=0) + selected_feature = selected_feature.reshape(-1, 1024) + return selected_feature,selected_idx + + def feature2chunks(self,feature_array,fps,batch_size,audio_feat_length = [8,8],start=0): + whisper_chunks = [] + whisper_idx_multiplier = 50./fps + i = 0 + #print(f"video in {fps} FPS, audio idx in 50FPS") + for _ in range(batch_size): + # start_idx = int(i * whisper_idx_multiplier) + # if start_idx>=len(feature_array): + # break + selected_feature,selected_idx = self.get_sliced_feature(feature_array= feature_array,vid_idx = i+start,audio_feat_length=audio_feat_length,fps=fps) + #print(f"i:{i},selected_idx {selected_idx}") + whisper_chunks.append(selected_feature) + i += 1 + + return whisper_chunks diff --git a/ultralight/unet.py b/ultralight/unet.py new file mode 100644 index 0000000..d60f51f --- /dev/null +++ b/ultralight/unet.py @@ -0,0 +1,283 @@ +import time +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, use_res_connect, expand_ratio=6): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + self.use_res_connect = use_res_connect + + self.conv = nn.Sequential( + nn.Conv2d(inp, inp * expand_ratio, 1, 1, 0, bias=False), + nn.BatchNorm2d(inp * expand_ratio), + nn.ReLU(inplace=True), + nn.Conv2d(inp * expand_ratio, + inp * expand_ratio, + 3, + stride, + 1, + groups=inp * expand_ratio, + bias=False), + nn.BatchNorm2d(inp * expand_ratio), + nn.ReLU(inplace=True), + nn.Conv2d(inp * expand_ratio, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + +class DoubleConvDW(nn.Module): + + def __init__(self, in_channels, out_channels, stride=2): + + super(DoubleConvDW, self).__init__() + self.double_conv = nn.Sequential( + InvertedResidual(in_channels, out_channels, stride=stride, use_res_connect=False, expand_ratio=2), + InvertedResidual(out_channels, out_channels, stride=1, use_res_connect=True, expand_ratio=2) + ) + + def forward(self, x): + return self.double_conv(x) + +class InConvDw(nn.Module): + def __init__(self, in_channels, out_channels): + super(InConvDw, self).__init__() + self.inconv = nn.Sequential( + InvertedResidual(in_channels, out_channels, stride=1, use_res_connect=False, expand_ratio=2) + ) + def forward(self, x): + return self.inconv(x) + +class Down(nn.Module): + + def __init__(self, in_channels, out_channels): + + super(Down, self).__init__() + self.maxpool_conv = nn.Sequential( + DoubleConvDW(in_channels, out_channels, stride=2) + ) + + def forward(self, x): + return self.maxpool_conv(x) + +class Up(nn.Module): + + def __init__(self, in_channels, out_channels): + super(Up, self).__init__() + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConvDW(in_channels, out_channels, stride=1) + + def forward(self, x1, x2): + + x1 = self.up(x1) + diffY = x2.shape[2] - x1.shape[2] + diffX = x2.shape[3] - x1.shape[3] + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) + x = torch.cat([x1, x2], axis=1) + + return self.conv(x) + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + def forward(self, x): + return self.conv(x) + +class AudioConvWenet(nn.Module): + def __init__(self): + super(AudioConvWenet, self).__init__() + # ch = [16, 32, 64, 128, 256] # if you want to run this model on a mobile device, use this. + ch = [32, 64, 128, 256, 512] + self.conv1 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2) + self.conv2 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2) + + self.conv3 = nn.Conv2d(ch[3], ch[3], kernel_size=3, padding=1, stride=(1,2)) + self.bn3 = nn.BatchNorm2d(ch[3]) + + self.conv4 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2) + + self.conv5 = nn.Conv2d(ch[3], ch[4], kernel_size=3, padding=3, stride=2) + self.bn5 = nn.BatchNorm2d(ch[4]) + self.relu = nn.ReLU() + + self.conv6 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2) + self.conv7 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2) + + def forward(self, x): + + x = self.conv1(x) + x = self.conv2(x) + + x = self.relu(self.bn3(self.conv3(x))) + + x = self.conv4(x) + + x = self.relu(self.bn5(self.conv5(x))) + + x = self.conv6(x) + x = self.conv7(x) + + return x + +class AudioConvHubert(nn.Module): + def __init__(self): + super(AudioConvHubert, self).__init__() + # ch = [16, 32, 64, 128, 256] # if you want to run this model on a mobile device, use this. + ch = [32, 64, 128, 256, 512] + self.conv1 = InvertedResidual(32, ch[1], stride=1, use_res_connect=False, expand_ratio=2) + self.conv2 = InvertedResidual(ch[1], ch[2], stride=1, use_res_connect=False, expand_ratio=2) + + self.conv3 = nn.Conv2d(ch[2], ch[3], kernel_size=3, padding=1, stride=(2,2)) + self.bn3 = nn.BatchNorm2d(ch[3]) + + self.conv4 = InvertedResidual(ch[3], ch[3], stride=1, use_res_connect=True, expand_ratio=2) + + self.conv5 = nn.Conv2d(ch[3], ch[4], kernel_size=3, padding=3, stride=2) + self.bn5 = nn.BatchNorm2d(ch[4]) + self.relu = nn.ReLU() + + self.conv6 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2) + self.conv7 = InvertedResidual(ch[4], ch[4], stride=1, use_res_connect=True, expand_ratio=2) + + def forward(self, x): + + x = self.conv1(x) + x = self.conv2(x) + + x = self.relu(self.bn3(self.conv3(x))) + + x = self.conv4(x) + + x = self.relu(self.bn5(self.conv5(x))) + + x = self.conv6(x) + x = self.conv7(x) + + return x + +class Model(nn.Module): + def __init__(self,n_channels=6, mode='hubert'): + super(Model, self).__init__() + self.n_channels = n_channels #BGR + # ch = [16, 32, 64, 128, 256] # if you want to run this model on a mobile device, use this. + ch = [32, 64, 128, 256, 512] + + if mode=='hubert': + self.audio_model = AudioConvHubert() + if mode=='wenet': + self.audio_model = AudioConvWenet() + + self.fuse_conv = nn.Sequential( + DoubleConvDW(ch[4]*2, ch[4], stride=1), + DoubleConvDW(ch[4], ch[3], stride=1) + ) + + self.inc = InConvDw(n_channels, ch[0]) + self.down1 = Down(ch[0], ch[1]) + self.down2 = Down(ch[1], ch[2]) + self.down3 = Down(ch[2], ch[3]) + self.down4 = Down(ch[3], ch[4]) + + self.up1 = Up(ch[4], ch[3]//2) + self.up2 = Up(ch[3], ch[2]//2) + self.up3 = Up(ch[2], ch[1]//2) + self.up4 = Up(ch[1], ch[0]) + + self.outc = OutConv(ch[0], 3) + + def forward(self, x, audio_feat): + + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + + audio_feat = self.audio_model(audio_feat) + x5 = torch.cat([x5, audio_feat], axis=1) + x5 = self.fuse_conv(x5) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + out = self.outc(x) + out = F.sigmoid(out) + return out + +if __name__ == '__main__': + import time + import copy + import onnx + import numpy as np + onnx_path = "./unet.onnx" + + from thop import profile, clever_format + + def reparameterize_model(model: torch.nn.Module) -> torch.nn.Module: + """ Method returns a model where a multi-branched structure + used in training is re-parameterized into a single branch + for inference. + :param model: MobileOne model in train mode. + :return: MobileOne model in inference mode. + """ + # Avoid editing original graph + model = copy.deepcopy(model) + for module in model.modules(): + if hasattr(module, 'reparameterize'): + module.reparameterize() + return model + device = torch.device("cuda") + def check_onnx(torch_out, torch_in, audio): + onnx_model = onnx.load(onnx_path) + onnx.checker.check_model(onnx_model) + import onnxruntime + providers = ["CUDAExecutionProvider"] + ort_session = onnxruntime.InferenceSession(onnx_path, providers=providers) + print(ort_session.get_providers()) + ort_inputs = {ort_session.get_inputs()[0].name: torch_in.cpu().numpy(), ort_session.get_inputs()[1].name: audio.cpu().numpy()} + ort_outs = ort_session.run(None, ort_inputs) + np.testing.assert_allclose(torch_out[0].cpu().numpy(), ort_outs[0][0], rtol=1e-03, atol=1e-05) + print("Exported model has been tested with ONNXRuntime, and the result looks good!") + + net = Model(6).eval().to(device) + img = torch.zeros([1, 6, 160, 160]).to(device) + audio = torch.zeros([1, 16, 32, 32]).to(device) + # net = reparameterize_model(net) + flops, params = profile(net, (img,audio)) + macs, params = clever_format([flops, params], "%3f") + print(macs, params) + # dynamic_axes= {'input':[2, 3], 'output':[2, 3]} + + input_dict = {"input": img, "audio": audio} + + with torch.no_grad(): + torch_out = net(img, audio) + print(torch_out.shape) + torch.onnx.export(net, (img, audio), onnx_path, input_names=['input', "audio"], + output_names=['output'], + # dynamic_axes=dynamic_axes, + # example_outputs=torch_out, + opset_version=11, + export_params=True) + check_onnx(torch_out, img, audio) + + # img = torch.zeros([1, 6, 160, 160]).to(device) + # audio = torch.zeros([1, 16, 32, 32]).to(device) + # with torch.no_grad(): + # for i in range(100000): + # t1 = time.time() + # out = net(img, audio) + # t2 = time.time() + # # print(out.shape) + # print('time cost::', t2-t1) + # torch.save(net.state_dict(), '1.pth') \ No newline at end of file