From 92819d0b318c3566caa8861df34241860189b93b Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Wed, 4 Jun 2025 15:19:20 +0800 Subject: [PATCH] support sovits v2Pro v2ProPlus support sovits v2Pro v2ProPlus --- GPT_SoVITS/TTS_infer_pack/TTS.py | 123 +++++++++++++++++++++---------- 1 file changed, 84 insertions(+), 39 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 6ef46eb..aa5e9b8 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -35,7 +35,16 @@ from tools.i18n.i18n import I18nAuto, scan_language_list from tools.my_utils import load_audio from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.TextPreprocessor import TextPreprocessor - +from sv import SV +resample_transform_dict={} +def resample(audio_tensor, sr0,sr1,device): + global resample_transform_dict + key="%s-%s-%s"%(sr0,sr1,str(device)) + if key not in resample_transform_dict: + resample_transform_dict[key] = torchaudio.transforms.Resample( + sr0, sr1 + ).to(device) + return resample_transform_dict[key](audio_tensor) language = os.environ.get("language", "Auto") language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language i18n = I18nAuto(language=language) @@ -102,18 +111,6 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int): return processed_audio - -resample_transform_dict = {} - - -def resample(audio_tensor, sr0, sr1, device): - global resample_transform_dict - key = "%s-%s" % (sr0, sr1) - if key not in resample_transform_dict: - resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) - return resample_transform_dict[key](audio_tensor) - - class DictToAttrRecursive(dict): def __init__(self, input_dict): super().__init__(input_dict) @@ -252,6 +249,24 @@ class TTS_Config: "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", }, + "v2Pro": { + "device": "cpu", + "is_half": False, + "version": "v2Pro", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro_pre1.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, + "v2ProPlus": { + "device": "cpu", + "is_half": False, + "version": "v2ProPlus", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus_pre1.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, } configs: dict = None v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] @@ -287,7 +302,7 @@ class TTS_Config: assert isinstance(configs, dict) version = configs.get("version", "v2").lower() - assert version in ["v1", "v2", "v3", "v4"] + assert version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"] self.default_configs[version] = configs.get(version, self.default_configs[version]) self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version])) @@ -403,6 +418,7 @@ class TTS: self.cnhuhbert_model: CNHubert = None self.vocoder = None self.sr_model: AP_BWE = None + self.sv_model = None self.sr_model_not_exist: bool = False self.vocoder_configs: dict = { @@ -463,6 +479,8 @@ class TTS: def init_vits_weights(self, weights_path: str): self.configs.vits_weights_path = weights_path version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path) + if "Pro"in model_version: + self.init_sv_model() path_sovits = self.configs.default_configs[model_version]["vits_weights_path"] if if_lora_v3 == True and os.path.exists(path_sovits) == False: @@ -472,7 +490,6 @@ class TTS: # dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False) dict_s2 = load_sovits_new(weights_path) hps = dict_s2["config"] - hps["model"]["semantic_frame_rate"] = "25hz" if "enc_p.text_embedding.weight" not in dict_s2["weight"]: hps["model"]["version"] = "v2" # v3model,v2sybomls @@ -480,7 +497,15 @@ class TTS: hps["model"]["version"] = "v1" else: hps["model"]["version"] = "v2" - # version = hps["model"]["version"] + version = hps["model"]["version"] + v3v4set={"v3", "v4"} + if model_version not in v3v4set: + if "Pro"not in model_version: + model_version = version + else: + hps["model"]["version"] = model_version + else: + hps["model"]["version"] = model_version self.configs.filter_length = hps["data"]["filter_length"] self.configs.segment_size = hps["train"]["segment_size"] @@ -496,7 +521,7 @@ class TTS: # print(f"model_version:{model_version}") # print(f'hps["model"]["version"]:{hps["model"]["version"]}') - if model_version not in {"v3", "v4"}: + if model_version not in v3v4set: vits_model = SynthesizerTrn( self.configs.filter_length // 2 + 1, self.configs.segment_size // self.configs.hop_length, @@ -517,6 +542,8 @@ class TTS: if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"): del vits_model.enc_q + self.is_v2pro=model_version in {"v2Pro","v2ProPlus"} + if if_lora_v3 == False: print( f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}" @@ -551,7 +578,7 @@ class TTS: self.configs.t2s_weights_path = weights_path self.configs.save_configs() self.configs.hz = 50 - dict_s1 = torch.load(weights_path, map_location=self.configs.device) + dict_s1 = torch.load(weights_path, map_location=self.configs.device, weights_only=False) config = dict_s1["config"] self.configs.max_sec = config["data"]["max_sec"] t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) @@ -605,7 +632,7 @@ class TTS: ) self.vocoder.remove_weight_norm() state_dict_g = torch.load( - "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu" + "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False ) print("loading vocoder", self.vocoder.load_state_dict(state_dict_g)) @@ -631,6 +658,11 @@ class TTS: print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好")) self.sr_model_not_exist = True + def init_sv_model(self): + if self.sv_model is not None: + return + self.sv_model = SV(self.configs.device, self.configs.is_half) + def enable_half_precision(self, enable: bool = True, save: bool = True): """ To enable half precision for the TTS model. @@ -706,11 +738,11 @@ class TTS: self.prompt_cache["ref_audio_path"] = ref_audio_path def _set_ref_spec(self, ref_audio_path): - spec = self._get_ref_spec(ref_audio_path) + spec_audio = self._get_ref_spec(ref_audio_path) if self.prompt_cache["refer_spec"] in [[], None]: - self.prompt_cache["refer_spec"] = [spec] + self.prompt_cache["refer_spec"] = [spec_audio] else: - self.prompt_cache["refer_spec"][0] = spec + self.prompt_cache["refer_spec"][0] = spec_audio def _get_ref_spec(self, ref_audio_path): raw_audio, raw_sr = torchaudio.load(ref_audio_path) @@ -718,25 +750,33 @@ class TTS: self.prompt_cache["raw_audio"] = raw_audio self.prompt_cache["raw_sr"] = raw_sr - audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) - audio = torch.FloatTensor(audio) + if raw_sr != self.configs.sampling_rate: + audio = raw_audio.to(self.configs.device) + if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0) + audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device) + else: + audio = raw_audio.to(self.configs.device) + if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0) + maxx = audio.abs().max() if maxx > 1: audio /= min(2, maxx) - audio_norm = audio - audio_norm = audio_norm.unsqueeze(0) spec = spectrogram_torch( - audio_norm, + audio, self.configs.filter_length, self.configs.sampling_rate, self.configs.hop_length, self.configs.win_length, center=False, ) - spec = spec.to(self.configs.device) if self.configs.is_half: spec = spec.half() - return spec + if self.is_v2pro == True: + audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device) + if self.configs.is_half: + audio = audio.half() + else:audio=None + return spec,audio def _set_prompt_semantic(self, ref_wav_path: str): zero_wav = np.zeros( @@ -1171,10 +1211,13 @@ class TTS: t4 = time.perf_counter() t_34 += t4 - t3 - refer_audio_spec: torch.Tensor = [ - item.to(dtype=self.precision, device=self.configs.device) - for item in self.prompt_cache["refer_spec"] - ] + refer_audio_spec = [] + if self.is_v2pro:sv_emb=[] + for spec,audio_tensor in self.prompt_cache["refer_spec"]: + spec=spec.to(dtype=self.precision, device=self.configs.device) + refer_audio_spec.append(spec) + if self.is_v2pro: + sv_emb.append(self.sv_model.compute_embedding3(audio_tensor)) batch_audio_fragment = [] @@ -1206,9 +1249,10 @@ class TTS: torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) ) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) - _batch_audio_fragment = self.vits_model.decode( - all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] + if self.is_v2pro!=True: + _batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :] + else: + _batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :] audio_frag_end_idx.insert(0, 0) batch_audio_fragment = [ _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] @@ -1221,9 +1265,10 @@ class TTS: _pred_semantic = ( pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 - audio_fragment = self.vits_model.decode( - _pred_semantic, phones, refer_audio_spec, speed=speed_factor - ).detach()[0, 0, :] + if self.is_v2pro != True: + audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :] + else: + audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :] batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 else: if parallel_infer: