|
|
@ -116,9 +116,9 @@ def main(args):
|
|
|
|
devices=-1,
|
|
|
|
devices=-1,
|
|
|
|
benchmark=False,
|
|
|
|
benchmark=False,
|
|
|
|
fast_dev_run=False,
|
|
|
|
fast_dev_run=False,
|
|
|
|
strategy=DDPStrategy(
|
|
|
|
strategy = "auto" if torch.mps.is_available() else DDPStrategy(
|
|
|
|
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
|
|
|
process_group_backend="nccl" if platform.system() != "Windows" else "gloo"
|
|
|
|
),
|
|
|
|
), # mps 不支持多节点训练
|
|
|
|
precision=config["train"]["precision"],
|
|
|
|
precision=config["train"]["precision"],
|
|
|
|
logger=logger,
|
|
|
|
logger=logger,
|
|
|
|
num_sanity_val_steps=0,
|
|
|
|
num_sanity_val_steps=0,
|
|
|
|