优化TTS_Config的代码逻辑 (#2536)

* 优化TTS_Config的代码逻辑

* 在载入vits权重之后保存tts_config
main
ChasonJiang 2 weeks ago committed by GitHub
parent cefafee32c
commit b9211657d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -304,10 +304,10 @@ class TTS_Config:
configs: dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict)
version = configs.get("version", "v2").lower()
assert version in ["v1", "v2", "v3", "v4", "v2pro", "v2proplus"]
self.default_configs[version] = configs.get(version, self.default_configs[version])
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version]))
configs_ = deepcopy(self.default_configs)
configs_.update(configs)
self.configs: dict = configs_.get("custom", configs_["v2"])
self.default_configs = deepcopy(configs_)
self.device = self.configs.get("device", torch.device("cpu"))
if "cuda" in str(self.device) and not torch.cuda.is_available():
@ -315,11 +315,13 @@ class TTS_Config:
self.device = torch.device("cpu")
self.is_half = self.configs.get("is_half", False)
# if str(self.device) == "cpu" and self.is_half:
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
# self.is_half = False
if str(self.device) == "cpu" and self.is_half:
print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
self.is_half = False
version = self.configs.get("version", None)
self.version = version
assert self.version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"], "Invalid version!"
self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None)
@ -576,6 +578,10 @@ class TTS:
if self.configs.is_half and str(self.configs.device) != "cpu":
self.vits_model = self.vits_model.half()
self.configs.save_configs()
def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path

@ -1,4 +1,3 @@
version: v2ProPlus
custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base

@ -125,7 +125,8 @@ is_exist_s2gv4 = os.path.exists(path_sovits_v4)
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device
tts_config.is_half = is_half
tts_config.version = version
# tts_config.version = version
tts_config.update_version(version)
if gpt_path is not None:
if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path]

Loading…
Cancel
Save