|
|
@ -304,7 +304,7 @@ class TTS:
|
|
|
|
def init_vits_weights(self, weights_path: str):
|
|
|
|
def init_vits_weights(self, weights_path: str):
|
|
|
|
print(f"Loading VITS weights from {weights_path}")
|
|
|
|
print(f"Loading VITS weights from {weights_path}")
|
|
|
|
self.configs.vits_weights_path = weights_path
|
|
|
|
self.configs.vits_weights_path = weights_path
|
|
|
|
dict_s2 = torch.load(weights_path, map_location=self.configs.device)
|
|
|
|
dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
|
|
|
|
hps = dict_s2["config"]
|
|
|
|
hps = dict_s2["config"]
|
|
|
|
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
|
|
|
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
|
|
|
self.configs.update_version("v1")
|
|
|
|
self.configs.update_version("v1")
|
|
|
|