@ -44,9 +44,12 @@ global_step = 0
def main ( ) :
def main ( ) :
""" Assume Single Node Multi GPUs Training Only """
""" Assume Single Node Multi GPUs Training Only """
assert torch . cuda . is_available ( ) , " CPU training is not allowed."
assert torch . cuda . is_available ( ) or torch . backends . mps . is_available ( ) , " Only GPU training is allowed."
n_gpus = torch . cuda . device_count ( )
if torch . backends . mps . is_available ( ) :
n_gpus = 1
else :
n_gpus = torch . cuda . device_count ( )
os . environ [ " MASTER_ADDR " ] = " localhost "
os . environ [ " MASTER_ADDR " ] = " localhost "
os . environ [ " MASTER_PORT " ] = str ( randint ( 20000 , 55555 ) )
os . environ [ " MASTER_PORT " ] = str ( randint ( 20000 , 55555 ) )
@ -70,13 +73,14 @@ def run(rank, n_gpus, hps):
writer_eval = SummaryWriter ( log_dir = os . path . join ( hps . s2_ckpt_dir , " eval " ) )
writer_eval = SummaryWriter ( log_dir = os . path . join ( hps . s2_ckpt_dir , " eval " ) )
dist . init_process_group (
dist . init_process_group (
backend = " gloo " if os . name == " nt " else " nccl " ,
backend = " gloo " if os . name == " nt " or torch . backends . mps . is_available ( ) else " nccl " ,
init_method = " env:// " ,
init_method = " env:// " ,
world_size = n_gpus ,
world_size = n_gpus ,
rank = rank ,
rank = rank ,
)
)
torch . manual_seed ( hps . train . seed )
torch . manual_seed ( hps . train . seed )
torch . cuda . set_device ( rank )
if torch . cuda . is_available ( ) :
torch . cuda . set_device ( rank )
train_dataset = TextAudioSpeakerLoader ( hps . data ) ########
train_dataset = TextAudioSpeakerLoader ( hps . data ) ########
train_sampler = DistributedBucketSampler (
train_sampler = DistributedBucketSampler (
@ -128,9 +132,14 @@ def run(rank, n_gpus, hps):
hps . train . segment_size / / hps . data . hop_length ,
hps . train . segment_size / / hps . data . hop_length ,
n_speakers = hps . data . n_speakers ,
n_speakers = hps . data . n_speakers ,
* * hps . model ,
* * hps . model ,
) . cuda ( rank )
) . cuda ( rank ) if torch . cuda . is_available ( ) else SynthesizerTrn (
hps . data . filter_length / / 2 + 1 ,
hps . train . segment_size / / hps . data . hop_length ,
n_speakers = hps . data . n_speakers ,
* * hps . model ,
) . to ( " mps " )
net_d = MultiPeriodDiscriminator ( hps . model . use_spectral_norm ) . cuda ( rank )
net_d = MultiPeriodDiscriminator ( hps . model . use_spectral_norm ) . cuda ( rank ) if torch . cuda . is_available ( ) else MultiPeriodDiscriminator ( hps . model . use_spectral_norm ) . to ( " mps " )
for name , param in net_g . named_parameters ( ) :
for name , param in net_g . named_parameters ( ) :
if not param . requires_grad :
if not param . requires_grad :
print ( name , " not requires_grad " )
print ( name , " not requires_grad " )
@ -174,8 +183,12 @@ def run(rank, n_gpus, hps):
betas = hps . train . betas ,
betas = hps . train . betas ,
eps = hps . train . eps ,
eps = hps . train . eps ,
)
)
net_g = DDP ( net_g , device_ids = [ rank ] , find_unused_parameters = True )
if torch . cuda . is_available ( ) :
net_d = DDP ( net_d , device_ids = [ rank ] , find_unused_parameters = True )
net_g = DDP ( net_g , device_ids = [ rank ] , find_unused_parameters = True )
net_d = DDP ( net_d , device_ids = [ rank ] , find_unused_parameters = True )
else :
net_g = net_g . to ( " mps " )
net_d = net_d . to ( " mps " )
try : # 如果能加载自动resume
try : # 如果能加载自动resume
_ , _ , _ , epoch_str = utils . load_checkpoint (
_ , _ , _ , epoch_str = utils . load_checkpoint (
@ -205,6 +218,9 @@ def run(rank, n_gpus, hps):
net_g . module . load_state_dict (
net_g . module . load_state_dict (
torch . load ( hps . train . pretrained_s2G , map_location = " cpu " ) [ " weight " ] ,
torch . load ( hps . train . pretrained_s2G , map_location = " cpu " ) [ " weight " ] ,
strict = False ,
strict = False ,
) if torch . cuda . is_available ( ) else net_g . load_state_dict (
torch . load ( hps . train . pretrained_s2G , map_location = " cpu " ) [ " weight " ] ,
strict = False ,
)
)
) ##测试不加载优化器
) ##测试不加载优化器
if hps . train . pretrained_s2D != " " :
if hps . train . pretrained_s2D != " " :
@ -213,6 +229,8 @@ def run(rank, n_gpus, hps):
print (
print (
net_d . module . load_state_dict (
net_d . module . load_state_dict (
torch . load ( hps . train . pretrained_s2D , map_location = " cpu " ) [ " weight " ]
torch . load ( hps . train . pretrained_s2D , map_location = " cpu " ) [ " weight " ]
) if torch . cuda . is_available ( ) else net_d . load_state_dict (
torch . load ( hps . train . pretrained_s2D , map_location = " cpu " ) [ " weight " ]
)
)
)
)
@ -288,18 +306,26 @@ def train_and_evaluate(
text ,
text ,
text_lengths ,
text_lengths ,
) in tqdm ( enumerate ( train_loader ) ) :
) in tqdm ( enumerate ( train_loader ) ) :
spec , spec_lengths = spec . cuda ( rank , non_blocking = True ) , spec_lengths . cuda (
if torch . cuda . is_available ( ) :
rank , non_blocking = True
spec , spec_lengths = spec . cuda ( rank , non_blocking = True ) , spec_lengths . cuda (
)
rank , non_blocking = True
y , y_lengths = y . cuda ( rank , non_blocking = True ) , y_lengths . cuda (
)
rank , non_blocking = True
y , y_lengths = y . cuda ( rank , non_blocking = True ) , y_lengths . cuda (
)
rank , non_blocking = True
ssl = ssl . cuda ( rank , non_blocking = True )
)
ssl . requires_grad = False
ssl = ssl . cuda ( rank , non_blocking = True )
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
ssl . requires_grad = False
text , text_lengths = text . cuda ( rank , non_blocking = True ) , text_lengths . cuda (
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
rank , non_blocking = True
text , text_lengths = text . cuda ( rank , non_blocking = True ) , text_lengths . cuda (
)
rank , non_blocking = True
)
else :
spec , spec_lengths = spec . to ( " mps " ) , spec_lengths . to ( " mps " )
y , y_lengths = y . to ( " mps " ) , y_lengths . to ( " mps " )
ssl = ssl . to ( " mps " )
ssl . requires_grad = False
# ssl_lengths = ssl_lengths.cuda(rank, non_blocking=True)
text , text_lengths = text . to ( " mps " ) , text_lengths . to ( " mps " )
with autocast ( enabled = hps . train . fp16_run ) :
with autocast ( enabled = hps . train . fp16_run ) :
(
(
@ -500,13 +526,21 @@ def evaluate(hps, generator, eval_loader, writer_eval):
text_lengths ,
text_lengths ,
) in enumerate ( eval_loader ) :
) in enumerate ( eval_loader ) :
print ( 111 )
print ( 111 )
spec , spec_lengths = spec . cuda ( ) , spec_lengths . cuda ( )
if torch . cuda . is_available ( ) :
y , y_lengths = y . cuda ( ) , y_lengths . cuda ( )
spec , spec_lengths = spec . cuda ( ) , spec_lengths . cuda ( )
ssl = ssl . cuda ( )
y , y_lengths = y . cuda ( ) , y_lengths . cuda ( )
text , text_lengths = text . cuda ( ) , text_lengths . cuda ( )
ssl = ssl . cuda ( )
text , text_lengths = text . cuda ( ) , text_lengths . cuda ( )
else :
spec , spec_lengths = spec . to ( " mps " ) , spec_lengths . to ( " mps " )
y , y_lengths = y . to ( " mps " ) , y_lengths . to ( " mps " )
ssl = ssl . to ( " mps " )
text , text_lengths = text . to ( " mps " ) , text_lengths . to ( " mps " )
for test in [ 0 , 1 ] :
for test in [ 0 , 1 ] :
y_hat , mask , * _ = generator . module . infer (
y_hat , mask , * _ = generator . module . infer (
ssl , spec , spec_lengths , text , text_lengths , test = test
ssl , spec , spec_lengths , text , text_lengths , test = test
) if torch . cuda . is_available ( ) else generator . infer (
ssl , spec , spec_lengths , text , text_lengths , test = test
)
)
y_hat_lengths = mask . sum ( [ 1 , 2 ] ) . long ( ) * hps . data . hop_length
y_hat_lengths = mask . sum ( [ 1 , 2 ] ) . long ( ) * hps . data . hop_length