From 132f6e7b8bbf30fe0b0ff5625552318b59ddc27e Mon Sep 17 00:00:00 2001 From: XXXXRT666 <157766680+XXXXRT666@users.noreply.github.com> Date: Thu, 5 Jun 2025 18:37:19 +0800 Subject: [PATCH] Fix Bugs, Modified Layout (#2434) Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> --- GPT_SoVITS/inference_webui.py | 119 ++++++++++++++++++----------- GPT_SoVITS/inference_webui_fast.py | 30 +++++--- tools/assets.py | 112 +++++++++++++++++++++++++++ webui.py | 16 ++-- 4 files changed, 215 insertions(+), 62 deletions(-) create mode 100644 tools/assets.py diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 69b21cc..21ae83a 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -31,9 +31,11 @@ warnings.simplefilter(action="ignore", category=FutureWarning) version = model_version = os.environ.get("version", "v2") -from config import name2sovits_path,name2gpt_path,change_choices,get_weights_names +from config import change_choices, get_weights_names, name2gpt_path, name2sovits_path + SoVITS_names, GPT_names = get_weights_names() from config import pretrained_sovits_name + path_sovits_v3 = pretrained_sovits_name["v3"] path_sovits_v4 = pretrained_sovits_name["v4"] is_exist_s2gv3 = os.path.exists(path_sovits_v3) @@ -108,6 +110,7 @@ from peft import LoraConfig, get_peft_model from text import cleaned_text_to_sequence from text.cleaner import clean_text +from tools.assets import css, js, top_html from tools.i18n.i18n import I18nAuto, scan_language_list language = os.environ.get("language", "Auto") @@ -208,8 +211,11 @@ else: from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new v3v4set = {"v3", "v4"} + + def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): - if "!"in sovits_path:sovits_path=name2sovits_path[sovits_path] + if "!" in sovits_path: + sovits_path = name2sovits_path[sovits_path] global vq_model, hps, version, model_version, dict_language, if_lora_v3 version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) print(sovits_path, version, model_version, if_lora_v3) @@ -272,7 +278,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) version = hps.model.version # print("sovits版本:",hps.model.version) if model_version not in v3v4set: - if "Pro"not in model_version: + if "Pro" not in model_version: model_version = version else: hps.model.version = model_version @@ -355,7 +361,8 @@ except: def change_gpt_weights(gpt_path): - if "!"in gpt_path:gpt_path=name2gpt_path[gpt_path] + if "!" in gpt_path: + gpt_path = name2gpt_path[gpt_path] global hz, max_sec, t2s_model, config hz = 50 dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False) @@ -383,6 +390,7 @@ import torch now_dir = os.getcwd() + def clean_hifigan_model(): global hifigan_model if hifigan_model: @@ -392,6 +400,8 @@ def clean_hifigan_model(): torch.cuda.empty_cache() except: pass + + def clean_bigvgan_model(): global bigvgan_model if bigvgan_model: @@ -401,6 +411,8 @@ def clean_bigvgan_model(): torch.cuda.empty_cache() except: pass + + def clean_sv_cn_model(): global sv_cn_model if sv_cn_model: @@ -411,8 +423,9 @@ def clean_sv_cn_model(): except: pass + def init_bigvgan(): - global bigvgan_model, hifigan_model,sv_cn_model + global bigvgan_model, hifigan_model, sv_cn_model from BigVGAN import bigvgan bigvgan_model = bigvgan.BigVGAN.from_pretrained( @@ -429,8 +442,9 @@ def init_bigvgan(): else: bigvgan_model = bigvgan_model.to(device) + def init_hifigan(): - global hifigan_model, bigvgan_model,sv_cn_model + global hifigan_model, bigvgan_model, sv_cn_model hifigan_model = Generator( initial_channel=100, resblock="1", @@ -445,7 +459,9 @@ def init_hifigan(): hifigan_model.eval() hifigan_model.remove_weight_norm() state_dict_g = torch.load( - "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False + "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), + map_location="cpu", + weights_only=False, ) print("loading vocoder", hifigan_model.load_state_dict(state_dict_g)) clean_bigvgan_model() @@ -455,9 +471,12 @@ def init_hifigan(): else: hifigan_model = hifigan_model.to(device) + from sv import SV + + def init_sv_cn(): - global hifigan_model, bigvgan_model,sv_cn_model + global hifigan_model, bigvgan_model, sv_cn_model sv_cn_model = SV(device, is_half) clean_bigvgan_model() clean_hifigan_model() @@ -468,34 +487,37 @@ if model_version == "v3": init_bigvgan() if model_version == "v4": init_hifigan() -if model_version in {"v2Pro","v2ProPlus"}: +if model_version in {"v2Pro", "v2ProPlus"}: init_sv_cn() -resample_transform_dict={} -def resample(audio_tensor, sr0,sr1,device): +resample_transform_dict = {} + + +def resample(audio_tensor, sr0, sr1, device): global resample_transform_dict - key="%s-%s-%s"%(sr0,sr1,str(device)) + key = "%s-%s-%s" % (sr0, sr1, str(device)) if key not in resample_transform_dict: - resample_transform_dict[key] = torchaudio.transforms.Resample( - sr0, sr1 - ).to(device) + resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) return resample_transform_dict[key](audio_tensor) -def get_spepc(hps, filename,dtype,device,is_v2pro=False): + +def get_spepc(hps, filename, dtype, device, is_v2pro=False): # audio = load_audio(filename, int(hps.data.sampling_rate)) # audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate)) # audio = torch.FloatTensor(audio) - sr1=int(hps.data.sampling_rate) - audio, sr0=torchaudio.load(filename) - if sr0!=sr1: - audio=audio.to(device) - if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0) - audio=resample(audio,sr0,sr1,device) + sr1 = int(hps.data.sampling_rate) + audio, sr0 = torchaudio.load(filename) + if sr0 != sr1: + audio = audio.to(device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + audio = resample(audio, sr0, sr1, device) else: - audio=audio.to(device) - if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0) + audio = audio.to(device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) maxx = audio.abs().max() if maxx > 1: @@ -508,10 +530,10 @@ def get_spepc(hps, filename,dtype,device,is_v2pro=False): hps.data.win_length, center=False, ) - spec=spec.to(dtype) - if is_v2pro==True: - audio=resample(audio,sr1,16000,device).to(dtype) - return spec,audio + spec = spec.to(dtype) + if is_v2pro == True: + audio = resample(audio, sr1, 16000, device).to(dtype) + return spec, audio def clean_text_inf(text, language, version): @@ -744,7 +766,7 @@ def get_tts_wav( ref_free = False # s2v3暂不支持ref_free else: if_sr = False - if model_version not in {"v3","v4","v2Pro","v2ProPlus"}: + if model_version not in {"v3", "v4", "v2Pro", "v2ProPlus"}: clean_bigvgan_model() clean_hifigan_model() clean_sv_cn_model() @@ -851,35 +873,39 @@ def get_tts_wav( pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) cache[i_text] = pred_semantic t3 = ttime() - is_v2pro=model_version in {"v2Pro","v2ProPlus"} + is_v2pro = model_version in {"v2Pro", "v2ProPlus"} # print(23333,is_v2pro,model_version) ###v3不存在以下逻辑和inp_refs if model_version not in v3v4set: refers = [] if is_v2pro: - sv_emb=[] + sv_emb = [] if sv_cn_model == None: init_sv_cn() if inp_refs: for path in inp_refs: - try:#####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer - refer,audio_tensor = get_spepc(hps, path.name,dtype,device,is_v2pro) + try: #####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer + refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro) refers.append(refer) if is_v2pro: sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor)) except: traceback.print_exc() if len(refers) == 0: - refers,audio_tensor = get_spepc(hps, ref_wav_path,dtype,device,is_v2pro) - refers=[refers] + refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro) + refers = [refers] if is_v2pro: - sv_emb=[sv_cn_model.compute_embedding3(audio_tensor)] + sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] if is_v2pro: - audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed,sv_emb=sv_emb)[0][0] + audio = vq_model.decode( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb + )[0][0] else: - audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)[0][0] + audio = vq_model.decode( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed + )[0][0] else: - refer,audio_tensor = get_spepc(hps, ref_wav_path,dtype,device) + refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) @@ -889,7 +915,7 @@ def get_tts_wav( ref_audio = ref_audio.mean(0).unsqueeze(0) tgt_sr = 24000 if model_version == "v3" else 32000 if sr != tgt_sr: - ref_audio = resample(ref_audio, sr, tgt_sr,device) + ref_audio = resample(ref_audio, sr, tgt_sr, device) # print("ref_audio",ref_audio.abs().mean()) mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio) mel2 = norm_spec(mel2) @@ -1076,6 +1102,7 @@ def process_text(texts): _text.append(text) return _text + def html_center(text, label="p"): return f"""