@ -41,15 +41,15 @@ torch.set_float32_matmul_precision("medium") # 最低精度但最快(也就
# from config import pretrained_s2G,pretrained_s2D
# from config import pretrained_s2G,pretrained_s2D
global_step = 0
global_step = 0
device = " cpu " # cuda以外的设备, 等mps优化后加入
def main ( ) :
def main ( ) :
""" Assume Single Node Multi GPUs Training Only """
assert torch . cuda . is_available ( ) or torch . backends . mps . is_available ( ) , " Only GPU training is allowed. "
if torch . backends . mps . is_available ( ) :
if torch . cuda . is_available ( ) :
n_gpus = 1
else :
n_gpus = torch . cuda . device_count ( )
n_gpus = torch . cuda . device_count ( )
else :
n_gpus = 1
os . environ [ " MASTER_ADDR " ] = " localhost "
os . environ [ " MASTER_ADDR " ] = " localhost "
os . environ [ " MASTER_PORT " ] = str ( randint ( 20000 , 55555 ) )
os . environ [ " MASTER_PORT " ] = str ( randint ( 20000 , 55555 ) )
@ -73,7 +73,7 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter ( log_dir = os . path . join ( hps . s2_ckpt_dir , " eval " ) )
writer_eval = SummaryWriter ( log_dir = os . path . join ( hps . s2_ckpt_dir , " eval " ) )
dist . init_process_group (
dist . init_process_group (
backend = " gloo " if os . name == " nt " or torch . backends . mps . is_available ( ) else " nccl " ,
backend = " gloo " if os . name == " nt " or not torch . cuda . is_available ( ) else " nccl " ,
init_method = " env:// " ,
init_method = " env:// " ,
world_size = n_gpus ,
world_size = n_gpus ,
rank = rank ,
rank = rank ,
@ -137,9 +137,9 @@ def run(rank, n_gpus, hps):
hps . train . segment_size / / hps . data . hop_length ,
hps . train . segment_size / / hps . data . hop_length ,
n_speakers = hps . data . n_speakers ,
n_speakers = hps . data . n_speakers ,
* * hps . model ,
* * hps . model ,
) . to ( " mps " )
) . to ( device )
net_d = MultiPeriodDiscriminator ( hps . model . use_spectral_norm ) . cuda ( rank ) if torch . cuda . is_available ( ) else MultiPeriodDiscriminator ( hps . model . use_spectral_norm ) . to ( " mps " )
net_d = MultiPeriodDiscriminator ( hps . model . use_spectral_norm ) . cuda ( rank ) if torch . cuda . is_available ( ) else MultiPeriodDiscriminator ( hps . model . use_spectral_norm ) . to ( device )
for name , param in net_g . named_parameters ( ) :
for name , param in net_g . named_parameters ( ) :
if not param . requires_grad :
if not param . requires_grad :
print ( name , " not requires_grad " )
print ( name , " not requires_grad " )
@ -187,8 +187,8 @@ def run(rank, n_gpus, hps):
net_g = DDP ( net_g , device_ids = [ rank ] , find_unused_parameters = True )
net_g = DDP ( net_g , device_ids = [ rank ] , find_unused_parameters = True )
net_d = DDP ( net_d , device_ids = [ rank ] , find_unused_parameters = True )
net_d = DDP ( net_d , device_ids = [ rank ] , find_unused_parameters = True )
else :
else :
net_g = net_g . to ( " mps " )
net_g = net_g . to ( device )
net_d = net_d . to ( " mps " )
net_d = net_d . to ( device )
try : # 如果能加载自动resume
try : # 如果能加载自动resume
_ , _ , _ , epoch_str = utils . load_checkpoint (
_ , _ , _ , epoch_str = utils . load_checkpoint (
@ -320,12 +320,12 @@ def train_and_evaluate(
rank , non_blocking = True
rank , non_blocking = True
)
)
else :
else :
spec , spec_lengths = spec . to ( " mps " ) , spec_lengths . to ( " mps " )
spec , spec_lengths = spec . to ( device ) , spec_lengths . to ( device )
y , y_lengths = y . to ( " mps " ) , y_lengths . to ( " mps " )
y , y_lengths = y . to ( device ) , y_lengths . to ( device )
ssl = ssl . to ( " mps " )
ssl = ssl . to ( device )
ssl . requires_grad = False
ssl . requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text , text_lengths = text . to ( " mps " ) , text_lengths . to ( " mps " )
text , text_lengths = text . to ( device ) , text_lengths . to ( device )
with autocast ( enabled = hps . train . fp16_run ) :
with autocast ( enabled = hps . train . fp16_run ) :
(
(
@ -532,10 +532,10 @@ def evaluate(hps, generator, eval_loader, writer_eval):
ssl = ssl . cuda ( )
ssl = ssl . cuda ( )
text , text_lengths = text . cuda ( ) , text_lengths . cuda ( )
text , text_lengths = text . cuda ( ) , text_lengths . cuda ( )
else :
else :
spec , spec_lengths = spec . to ( " mps " ) , spec_lengths . to ( " mps " )
spec , spec_lengths = spec . to ( device ) , spec_lengths . to ( device )
y , y_lengths = y . to ( " mps " ) , y_lengths . to ( " mps " )
y , y_lengths = y . to ( device ) , y_lengths . to ( device )
ssl = ssl . to ( " mps " )
ssl = ssl . to ( device )
text , text_lengths = text . to ( " mps " ) , text_lengths . to ( " mps " )
text , text_lengths = text . to ( device ) , text_lengths . to ( device )
for test in [ 0 , 1 ] :
for test in [ 0 , 1 ] :
y_hat , mask , * _ = generator . module . infer (
y_hat , mask , * _ = generator . module . infer (
ssl , spec , spec_lengths , text , text_lengths , test = test
ssl , spec , spec_lengths , text , text_lengths , test = test