add wav2lip and musetalk model warmup

main
lipku 7 months ago
parent d35652ef67
commit 7a57e5a891

@ -461,19 +461,21 @@ if __name__ == '__main__':
# nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model)
# nerfreals.append(nerfreal)
elif opt.model == 'musetalk':
from musereal import MuseReal,load_model,load_avatar
from musereal import MuseReal,load_model,load_avatar,warm_up
print(opt)
model = load_model()
avatar = load_avatar(opt.avatar_id)
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model)
# for k in range(opt.max_session):
# opt.sessionid=k
# nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps)
# nerfreals.append(nerfreal)
elif opt.model == 'wav2lip':
from lipreal import LipReal,load_model,load_avatar
from lipreal import LipReal,load_model,load_avatar,warm_up
print(opt)
model = load_model("./models/wav2lip.pth")
avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,96)
# for k in range(opt.max_session):
# opt.sessionid=k
# nerfreal = LipReal(opt,model)

@ -85,6 +85,14 @@ def load_avatar(avatar_id):
return frame_list_cycle,face_list_cycle,coord_list_cycle
@torch.no_grad()
def warm_up(batch_size,model,modelres):
# 预热函数
print('warmup model...')
img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device)
mel_batch = torch.ones(batch_size, 1, 80, 16).to(device)
model(mel_batch, img_batch)
def read_imgs(img_list):
frames = []
print('reading images...')
@ -191,7 +199,6 @@ class LipReal(BaseReal):
self.asr = LipASR(opt,self)
self.asr.warm_up()
#self.__warm_up()
self.render_event = mp.Event()

@ -89,6 +89,25 @@ def load_avatar(avatar_id):
mask_list_cycle = read_imgs(input_mask_list)
return frame_list_cycle,mask_list_cycle,coord_list_cycle,mask_coords_list_cycle,input_latent_list_cycle
@torch.no_grad()
def warm_up(batch_size,model):
# 预热函数
print('warmup model...')
vae, unet, pe, timesteps, audio_processor = model
#batch_size = 16
#timesteps = torch.tensor([0], device=unet.device)
whisper_batch = np.ones((batch_size, 50, 384), dtype=np.uint8)
latent_batch = torch.ones(batch_size, 8, 32, 32).to(unet.device)
audio_feature_batch = torch.from_numpy(whisper_batch)
audio_feature_batch = audio_feature_batch.to(device=unet.device, dtype=unet.model.dtype)
audio_feature_batch = pe(audio_feature_batch)
latent_batch = latent_batch.to(dtype=unet.model.dtype)
pred_latents = unet.model(latent_batch,
timesteps,
encoder_hidden_states=audio_feature_batch).sample
vae.decode_latents(pred_latents)
def read_imgs(img_list):
frames = []
print('reading images...')
@ -206,7 +225,6 @@ class MuseReal(BaseReal):
self.asr = MuseASR(opt,self,self.audio_processor)
self.asr.warm_up()
#self.__warm_up()
self.render_event = mp.Event()

@ -75,3 +75,23 @@ function stop() {
pc.close();
}, 500);
}
window.onunload = function(event) {
// 在这里执行你想要的操作
setTimeout(() => {
pc.close();
}, 500);
};
window.onbeforeunload = function (e) {
setTimeout(() => {
pc.close();
}, 500);
e = e || window.event
// 兼容IE8和Firefox 4之前的版本
if (e) {
e.returnValue = '关闭提示'
}
// Chrome, Safari, Firefox 4+, Opera 12+ , IE 9+
return '关闭提示'
}
Loading…
Cancel
Save