support mac m1-m4 (#376)

Co-authored-by: yuheng <lipku@163.com>
main
sloop 4 months ago committed by GitHub
parent 7b340cc9a2
commit 2cdf382897
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -179,8 +179,11 @@ print(f'[INFO] fitting light...')
batch_size = 32 batch_size = 32
device_default = torch.device("cuda:0") device_default = torch.device("cuda:0" if torch.cuda.is_available() else (
device_render = torch.device("cuda:0") "mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
device_render = torch.device("cuda:0" if torch.cuda.is_available() else (
"mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render) renderer = Render_3DMM(arg_focal, h, w, batch_size, device_render)
sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size] sel_ids = np.arange(0, num_frames, int(num_frames / batch_size))[:batch_size]

@ -83,7 +83,7 @@ class Render_3DMM(nn.Module):
img_h=500, img_h=500,
img_w=500, img_w=500,
batch_size=1, batch_size=1,
device=torch.device("cuda:0"), device=torch.device("cuda:0" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")),
): ):
super(Render_3DMM, self).__init__() super(Render_3DMM, self).__init__()

@ -147,7 +147,7 @@ if __name__ == '__main__':
seed_everything(opt.seed) seed_everything(opt.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
model = NeRFNetwork(opt) model = NeRFNetwork(opt)

@ -442,7 +442,7 @@ class LPIPSMeter:
self.N = 0 self.N = 0
self.net = net self.net = net
self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else ('mps' if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() else 'cpu'))
self.fn = lpips.LPIPS(net=net).eval().to(self.device) self.fn = lpips.LPIPS(net=net).eval().to(self.device)
def clear(self): def clear(self):
@ -618,7 +618,11 @@ class Trainer(object):
self.flip_init_lips = self.opt.init_lips self.flip_init_lips = self.opt.init_lips
self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
self.scheduler_update_every_step = scheduler_update_every_step self.scheduler_update_every_step = scheduler_update_every_step
self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') self.device = device if device is not None else torch.device(
f'cuda:{local_rank}' if torch.cuda.is_available() else (
'mps' if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else 'cpu'
)
)
self.console = Console() self.console = Console()
model.to(self.device) model.to(self.device)

@ -56,10 +56,8 @@ from ultralight.unet import Model
from ultralight.audio2feature import Audio2Feature from ultralight.audio2feature import Audio2Feature
from logger import logger from logger import logger
device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
device = 'cuda' if torch.cuda.is_available() else 'cpu' print('Using {} for inference.'.format(device))
logger.info('Using {} for inference.'.format(device))
def load_model(opt): def load_model(opt):
audio_processor = Audio2Feature() audio_processor = Audio2Feature()

@ -44,8 +44,8 @@ from basereal import BaseReal
from tqdm import tqdm from tqdm import tqdm
from logger import logger from logger import logger
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
logger.info('Using {} for inference.'.format(device)) print('Using {} for inference.'.format(device))
def _load(checkpoint_path): def _load(checkpoint_path):
if device == 'cuda': if device == 'cuda':

@ -51,7 +51,7 @@ from logger import logger
def load_model(): def load_model():
# load model weights # load model weights
audio_processor,vae, unet, pe = load_all_model() audio_processor,vae, unet, pe = load_all_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
timesteps = torch.tensor([0], device=device) timesteps = torch.tensor([0], device=device)
pe = pe.half() pe = pe.half()
vae.vae = vae.vae.half() vae.vae = vae.vae.half()
@ -77,7 +77,7 @@ def load_avatar(avatar_id):
# "bbox_shift":self.bbox_shift # "bbox_shift":self.bbox_shift
# } # }
input_latent_list_cycle = torch.load(latents_out_path) #,weights_only=True input_latent_list_cycle = torch.load(latents_out_path, map_location=torch.device('mps')) #,weights_only=True
with open(coords_path, 'rb') as f: with open(coords_path, 'rb') as f:
coord_list_cycle = pickle.load(f) coord_list_cycle = pickle.load(f)
input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]')) input_img_list = glob.glob(os.path.join(full_imgs_path, '*.[jpJP][pnPN]*[gG]'))

@ -36,7 +36,7 @@ class UNet():
unet_config = json.load(f) unet_config = json.load(f)
self.model = UNet2DConditionModel(**unet_config) self.model = UNet2DConditionModel(**unet_config)
self.pe = PositionalEncoding(d_model=384) self.pe = PositionalEncoding(d_model=384)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device) weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
self.model.load_state_dict(weights) self.model.load_state_dict(weights)
if use_float16: if use_float16:

@ -23,7 +23,7 @@ class VAE():
self.model_path = model_path self.model_path = model_path
self.vae = AutoencoderKL.from_pretrained(self.model_path) self.vae = AutoencoderKL.from_pretrained(self.model_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
self.vae.to(self.device) self.vae.to(self.device)
if use_float16: if use_float16:

@ -325,7 +325,7 @@ def create_musetalk_human(file, avatar_id):
# initialize the mmpose model # initialize the mmpose model
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
fa = FaceAlignment(1, flip_input=False, device=device) fa = FaceAlignment(1, flip_input=False, device=device)
config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py') config_file = os.path.join(current_dir, 'utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py')
checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth')) checkpoint_file = os.path.abspath(os.path.join(current_dir, '../models/dwpose/dw-ll_ucoco_384.pth'))

@ -80,7 +80,7 @@ class Resnet18(nn.Module):
return feat8, feat16, feat32 return feat8, feat16, feat32
def init_weight(self, model_path): def init_weight(self, model_path):
state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url) state_dict = torch.load(model_path, weights_only=False) #modelzoo.load_url(resnet18_url)
self_state_dict = self.state_dict() self_state_dict = self.state_dict()
for k, v in state_dict.items(): for k, v in state_dict.items():
if 'fc' in k: continue if 'fc' in k: continue

@ -13,14 +13,14 @@ import torch
from tqdm import tqdm from tqdm import tqdm
# initialize the mmpose model # initialize the mmpose model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py' config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth' checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
model = init_model(config_file, checkpoint_file, device=device) model = init_model(config_file, checkpoint_file, device=device)
# initialize the face detection model # initialize the face detection model
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device) fa = FaceAlignment(LandmarksType._2D, flip_input=False, device=device)
# maker if the bbox is not sufficient # maker if the bbox is not sufficient
coord_placeholder = (0.0,0.0,0.0,0.0) coord_placeholder = (0.0,0.0,0.0,0.0)

@ -91,7 +91,7 @@ def load_model(name: str, device: Optional[Union[str, torch.device]] = None, dow
""" """
if device is None: if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu" device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
if download_root is None: if download_root is None:
download_root = os.getenv( download_root = os.getenv(
"XDG_CACHE_HOME", "XDG_CACHE_HOME",

@ -78,6 +78,8 @@ def transcribe(
if dtype == torch.float16: if dtype == torch.float16:
warnings.warn("FP16 is not supported on CPU; using FP32 instead") warnings.warn("FP16 is not supported on CPU; using FP32 instead")
dtype = torch.float32 dtype = torch.float32
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
warnings.warn("Performing inference on CPU when MPS is available")
if dtype == torch.float32: if dtype == torch.float32:
decode_options["fp16"] = False decode_options["fp16"] = False
@ -135,7 +137,7 @@ def cli():
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe") parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use") parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default") parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference") parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "mps", help="device to use for PyTorch inference")
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs") parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages") parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")

@ -30,7 +30,7 @@ class NerfASR(BaseASR):
def __init__(self, opt, parent, audio_processor,audio_model): def __init__(self, opt, parent, audio_processor,audio_model):
super().__init__(opt,parent) super().__init__(opt,parent)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = "cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu")
if 'esperanto' in self.opt.asr_model: if 'esperanto' in self.opt.asr_model:
self.audio_dim = 44 self.audio_dim = 44
elif 'deepspeech' in self.opt.asr_model: elif 'deepspeech' in self.opt.asr_model:

@ -77,7 +77,7 @@ def load_model(opt):
seed_everything(opt.seed) seed_everything(opt.seed)
logger.info(opt) logger.info(opt)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else ('mps' if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else 'cpu'))
model = NeRFNetwork(opt) model = NeRFNetwork(opt)
criterion = torch.nn.MSELoss(reduction='none') criterion = torch.nn.MSELoss(reduction='none')

@ -236,7 +236,7 @@ if __name__ == '__main__':
if hasattr(module, 'reparameterize'): if hasattr(module, 'reparameterize'):
module.reparameterize() module.reparameterize()
return model return model
device = torch.device("cuda") device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()) else "cpu"))
def check_onnx(torch_out, torch_in, audio): def check_onnx(torch_out, torch_in, audio):
onnx_model = onnx.load(onnx_path) onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model) onnx.checker.check_model(onnx_model)

Loading…
Cancel
Save