cuda memory does not increase with the number of concurrency

main
lipku 8 months ago
parent e772acb3fc
commit dbe508cb65

@ -5,6 +5,9 @@ Real time interactive streaming digital human realize audio video synchronous
## 为避免与3d数字人混淆原项目metahuman-stream改名为livetalking原有链接地址继续可用 ## 为避免与3d数字人混淆原项目metahuman-stream改名为livetalking原有链接地址继续可用
## News
- 2024.12.8 完善多并发,显存不随并发数增加
## Features ## Features
1. 支持多种数字人模型: ernerf、musetalk、wav2lip 1. 支持多种数字人模型: ernerf、musetalk、wav2lip
2. 支持声音克隆 2. 支持声音克隆
@ -12,6 +15,7 @@ Real time interactive streaming digital human realize audio video synchronous
4. 支持全身视频拼接 4. 支持全身视频拼接
5. 支持rtmp和webrtc 5. 支持rtmp和webrtc
6. 支持视频编排:不说话时播放自定义视频 6. 支持视频编排:不说话时播放自定义视频
7. 支持多并发
## 1. Installation ## 1. Installation

@ -1,3 +1,20 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# server.py # server.py
from flask import Flask, render_template,send_from_directory,request, jsonify from flask import Flask, render_template,send_from_directory,request, jsonify
from flask_sockets import Sockets from flask_sockets import Sockets
@ -11,7 +28,8 @@ import os
import re import re
import numpy as np import numpy as np
from threading import Thread,Event from threading import Thread,Event
import multiprocessing #import multiprocessing
import torch.multiprocessing as mp
from aiohttp import web from aiohttp import web
import aiohttp import aiohttp
@ -24,7 +42,7 @@ import argparse
import shutil import shutil
import asyncio import asyncio
import string import torch
app = Flask(__name__) app = Flask(__name__)
@ -302,7 +320,7 @@ async def run(push_url):
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1' # os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver' # os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
if __name__ == '__main__': if __name__ == '__main__':
multiprocessing.set_start_method('spawn') mp.set_start_method('spawn')
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source") parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area") parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
@ -452,6 +470,7 @@ if __name__ == '__main__':
from ernerf.nerf_triplane.provider import NeRFDataset_Test from ernerf.nerf_triplane.provider import NeRFDataset_Test
from ernerf.nerf_triplane.utils import * from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork from ernerf.nerf_triplane.network import NeRFNetwork
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
from nerfreal import NeRFReal from nerfreal import NeRFReal
# assert test mode # assert test mode
opt.test = True opt.test = True
@ -493,24 +512,42 @@ if __name__ == '__main__':
model.aud_features = test_loader._data.auds model.aud_features = test_loader._data.auds
model.eye_areas = test_loader._data.eye_area model.eye_areas = test_loader._data.eye_area
print(f'[INFO] loading ASR model {opt.asr_model}...')
if 'hubert' in opt.asr_model:
audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
audio_model = HubertModel.from_pretrained(opt.asr_model).to(device)
else:
audio_processor = AutoProcessor.from_pretrained(opt.asr_model)
audio_model = AutoModelForCTC.from_pretrained(opt.asr_model).to(device)
# we still need test_loader to provide audio features for testing. # we still need test_loader to provide audio features for testing.
for k in range(opt.max_session): for k in range(opt.max_session):
opt.sessionid=k opt.sessionid=k
nerfreal = NeRFReal(opt, trainer, test_loader) nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model)
nerfreals.append(nerfreal) nerfreals.append(nerfreal)
elif opt.model == 'musetalk': elif opt.model == 'musetalk':
from musereal import MuseReal from musereal import MuseReal
from musetalk.utils.utils import load_all_model
print(opt) print(opt)
audio_processor,vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
pe = pe.half()
vae.vae = vae.vae.half()
#vae.vae.share_memory()
unet.model = unet.model.half()
#unet.model.share_memory()
for k in range(opt.max_session): for k in range(opt.max_session):
opt.sessionid=k opt.sessionid=k
nerfreal = MuseReal(opt) nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps)
nerfreals.append(nerfreal) nerfreals.append(nerfreal)
elif opt.model == 'wav2lip': elif opt.model == 'wav2lip':
from lipreal import LipReal from lipreal import LipReal,load_model
print(opt) print(opt)
model = load_model("./models/wav2lip.pth")
for k in range(opt.max_session): for k in range(opt.max_session):
opt.sessionid=k opt.sessionid=k
nerfreal = LipReal(opt) nerfreal = LipReal(opt,model)
nerfreals.append(nerfreal) nerfreals.append(nerfreal)
for _ in range(opt.max_session): for _ in range(opt.max_session):

@ -1,9 +1,26 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time import time
import numpy as np import numpy as np
import queue import queue
from queue import Queue from queue import Queue
import multiprocessing as mp import torch.multiprocessing as mp
class BaseASR: class BaseASR:

@ -1,3 +1,20 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import math import math
import torch import torch
import numpy as np import numpy as np
@ -7,8 +24,6 @@ import os
import time import time
import cv2 import cv2
import glob import glob
import pickle
import copy
import resampy import resampy
import queue import queue
@ -36,6 +51,7 @@ class BaseReal:
self.opt = opt self.opt = opt
self.sample_rate = 16000 self.sample_rate = 16000
self.chunk = self.sample_rate // opt.fps # 320 samples per chunk (20ms * 16000 / 1000) self.chunk = self.sample_rate // opt.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.sessionid = self.opt.sessionid
if opt.tts == "edgetts": if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self) self.tts = EdgeTTS(opt,self)

@ -1,10 +1,27 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time import time
import torch import torch
import numpy as np import numpy as np
import queue import queue
from queue import Queue from queue import Queue
import multiprocessing as mp #import multiprocessing as mp
from baseasr import BaseASR from baseasr import BaseASR
from wav2lip import audio from wav2lip import audio

@ -1,9 +1,25 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import math import math
import torch import torch
import numpy as np import numpy as np
#from .utils import * #from .utils import *
import subprocess
import os import os
import time import time
import cv2 import cv2
@ -14,11 +30,8 @@ import copy
import queue import queue
from queue import Queue from queue import Queue
from threading import Thread, Event from threading import Thread, Event
from io import BytesIO import torch.multiprocessing as mp
import multiprocessing as mp
from ttsreal import EdgeTTS,VoitsTTS,XTTS
from lipasr import LipASR from lipasr import LipASR
import asyncio import asyncio
@ -35,7 +48,7 @@ print('Using {} for inference.'.format(device))
def _load(checkpoint_path): def _load(checkpoint_path):
if device == 'cuda': if device == 'cuda':
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path,weights_only=True)
else: else:
checkpoint = torch.load(checkpoint_path, checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage) map_location=lambda storage, loc: storage)
@ -71,12 +84,12 @@ def __mirror_index(size, index):
else: else:
return size - res - 1 return size - res - 1
def inference(render_event,batch_size,face_imgs_path,audio_feat_queue,audio_out_queue,res_frame_queue): def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,model):
model = load_model("./models/wav2lip.pth") #model = load_model("./models/wav2lip.pth")
input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]')) # input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) # input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
face_list_cycle = read_imgs(input_face_list) # face_list_cycle = read_imgs(input_face_list)
#input_latent_list_cycle = torch.load(latents_out_path) #input_latent_list_cycle = torch.load(latents_out_path)
length = len(face_list_cycle) length = len(face_list_cycle)
@ -84,69 +97,66 @@ def inference(render_event,batch_size,face_imgs_path,audio_feat_queue,audio_out_
count=0 count=0
counttime=0 counttime=0
print('start inference') print('start inference')
while True: while not quit_event.is_set():
if render_event.is_set(): starttime=time.perf_counter()
starttime=time.perf_counter() mel_batch = []
mel_batch = [] try:
try: mel_batch = audio_feat_queue.get(block=True, timeout=1)
mel_batch = audio_feat_queue.get(block=True, timeout=1) except queue.Empty:
except queue.Empty: continue
continue
is_all_silence=True
is_all_silence=True audio_frames = []
audio_frames = [] for _ in range(batch_size*2):
for _ in range(batch_size*2): frame,type = audio_out_queue.get()
frame,type = audio_out_queue.get() audio_frames.append((frame,type))
audio_frames.append((frame,type)) if type==0:
if type==0: is_all_silence=False
is_all_silence=False
if is_all_silence: if is_all_silence:
for i in range(batch_size): for i in range(batch_size):
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1 index = index + 1
else: else:
# print('infer=======') # print('infer=======')
t=time.perf_counter() t=time.perf_counter()
img_batch = [] img_batch = []
for i in range(batch_size): for i in range(batch_size):
idx = __mirror_index(length,index+i) idx = __mirror_index(length,index+i)
face = face_list_cycle[idx] face = face_list_cycle[idx]
img_batch.append(face) img_batch.append(face)
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy() img_masked = img_batch.copy()
img_masked[:, face.shape[0]//2:] = 0 img_masked[:, face.shape[0]//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255. img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1]) mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad(): with torch.no_grad():
pred = model(mel_batch, img_batch) pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255. pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
counttime += (time.perf_counter() - t) counttime += (time.perf_counter() - t)
count += batch_size count += batch_size
#_totalframe += 1 #_totalframe += 1
if count>=100: if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}") print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0 count=0
counttime=0 counttime=0
for i,res_frame in enumerate(pred): for i,res_frame in enumerate(pred):
#self.__pushmedia(res_frame,loop,audio_track,video_track) #self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2])) res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1 index = index + 1
#print('total batch time:',time.perf_counter()-starttime) #print('total batch time:',time.perf_counter()-starttime)
else: print('lipreal inference processor stop')
time.sleep(1)
print('musereal inference processor stop')
@torch.no_grad()
class LipReal(BaseReal): class LipReal(BaseReal):
def __init__(self, opt): @torch.no_grad()
def __init__(self, opt, model):
super().__init__(opt) super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
@ -162,7 +172,7 @@ class LipReal(BaseReal):
self.coords_path = f"{self.avatar_path}/coords.pkl" self.coords_path = f"{self.avatar_path}/coords.pkl"
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.idx = 0 self.idx = 0
self.res_frame_queue = mp.Queue(self.batch_size*2) self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue
#self.__loadmodels() #self.__loadmodels()
self.__loadavatar() self.__loadavatar()
@ -170,19 +180,8 @@ class LipReal(BaseReal):
self.asr.warm_up() self.asr.warm_up()
#self.__warm_up() #self.__warm_up()
self.model = model
self.render_event = mp.Event() self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.face_imgs_path,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
)).start()
# def __loadmodels(self):
# # load model weights
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.timesteps = torch.tensor([0], device=device)
# self.pe = self.pe.half()
# self.vae.vae = self.vae.vae.half()
# self.unet.model = self.unet.model.half()
def __loadavatar(self): def __loadavatar(self):
with open(self.coords_path, 'rb') as f: with open(self.coords_path, 'rb') as f:
@ -191,6 +190,9 @@ class LipReal(BaseReal):
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0])) input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.frame_list_cycle = read_imgs(input_img_list) self.frame_list_cycle = read_imgs(input_img_list)
#self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000) #self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000)
input_face_list = glob.glob(os.path.join(self.face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.face_list_cycle = read_imgs(input_face_list)
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None): def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -242,7 +244,7 @@ class LipReal(BaseReal):
# time.sleep(0.1) # time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop) asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
self.record_audio_data(frame) self.record_audio_data(frame)
print('musereal process_frames thread stop') print('lipreal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None): def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr: #if self.opt.asr:
@ -253,7 +255,11 @@ class LipReal(BaseReal):
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track)) process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
process_thread.start() process_thread.start()
self.render_event.set() #start infer process render Thread(target=inference, args=(quit_event,self.batch_size,self.face_list_cycle,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
self.model,)).start() #mp.Process
#self.render_event.set() #start infer process render
count=0 count=0
totaltime=0 totaltime=0
_starttime=time.perf_counter() _starttime=time.perf_counter()
@ -274,6 +280,6 @@ class LipReal(BaseReal):
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms # delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0: # if delay > 0:
# time.sleep(delay) # time.sleep(delay)
self.render_event.clear() #end infer process render #self.render_event.clear() #end infer process render
print('musereal thread stop') print('lipreal thread stop')

@ -1,9 +1,26 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time import time
import numpy as np import numpy as np
import queue import queue
from queue import Queue from queue import Queue
import multiprocessing as mp #import multiprocessing as mp
from baseasr import BaseASR from baseasr import BaseASR
from musetalk.whisper.audio2feature import Audio2Feature from musetalk.whisper.audio2feature import Audio2Feature

@ -1,3 +1,20 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import math import math
import torch import torch
import numpy as np import numpy as np
@ -15,14 +32,13 @@ import copy
import queue import queue
from queue import Queue from queue import Queue
from threading import Thread, Event from threading import Thread, Event
from io import BytesIO import torch.multiprocessing as mp
import multiprocessing as mp
from musetalk.utils.utils import get_file_type,get_video_fps,datagen from musetalk.utils.utils import get_file_type,get_video_fps,datagen
#from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder #from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model
from ttsreal import EdgeTTS,VoitsTTS,XTTS from musetalk.whisper.audio2feature import Audio2Feature
from museasr import MuseASR from museasr import MuseASR
import asyncio import asyncio
@ -46,88 +62,90 @@ def __mirror_index(size, index):
return res return res
else: else:
return size - res - 1 return size - res - 1
@torch.no_grad() @torch.no_grad()
def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue, def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,
): #vae, unet, pe,timesteps vae, unet, pe,timesteps): #vae, unet, pe,timesteps
vae, unet, pe = load_diffusion_model() # vae, unet, pe = load_diffusion_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device) # timesteps = torch.tensor([0], device=device)
pe = pe.half() # pe = pe.half()
vae.vae = vae.vae.half() # vae.vae = vae.vae.half()
unet.model = unet.model.half() # unet.model = unet.model.half()
input_latent_list_cycle = torch.load(latents_out_path)
length = len(input_latent_list_cycle) length = len(input_latent_list_cycle)
index = 0 index = 0
count=0 count=0
counttime=0 counttime=0
print('start inference') print('start inference')
while True: while render_event.is_set():
if render_event.is_set(): starttime=time.perf_counter()
starttime=time.perf_counter() try:
try: whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
whisper_chunks = audio_feat_queue.get(block=True, timeout=1) except queue.Empty:
except queue.Empty: continue
continue is_all_silence=True
is_all_silence=True audio_frames = []
audio_frames = [] for _ in range(batch_size*2):
for _ in range(batch_size*2): frame,type = audio_out_queue.get()
frame,type = audio_out_queue.get() audio_frames.append((frame,type))
audio_frames.append((frame,type)) if type==0:
if type==0: is_all_silence=False
is_all_silence=False if is_all_silence:
if is_all_silence: for i in range(batch_size):
for i in range(batch_size): res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2])) index = index + 1
index = index + 1
else:
# print('infer=======')
t=time.perf_counter()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(batch_size):
idx = __mirror_index(length,index+i)
latent = input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
# print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter()
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t)
# t=time.perf_counter()
recon = vae.decode_latents(pred_latents)
# print('vae time:',time.perf_counter()-t)
#print('diffusion len=',len(recon))
counttime += (time.perf_counter() - t)
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(recon):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
#print('total batch time:',time.perf_counter()-starttime)
else: else:
time.sleep(1) # print('infer=======')
t=time.perf_counter()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(batch_size):
idx = __mirror_index(length,index+i)
latent = input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
# print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter()
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t)
# t=time.perf_counter()
recon = vae.decode_latents(pred_latents)
# infer_inqueue.put((whisper_batch,latent_batch,sessionid))
# recon,outsessionid = infer_outqueue.get()
# if outsessionid != sessionid:
# print('outsessionid:',outsessionid,' mysessionid:',sessionid)
# print('vae time:',time.perf_counter()-t)
#print('diffusion len=',len(recon))
counttime += (time.perf_counter() - t)
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(recon):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
#print('total batch time:',time.perf_counter()-starttime)
print('musereal inference processor stop') print('musereal inference processor stop')
@torch.no_grad()
class MuseReal(BaseReal): class MuseReal(BaseReal):
def __init__(self, opt): @torch.no_grad()
def __init__(self, opt, audio_processor:Audio2Feature,vae, unet, pe,timesteps):
super().__init__(opt) super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
@ -155,7 +173,8 @@ class MuseReal(BaseReal):
self.batch_size = opt.batch_size self.batch_size = opt.batch_size
self.idx = 0 self.idx = 0
self.res_frame_queue = mp.Queue(self.batch_size*2) self.res_frame_queue = mp.Queue(self.batch_size*2)
self.__loadmodels() #self.__loadmodels()
self.audio_processor= audio_processor
self.__loadavatar() self.__loadavatar()
self.asr = MuseASR(opt,self,self.audio_processor) self.asr = MuseASR(opt,self,self.audio_processor)
@ -163,13 +182,15 @@ class MuseReal(BaseReal):
#self.__warm_up() #self.__warm_up()
self.render_event = mp.Event() self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path, self.vae = vae
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue, self.unet = unet
)).start() #self.vae, self.unet, self.pe,self.timesteps self.pe = pe
self.timesteps = timesteps
def __loadmodels(self): # def __loadmodels(self):
# load model weights # # load model weights
self.audio_processor= load_audio_model() # self.audio_processor= load_audio_model()
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model() # self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.timesteps = torch.tensor([0], device=device) # self.timesteps = torch.tensor([0], device=device)
@ -178,7 +199,7 @@ class MuseReal(BaseReal):
# self.unet.model = self.unet.model.half() # self.unet.model = self.unet.model.half()
def __loadavatar(self): def __loadavatar(self):
#self.input_latent_list_cycle = torch.load(self.latents_out_path) self.input_latent_list_cycle = torch.load(self.latents_out_path,weights_only=True)
with open(self.coords_path, 'rb') as f: with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f) self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
@ -287,6 +308,9 @@ class MuseReal(BaseReal):
process_thread.start() process_thread.start()
self.render_event.set() #start infer process render self.render_event.set() #start infer process render
Thread(target=inference, args=(self.render_event,self.batch_size,self.input_latent_list_cycle,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
self.vae, self.unet, self.pe,self.timesteps)).start() #mp.Process
count=0 count=0
totaltime=0 totaltime=0
_starttime=time.perf_counter() _starttime=time.perf_counter()

@ -30,6 +30,7 @@ def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
break break
ret, frame = cap.read() ret, frame = cap.read()
if ret: if ret:
cv2.putText(frame, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1)
cv2.imwrite(f"{save_path}/{count:08d}.png", frame) cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1 count += 1
else: else:

@ -1,19 +1,33 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time import time
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
import queue import queue
from queue import Queue from queue import Queue
#from collections import deque #from collections import deque
from threading import Thread, Event
from baseasr import BaseASR from baseasr import BaseASR
class NerfASR(BaseASR): class NerfASR(BaseASR):
def __init__(self, opt, parent): def __init__(self, opt, parent, audio_processor,audio_model):
super().__init__(opt,parent) super().__init__(opt,parent)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -37,13 +51,15 @@ class NerfASR(BaseASR):
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size) self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
# create wav2vec model # create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...') # print(f'[INFO] loading ASR model {self.opt.asr_model}...')
if 'hubert' in self.opt.asr_model: # if 'hubert' in self.opt.asr_model:
self.processor = Wav2Vec2Processor.from_pretrained(opt.asr_model) # self.processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
self.model = HubertModel.from_pretrained(opt.asr_model).to(self.device) # self.model = HubertModel.from_pretrained(opt.asr_model).to(self.device)
else: # else:
self.processor = AutoProcessor.from_pretrained(opt.asr_model) # self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device) # self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
self.processor = audio_processor
self.model = audio_model
# the extracted features # the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------] # use a loop queue to efficiently record endless features: [f--t---][-------][-------]

@ -1,9 +1,25 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import math import math
import torch import torch
import numpy as np import numpy as np
#from .utils import * #from .utils import *
import subprocess
import os import os
import time import time
import torch.nn.functional as F import torch.nn.functional as F
@ -11,7 +27,6 @@ import cv2
import glob import glob
from nerfasr import NerfASR from nerfasr import NerfASR
from ttsreal import EdgeTTS,VoitsTTS,XTTS
import asyncio import asyncio
from av import AudioFrame, VideoFrame from av import AudioFrame, VideoFrame
@ -29,7 +44,7 @@ def read_imgs(img_list):
return frames return frames
class NeRFReal(BaseReal): class NeRFReal(BaseReal):
def __init__(self, opt, trainer, data_loader, debug=True): def __init__(self, opt, trainer, data_loader, audio_processor,audio_model, debug=True):
super().__init__(opt) super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters. #self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W self.W = opt.W
@ -79,7 +94,7 @@ class NeRFReal(BaseReal):
#self.customimg_index = 0 #self.customimg_index = 0
# build asr # build asr
self.asr = NerfASR(opt,self) self.asr = NerfASR(opt,self,audio_processor,audio_model)
self.asr.warm_up() self.asr.warm_up()
''' '''

@ -1,3 +1,20 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time import time
import numpy as np import numpy as np
import soundfile as sf import soundfile as sf

@ -36,6 +36,7 @@ def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
break break
ret, frame = cap.read() ret, frame = cap.read()
if ret: if ret:
cv2.putText(frame, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1)
cv2.imwrite(f"{save_path}/{count:08d}.png", frame) cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1 count += 1
else: else:

@ -1,3 +1,19 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import asyncio import asyncio
import json import json

Loading…
Cancel
Save