diff --git a/app.py b/app.py index 3cd81d3..85130cf 100644 --- a/app.py +++ b/app.py @@ -461,19 +461,21 @@ if __name__ == '__main__': # nerfreal = NeRFReal(opt, trainer, test_loader,audio_processor,audio_model) # nerfreals.append(nerfreal) elif opt.model == 'musetalk': - from musereal import MuseReal,load_model,load_avatar + from musereal import MuseReal,load_model,load_avatar,warm_up print(opt) model = load_model() - avatar = load_avatar(opt.avatar_id) + avatar = load_avatar(opt.avatar_id) + warm_up(opt.batch_size,model) # for k in range(opt.max_session): # opt.sessionid=k # nerfreal = MuseReal(opt,audio_processor,vae, unet, pe,timesteps) # nerfreals.append(nerfreal) elif opt.model == 'wav2lip': - from lipreal import LipReal,load_model,load_avatar + from lipreal import LipReal,load_model,load_avatar,warm_up print(opt) model = load_model("./models/wav2lip.pth") avatar = load_avatar(opt.avatar_id) + warm_up(opt.batch_size,model,96) # for k in range(opt.max_session): # opt.sessionid=k # nerfreal = LipReal(opt,model) diff --git a/lipreal.py b/lipreal.py index b7a6d52..847f99b 100644 --- a/lipreal.py +++ b/lipreal.py @@ -85,6 +85,14 @@ def load_avatar(avatar_id): return frame_list_cycle,face_list_cycle,coord_list_cycle +@torch.no_grad() +def warm_up(batch_size,model,modelres): + # 预热函数 + print('warmup model...') + img_batch = torch.ones(batch_size, 6, modelres, modelres).to(device) + mel_batch = torch.ones(batch_size, 1, 80, 16).to(device) + model(mel_batch, img_batch) + def read_imgs(img_list): frames = [] print('reading images...') @@ -191,7 +199,6 @@ class LipReal(BaseReal): self.asr = LipASR(opt,self) self.asr.warm_up() - #self.__warm_up() self.render_event = mp.Event() diff --git a/musereal.py b/musereal.py index ebb25a7..ba6009b 100644 --- a/musereal.py +++ b/musereal.py @@ -89,6 +89,25 @@ def load_avatar(avatar_id): mask_list_cycle = read_imgs(input_mask_list) return frame_list_cycle,mask_list_cycle,coord_list_cycle,mask_coords_list_cycle,input_latent_list_cycle +@torch.no_grad() +def warm_up(batch_size,model): + # 预热函数 + print('warmup model...') + vae, unet, pe, timesteps, audio_processor = model + #batch_size = 16 + #timesteps = torch.tensor([0], device=unet.device) + whisper_batch = np.ones((batch_size, 50, 384), dtype=np.uint8) + latent_batch = torch.ones(batch_size, 8, 32, 32).to(unet.device) + + audio_feature_batch = torch.from_numpy(whisper_batch) + audio_feature_batch = audio_feature_batch.to(device=unet.device, dtype=unet.model.dtype) + audio_feature_batch = pe(audio_feature_batch) + latent_batch = latent_batch.to(dtype=unet.model.dtype) + pred_latents = unet.model(latent_batch, + timesteps, + encoder_hidden_states=audio_feature_batch).sample + vae.decode_latents(pred_latents) + def read_imgs(img_list): frames = [] print('reading images...') @@ -206,7 +225,6 @@ class MuseReal(BaseReal): self.asr = MuseASR(opt,self,self.audio_processor) self.asr.warm_up() - #self.__warm_up() self.render_event = mp.Event() diff --git a/web/client.js b/web/client.js index b23ecc1..8a51113 100644 --- a/web/client.js +++ b/web/client.js @@ -75,3 +75,23 @@ function stop() { pc.close(); }, 500); } + +window.onunload = function(event) { + // 在这里执行你想要的操作 + setTimeout(() => { + pc.close(); + }, 500); +}; + +window.onbeforeunload = function (e) { + setTimeout(() => { + pc.close(); + }, 500); + e = e || window.event + // 兼容IE8和Firefox 4之前的版本 + if (e) { + e.returnValue = '关闭提示' + } + // Chrome, Safari, Firefox 4+, Opera 12+ , IE 9+ + return '关闭提示' + } \ No newline at end of file