|
|
@ -281,7 +281,7 @@ def train_and_evaluate(
|
|
|
|
# text,
|
|
|
|
# text,
|
|
|
|
# text_lengths,
|
|
|
|
# text_lengths,
|
|
|
|
# ) in enumerate(tqdm(train_loader)):
|
|
|
|
# ) in enumerate(tqdm(train_loader)):
|
|
|
|
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in tqdm(enumerate(train_loader)):
|
|
|
|
for batch_idx, (ssl, spec, mel, ssl_lengths, spec_lengths, text, text_lengths, mel_lengths) in enumerate(tqdm(train_loader)):
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
|
|
|
spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(
|
|
|
|
rank, non_blocking=True
|
|
|
|
rank, non_blocking=True
|
|
|
|