优化TTS_Config的代码逻辑 (#2536)

* 优化TTS_Config的代码逻辑

* 在载入vits权重之后保存tts_config
main
ChasonJiang 2 weeks ago committed by GitHub
parent cefafee32c
commit b9211657d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -304,10 +304,10 @@ class TTS_Config:
configs: dict = self._load_configs(self.configs_path) configs: dict = self._load_configs(self.configs_path)
assert isinstance(configs, dict) assert isinstance(configs, dict)
version = configs.get("version", "v2").lower() configs_ = deepcopy(self.default_configs)
assert version in ["v1", "v2", "v3", "v4", "v2pro", "v2proplus"] configs_.update(configs)
self.default_configs[version] = configs.get(version, self.default_configs[version]) self.configs: dict = configs_.get("custom", configs_["v2"])
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version])) self.default_configs = deepcopy(configs_)
self.device = self.configs.get("device", torch.device("cpu")) self.device = self.configs.get("device", torch.device("cpu"))
if "cuda" in str(self.device) and not torch.cuda.is_available(): if "cuda" in str(self.device) and not torch.cuda.is_available():
@ -315,11 +315,13 @@ class TTS_Config:
self.device = torch.device("cpu") self.device = torch.device("cpu")
self.is_half = self.configs.get("is_half", False) self.is_half = self.configs.get("is_half", False)
# if str(self.device) == "cpu" and self.is_half: if str(self.device) == "cpu" and self.is_half:
# print(f"Warning: Half precision is not supported on CPU, set is_half to False.") print(f"Warning: Half precision is not supported on CPU, set is_half to False.")
# self.is_half = False self.is_half = False
version = self.configs.get("version", None)
self.version = version self.version = version
assert self.version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"], "Invalid version!"
self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.t2s_weights_path = self.configs.get("t2s_weights_path", None)
self.vits_weights_path = self.configs.get("vits_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None)
self.bert_base_path = self.configs.get("bert_base_path", None) self.bert_base_path = self.configs.get("bert_base_path", None)
@ -576,6 +578,10 @@ class TTS:
if self.configs.is_half and str(self.configs.device) != "cpu": if self.configs.is_half and str(self.configs.device) != "cpu":
self.vits_model = self.vits_model.half() self.vits_model = self.vits_model.half()
self.configs.save_configs()
def init_t2s_weights(self, weights_path: str): def init_t2s_weights(self, weights_path: str):
print(f"Loading Text2Semantic weights from {weights_path}") print(f"Loading Text2Semantic weights from {weights_path}")
self.configs.t2s_weights_path = weights_path self.configs.t2s_weights_path = weights_path

@ -1,4 +1,3 @@
version: v2ProPlus
custom: custom:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base

@ -125,7 +125,8 @@ is_exist_s2gv4 = os.path.exists(path_sovits_v4)
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device tts_config.device = device
tts_config.is_half = is_half tts_config.is_half = is_half
tts_config.version = version # tts_config.version = version
tts_config.update_version(version)
if gpt_path is not None: if gpt_path is not None:
if "" in gpt_path or "!" in gpt_path: if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path] gpt_path = name2gpt_path[gpt_path]

Loading…
Cancel
Save