|
|
|
@ -89,6 +89,25 @@ def load_avatar(avatar_id):
|
|
|
|
|
mask_list_cycle = read_imgs(input_mask_list)
|
|
|
|
|
return frame_list_cycle,mask_list_cycle,coord_list_cycle,mask_coords_list_cycle,input_latent_list_cycle
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def warm_up(batch_size,model):
|
|
|
|
|
# 预热函数
|
|
|
|
|
print('warmup model...')
|
|
|
|
|
vae, unet, pe, timesteps, audio_processor = model
|
|
|
|
|
#batch_size = 16
|
|
|
|
|
#timesteps = torch.tensor([0], device=unet.device)
|
|
|
|
|
whisper_batch = np.ones((batch_size, 50, 384), dtype=np.uint8)
|
|
|
|
|
latent_batch = torch.ones(batch_size, 8, 32, 32).to(unet.device)
|
|
|
|
|
|
|
|
|
|
audio_feature_batch = torch.from_numpy(whisper_batch)
|
|
|
|
|
audio_feature_batch = audio_feature_batch.to(device=unet.device, dtype=unet.model.dtype)
|
|
|
|
|
audio_feature_batch = pe(audio_feature_batch)
|
|
|
|
|
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
|
|
|
|
pred_latents = unet.model(latent_batch,
|
|
|
|
|
timesteps,
|
|
|
|
|
encoder_hidden_states=audio_feature_batch).sample
|
|
|
|
|
vae.decode_latents(pred_latents)
|
|
|
|
|
|
|
|
|
|
def read_imgs(img_list):
|
|
|
|
|
frames = []
|
|
|
|
|
print('reading images...')
|
|
|
|
@ -206,7 +225,6 @@ class MuseReal(BaseReal):
|
|
|
|
|
|
|
|
|
|
self.asr = MuseASR(opt,self,self.audio_processor)
|
|
|
|
|
self.asr.warm_up()
|
|
|
|
|
#self.__warm_up()
|
|
|
|
|
|
|
|
|
|
self.render_event = mp.Event()
|
|
|
|
|
|
|
|
|
|