diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 3f9750a..c0af860 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -39,21 +39,25 @@ except: ... version = model_version = os.environ.get("version", "v2") path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" +path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth" is_exist_s2gv3 = os.path.exists(path_sovits_v3) +is_exist_s2gv4 = os.path.exists(path_sovits_v4) pretrained_sovits_name = [ "GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", - path_sovits_v3, + "GPT_SoVITS/pretrained_models/s2Gv3.pth", + "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth", ] pretrained_gpt_name = [ "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt", + "GPT_SoVITS/pretrained_models/s1v3.ckpt", ] _ = [[], []] -for i in range(3): +for i in range(4): if os.path.exists(pretrained_gpt_name[i]): _[0].append(pretrained_gpt_name[i]) if os.path.exists(pretrained_sovits_name[i]): @@ -102,7 +106,7 @@ cnhubert.cnhubert_base_path = cnhubert_base_path import random -from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3 +from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3,Generator def set_seed(seed): @@ -222,23 +226,25 @@ else: resample_transform_dict = {} -def resample(audio_tensor, sr0): +def resample(audio_tensor, sr0,sr1): global resample_transform_dict - if sr0 not in resample_transform_dict: - resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device) - return resample_transform_dict[sr0](audio_tensor) + key="%s-%s"%(sr0,sr1) + if key not in resample_transform_dict: + resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) + return resample_transform_dict[key](audio_tensor) ###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt # symbol_version-model_version-if_lora_v3 from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new - +v3v4set={"v3","v4"} def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): global vq_model, hps, version, model_version, dict_language, if_lora_v3 version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) - # print(sovits_path,version, model_version, if_lora_v3) - if if_lora_v3 == True and is_exist_s2gv3 == False: + print(sovits_path,version, model_version, if_lora_v3) + is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4 + if if_lora_v3 == True and is_exist == False: info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") gr.Warning(info) raise FileExistsError(info) @@ -257,7 +263,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) else: text_update = {"__type__": "update", "value": ""} text_language_update = {"__type__": "update", "value": i18n("中文")} - if model_version == "v3": + if model_version in v3v4set: visible_sample_steps = True visible_inp_refs = False else: @@ -270,10 +276,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) prompt_language_update, text_update, text_language_update, - {"__type__": "update", "visible": visible_sample_steps, "value": 32}, + {"__type__": "update", "visible": visible_sample_steps, "value": 32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]}, {"__type__": "update", "visible": visible_inp_refs}, - {"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False}, - {"__type__": "update", "visible": True if model_version == "v3" else False}, + {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False}, + {"__type__": "update", "visible": True if model_version =="v3" else False}, {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False}, ) @@ -289,7 +295,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) hps.model.version = "v2" version = hps.model.version # print("sovits版本:",hps.model.version) - if model_version != "v3": + if model_version not in v3v4set: vq_model = SynthesizerTrn( hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, @@ -317,9 +323,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) if if_lora_v3 == False: print("loading sovits_%s" % model_version, vq_model.load_state_dict(dict_s2["weight"], strict=False)) else: + path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4 print( - "loading sovits_v3pretrained_G", - vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False), + "loading sovits_%spretrained_G"%model_version, + vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False), ) lora_rank = dict_s2["lora_rank"] lora_config = LoraConfig( @@ -329,7 +336,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) init_lora_weights=True, ) vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) - print("loading sovits_v3_lora%s" % (lora_rank)) + print("loading sovits_%s_lora%s" % (model_version,lora_rank)) vq_model.load_state_dict(dict_s2["weight"], strict=False) vq_model.cfm = vq_model.cfm.merge_and_unload() # torch.save(vq_model.state_dict(),"merge_win.pth") @@ -342,10 +349,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) prompt_language_update, text_update, text_language_update, - {"__type__": "update", "visible": visible_sample_steps, "value": 32}, + {"__type__": "update", "visible": visible_sample_steps, "value":32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]}, {"__type__": "update", "visible": visible_inp_refs}, - {"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False}, - {"__type__": "update", "visible": True if model_version == "v3" else False}, + {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False}, + {"__type__": "update", "visible": True if model_version =="v3" else False}, {"__type__": "update", "value": i18n("合成语音"), "interactive": True}, ) with open("./weight.json") as f: @@ -392,7 +399,7 @@ now_dir = os.getcwd() def init_bigvgan(): - global bigvgan_model + global bigvgan_model,hifigan_model from BigVGAN import bigvgan bigvgan_model = bigvgan.BigVGAN.from_pretrained( @@ -402,16 +409,47 @@ def init_bigvgan(): # remove weight norm in the model and set to eval mode bigvgan_model.remove_weight_norm() bigvgan_model = bigvgan_model.eval() + if hifigan_model: + hifigan_model=hifigan_model.cpu() + hifigan_model=None + try:torch.cuda.empty_cache() + except:pass if is_half == True: bigvgan_model = bigvgan_model.half().to(device) else: bigvgan_model = bigvgan_model.to(device) +def init_hifigan(): + global hifigan_model,bigvgan_model + hifigan_model = Generator( + initial_channel=100, + resblock="1", + resblock_kernel_sizes=[3, 7, 11], + resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]], + upsample_rates=[10, 6, 2, 2, 2], + upsample_initial_channel=512, + upsample_kernel_sizes=[20, 12, 4, 4, 4], + gin_channels=0, is_bias=True + ) + hifigan_model.eval() + hifigan_model.remove_weight_norm() + state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu") + print("loading vocoder",hifigan_model.load_state_dict(state_dict_g)) + if bigvgan_model: + bigvgan_model=bigvgan_model.cpu() + bigvgan_model=None + try:torch.cuda.empty_cache() + except:pass + if is_half == True: + hifigan_model = hifigan_model.half().to(device) + else: + hifigan_model = hifigan_model.to(device) -if model_version != "v3": - bigvgan_model = None -else: +bigvgan_model=hifigan_model=None +if model_version=="v3": init_bigvgan() +if model_version=="v4": + init_hifigan() def get_spepc(hps, filename): @@ -576,6 +614,19 @@ mel_fn = lambda x: mel_spectrogram_torch( "center": False, }, ) +mel_fn_v4 = lambda x: mel_spectrogram_torch( + x, + **{ + "n_fft": 1280, + "win_size": 1280, + "hop_size": 320, + "num_mels": 100, + "sampling_rate": 32000, + "fmin": 0, + "fmax": None, + "center": False, + }, +) def merge_short_text_in_array(texts, threshold): @@ -647,7 +698,7 @@ def get_tts_wav( t = [] if prompt_text is None or len(prompt_text) == 0: ref_free = True - if model_version == "v3": + if model_version in v3v4set: ref_free = False # s2v3暂不支持ref_free else: if_sr = False @@ -755,7 +806,7 @@ def get_tts_wav( cache[i_text] = pred_semantic t3 = ttime() ###v3不存在以下逻辑和inp_refs - if model_version != "v3": + if model_version not in v3v4set: refers = [] if inp_refs: for path in inp_refs: @@ -779,25 +830,24 @@ def get_tts_wav( ref_audio = ref_audio.to(device).float() if ref_audio.shape[0] == 2: ref_audio = ref_audio.mean(0).unsqueeze(0) - if sr != 24000: - ref_audio = resample(ref_audio, sr) + tgt_sr=24000 if model_version=="v3"else 32000 + if sr != tgt_sr: + ref_audio = resample(ref_audio, sr,tgt_sr) # print("ref_audio",ref_audio.abs().mean()) - mel2 = mel_fn(ref_audio) + mel2 = mel_fn(ref_audio)if model_version=="v3"else mel_fn_v4(ref_audio) mel2 = norm_spec(mel2) T_min = min(mel2.shape[2], fea_ref.shape[2]) mel2 = mel2[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min] - if T_min > 468: - mel2 = mel2[:, :, -468:] - fea_ref = fea_ref[:, :, -468:] - T_min = 468 - chunk_len = 934 - T_min - # print("fea_ref",fea_ref,fea_ref.shape) - # print("mel2",mel2) + Tref=468 if model_version=="v3"else 500 + Tchunk=934 if model_version=="v3"else 1000 + if T_min > Tref: + mel2 = mel2[:, :, -Tref:] + fea_ref = fea_ref[:, :, -Tref:] + T_min = Tref + chunk_len = Tchunk - T_min mel2 = mel2.to(dtype) fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) - # print("fea_todo",fea_todo) - # print("ge",ge.abs().mean()) cfm_resss = [] idx = 0 while 1: @@ -806,22 +856,24 @@ def get_tts_wav( break idx += chunk_len fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) - # set_seed(123) cfm_res = vq_model.cfm.inference( fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 ) cfm_res = cfm_res[:, :, mel2.shape[2] :] mel2 = cfm_res[:, :, -T_min:] - # print("fea", fea) - # print("mel2in", mel2) fea_ref = fea_todo_chunk[:, :, -T_min:] cfm_resss.append(cfm_res) - cmf_res = torch.cat(cfm_resss, 2) - cmf_res = denorm_spec(cmf_res) - if bigvgan_model == None: - init_bigvgan() + cfm_res = torch.cat(cfm_resss, 2) + cfm_res = denorm_spec(cfm_res) + if model_version=="v3": + if bigvgan_model == None: + init_bigvgan() + else:#v4 + if hifigan_model == None: + init_hifigan() + vocoder_model=bigvgan_model if model_version=="v3"else hifigan_model with torch.inference_mode(): - wav_gen = bigvgan_model(cmf_res) + wav_gen = vocoder_model(cfm_res) audio = wav_gen[0][0] # .cpu().detach().numpy() max_audio = torch.abs(audio).max() # 简单防止16bit爆音 if max_audio > 1: @@ -833,16 +885,18 @@ def get_tts_wav( t1 = ttime() print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))) audio_opt = torch.cat(audio_opt, 0) # np.concatenate - sr = hps.data.sampling_rate if model_version != "v3" else 24000 - if if_sr == True and sr == 24000: + if model_version in {"v1","v2"}:opt_sr=32000 + elif model_version=="v3":opt_sr=24000 + else:opt_sr=48000#v4 + if if_sr == True and opt_sr == 24000: print(i18n("音频超分中")) - audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) + audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr) max_audio = np.abs(audio_opt).max() if max_audio > 1: audio_opt /= max_audio else: audio_opt = audio_opt.cpu().detach().numpy() - yield sr, (audio_opt * 32767).astype(np.int16) + yield opt_sr, (audio_opt * 32767).astype(np.int16) def split(todo_text): @@ -971,8 +1025,8 @@ def change_choices(): } -SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"] -GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"] +SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"] +GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"] for path in SoVITS_weight_root + GPT_weight_root: os.makedirs(path, exist_ok=True) @@ -1039,7 +1093,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。") + i18n("v3暂不支持该模式,使用了会报错。"), value=False, - interactive=True if model_version != "v3" else False, + interactive=True if model_version not in v3v4set else False, show_label=True, scale=1, ) @@ -1064,7 +1118,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: ), file_count="multiple", ) - if model_version != "v3" + if model_version not in v3v4set else gr.File( label=i18n( "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。" @@ -1076,16 +1130,16 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: sample_steps = ( gr.Radio( label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"), - value=32, - choices=[4, 8, 16, 32], + value=32 if model_version=="v3"else 8, + choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32], visible=True, ) - if model_version == "v3" + if model_version in v3v4set else gr.Radio( label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"), - choices=[4, 8, 16, 32], + choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32], visible=False, - value=32, + value=32 if model_version=="v3"else 8, ) ) if_sr_Checkbox = gr.Checkbox( @@ -1093,7 +1147,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app: value=False, interactive=True, show_label=True, - visible=False if model_version != "v3" else True, + visible=False if model_version !="v3" else True, ) gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3")) with gr.Row(): diff --git a/GPT_SoVITS/process_ckpt.py b/GPT_SoVITS/process_ckpt.py index 147f3bd..4a2a1ba 100644 --- a/GPT_SoVITS/process_ckpt.py +++ b/GPT_SoVITS/process_ckpt.py @@ -22,23 +22,24 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path 01:v2 02:v3 03:v3lora - +04:v4lora """ from io import BytesIO -def my_save2(fea, path): +def my_save2(fea, path,cfm_version): bio = BytesIO() torch.save(fea, bio) bio.seek(0) data = bio.getvalue() - data = b"03" + data[2:] ###temp for v3lora only, todo + byte=b"03" if cfm_version=="v3"else b"04" + data = byte + data[2:] with open(path, "wb") as f: f.write(data) -def savee(ckpt, name, epoch, steps, hps, lora_rank=None): +def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None): try: opt = OrderedDict() opt["weight"] = {} @@ -50,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, lora_rank=None): opt["info"] = "%sepoch_%siteration" % (epoch, steps) if lora_rank: opt["lora_rank"] = lora_rank - my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) + my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version) else: my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) return "Success." @@ -63,11 +64,13 @@ head2version = { b"01": ["v2", "v2", False], b"02": ["v2", "v3", False], b"03": ["v2", "v3", True], + b"04": ["v2", "v4", True], } hash_pretrained_dict = { "dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained "43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained "6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained + "4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained } import hashlib @@ -85,7 +88,7 @@ def get_sovits_version_from_path_fast(sovits_path): hash = get_hash_from_file(sovits_path) if hash in hash_pretrained_dict: return hash_pretrained_dict[hash] - ###2-new weights or old weights, by head + ###2-new weights, by head with open(sovits_path, "rb") as f: version = f.read(2) if version != b"PK": diff --git a/GPT_SoVITS/s2_train_v3_lora.py b/GPT_SoVITS/s2_train_v3_lora.py index 42582b4..ddeec4f 100644 --- a/GPT_SoVITS/s2_train_v3_lora.py +++ b/GPT_SoVITS/s2_train_v3_lora.py @@ -27,12 +27,11 @@ from random import randint from module import commons from module.data_utils import ( DistributedBucketSampler, -) -from module.data_utils import ( - TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate, -) -from module.data_utils import ( - TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader, + TextAudioSpeakerCollateV3, + TextAudioSpeakerLoaderV3, + TextAudioSpeakerCollateV4, + TextAudioSpeakerLoaderV4, + ) from module.models import ( SynthesizerTrnV3 as SynthesizerTrn, @@ -89,6 +88,8 @@ def run(rank, n_gpus, hps): if torch.cuda.is_available(): torch.cuda.set_device(rank) + TextAudioSpeakerLoader=TextAudioSpeakerLoaderV3 if hps.model.version=="v3"else TextAudioSpeakerLoaderV4 + TextAudioSpeakerCollate=TextAudioSpeakerCollateV3 if hps.model.version=="v3"else TextAudioSpeakerCollateV4 train_dataset = TextAudioSpeakerLoader(hps.data) ######## train_sampler = DistributedBucketSampler( train_dataset, @@ -364,7 +365,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank), epoch, global_step, - hps, + hps,cfm_version=hps.model.version, lora_rank=lora_rank, ), )