Update export_torch_script.py (#2494)

Avoid dtype inconsistency when exporting
main
Yixiao Chen 1 month ago committed by GitHub
parent 6df61f58e4
commit 8c579d46dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -474,6 +474,10 @@ class T2SModel(nn.Module):
bert = bert.unsqueeze(0)
x = self.ar_text_embedding(all_phoneme_ids)
# avoid dtype inconsistency when exporting
bert = bert.to(dtype=self.bert_proj.weight.dtype)
x = x + self.bert_proj(bert.transpose(1, 2))
x: torch.Tensor = self.ar_text_position(x)

Loading…
Cancel
Save