diff --git a/GPT_SoVITS/export_torch_script.py b/GPT_SoVITS/export_torch_script.py index a3a4827..e4406f2 100644 --- a/GPT_SoVITS/export_torch_script.py +++ b/GPT_SoVITS/export_torch_script.py @@ -474,6 +474,10 @@ class T2SModel(nn.Module): bert = bert.unsqueeze(0) x = self.ar_text_embedding(all_phoneme_ids) + + # avoid dtype inconsistency when exporting + bert = bert.to(dtype=self.bert_proj.weight.dtype) + x = x + self.bert_proj(bert.transpose(1, 2)) x: torch.Tensor = self.ar_text_position(x)