diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index da6a6df..0c1d248 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -304,10 +304,10 @@ class TTS_Config: configs: dict = self._load_configs(self.configs_path) assert isinstance(configs, dict) - version = configs.get("version", "v2").lower() - assert version in ["v1", "v2", "v3", "v4", "v2pro", "v2proplus"] - self.default_configs[version] = configs.get(version, self.default_configs[version]) - self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version])) + configs_ = deepcopy(self.default_configs) + configs_.update(configs) + self.configs: dict = configs_.get("custom", configs_["v2"]) + self.default_configs = deepcopy(configs_) self.device = self.configs.get("device", torch.device("cpu")) if "cuda" in str(self.device) and not torch.cuda.is_available(): @@ -315,11 +315,13 @@ class TTS_Config: self.device = torch.device("cpu") self.is_half = self.configs.get("is_half", False) - # if str(self.device) == "cpu" and self.is_half: - # print(f"Warning: Half precision is not supported on CPU, set is_half to False.") - # self.is_half = False + if str(self.device) == "cpu" and self.is_half: + print(f"Warning: Half precision is not supported on CPU, set is_half to False.") + self.is_half = False + version = self.configs.get("version", None) self.version = version + assert self.version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"], "Invalid version!" self.t2s_weights_path = self.configs.get("t2s_weights_path", None) self.vits_weights_path = self.configs.get("vits_weights_path", None) self.bert_base_path = self.configs.get("bert_base_path", None) @@ -576,6 +578,10 @@ class TTS: if self.configs.is_half and str(self.configs.device) != "cpu": self.vits_model = self.vits_model.half() + self.configs.save_configs() + + + def init_t2s_weights(self, weights_path: str): print(f"Loading Text2Semantic weights from {weights_path}") self.configs.t2s_weights_path = weights_path diff --git a/GPT_SoVITS/configs/tts_infer.yaml b/GPT_SoVITS/configs/tts_infer.yaml index 531aeb5..f31061c 100644 --- a/GPT_SoVITS/configs/tts_infer.yaml +++ b/GPT_SoVITS/configs/tts_infer.yaml @@ -1,4 +1,3 @@ -version: v2ProPlus custom: bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py index 2c159d8..51a120f 100644 --- a/GPT_SoVITS/inference_webui_fast.py +++ b/GPT_SoVITS/inference_webui_fast.py @@ -125,7 +125,8 @@ is_exist_s2gv4 = os.path.exists(path_sovits_v4) tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml") tts_config.device = device tts_config.is_half = is_half -tts_config.version = version +# tts_config.version = version +tts_config.update_version(version) if gpt_path is not None: if "!" in gpt_path or "!" in gpt_path: gpt_path = name2gpt_path[gpt_path]