From dbe508cb65fce4ad859ab9f04d86b420306557ef Mon Sep 17 00:00:00 2001 From: lipku Date: Sun, 8 Dec 2024 16:49:06 +0800 Subject: [PATCH] cuda memory does not increase with the number of concurrency --- README.md | 4 + app.py | 51 ++++++++-- baseasr.py | 19 +++- basereal.py | 20 +++- lipasr.py | 19 +++- lipreal.py | 174 +++++++++++++++++---------------- museasr.py | 19 +++- musereal.py | 186 ++++++++++++++++++++---------------- musetalk/simple_musetalk.py | 1 + nerfasr.py | 38 +++++--- nerfreal.py | 23 ++++- ttsreal.py | 17 ++++ wav2lip/genavatar.py | 1 + webrtc.py | 16 ++++ 14 files changed, 396 insertions(+), 192 deletions(-) diff --git a/README.md b/README.md index bb6e87d..2dd9015 100644 --- a/README.md +++ b/README.md @@ -5,6 +5,9 @@ Real time interactive streaming digital human, realize audio video synchronous ## 为避免与3d数字人混淆,原项目metahuman-stream改名为livetalking,原有链接地址继续可用 +## News +- 2024.12.8 完善多并发,显存不随并发数增加 + ## Features 1. 支持多种数字人模型: ernerf、musetalk、wav2lip 2. 支持声音克隆 @@ -12,6 +15,7 @@ Real time interactive streaming digital human, realize audio video synchronous 4. 支持全身视频拼接 5. 支持rtmp和webrtc 6. 支持视频编排:不说话时播放自定义视频 +7. 支持多并发 ## 1. Installation diff --git a/app.py b/app.py index 791bc55..05a696b 100644 --- a/app.py +++ b/app.py @@ -1,3 +1,20 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + # server.py from flask import Flask, render_template,send_from_directory,request, jsonify from flask_sockets import Sockets @@ -11,7 +28,8 @@ import os import re import numpy as np from threading import Thread,Event -import multiprocessing +#import multiprocessing +import torch.multiprocessing as mp from aiohttp import web import aiohttp @@ -24,7 +42,7 @@ import argparse import shutil import asyncio -import string +import torch app = Flask(__name__) @@ -302,7 +320,7 @@ async def run(push_url): # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' if __name__ == '__main__': - multiprocessing.set_start_method('spawn') + mp.set_start_method('spawn') parser = argparse.ArgumentParser() parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source") parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area") @@ -452,6 +470,7 @@ if __name__ == '__main__': from ernerf.nerf_triplane.provider import NeRFDataset_Test from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.network import NeRFNetwork + from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel from nerfreal import NeRFReal # assert test mode opt.test = True @@ -493,24 +512,42 @@ if __name__ == '__main__': model.aud_features = test_loader._data.auds model.eye_areas = test_loader._data.eye_area + print(f'[INFO] loading ASR model {opt.asr_model}...') + if 'hubert' in opt.asr_model: + audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model) + audio_model = HubertModel.from_pretrained(opt.asr_model).to(device) + else: + audio_processor = AutoProcessor.from_pretrained(opt.asr_model) + audio_model = AutoModelForCTC.from_pretrained(opt.asr_model).to(device) + # we still need test_loader to provide audio features for testing. for k in range(opt.max_session): opt.sessionid=k - nerfreal = NeRFReal(opt, trainer, test_loader) + nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model) nerfreals.append(nerfreal) elif opt.model == 'musetalk': from musereal import MuseReal + from musetalk.utils.utils import load_all_model print(opt) + audio_processor,vae, unet, pe = load_all_model() + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + timesteps = torch.tensor([0], device=device) + pe = pe.half() + vae.vae = vae.vae.half() + #vae.vae.share_memory() + unet.model = unet.model.half() + #unet.model.share_memory() for k in range(opt.max_session): opt.sessionid=k - nerfreal = MuseReal(opt) + nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps) nerfreals.append(nerfreal) elif opt.model == 'wav2lip': - from lipreal import LipReal + from lipreal import LipReal,load_model print(opt) + model = load_model("./models/wav2lip.pth") for k in range(opt.max_session): opt.sessionid=k - nerfreal = LipReal(opt) + nerfreal = LipReal(opt,model) nerfreals.append(nerfreal) for _ in range(opt.max_session): diff --git a/baseasr.py b/baseasr.py index 7c370a3..827bf90 100644 --- a/baseasr.py +++ b/baseasr.py @@ -1,9 +1,26 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + import time import numpy as np import queue from queue import Queue -import multiprocessing as mp +import torch.multiprocessing as mp class BaseASR: diff --git a/basereal.py b/basereal.py index 4cb0bf7..edc07c7 100644 --- a/basereal.py +++ b/basereal.py @@ -1,3 +1,20 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + import math import torch import numpy as np @@ -7,8 +24,6 @@ import os import time import cv2 import glob -import pickle -import copy import resampy import queue @@ -36,6 +51,7 @@ class BaseReal: self.opt = opt self.sample_rate = 16000 self.chunk = self.sample_rate // opt.fps # 320 samples per chunk (20ms * 16000 / 1000) + self.sessionid = self.opt.sessionid if opt.tts == "edgetts": self.tts = EdgeTTS(opt,self) diff --git a/lipasr.py b/lipasr.py index 29948ac..0868c97 100644 --- a/lipasr.py +++ b/lipasr.py @@ -1,10 +1,27 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + import time import torch import numpy as np import queue from queue import Queue -import multiprocessing as mp +#import multiprocessing as mp from baseasr import BaseASR from wav2lip import audio diff --git a/lipreal.py b/lipreal.py index 968eabc..05beade 100644 --- a/lipreal.py +++ b/lipreal.py @@ -1,9 +1,25 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + import math import torch import numpy as np #from .utils import * -import subprocess import os import time import cv2 @@ -14,11 +30,8 @@ import copy import queue from queue import Queue from threading import Thread, Event -from io import BytesIO -import multiprocessing as mp - +import torch.multiprocessing as mp -from ttsreal import EdgeTTS,VoitsTTS,XTTS from lipasr import LipASR import asyncio @@ -35,7 +48,7 @@ print('Using {} for inference.'.format(device)) def _load(checkpoint_path): if device == 'cuda': - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path,weights_only=True) else: checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage) @@ -71,12 +84,12 @@ def __mirror_index(size, index): else: return size - res - 1 -def inference(render_event,batch_size,face_imgs_path,audio_feat_queue,audio_out_queue,res_frame_queue): +def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,model): - model = load_model("./models/wav2lip.pth") - input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]')) - input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) - face_list_cycle = read_imgs(input_face_list) + #model = load_model("./models/wav2lip.pth") + # input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]')) + # input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + # face_list_cycle = read_imgs(input_face_list) #input_latent_list_cycle = torch.load(latents_out_path) length = len(face_list_cycle) @@ -84,69 +97,66 @@ def inference(render_event,batch_size,face_imgs_path,audio_feat_queue,audio_out_ count=0 counttime=0 print('start inference') - while True: - if render_event.is_set(): - starttime=time.perf_counter() - mel_batch = [] - try: - mel_batch = audio_feat_queue.get(block=True, timeout=1) - except queue.Empty: - continue - - is_all_silence=True - audio_frames = [] - for _ in range(batch_size*2): - frame,type = audio_out_queue.get() - audio_frames.append((frame,type)) - if type==0: - is_all_silence=False + while not quit_event.is_set(): + starttime=time.perf_counter() + mel_batch = [] + try: + mel_batch = audio_feat_queue.get(block=True, timeout=1) + except queue.Empty: + continue + + is_all_silence=True + audio_frames = [] + for _ in range(batch_size*2): + frame,type = audio_out_queue.get() + audio_frames.append((frame,type)) + if type==0: + is_all_silence=False - if is_all_silence: - for i in range(batch_size): - res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) - index = index + 1 - else: - # print('infer=======') - t=time.perf_counter() - img_batch = [] - for i in range(batch_size): - idx = __mirror_index(length,index+i) - face = face_list_cycle[idx] - img_batch.append(face) - img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) + if is_all_silence: + for i in range(batch_size): + res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) + index = index + 1 + else: + # print('infer=======') + t=time.perf_counter() + img_batch = [] + for i in range(batch_size): + idx = __mirror_index(length,index+i) + face = face_list_cycle[idx] + img_batch.append(face) + img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) - img_masked = img_batch.copy() - img_masked[:, face.shape[0]//2:] = 0 + img_masked = img_batch.copy() + img_masked[:, face.shape[0]//2:] = 0 - img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. - mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) - - img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) - mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) + img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. + mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) + + img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) + mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) - with torch.no_grad(): - pred = model(mel_batch, img_batch) - pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. + with torch.no_grad(): + pred = model(mel_batch, img_batch) + pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. - counttime += (time.perf_counter() - t) - count += batch_size - #_totalframe += 1 - if count>=100: - print(f"------actual avg infer fps:{count/counttime:.4f}") - count=0 - counttime=0 - for i,res_frame in enumerate(pred): - #self.__pushmedia(res_frame,loop,audio_track,video_track) - res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) - index = index + 1 - #print('total batch time:',time.perf_counter()-starttime) - else: - time.sleep(1) - print('musereal inference processor stop') + counttime += (time.perf_counter() - t) + count += batch_size + #_totalframe += 1 + if count>=100: + print(f"------actual avg infer fps:{count/counttime:.4f}") + count=0 + counttime=0 + for i,res_frame in enumerate(pred): + #self.__pushmedia(res_frame,loop,audio_track,video_track) + res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) + index = index + 1 + #print('total batch time:',time.perf_counter()-starttime) + print('lipreal inference processor stop') -@torch.no_grad() class LipReal(BaseReal): - def __init__(self, opt): + @torch.no_grad() + def __init__(self, opt, model): super().__init__(opt) #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W @@ -162,7 +172,7 @@ class LipReal(BaseReal): self.coords_path = f"{self.avatar_path}/coords.pkl" self.batch_size = opt.batch_size self.idx = 0 - self.res_frame_queue = mp.Queue(self.batch_size*2) + self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue #self.__loadmodels() self.__loadavatar() @@ -170,19 +180,8 @@ class LipReal(BaseReal): self.asr.warm_up() #self.__warm_up() + self.model = model self.render_event = mp.Event() - mp.Process(target=inference, args=(self.render_event,self.batch_size,self.face_imgs_path, - self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, - )).start() - - # def __loadmodels(self): - # # load model weights - # self.audio_processor, self.vae, self.unet, self.pe = load_all_model() - # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # self.timesteps = torch.tensor([0], device=device) - # self.pe = self.pe.half() - # self.vae.vae = self.vae.vae.half() - # self.unet.model = self.unet.model.half() def __loadavatar(self): with open(self.coords_path, 'rb') as f: @@ -191,6 +190,9 @@ class LipReal(BaseReal): input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) self.frame_list_cycle = read_imgs(input_img_list) #self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000) + input_face_list = glob.glob(os.path.join(self.face_imgs_path, '*.[jpJP][pnPN]*[gG]')) + input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) + self.face_list_cycle = read_imgs(input_face_list) def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): @@ -242,7 +244,7 @@ class LipReal(BaseReal): # time.sleep(0.1) asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) self.record_audio_data(frame) - print('musereal process_frames thread stop') + print('lipreal process_frames thread stop') def render(self,quit_event,loop=None,audio_track=None,video_track=None): #if self.opt.asr: @@ -253,7 +255,11 @@ class LipReal(BaseReal): process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track)) process_thread.start() - self.render_event.set() #start infer process render + Thread(target=inference, args=(quit_event,self.batch_size,self.face_list_cycle, + self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, + self.model,)).start() #mp.Process + + #self.render_event.set() #start infer process render count=0 totaltime=0 _starttime=time.perf_counter() @@ -274,6 +280,6 @@ class LipReal(BaseReal): # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # if delay > 0: # time.sleep(delay) - self.render_event.clear() #end infer process render - print('musereal thread stop') + #self.render_event.clear() #end infer process render + print('lipreal thread stop') \ No newline at end of file diff --git a/museasr.py b/museasr.py index c5c767c..b8be556 100644 --- a/museasr.py +++ b/museasr.py @@ -1,9 +1,26 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + import time import numpy as np import queue from queue import Queue -import multiprocessing as mp +#import multiprocessing as mp from baseasr import BaseASR from musetalk.whisper.audio2feature import Audio2Feature diff --git a/musereal.py b/musereal.py index 9ec6b0f..6891705 100644 --- a/musereal.py +++ b/musereal.py @@ -1,3 +1,20 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + import math import torch import numpy as np @@ -15,14 +32,13 @@ import copy import queue from queue import Queue from threading import Thread, Event -from io import BytesIO -import multiprocessing as mp +import torch.multiprocessing as mp from musetalk.utils.utils import get_file_type,get_video_fps,datagen #from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model -from ttsreal import EdgeTTS,VoitsTTS,XTTS +from musetalk.whisper.audio2feature import Audio2Feature from museasr import MuseASR import asyncio @@ -46,88 +62,90 @@ def __mirror_index(size, index): return res else: return size - res - 1 + @torch.no_grad() -def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue, - ): #vae, unet, pe,timesteps +def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue, + vae, unet, pe,timesteps): #vae, unet, pe,timesteps - vae, unet, pe = load_diffusion_model() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - timesteps = torch.tensor([0], device=device) - pe = pe.half() - vae.vae = vae.vae.half() - unet.model = unet.model.half() + # vae, unet, pe = load_diffusion_model() + # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # timesteps = torch.tensor([0], device=device) + # pe = pe.half() + # vae.vae = vae.vae.half() + # unet.model = unet.model.half() - input_latent_list_cycle = torch.load(latents_out_path) length = len(input_latent_list_cycle) index = 0 count=0 counttime=0 print('start inference') - while True: - if render_event.is_set(): - starttime=time.perf_counter() - try: - whisper_chunks = audio_feat_queue.get(block=True, timeout=1) - except queue.Empty: - continue - is_all_silence=True - audio_frames = [] - for _ in range(batch_size*2): - frame,type = audio_out_queue.get() - audio_frames.append((frame,type)) - if type==0: - is_all_silence=False - if is_all_silence: - for i in range(batch_size): - res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) - index = index + 1 - else: - # print('infer=======') - t=time.perf_counter() - whisper_batch = np.stack(whisper_chunks) - latent_batch = [] - for i in range(batch_size): - idx = __mirror_index(length,index+i) - latent = input_latent_list_cycle[idx] - latent_batch.append(latent) - latent_batch = torch.cat(latent_batch, dim=0) - - # for i, (whisper_batch,latent_batch) in enumerate(gen): - audio_feature_batch = torch.from_numpy(whisper_batch) - audio_feature_batch = audio_feature_batch.to(device=unet.device, - dtype=unet.model.dtype) - audio_feature_batch = pe(audio_feature_batch) - latent_batch = latent_batch.to(dtype=unet.model.dtype) - # print('prepare time:',time.perf_counter()-t) - # t=time.perf_counter() - - pred_latents = unet.model(latent_batch, - timesteps, - encoder_hidden_states=audio_feature_batch).sample - # print('unet time:',time.perf_counter()-t) - # t=time.perf_counter() - recon = vae.decode_latents(pred_latents) - # print('vae time:',time.perf_counter()-t) - #print('diffusion len=',len(recon)) - counttime += (time.perf_counter() - t) - count += batch_size - #_totalframe += 1 - if count>=100: - print(f"------actual avg infer fps:{count/counttime:.4f}") - count=0 - counttime=0 - for i,res_frame in enumerate(recon): - #self.__pushmedia(res_frame,loop,audio_track,video_track) - res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) - index = index + 1 - #print('total batch time:',time.perf_counter()-starttime) + while render_event.is_set(): + starttime=time.perf_counter() + try: + whisper_chunks = audio_feat_queue.get(block=True, timeout=1) + except queue.Empty: + continue + is_all_silence=True + audio_frames = [] + for _ in range(batch_size*2): + frame,type = audio_out_queue.get() + audio_frames.append((frame,type)) + if type==0: + is_all_silence=False + if is_all_silence: + for i in range(batch_size): + res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) + index = index + 1 else: - time.sleep(1) + # print('infer=======') + t=time.perf_counter() + whisper_batch = np.stack(whisper_chunks) + latent_batch = [] + for i in range(batch_size): + idx = __mirror_index(length,index+i) + latent = input_latent_list_cycle[idx] + latent_batch.append(latent) + latent_batch = torch.cat(latent_batch, dim=0) + + # for i, (whisper_batch,latent_batch) in enumerate(gen): + audio_feature_batch = torch.from_numpy(whisper_batch) + audio_feature_batch = audio_feature_batch.to(device=unet.device, + dtype=unet.model.dtype) + audio_feature_batch = pe(audio_feature_batch) + latent_batch = latent_batch.to(dtype=unet.model.dtype) + # print('prepare time:',time.perf_counter()-t) + # t=time.perf_counter() + + pred_latents = unet.model(latent_batch, + timesteps, + encoder_hidden_states=audio_feature_batch).sample + # print('unet time:',time.perf_counter()-t) + # t=time.perf_counter() + recon = vae.decode_latents(pred_latents) + # infer_inqueue.put((whisper_batch,latent_batch,sessionid)) + # recon,outsessionid = infer_outqueue.get() + # if outsessionid != sessionid: + # print('outsessionid:',outsessionid,' mysessionid:',sessionid) + + # print('vae time:',time.perf_counter()-t) + #print('diffusion len=',len(recon)) + counttime += (time.perf_counter() - t) + count += batch_size + #_totalframe += 1 + if count>=100: + print(f"------actual avg infer fps:{count/counttime:.4f}") + count=0 + counttime=0 + for i,res_frame in enumerate(recon): + #self.__pushmedia(res_frame,loop,audio_track,video_track) + res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) + index = index + 1 + #print('total batch time:',time.perf_counter()-starttime) print('musereal inference processor stop') -@torch.no_grad() class MuseReal(BaseReal): - def __init__(self, opt): + @torch.no_grad() + def __init__(self, opt, audio_processor:Audio2Feature,vae, unet, pe,timesteps): super().__init__(opt) #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W @@ -155,7 +173,8 @@ class MuseReal(BaseReal): self.batch_size = opt.batch_size self.idx = 0 self.res_frame_queue = mp.Queue(self.batch_size*2) - self.__loadmodels() + #self.__loadmodels() + self.audio_processor= audio_processor self.__loadavatar() self.asr = MuseASR(opt,self,self.audio_processor) @@ -163,13 +182,15 @@ class MuseReal(BaseReal): #self.__warm_up() self.render_event = mp.Event() - mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path, - self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, - )).start() #self.vae, self.unet, self.pe,self.timesteps + self.vae = vae + self.unet = unet + self.pe = pe + self.timesteps = timesteps + - def __loadmodels(self): - # load model weights - self.audio_processor= load_audio_model() + # def __loadmodels(self): + # # load model weights + # self.audio_processor= load_audio_model() # self.audio_processor, self.vae, self.unet, self.pe = load_all_model() # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # self.timesteps = torch.tensor([0], device=device) @@ -178,7 +199,7 @@ class MuseReal(BaseReal): # self.unet.model = self.unet.model.half() def __loadavatar(self): - #self.input_latent_list_cycle = torch.load(self.latents_out_path) + self.input_latent_list_cycle = torch.load(self.latents_out_path,weights_only=True) with open(self.coords_path, 'rb') as f: self.coord_list_cycle = pickle.load(f) input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) @@ -287,6 +308,9 @@ class MuseReal(BaseReal): process_thread.start() self.render_event.set() #start infer process render + Thread(target=inference, args=(self.render_event,self.batch_size,self.input_latent_list_cycle, + self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, + self.vae, self.unet, self.pe,self.timesteps)).start() #mp.Process count=0 totaltime=0 _starttime=time.perf_counter() diff --git a/musetalk/simple_musetalk.py b/musetalk/simple_musetalk.py index 7ed0b46..4008cb0 100644 --- a/musetalk/simple_musetalk.py +++ b/musetalk/simple_musetalk.py @@ -30,6 +30,7 @@ def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000): break ret, frame = cap.read() if ret: + cv2.putText(frame, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1) cv2.imwrite(f"{save_path}/{count:08d}.png", frame) count += 1 else: diff --git a/nerfasr.py b/nerfasr.py index b74b199..131fca1 100644 --- a/nerfasr.py +++ b/nerfasr.py @@ -1,19 +1,33 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + import time import numpy as np import torch import torch.nn.functional as F -from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel - import queue from queue import Queue #from collections import deque -from threading import Thread, Event from baseasr import BaseASR class NerfASR(BaseASR): - def __init__(self, opt, parent): + def __init__(self, opt, parent, audio_processor,audio_model): super().__init__(opt,parent) self.device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -37,13 +51,15 @@ class NerfASR(BaseASR): self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) # create wav2vec model - print(f'[INFO] loading ASR model {self.opt.asr_model}...') - if 'hubert' in self.opt.asr_model: - self.processor = Wav2Vec2Processor.from_pretrained(opt.asr_model) - self.model = HubertModel.from_pretrained(opt.asr_model).to(self.device) - else: - self.processor = AutoProcessor.from_pretrained(opt.asr_model) - self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) + # print(f'[INFO] loading ASR model {self.opt.asr_model}...') + # if 'hubert' in self.opt.asr_model: + # self.processor = Wav2Vec2Processor.from_pretrained(opt.asr_model) + # self.model = HubertModel.from_pretrained(opt.asr_model).to(self.device) + # else: + # self.processor = AutoProcessor.from_pretrained(opt.asr_model) + # self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) + self.processor = audio_processor + self.model = audio_model # the extracted features # use a loop queue to efficiently record endless features: [f--t---][-------][-------] diff --git a/nerfreal.py b/nerfreal.py index 04805ae..8c34851 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -1,9 +1,25 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + import math import torch import numpy as np #from .utils import * -import subprocess import os import time import torch.nn.functional as F @@ -11,7 +27,6 @@ import cv2 import glob from nerfasr import NerfASR -from ttsreal import EdgeTTS,VoitsTTS,XTTS import asyncio from av import AudioFrame, VideoFrame @@ -29,7 +44,7 @@ def read_imgs(img_list): return frames class NeRFReal(BaseReal): - def __init__(self, opt, trainer, data_loader, debug=True): + def __init__(self, opt, trainer, data_loader, audio_processor,audio_model, debug=True): super().__init__(opt) #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. self.W = opt.W @@ -79,7 +94,7 @@ class NeRFReal(BaseReal): #self.customimg_index = 0 # build asr - self.asr = NerfASR(opt,self) + self.asr = NerfASR(opt,self,audio_processor,audio_model) self.asr.warm_up() ''' diff --git a/ttsreal.py b/ttsreal.py index b4da862..636b306 100644 --- a/ttsreal.py +++ b/ttsreal.py @@ -1,3 +1,20 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### + import time import numpy as np import soundfile as sf diff --git a/wav2lip/genavatar.py b/wav2lip/genavatar.py index ed2315c..bd7c475 100644 --- a/wav2lip/genavatar.py +++ b/wav2lip/genavatar.py @@ -36,6 +36,7 @@ def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000): break ret, frame = cap.read() if ret: + cv2.putText(frame, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1) cv2.imwrite(f"{save_path}/{count:08d}.png", frame) count += 1 else: diff --git a/webrtc.py b/webrtc.py index d1a37ca..4692048 100644 --- a/webrtc.py +++ b/webrtc.py @@ -1,3 +1,19 @@ +############################################################################### +# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking +# email: lipku@foxmail.com +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################### import asyncio import json