|
|
@ -304,10 +304,10 @@ class TTS_Config:
|
|
|
|
configs: dict = self._load_configs(self.configs_path)
|
|
|
|
configs: dict = self._load_configs(self.configs_path)
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(configs, dict)
|
|
|
|
assert isinstance(configs, dict)
|
|
|
|
version = configs.get("version", "v2").lower()
|
|
|
|
configs_ = deepcopy(self.default_configs)
|
|
|
|
assert version in ["v1", "v2", "v3", "v4", "v2pro", "v2proplus"]
|
|
|
|
configs_.update(configs)
|
|
|
|
self.default_configs[version] = configs.get(version, self.default_configs[version])
|
|
|
|
self.configs: dict = configs_.get("custom", configs_["v2"])
|
|
|
|
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version]))
|
|
|
|
self.default_configs = deepcopy(configs_)
|
|
|
|
|
|
|
|
|
|
|
|
self.device = self.configs.get("device", torch.device("cpu"))
|
|
|
|
self.device = self.configs.get("device", torch.device("cpu"))
|
|
|
|
if "cuda" in str(self.device) and not torch.cuda.is_available():
|
|
|
|
if "cuda" in str(self.device) and not torch.cuda.is_available():
|
|
|
@ -315,11 +315,13 @@ class TTS_Config:
|
|
|
|
self.device = torch.device("cpu")
|
|
|
|
self.device = torch.device("cpu")
|
|
|
|
|
|
|
|
|
|
|
|
self.is_half = self.configs.get("is_half", False)
|
|
|
|
self.is_half = self.configs.get("is_half", False)
|
|
|
|
# if str(self.device) == "cpu" and self.is_half:
|
|
|
|
if str(self.device) == "cpu" and self.is_half:
|
|
|
|
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
|
|
|
|
print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
|
|
|
|
# self.is_half = False
|
|
|
|
self.is_half = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
version = self.configs.get("version", None)
|
|
|
|
self.version = version
|
|
|
|
self.version = version
|
|
|
|
|
|
|
|
assert self.version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"], "Invalid version!"
|
|
|
|
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
|
|
|
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
|
|
|
|
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
|
|
|
self.vits_weights_path = self.configs.get("vits_weights_path", None)
|
|
|
|
self.bert_base_path = self.configs.get("bert_base_path", None)
|
|
|
|
self.bert_base_path = self.configs.get("bert_base_path", None)
|
|
|
@ -576,6 +578,10 @@ class TTS:
|
|
|
|
if self.configs.is_half and str(self.configs.device) != "cpu":
|
|
|
|
if self.configs.is_half and str(self.configs.device) != "cpu":
|
|
|
|
self.vits_model = self.vits_model.half()
|
|
|
|
self.vits_model = self.vits_model.half()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.configs.save_configs()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_t2s_weights(self, weights_path: str):
|
|
|
|
def init_t2s_weights(self, weights_path: str):
|
|
|
|
print(f"Loading Text2Semantic weights from {weights_path}")
|
|
|
|
print(f"Loading Text2Semantic weights from {weights_path}")
|
|
|
|
self.configs.t2s_weights_path = weights_path
|
|
|
|
self.configs.t2s_weights_path = weights_path
|
|
|
|