From 839ff9ce5bbad4c4df0c90f3f87c7b40b3f9ed1e Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Mon, 21 Apr 2025 22:43:46 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8Dv4=E5=B9=B6=E8=A1=8C=E6=8E=A8?= =?UTF-8?q?=E7=90=86=EF=BC=88=E8=BF=98=E6=B2=A1=E5=86=99=E5=AE=8C=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- GPT_SoVITS/TTS_infer_pack/TTS.py | 211 ++++++++++++++++--------------- 1 file changed, 112 insertions(+), 99 deletions(-) diff --git a/GPT_SoVITS/TTS_infer_pack/TTS.py b/GPT_SoVITS/TTS_infer_pack/TTS.py index 1b7ad11..f49970c 100644 --- a/GPT_SoVITS/TTS_infer_pack/TTS.py +++ b/GPT_SoVITS/TTS_infer_pack/TTS.py @@ -25,7 +25,7 @@ from AR.models.t2s_lightning_module import Text2SemanticLightningModule from BigVGAN.bigvgan import BigVGAN from feature_extractor.cnhubert import CNHubert from module.mel_processing import mel_spectrogram_torch, spectrogram_torch -from module.models import SynthesizerTrn, SynthesizerTrnV3 +from module.models import SynthesizerTrn, SynthesizerTrnV3,Generator from peft import LoraConfig, get_peft_model from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from transformers import AutoModelForMaskedLM, AutoTokenizer @@ -66,6 +66,19 @@ mel_fn = lambda x: mel_spectrogram_torch( "center": False, }, ) +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) def speed_change(input_audio: np.ndarray, speed: float, sr: int): @@ -92,11 +105,12 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int): resample_transform_dict = {} -def resample(audio_tensor, sr0, device): +def resample(audio_tensor, sr0,sr1, device): global resample_transform_dict - if sr0 not in resample_transform_dict: - resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device) - return resample_transform_dict[sr0](audio_tensor) + key="%s-%s"%(sr0,sr1) + if key not in resample_transform_dict: + resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) + return resample_transform_dict[key](audio_tensor) class DictToAttrRecursive(dict): @@ -130,44 +144,6 @@ class DictToAttrRecursive(dict): class NO_PROMPT_ERROR(Exception): pass - -# configs/tts_infer.yaml -""" -custom: - bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large - cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - device: cpu - is_half: false - t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt - vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth - version: v2 -default: - bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large - cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - device: cpu - is_half: false - t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt - vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth - version: v1 -default_v2: - bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large - cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - device: cpu - is_half: false - t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt - vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth - version: v2 -default_v3: - bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large - cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base - device: cpu - is_half: false - t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt - vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth - version: v3 -""" - - def set_seed(seed: int): seed = int(seed) seed = seed if seed != -1 else random.randint(0, 2**32 - 1) @@ -220,6 +196,15 @@ class TTS_Config: "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", }, + "v4": { + "device": "cpu", + "is_half": False, + "version": "v4", + "t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth", + "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", + "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", + }, } configs: dict = None v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] @@ -255,7 +240,7 @@ class TTS_Config: assert isinstance(configs, dict) version = configs.get("version", "v2").lower() - assert version in ["v1", "v2", "v3"] + assert version in ["v1", "v2", "v3", "v4"] self.default_configs[version] = configs.get(version, self.default_configs[version]) self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version])) @@ -356,7 +341,7 @@ class TTS_Config: def __eq__(self, other): return isinstance(other, TTS_Config) and self.configs_path == other.configs_path - +from inference_webui import v3v4set class TTS: def __init__(self, configs: Union[dict, str, TTS_Config]): if isinstance(configs, TTS_Config): @@ -369,7 +354,7 @@ class TTS: self.bert_tokenizer: AutoTokenizer = None self.bert_model: AutoModelForMaskedLM = None self.cnhuhbert_model: CNHubert = None - self.bigvgan_model: BigVGAN = None + self.vocoder_model = None self.sr_model: AP_BWE = None self.sr_model_not_exist: bool = False @@ -423,10 +408,11 @@ class TTS: def init_vits_weights(self, weights_path: str): self.configs.vits_weights_path = weights_path version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path) - path_sovits_v3 = self.configs.default_configs["v3"]["vits_weights_path"] + print(self.configs.default_configs) + path_sovits = self.configs.default_configs[model_version]["vits_weights_path"] - if if_lora_v3 == True and os.path.exists(path_sovits_v3) == False: - info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") + if if_lora_v3 == True and os.path.exists(path_sovits) == False: + info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version) raise FileExistsError(info) # dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False) @@ -456,7 +442,7 @@ class TTS: # print(f"model_version:{model_version}") # print(f'hps["model"]["version"]:{hps["model"]["version"]}') - if model_version != "v3": + if model_version not in v3v4set: vits_model = SynthesizerTrn( self.configs.filter_length // 2 + 1, self.configs.segment_size // self.configs.hop_length, @@ -465,14 +451,14 @@ class TTS: ) self.configs.is_v3_synthesizer = False else: + self.configs.is_v3_synthesizer = kwargs["version"]=model_version vits_model = SynthesizerTrnV3( self.configs.filter_length // 2 + 1, self.configs.segment_size // self.configs.hop_length, n_speakers=self.configs.n_speakers, **kwargs, ) - self.configs.is_v3_synthesizer = True - self.init_bigvgan() + self.init_vocoder() if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"): del vits_model.enc_q @@ -482,7 +468,7 @@ class TTS: ) else: print( - f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)['weight'], strict=False)}" + f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits)['weight'], strict=False)}" ) lora_rank = dict_s2["lora_rank"] lora_config = LoraConfig( @@ -521,20 +507,36 @@ class TTS: if self.configs.is_half and str(self.configs.device) != "cpu": self.t2s_model = self.t2s_model.half() - def init_bigvgan(self): - if self.bigvgan_model is not None: + def init_vocoder(self): + if self.vocoder_model is not None: return - self.bigvgan_model = BigVGAN.from_pretrained( - "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), - use_cuda_kernel=False, - ) # if True, RuntimeError: Ninja is required to load C++ extensions - # remove weight norm in the model and set to eval mode - self.bigvgan_model.remove_weight_norm() - self.bigvgan_model = self.bigvgan_model.eval() + if self.configs.is_v3_synthesizer=="v3": + self.vocoder_model = BigVGAN.from_pretrained( + "%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,), + use_cuda_kernel=False, + ) # if True, RuntimeError: Ninja is required to load C++ extensions + # remove weight norm in the model and set to eval mode + self.vocoder_model.remove_weight_norm() + self.vocoder_model = self.vocoder_model.eval() + else: + self.vocoder_model = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, is_bias=True + ) + self.vocoder_model.eval() + self.vocoder_model.remove_weight_norm() + state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu") + print("loading v4 vocoder", self.vocoder_model.load_state_dict(state_dict_g)) if self.configs.is_half == True: - self.bigvgan_model = self.bigvgan_model.half().to(self.configs.device) + self.vocoder_model = self.vocoder_model.half().to(self.configs.device) else: - self.bigvgan_model = self.bigvgan_model.to(self.configs.device) + self.vocoder_model = self.vocoder_model.to(self.configs.device) def init_sr_model(self): if self.sr_model is not None: @@ -570,8 +572,8 @@ class TTS: self.bert_model = self.bert_model.half() if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.half() - if self.bigvgan_model is not None: - self.bigvgan_model = self.bigvgan_model.half() + if self.vocoder_model is not None: + self.vocoder_model = self.vocoder_model.half() else: if self.t2s_model is not None: self.t2s_model = self.t2s_model.float() @@ -581,8 +583,8 @@ class TTS: self.bert_model = self.bert_model.float() if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.float() - if self.bigvgan_model is not None: - self.bigvgan_model = self.bigvgan_model.float() + if self.vocoder_model is not None: + self.vocoder_model = self.vocoder_model.float() def set_device(self, device: torch.device, save: bool = True): """ @@ -601,8 +603,8 @@ class TTS: self.bert_model = self.bert_model.to(device) if self.cnhuhbert_model is not None: self.cnhuhbert_model = self.cnhuhbert_model.to(device) - if self.bigvgan_model is not None: - self.bigvgan_model = self.bigvgan_model.to(device) + if self.vocoder_model is not None: + self.vocoder_model = self.vocoder_model.to(device) if self.sr_model is not None: self.sr_model = self.sr_model.to(device) @@ -913,13 +915,13 @@ class TTS: split_bucket = False print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理")) - if split_bucket and speed_factor == 1.0 and not (self.configs.is_v3_synthesizer and parallel_infer): + if split_bucket and speed_factor == 1.0 and not (self.configs.is_v3_synthesizer!=False and parallel_infer): print(i18n("分桶处理模式已开启")) elif speed_factor != 1.0: print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理")) split_bucket = False - elif self.configs.is_v3_synthesizer and parallel_infer: - print(i18n("当开启并行推理模式时,SoVits V3模型不支持分桶处理,已自动关闭分桶处理")) + elif self.configs.is_v3_synthesizer!=False and parallel_infer: + print(i18n("当开启并行推理模式时,SoVits V3V4模型不支持分桶处理,已自动关闭分桶处理")) split_bucket = False else: print(i18n("分桶处理模式已关闭")) @@ -936,7 +938,7 @@ class TTS: if not no_prompt_text: assert prompt_lang in self.configs.languages - if no_prompt_text and self.configs.is_v3_synthesizer: + if no_prompt_text and self.configs.is_v3_synthesizer!=False: raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3") if ref_audio_path in [None, ""] and ( @@ -1044,7 +1046,12 @@ class TTS: t_34 = 0.0 t_45 = 0.0 audio = [] - output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000 + if self.configs.is_v3_synthesizer==False: + output_sr = 32000 + elif self.configs.is_v3_synthesizer == "v3": + output_sr = 24000 + else: + output_sr = 48000 # v4 for item in data: t3 = time.perf_counter() if return_fragment: @@ -1144,7 +1151,7 @@ class TTS: if parallel_infer: print(f"{i18n('并行合成中')}...") audio_fragments = self.v3_synthesis_batched_infer( - idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps + idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps,model_version=self.configs.is_v3_synthesizer ) batch_audio_fragment.extend(audio_fragments) else: @@ -1154,7 +1161,7 @@ class TTS: pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) ) # .unsqueeze(0)#mq要多unsqueeze一次 audio_fragment = self.v3_synthesis( - _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps + _pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps,model_version=self.configs.is_v3_synthesizer ) batch_audio_fragment.append(audio_fragment) @@ -1169,7 +1176,7 @@ class TTS: speed_factor, False, fragment_interval, - super_sampling if self.configs.is_v3_synthesizer else False, + super_sampling if self.configs.is_v3_synthesizer=="v3" else False, ) else: audio.append(batch_audio_fragment) @@ -1190,7 +1197,7 @@ class TTS: speed_factor, split_bucket, fragment_interval, - super_sampling if self.configs.is_v3_synthesizer else False, + super_sampling if self.configs.is_v3_synthesizer=="v3" else False, ) except Exception as e: @@ -1273,7 +1280,7 @@ class TTS: return sr, audio def v3_synthesis( - self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32 + self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32,model_version="v4" ): prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) @@ -1285,19 +1292,22 @@ class TTS: ref_audio = ref_audio.to(self.configs.device).float() if ref_audio.shape[0] == 2: ref_audio = ref_audio.mean(0).unsqueeze(0) - if ref_sr != 24000: - ref_audio = resample(ref_audio, ref_sr, self.configs.device) + tgt_sr = 24000 if model_version == "v3" else 32000 + if ref_sr != tgt_sr: + ref_audio = resample(ref_audio, ref_sr,tgt_sr, self.configs.device) - mel2 = mel_fn(ref_audio) + mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio) mel2 = norm_spec(mel2) T_min = min(mel2.shape[2], fea_ref.shape[2]) mel2 = mel2[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min] - if T_min > 468: - mel2 = mel2[:, :, -468:] - fea_ref = fea_ref[:, :, -468:] - T_min = 468 - chunk_len = 934 - T_min + Tref = 468 if model_version == "v3" else 500 + Tchunk = 934 if model_version == "v3" else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min mel2 = mel2.to(self.precision) fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed) @@ -1324,7 +1334,7 @@ class TTS: cfm_res = denorm_spec(cfm_res) with torch.inference_mode(): - wav_gen = self.bigvgan_model(cfm_res) + wav_gen = self.vocoder_model(cfm_res) audio = wav_gen[0][0] # .cpu().detach().numpy() return audio @@ -1335,7 +1345,7 @@ class TTS: semantic_tokens_list: List[torch.Tensor], batch_phones: List[torch.Tensor], speed: float = 1.0, - sample_steps: int = 32, + sample_steps: int = 32,model_version="v4" ) -> List[torch.Tensor]: prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device) prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device) @@ -1347,19 +1357,22 @@ class TTS: ref_audio = ref_audio.to(self.configs.device).float() if ref_audio.shape[0] == 2: ref_audio = ref_audio.mean(0).unsqueeze(0) - if ref_sr != 24000: - ref_audio = resample(ref_audio, ref_sr, self.configs.device) + tgt_sr = 24000 if model_version == "v3" else 32000 + if ref_sr != tgt_sr: + ref_audio = resample(ref_audio, ref_sr,tgt_sr, self.configs.device) - mel2 = mel_fn(ref_audio) + mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio) mel2 = norm_spec(mel2) T_min = min(mel2.shape[2], fea_ref.shape[2]) mel2 = mel2[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min] - if T_min > 468: - mel2 = mel2[:, :, -468:] - fea_ref = fea_ref[:, :, -468:] - T_min = 468 - chunk_len = 934 - T_min + Tref = 468 if model_version == "v3" else 500 + Tchunk = 934 if model_version == "v3" else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min mel2 = mel2.to(self.precision) @@ -1413,7 +1426,7 @@ class TTS: pred_spec = denorm_spec(pred_spec) with torch.no_grad(): - wav_gen = self.bigvgan_model(pred_spec) + wav_gen = self.vocoder_model(pred_spec) audio = wav_gen[0][0] # .cpu().detach().numpy() audio_fragments = []