cuda memory does not increase with the number of concurrency

main
lipku 8 months ago
parent e772acb3fc
commit dbe508cb65

@ -5,6 +5,9 @@ Real time interactive streaming digital human realize audio video synchronous
## 为避免与3d数字人混淆原项目metahuman-stream改名为livetalking原有链接地址继续可用
## News
- 2024.12.8 完善多并发,显存不随并发数增加
## Features
1. 支持多种数字人模型: ernerf、musetalk、wav2lip
2. 支持声音克隆
@ -12,6 +15,7 @@ Real time interactive streaming digital human realize audio video synchronous
4. 支持全身视频拼接
5. 支持rtmp和webrtc
6. 支持视频编排:不说话时播放自定义视频
7. 支持多并发
## 1. Installation

@ -1,3 +1,20 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
# server.py
from flask import Flask, render_template,send_from_directory,request, jsonify
from flask_sockets import Sockets
@ -11,7 +28,8 @@ import os
import re
import numpy as np
from threading import Thread,Event
import multiprocessing
#import multiprocessing
import torch.multiprocessing as mp
from aiohttp import web
import aiohttp
@ -24,7 +42,7 @@ import argparse
import shutil
import asyncio
import string
import torch
app = Flask(__name__)
@ -302,7 +320,7 @@ async def run(push_url):
# os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
# os.environ['MULTIPROCESSING_METHOD'] = 'forkserver'
if __name__ == '__main__':
multiprocessing.set_start_method('spawn')
mp.set_start_method('spawn')
parser = argparse.ArgumentParser()
parser.add_argument('--pose', type=str, default="data/data_kf.json", help="transforms.json, pose source")
parser.add_argument('--au', type=str, default="data/au.csv", help="eye blink area")
@ -452,6 +470,7 @@ if __name__ == '__main__':
from ernerf.nerf_triplane.provider import NeRFDataset_Test
from ernerf.nerf_triplane.utils import *
from ernerf.nerf_triplane.network import NeRFNetwork
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
from nerfreal import NeRFReal
# assert test mode
opt.test = True
@ -493,24 +512,42 @@ if __name__ == '__main__':
model.aud_features = test_loader._data.auds
model.eye_areas = test_loader._data.eye_area
print(f'[INFO] loading ASR model {opt.asr_model}...')
if 'hubert' in opt.asr_model:
audio_processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
audio_model = HubertModel.from_pretrained(opt.asr_model).to(device)
else:
audio_processor = AutoProcessor.from_pretrained(opt.asr_model)
audio_model = AutoModelForCTC.from_pretrained(opt.asr_model).to(device)
# we still need test_loader to provide audio features for testing.
for k in range(opt.max_session):
opt.sessionid=k
nerfreal = NeRFReal(opt, trainer, test_loader)
nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model)
nerfreals.append(nerfreal)
elif opt.model == 'musetalk':
from musereal import MuseReal
from musetalk.utils.utils import load_all_model
print(opt)
audio_processor,vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
pe = pe.half()
vae.vae = vae.vae.half()
#vae.vae.share_memory()
unet.model = unet.model.half()
#unet.model.share_memory()
for k in range(opt.max_session):
opt.sessionid=k
nerfreal = MuseReal(opt)
nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps)
nerfreals.append(nerfreal)
elif opt.model == 'wav2lip':
from lipreal import LipReal
from lipreal import LipReal,load_model
print(opt)
model = load_model("./models/wav2lip.pth")
for k in range(opt.max_session):
opt.sessionid=k
nerfreal = LipReal(opt)
nerfreal = LipReal(opt,model)
nerfreals.append(nerfreal)
for _ in range(opt.max_session):

@ -1,9 +1,26 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time
import numpy as np
import queue
from queue import Queue
import multiprocessing as mp
import torch.multiprocessing as mp
class BaseASR:

@ -1,3 +1,20 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import math
import torch
import numpy as np
@ -7,8 +24,6 @@ import os
import time
import cv2
import glob
import pickle
import copy
import resampy
import queue
@ -36,6 +51,7 @@ class BaseReal:
self.opt = opt
self.sample_rate = 16000
self.chunk = self.sample_rate // opt.fps # 320 samples per chunk (20ms * 16000 / 1000)
self.sessionid = self.opt.sessionid
if opt.tts == "edgetts":
self.tts = EdgeTTS(opt,self)

@ -1,10 +1,27 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time
import torch
import numpy as np
import queue
from queue import Queue
import multiprocessing as mp
#import multiprocessing as mp
from baseasr import BaseASR
from wav2lip import audio

@ -1,9 +1,25 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import math
import torch
import numpy as np
#from .utils import *
import subprocess
import os
import time
import cv2
@ -14,11 +30,8 @@ import copy
import queue
from queue import Queue
from threading import Thread, Event
from io import BytesIO
import multiprocessing as mp
import torch.multiprocessing as mp
from ttsreal import EdgeTTS,VoitsTTS,XTTS
from lipasr import LipASR
import asyncio
@ -35,7 +48,7 @@ print('Using {} for inference.'.format(device))
def _load(checkpoint_path):
if device == 'cuda':
checkpoint = torch.load(checkpoint_path)
checkpoint = torch.load(checkpoint_path,weights_only=True)
else:
checkpoint = torch.load(checkpoint_path,
map_location=lambda storage, loc: storage)
@ -71,12 +84,12 @@ def __mirror_index(size, index):
else:
return size - res - 1
def inference(render_event,batch_size,face_imgs_path,audio_feat_queue,audio_out_queue,res_frame_queue):
def inference(quit_event,batch_size,face_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,model):
model = load_model("./models/wav2lip.pth")
input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
face_list_cycle = read_imgs(input_face_list)
#model = load_model("./models/wav2lip.pth")
# input_face_list = glob.glob(os.path.join(face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
# input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
# face_list_cycle = read_imgs(input_face_list)
#input_latent_list_cycle = torch.load(latents_out_path)
length = len(face_list_cycle)
@ -84,69 +97,66 @@ def inference(render_event,batch_size,face_imgs_path,audio_feat_queue,audio_out_
count=0
counttime=0
print('start inference')
while True:
if render_event.is_set():
starttime=time.perf_counter()
mel_batch = []
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence=True
audio_frames = []
for _ in range(batch_size*2):
frame,type = audio_out_queue.get()
audio_frames.append((frame,type))
if type==0:
is_all_silence=False
while not quit_event.is_set():
starttime=time.perf_counter()
mel_batch = []
try:
mel_batch = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence=True
audio_frames = []
for _ in range(batch_size*2):
frame,type = audio_out_queue.get()
audio_frames.append((frame,type))
if type==0:
is_all_silence=False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
else:
# print('infer=======')
t=time.perf_counter()
img_batch = []
for i in range(batch_size):
idx = __mirror_index(length,index+i)
face = face_list_cycle[idx]
img_batch.append(face)
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
else:
# print('infer=======')
t=time.perf_counter()
img_batch = []
for i in range(batch_size):
idx = __mirror_index(length,index+i)
face = face_list_cycle[idx]
img_batch.append(face)
img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch)
img_masked = img_batch.copy()
img_masked[:, face.shape[0]//2:] = 0
img_masked = img_batch.copy()
img_masked[:, face.shape[0]//2:] = 0
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.
mel_batch = np.reshape(mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1])
img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device)
mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device)
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
with torch.no_grad():
pred = model(mel_batch, img_batch)
pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.
counttime += (time.perf_counter() - t)
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(pred):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
#print('total batch time:',time.perf_counter()-starttime)
else:
time.sleep(1)
print('musereal inference processor stop')
counttime += (time.perf_counter() - t)
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(pred):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
#print('total batch time:',time.perf_counter()-starttime)
print('lipreal inference processor stop')
@torch.no_grad()
class LipReal(BaseReal):
def __init__(self, opt):
@torch.no_grad()
def __init__(self, opt, model):
super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
@ -162,7 +172,7 @@ class LipReal(BaseReal):
self.coords_path = f"{self.avatar_path}/coords.pkl"
self.batch_size = opt.batch_size
self.idx = 0
self.res_frame_queue = mp.Queue(self.batch_size*2)
self.res_frame_queue = Queue(self.batch_size*2) #mp.Queue
#self.__loadmodels()
self.__loadavatar()
@ -170,19 +180,8 @@ class LipReal(BaseReal):
self.asr.warm_up()
#self.__warm_up()
self.model = model
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.face_imgs_path,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
)).start()
# def __loadmodels(self):
# # load model weights
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.timesteps = torch.tensor([0], device=device)
# self.pe = self.pe.half()
# self.vae.vae = self.vae.vae.half()
# self.unet.model = self.unet.model.half()
def __loadavatar(self):
with open(self.coords_path, 'rb') as f:
@ -191,6 +190,9 @@ class LipReal(BaseReal):
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.frame_list_cycle = read_imgs(input_img_list)
#self.imagecache = ImgCache(len(self.coord_list_cycle),self.full_imgs_path,1000)
input_face_list = glob.glob(os.path.join(self.face_imgs_path, '*.[jpJP][pnPN]*[gG]'))
input_face_list = sorted(input_face_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
self.face_list_cycle = read_imgs(input_face_list)
def process_frames(self,quit_event,loop=None,audio_track=None,video_track=None):
@ -242,7 +244,7 @@ class LipReal(BaseReal):
# time.sleep(0.1)
asyncio.run_coroutine_threadsafe(audio_track._queue.put(new_frame), loop)
self.record_audio_data(frame)
print('musereal process_frames thread stop')
print('lipreal process_frames thread stop')
def render(self,quit_event,loop=None,audio_track=None,video_track=None):
#if self.opt.asr:
@ -253,7 +255,11 @@ class LipReal(BaseReal):
process_thread = Thread(target=self.process_frames, args=(quit_event,loop,audio_track,video_track))
process_thread.start()
self.render_event.set() #start infer process render
Thread(target=inference, args=(quit_event,self.batch_size,self.face_list_cycle,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
self.model,)).start() #mp.Process
#self.render_event.set() #start infer process render
count=0
totaltime=0
_starttime=time.perf_counter()
@ -274,6 +280,6 @@ class LipReal(BaseReal):
# delay = _starttime+_totalframe*0.04-time.perf_counter() #40ms
# if delay > 0:
# time.sleep(delay)
self.render_event.clear() #end infer process render
print('musereal thread stop')
#self.render_event.clear() #end infer process render
print('lipreal thread stop')

@ -1,9 +1,26 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time
import numpy as np
import queue
from queue import Queue
import multiprocessing as mp
#import multiprocessing as mp
from baseasr import BaseASR
from musetalk.whisper.audio2feature import Audio2Feature

@ -1,3 +1,20 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import math
import torch
import numpy as np
@ -15,14 +32,13 @@ import copy
import queue
from queue import Queue
from threading import Thread, Event
from io import BytesIO
import multiprocessing as mp
import torch.multiprocessing as mp
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
#from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
from musetalk.utils.blending import get_image,get_image_prepare_material,get_image_blending
from musetalk.utils.utils import load_all_model,load_diffusion_model,load_audio_model
from ttsreal import EdgeTTS,VoitsTTS,XTTS
from musetalk.whisper.audio2feature import Audio2Feature
from museasr import MuseASR
import asyncio
@ -46,88 +62,90 @@ def __mirror_index(size, index):
return res
else:
return size - res - 1
@torch.no_grad()
def inference(render_event,batch_size,latents_out_path,audio_feat_queue,audio_out_queue,res_frame_queue,
): #vae, unet, pe,timesteps
def inference(render_event,batch_size,input_latent_list_cycle,audio_feat_queue,audio_out_queue,res_frame_queue,
vae, unet, pe,timesteps): #vae, unet, pe,timesteps
vae, unet, pe = load_diffusion_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
timesteps = torch.tensor([0], device=device)
pe = pe.half()
vae.vae = vae.vae.half()
unet.model = unet.model.half()
# vae, unet, pe = load_diffusion_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# timesteps = torch.tensor([0], device=device)
# pe = pe.half()
# vae.vae = vae.vae.half()
# unet.model = unet.model.half()
input_latent_list_cycle = torch.load(latents_out_path)
length = len(input_latent_list_cycle)
index = 0
count=0
counttime=0
print('start inference')
while True:
if render_event.is_set():
starttime=time.perf_counter()
try:
whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence=True
audio_frames = []
for _ in range(batch_size*2):
frame,type = audio_out_queue.get()
audio_frames.append((frame,type))
if type==0:
is_all_silence=False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
else:
# print('infer=======')
t=time.perf_counter()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(batch_size):
idx = __mirror_index(length,index+i)
latent = input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
# print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter()
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t)
# t=time.perf_counter()
recon = vae.decode_latents(pred_latents)
# print('vae time:',time.perf_counter()-t)
#print('diffusion len=',len(recon))
counttime += (time.perf_counter() - t)
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(recon):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
#print('total batch time:',time.perf_counter()-starttime)
while render_event.is_set():
starttime=time.perf_counter()
try:
whisper_chunks = audio_feat_queue.get(block=True, timeout=1)
except queue.Empty:
continue
is_all_silence=True
audio_frames = []
for _ in range(batch_size*2):
frame,type = audio_out_queue.get()
audio_frames.append((frame,type))
if type==0:
is_all_silence=False
if is_all_silence:
for i in range(batch_size):
res_frame_queue.put((None,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
else:
time.sleep(1)
# print('infer=======')
t=time.perf_counter()
whisper_batch = np.stack(whisper_chunks)
latent_batch = []
for i in range(batch_size):
idx = __mirror_index(length,index+i)
latent = input_latent_list_cycle[idx]
latent_batch.append(latent)
latent_batch = torch.cat(latent_batch, dim=0)
# for i, (whisper_batch,latent_batch) in enumerate(gen):
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device,
dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
# print('prepare time:',time.perf_counter()-t)
# t=time.perf_counter()
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
# print('unet time:',time.perf_counter()-t)
# t=time.perf_counter()
recon = vae.decode_latents(pred_latents)
# infer_inqueue.put((whisper_batch,latent_batch,sessionid))
# recon,outsessionid = infer_outqueue.get()
# if outsessionid != sessionid:
# print('outsessionid:',outsessionid,' mysessionid:',sessionid)
# print('vae time:',time.perf_counter()-t)
#print('diffusion len=',len(recon))
counttime += (time.perf_counter() - t)
count += batch_size
#_totalframe += 1
if count>=100:
print(f"------actual avg infer fps:{count/counttime:.4f}")
count=0
counttime=0
for i,res_frame in enumerate(recon):
#self.__pushmedia(res_frame,loop,audio_track,video_track)
res_frame_queue.put((res_frame,__mirror_index(length,index),audio_frames[i*2:i*2+2]))
index = index + 1
#print('total batch time:',time.perf_counter()-starttime)
print('musereal inference processor stop')
@torch.no_grad()
class MuseReal(BaseReal):
def __init__(self, opt):
@torch.no_grad()
def __init__(self, opt, audio_processor:Audio2Feature,vae, unet, pe,timesteps):
super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
@ -155,7 +173,8 @@ class MuseReal(BaseReal):
self.batch_size = opt.batch_size
self.idx = 0
self.res_frame_queue = mp.Queue(self.batch_size*2)
self.__loadmodels()
#self.__loadmodels()
self.audio_processor= audio_processor
self.__loadavatar()
self.asr = MuseASR(opt,self,self.audio_processor)
@ -163,13 +182,15 @@ class MuseReal(BaseReal):
#self.__warm_up()
self.render_event = mp.Event()
mp.Process(target=inference, args=(self.render_event,self.batch_size,self.latents_out_path,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
)).start() #self.vae, self.unet, self.pe,self.timesteps
self.vae = vae
self.unet = unet
self.pe = pe
self.timesteps = timesteps
def __loadmodels(self):
# load model weights
self.audio_processor= load_audio_model()
# def __loadmodels(self):
# # load model weights
# self.audio_processor= load_audio_model()
# self.audio_processor, self.vae, self.unet, self.pe = load_all_model()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# self.timesteps = torch.tensor([0], device=device)
@ -178,7 +199,7 @@ class MuseReal(BaseReal):
# self.unet.model = self.unet.model.half()
def __loadavatar(self):
#self.input_latent_list_cycle = torch.load(self.latents_out_path)
self.input_latent_list_cycle = torch.load(self.latents_out_path,weights_only=True)
with open(self.coords_path, 'rb') as f:
self.coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(self.full_imgs_path, '*.[jpJP][pnPN]*[gG]'))
@ -287,6 +308,9 @@ class MuseReal(BaseReal):
process_thread.start()
self.render_event.set() #start infer process render
Thread(target=inference, args=(self.render_event,self.batch_size,self.input_latent_list_cycle,
self.asr.feat_queue,self.asr.output_queue,self.res_frame_queue,
self.vae, self.unet, self.pe,self.timesteps)).start() #mp.Process
count=0
totaltime=0
_starttime=time.perf_counter()

@ -30,6 +30,7 @@ def video2imgs(vid_path, save_path, ext='.png', cut_frame=10000000):
break
ret, frame = cap.read()
if ret:
cv2.putText(frame, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1)
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1
else:

@ -1,19 +1,33 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoModelForCTC, AutoProcessor, Wav2Vec2Processor, HubertModel
import queue
from queue import Queue
#from collections import deque
from threading import Thread, Event
from baseasr import BaseASR
class NerfASR(BaseASR):
def __init__(self, opt, parent):
def __init__(self, opt, parent, audio_processor,audio_model):
super().__init__(opt,parent)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
@ -37,13 +51,15 @@ class NerfASR(BaseASR):
self.frames.extend([np.zeros(self.chunk, dtype=np.float32)] * self.stride_left_size)
# create wav2vec model
print(f'[INFO] loading ASR model {self.opt.asr_model}...')
if 'hubert' in self.opt.asr_model:
self.processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
self.model = HubertModel.from_pretrained(opt.asr_model).to(self.device)
else:
self.processor = AutoProcessor.from_pretrained(opt.asr_model)
self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
# print(f'[INFO] loading ASR model {self.opt.asr_model}...')
# if 'hubert' in self.opt.asr_model:
# self.processor = Wav2Vec2Processor.from_pretrained(opt.asr_model)
# self.model = HubertModel.from_pretrained(opt.asr_model).to(self.device)
# else:
# self.processor = AutoProcessor.from_pretrained(opt.asr_model)
# self.model = AutoModelForCTC.from_pretrained(opt.asr_model).to(self.device)
self.processor = audio_processor
self.model = audio_model
# the extracted features
# use a loop queue to efficiently record endless features: [f--t---][-------][-------]

@ -1,9 +1,25 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import math
import torch
import numpy as np
#from .utils import *
import subprocess
import os
import time
import torch.nn.functional as F
@ -11,7 +27,6 @@ import cv2
import glob
from nerfasr import NerfASR
from ttsreal import EdgeTTS,VoitsTTS,XTTS
import asyncio
from av import AudioFrame, VideoFrame
@ -29,7 +44,7 @@ def read_imgs(img_list):
return frames
class NeRFReal(BaseReal):
def __init__(self, opt, trainer, data_loader, debug=True):
def __init__(self, opt, trainer, data_loader, audio_processor,audio_model, debug=True):
super().__init__(opt)
#self.opt = opt # shared with the trainer's opt to support in-place modification of rendering parameters.
self.W = opt.W
@ -79,7 +94,7 @@ class NeRFReal(BaseReal):
#self.customimg_index = 0
# build asr
self.asr = NerfASR(opt,self)
self.asr = NerfASR(opt,self,audio_processor,audio_model)
self.asr.warm_up()
'''

@ -1,3 +1,20 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import time
import numpy as np
import soundfile as sf

@ -36,6 +36,7 @@ def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
break
ret, frame = cap.read()
if ret:
cv2.putText(frame, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1)
cv2.imwrite(f"{save_path}/{count:08d}.png", frame)
count += 1
else:

@ -1,3 +1,19 @@
###############################################################################
# Copyright (C) 2024 LiveTalking@lipku https://github.com/lipku/LiveTalking
# email: lipku@foxmail.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
###############################################################################
import asyncio
import json

Loading…
Cancel
Save