From b9211657d8dfe8cd46f6b6eb9cfc55d5989e6548 Mon Sep 17 00:00:00 2001 From: ChasonJiang <46401978+ChasonJiang@users.noreply.github.com> Date: Fri, 18 Jul 2025 11:54:40 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96TTS=5FConfig=E7=9A=84?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E9=80=BB=E8=BE=91=20(#2536)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 优化TTS_Config的代码逻辑 * 在载入vits权重之后保存tts_config --- GPT_SoVITS/TTS_infer_pack/TTS.py | 20 +++++++++++++------- GPT_SoVITS/configs/tts_infer.yaml | 1 - GPT_SoVITS/inference_webui_fast.py | 3 ++- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index da6a6df..0c1d248 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -304,10 +304,10 @@ class TTS_Config: configs: dict = self._load_configs(self.configs_path) assert isinstance(configs, dict) - version = configs.get("version", "v2").lower() - assert version in ["v1", "v2", "v3", "v4", "v2pro", "v2proplus"] - self.default_configs[version] = configs.get(version, self.default_configs[version]) - self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version])) + configs_ = deepcopy(self.default_configs) + configs_.update(configs) + self.configs: dict = configs_.get("custom", configs_["v2"]) + self.default_configs = deepcopy(configs_) self.device = self.configs.get("device", torch.device("cpu")) if "cuda" in str(self.device) and not torch.cuda.is_available(): @@ -315,11 +315,13 @@ class TTS_Config: self.device = torch.device("cpu") self.is_half = self.configs.get("is_half", False) - # if str(self.device) == "cpu" and self.is_half: - # print(f"Warning: Half precision is not supported on CPU, set is_half to False.") - # self.is_half = False + if str(self.device) == "cpu" and self.is_half: + print(f"Warning: Half precision is not supported on CPU, set is_half to False.") + self.is_half = False + version = self.configs.get("version", None) self.version = version + assert self.version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"], "Invalid version!" self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None) self.bert_base_path = self.configs.get("bert_base_path", None) @@ -576,6 +578,10 @@ class TTS: if self.configs.is_half and str(self.configs.device) != "cpu": self.vits_model = self.vits_model.half() + self.configs.save_configs() + + + def init_t2s_weights(self, weights_path: str): print(f"Loading Text2Semantic weights from {weights_path}") self.configs.t2s_weights_path = weights_path diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index 531aeb5..f31061c 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -1,4 +1,3 @@ -version: v2ProPlus custom: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py index 2c159d8..51a120f 100644 --- a/GPT_SoVITS/inference_webui_fast.py +++ b/GPT_SoVITS/inference_webui_fast.py @@ -125,7 +125,8 @@ is_exist_s2gv4 = os.path.exists(path_sovits_v4) tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") tts_config.device = device tts_config.is_half = is_half -tts_config.version = version +# tts_config.version = version +tts_config.update_version(version) if gpt_path is not None: if "!" in gpt_path or "!" in gpt_path: gpt_path = name2gpt_path[gpt_path]