From 2cdf382897dd0ac1da463922da40d2e7ee1032a6 Mon Sep 17 00:00:00 2001 From: sloop Date: Sun, 16 Mar 2025 07:16:55 +0800 Subject: [PATCH] support mac m1-m4 (#376) Co-authored-by: yuheng --- ernerf/data_utils/face_tracking/face_tracker.py | 7 +++++-- ernerf/data_utils/face_tracking/render_3dmm.py | 2 +- ernerf/main.py | 2 +- ernerf/nerf_triplane/utils.py | 8 ++++++-- lightreal.py | 6 ++---- lipreal.py | 4 ++-- musereal.py | 4 ++-- musetalk/models/unet.py | 2 +- musetalk/models/vae.py | 2 +- musetalk/simple_musetalk.py | 2 +- musetalk/utils/face_parsing/resnet.py | 2 +- musetalk/utils/preprocessing.py | 6 +++--- musetalk/whisper/whisper/__init__.py | 2 +- musetalk/whisper/whisper/transcribe.py | 4 +++- nerfasr.py | 2 +- nerfreal.py | 2 +- ultralight/unet.py | 2 +- 17 files changed, 33 insertions(+), 26 deletions(-) diff --git a/ernerf/data_utils/face_tracking/face_tracker.py b/ernerf/data_utils/face_tracking/face_tracker.py index e978856..a581828 100644 --- a/ernerf/data_utils/face_tracking/face_tracker.py +++ b/ernerf/data_utils/face_tracking/face_tracker.py @@ -179,8 +179,11 @@ print(f'[INFO] fitting light...') batch_size = 32 -device_default = torch.device("cuda:0") -device_render = torch.device("cuda:0") +device_default = torch.device("cuda:0" if torch.cuda.is_available() else ( + "mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) +device_render = torch.device("cuda:0" if torch.cuda.is_available() else ( + "mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) + renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size] diff --git a/ernerf/data_utils/face_tracking/render_3dmm.py b/ernerf/data_utils/face_tracking/render_3dmm.py index 6a29e19..920c318 100644 --- a/ernerf/data_utils/face_tracking/render_3dmm.py +++ b/ernerf/data_utils/face_tracking/render_3dmm.py @@ -83,7 +83,7 @@ class Render_3DMM(nn.Module): img_h=500, img_w=500, batch_size=1, - device=torch.device("cuda:0"), + device=torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")), ): super(Render_3DMM, self).__init__() diff --git a/ernerf/main.py b/ernerf/main.py index a2ad228..8789d65 100644 --- a/ernerf/main.py +++ b/ernerf/main.py @@ -147,7 +147,7 @@ if __name__ == '__main__': seed_everything(opt.seed) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) model = NeRFNetwork(opt) diff --git a/ernerf/nerf_triplane/utils.py b/ernerf/nerf_triplane/utils.py index fa0b562..cc048f1 100644 --- a/ernerf/nerf_triplane/utils.py +++ b/ernerf/nerf_triplane/utils.py @@ -442,7 +442,7 @@ class LPIPSMeter: self.N = 0 self.net = net - self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else ('mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu')) self.fn = lpips.LPIPS(net=net).eval().to(self.device) def clear(self): @@ -618,7 +618,11 @@ class Trainer(object): self.flip_init_lips = self.opt.init_lips self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") self.scheduler_update_every_step = scheduler_update_every_step - self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') + self.device = device if device is not None else torch.device( + f'cuda:{local_rank}' if torch.cuda.is_available() else ( + 'mps' if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else 'cpu' + ) + ) self.console = Console() model.to(self.device) diff --git a/lightreal.py b/lightreal.py index 51f544e..adadbf1 100644 --- a/lightreal.py +++ b/lightreal.py @@ -56,10 +56,8 @@ from ultralight.unet import Model from ultralight.audio2feature import Audio2Feature from logger import logger - -device = 'cuda' if torch.cuda.is_available() else 'cpu' -logger.info('Using {} for inference.'.format(device)) - +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") +print('Using {} for inference.'.format(device)) def load_model(opt): audio_processor = Audio2Feature() diff --git a/lipreal.py b/lipreal.py index 97a5626..4f18a91 100644 --- a/lipreal.py +++ b/lipreal.py @@ -44,8 +44,8 @@ from basereal import BaseReal from tqdm import tqdm from logger import logger -device = 'cuda' if torch.cuda.is_available() else 'cpu' -logger.info('Using {} for inference.'.format(device)) +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") +print('Using {} for inference.'.format(device)) def _load(checkpoint_path): if device == 'cuda': diff --git a/musereal.py b/musereal.py index bfbb155..4bad64e 100644 --- a/musereal.py +++ b/musereal.py @@ -51,7 +51,7 @@ from logger import logger def load_model(): # load model weights audio_processor,vae, unet, pe = load_all_model() - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) timesteps = torch.tensor([0], device=device) pe = pe.half() vae.vae = vae.vae.half() @@ -77,7 +77,7 @@ def load_avatar(avatar_id): # "bbox_shift":self.bbox_shift # } - input_latent_list_cycle = torch.load(latents_out_path) #,weights_only=True + input_latent_list_cycle = torch.load(latents_out_path, map_location=torch.device('mps')) #,weights_only=True with open(coords_path, 'rb') as f: coord_list_cycle = pickle.load(f) input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')) diff --git a/musetalk/models/unet.py b/musetalk/models/unet.py index 2bcc2b0..9977ccd 100755 --- a/musetalk/models/unet.py +++ b/musetalk/models/unet.py @@ -36,7 +36,7 @@ class UNet(): unet_config = json.load(f) self.model = UNet2DConditionModel(**unet_config) self.pe = PositionalEncoding(d_model=384) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device) self.model.load_state_dict(weights) if use_float16: diff --git a/musetalk/models/vae.py b/musetalk/models/vae.py index 51efef4..963d4e3 100755 --- a/musetalk/models/vae.py +++ b/musetalk/models/vae.py @@ -23,7 +23,7 @@ class VAE(): self.model_path = model_path self.vae = AutoencoderKL.from_pretrained(self.model_path) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) self.vae.to(self.device) if use_float16: diff --git a/musetalk/simple_musetalk.py b/musetalk/simple_musetalk.py index 4008cb0..7bf9b21 100644 --- a/musetalk/simple_musetalk.py +++ b/musetalk/simple_musetalk.py @@ -325,7 +325,7 @@ def create_musetalk_human(file, avatar_id): # initialize the mmpose model -device = "cuda" if torch.cuda.is_available() else "cpu" +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") fa = FaceAlignment(1, flip_input=False, device=device) config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py') checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth')) diff --git a/musetalk/utils/face_parsing/resnet.py b/musetalk/utils/face_parsing/resnet.py index e2e5d87..a306abb 100755 --- a/musetalk/utils/face_parsing/resnet.py +++ b/musetalk/utils/face_parsing/resnet.py @@ -80,7 +80,7 @@ class Resnet18(nn.Module): return feat8, feat16, feat32 def init_weight(self, model_path): - state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url) + state_dict = torch.load(model_path, weights_only=False) #modelzoo.load_url(resnet18_url) self_state_dict = self.state_dict() for k, v in state_dict.items(): if 'fc' in k: continue diff --git a/musetalk/utils/preprocessing.py b/musetalk/utils/preprocessing.py index 1d2f024..dc6985b 100644 --- a/musetalk/utils/preprocessing.py +++ b/musetalk/utils/preprocessing.py @@ -13,14 +13,14 @@ import torch from tqdm import tqdm # initialize the mmpose model -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth' model = init_model(config_file, checkpoint_file, device=device) # initialize the face detection model -device = "cuda" if torch.cuda.is_available() else "cpu" -fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device) +device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") +fa = FaceAlignment(LandmarksType._2D, flip_input=False, device=device) # maker if the bbox is not sufficient coord_placeholder = (0.0,0.0,0.0,0.0) diff --git a/musetalk/whisper/whisper/__init__.py b/musetalk/whisper/whisper/__init__.py index b925553..cca868a 100644 --- a/musetalk/whisper/whisper/__init__.py +++ b/musetalk/whisper/whisper/__init__.py @@ -91,7 +91,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow """ if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") if download_root is None: download_root = os.getenv( "XDG_CACHE_HOME", diff --git a/musetalk/whisper/whisper/transcribe.py b/musetalk/whisper/whisper/transcribe.py index 745e775..04ad420 100644 --- a/musetalk/whisper/whisper/transcribe.py +++ b/musetalk/whisper/whisper/transcribe.py @@ -78,6 +78,8 @@ def transcribe( if dtype == torch.float16: warnings.warn("FP16 is not supported on CPU; using FP32 instead") dtype = torch.float32 + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + warnings.warn("Performing inference on CPU when MPS is available") if dtype == torch.float32: decode_options["fp16"] = False @@ -135,7 +137,7 @@ def cli(): parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") - parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") + parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "mps", help="device to use for PyTorch inference") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") diff --git a/nerfasr.py b/nerfasr.py index d869ed2..ba8b2ac 100644 --- a/nerfasr.py +++ b/nerfasr.py @@ -30,7 +30,7 @@ class NerfASR(BaseASR): def __init__(self, opt, parent, audio_processor,audio_model): super().__init__(opt,parent) - self.device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu") if 'esperanto' in self.opt.asr_model: self.audio_dim = 44 elif 'deepspeech' in self.opt.asr_model: diff --git a/nerfreal.py b/nerfreal.py index 6ae0c63..ab022b4 100644 --- a/nerfreal.py +++ b/nerfreal.py @@ -77,7 +77,7 @@ def load_model(opt): seed_everything(opt.seed) logger.info(opt) - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else 'cpu')) model = NeRFNetwork(opt) criterion = torch.nn.MSELoss(reduction='none') diff --git a/ultralight/unet.py b/ultralight/unet.py index d60f51f..e09a8fb 100644 --- a/ultralight/unet.py +++ b/ultralight/unet.py @@ -236,7 +236,7 @@ if __name__ == '__main__': if hasattr(module, 'reparameterize'): module.reparameterize() return model - device = torch.device("cuda") + device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")) def check_onnx(torch_out, torch_in, audio): onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model)