diff --git a/GPT_SoVITS/inference_webui.py b/GPT_SoVITS/inference_webui.py index 69b21cc..21ae83a 100644 --- a/GPT_SoVITS/inference_webui.py +++ b/GPT_SoVITS/inference_webui.py @@ -31,9 +31,11 @@ warnings.simplefilter(action="ignore", category=FutureWarning) version = model_version = os.environ.get("version", "v2") -from config import name2sovits_path,name2gpt_path,change_choices,get_weights_names +from config import change_choices, get_weights_names, name2gpt_path, name2sovits_path + SoVITS_names, GPT_names = get_weights_names() from config import pretrained_sovits_name + path_sovits_v3 = pretrained_sovits_name["v3"] path_sovits_v4 = pretrained_sovits_name["v4"] is_exist_s2gv3 = os.path.exists(path_sovits_v3) @@ -108,6 +110,7 @@ from peft import LoraConfig, get_peft_model from text import cleaned_text_to_sequence from text.cleaner import clean_text +from tools.assets import css, js, top_html from tools.i18n.i18n import I18nAuto, scan_language_list language = os.environ.get("language", "Auto") @@ -208,8 +211,11 @@ else: from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new v3v4set = {"v3", "v4"} + + def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): - if "!"in sovits_path:sovits_path=name2sovits_path[sovits_path] + if "!" in sovits_path: + sovits_path = name2sovits_path[sovits_path] global vq_model, hps, version, model_version, dict_language, if_lora_v3 version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) print(sovits_path, version, model_version, if_lora_v3) @@ -272,7 +278,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) version = hps.model.version # print("sovits版本:",hps.model.version) if model_version not in v3v4set: - if "Pro"not in model_version: + if "Pro" not in model_version: model_version = version else: hps.model.version = model_version @@ -355,7 +361,8 @@ except: def change_gpt_weights(gpt_path): - if "!"in gpt_path:gpt_path=name2gpt_path[gpt_path] + if "!" in gpt_path: + gpt_path = name2gpt_path[gpt_path] global hz, max_sec, t2s_model, config hz = 50 dict_s1 = torch.load(gpt_path, map_location="cpu", weights_only=False) @@ -383,6 +390,7 @@ import torch now_dir = os.getcwd() + def clean_hifigan_model(): global hifigan_model if hifigan_model: @@ -392,6 +400,8 @@ def clean_hifigan_model(): torch.cuda.empty_cache() except: pass + + def clean_bigvgan_model(): global bigvgan_model if bigvgan_model: @@ -401,6 +411,8 @@ def clean_bigvgan_model(): torch.cuda.empty_cache() except: pass + + def clean_sv_cn_model(): global sv_cn_model if sv_cn_model: @@ -411,8 +423,9 @@ def clean_sv_cn_model(): except: pass + def init_bigvgan(): - global bigvgan_model, hifigan_model,sv_cn_model + global bigvgan_model, hifigan_model, sv_cn_model from BigVGAN import bigvgan bigvgan_model = bigvgan.BigVGAN.from_pretrained( @@ -429,8 +442,9 @@ def init_bigvgan(): else: bigvgan_model = bigvgan_model.to(device) + def init_hifigan(): - global hifigan_model, bigvgan_model,sv_cn_model + global hifigan_model, bigvgan_model, sv_cn_model hifigan_model = Generator( initial_channel=100, resblock="1", @@ -445,7 +459,9 @@ def init_hifigan(): hifigan_model.eval() hifigan_model.remove_weight_norm() state_dict_g = torch.load( - "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False + "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), + map_location="cpu", + weights_only=False, ) print("loading vocoder", hifigan_model.load_state_dict(state_dict_g)) clean_bigvgan_model() @@ -455,9 +471,12 @@ def init_hifigan(): else: hifigan_model = hifigan_model.to(device) + from sv import SV + + def init_sv_cn(): - global hifigan_model, bigvgan_model,sv_cn_model + global hifigan_model, bigvgan_model, sv_cn_model sv_cn_model = SV(device, is_half) clean_bigvgan_model() clean_hifigan_model() @@ -468,34 +487,37 @@ if model_version == "v3": init_bigvgan() if model_version == "v4": init_hifigan() -if model_version in {"v2Pro","v2ProPlus"}: +if model_version in {"v2Pro", "v2ProPlus"}: init_sv_cn() -resample_transform_dict={} -def resample(audio_tensor, sr0,sr1,device): +resample_transform_dict = {} + + +def resample(audio_tensor, sr0, sr1, device): global resample_transform_dict - key="%s-%s-%s"%(sr0,sr1,str(device)) + key = "%s-%s-%s" % (sr0, sr1, str(device)) if key not in resample_transform_dict: - resample_transform_dict[key] = torchaudio.transforms.Resample( - sr0, sr1 - ).to(device) + resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device) return resample_transform_dict[key](audio_tensor) -def get_spepc(hps, filename,dtype,device,is_v2pro=False): + +def get_spepc(hps, filename, dtype, device, is_v2pro=False): # audio = load_audio(filename, int(hps.data.sampling_rate)) # audio, sampling_rate = librosa.load(filename, sr=int(hps.data.sampling_rate)) # audio = torch.FloatTensor(audio) - sr1=int(hps.data.sampling_rate) - audio, sr0=torchaudio.load(filename) - if sr0!=sr1: - audio=audio.to(device) - if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0) - audio=resample(audio,sr0,sr1,device) + sr1 = int(hps.data.sampling_rate) + audio, sr0 = torchaudio.load(filename) + if sr0 != sr1: + audio = audio.to(device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) + audio = resample(audio, sr0, sr1, device) else: - audio=audio.to(device) - if(audio.shape[0]==2):audio=audio.mean(0).unsqueeze(0) + audio = audio.to(device) + if audio.shape[0] == 2: + audio = audio.mean(0).unsqueeze(0) maxx = audio.abs().max() if maxx > 1: @@ -508,10 +530,10 @@ def get_spepc(hps, filename,dtype,device,is_v2pro=False): hps.data.win_length, center=False, ) - spec=spec.to(dtype) - if is_v2pro==True: - audio=resample(audio,sr1,16000,device).to(dtype) - return spec,audio + spec = spec.to(dtype) + if is_v2pro == True: + audio = resample(audio, sr1, 16000, device).to(dtype) + return spec, audio def clean_text_inf(text, language, version): @@ -744,7 +766,7 @@ def get_tts_wav( ref_free = False # s2v3暂不支持ref_free else: if_sr = False - if model_version not in {"v3","v4","v2Pro","v2ProPlus"}: + if model_version not in {"v3", "v4", "v2Pro", "v2ProPlus"}: clean_bigvgan_model() clean_hifigan_model() clean_sv_cn_model() @@ -851,35 +873,39 @@ def get_tts_wav( pred_semantic = pred_semantic[:, -idx:].unsqueeze(0) cache[i_text] = pred_semantic t3 = ttime() - is_v2pro=model_version in {"v2Pro","v2ProPlus"} + is_v2pro = model_version in {"v2Pro", "v2ProPlus"} # print(23333,is_v2pro,model_version) ###v3不存在以下逻辑和inp_refs if model_version not in v3v4set: refers = [] if is_v2pro: - sv_emb=[] + sv_emb = [] if sv_cn_model == None: init_sv_cn() if inp_refs: for path in inp_refs: - try:#####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer - refer,audio_tensor = get_spepc(hps, path.name,dtype,device,is_v2pro) + try: #####这里加上提取sv的逻辑,要么一堆sv一堆refer,要么单个sv单个refer + refer, audio_tensor = get_spepc(hps, path.name, dtype, device, is_v2pro) refers.append(refer) if is_v2pro: sv_emb.append(sv_cn_model.compute_embedding3(audio_tensor)) except: traceback.print_exc() if len(refers) == 0: - refers,audio_tensor = get_spepc(hps, ref_wav_path,dtype,device,is_v2pro) - refers=[refers] + refers, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device, is_v2pro) + refers = [refers] if is_v2pro: - sv_emb=[sv_cn_model.compute_embedding3(audio_tensor)] + sv_emb = [sv_cn_model.compute_embedding3(audio_tensor)] if is_v2pro: - audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed,sv_emb=sv_emb)[0][0] + audio = vq_model.decode( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed, sv_emb=sv_emb + )[0][0] else: - audio = vq_model.decode(pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed)[0][0] + audio = vq_model.decode( + pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refers, speed=speed + )[0][0] else: - refer,audio_tensor = get_spepc(hps, ref_wav_path,dtype,device) + refer, audio_tensor = get_spepc(hps, ref_wav_path, dtype, device) phoneme_ids0 = torch.LongTensor(phones1).to(device).unsqueeze(0) phoneme_ids1 = torch.LongTensor(phones2).to(device).unsqueeze(0) fea_ref, ge = vq_model.decode_encp(prompt.unsqueeze(0), phoneme_ids0, refer) @@ -889,7 +915,7 @@ def get_tts_wav( ref_audio = ref_audio.mean(0).unsqueeze(0) tgt_sr = 24000 if model_version == "v3" else 32000 if sr != tgt_sr: - ref_audio = resample(ref_audio, sr, tgt_sr,device) + ref_audio = resample(ref_audio, sr, tgt_sr, device) # print("ref_audio",ref_audio.abs().mean()) mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio) mel2 = norm_spec(mel2) @@ -1076,6 +1102,7 @@ def process_text(texts): _text.append(text) return _text + def html_center(text, label="p"): return f"""
<{label} style="margin: 0; padding: 0;">{text} @@ -1088,11 +1115,13 @@ def html_left(text, label="p"):
""" -with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False) as app: - gr.Markdown( - value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") - + "
" - + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") +with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app: + gr.HTML( + top_html.format( + i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") + ), + elem_classes="markdown", ) with gr.Group(): gr.Markdown(html_center(i18n("模型切换"), "h3")) diff --git a/GPT_SoVITS/inference_webui_fast.py b/GPT_SoVITS/inference_webui_fast.py index 342e7ee..4484ba4 100644 --- a/GPT_SoVITS/inference_webui_fast.py +++ b/GPT_SoVITS/inference_webui_fast.py @@ -47,6 +47,7 @@ import gradio as gr from TTS_infer_pack.text_segmentation_method import get_method from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config +from tools.assets import css, js, top_html from tools.i18n.i18n import I18nAuto, scan_language_list language = os.environ.get("language", "Auto") @@ -98,9 +99,11 @@ cut_method = { i18n("按标点符号切"): "cut5", } -from config import name2sovits_path,name2gpt_path,change_choices,get_weights_names +from config import change_choices, get_weights_names, name2gpt_path, name2sovits_path + SoVITS_names, GPT_names = get_weights_names() from config import pretrained_sovits_name + path_sovits_v3 = pretrained_sovits_name["v3"] path_sovits_v4 = pretrained_sovits_name["v4"] is_exist_s2gv3 = os.path.exists(path_sovits_v3) @@ -111,10 +114,12 @@ tts_config.device = device tts_config.is_half = is_half tts_config.version = version if gpt_path is not None: - if "!"in gpt_path:gpt_path=name2gpt_path[gpt_path] + if "!" in gpt_path: + gpt_path = name2gpt_path[gpt_path] tts_config.t2s_weights_path = gpt_path if sovits_path is not None: - if "!"in sovits_path:sovits_path=name2sovits_path[sovits_path] + if "!" in sovits_path: + sovits_path = name2sovits_path[sovits_path] tts_config.vits_weights_path = sovits_path if cnhubert_base_path is not None: tts_config.cnhuhbert_base_path = cnhubert_base_path @@ -189,6 +194,7 @@ def custom_sort_key(s): parts = [int(part) if part.isdigit() else part for part in parts] return parts + if os.path.exists("./weight.json"): pass else: @@ -206,9 +212,13 @@ with open("./weight.json", "r", encoding="utf-8") as file: sovits_path = sovits_path[0] from process_ckpt import get_sovits_version_from_path_fast + v3v4set = {"v3", "v4"} + + def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): - if "!"in sovits_path:sovits_path=name2sovits_path[sovits_path] + if "!" in sovits_path: + sovits_path = name2sovits_path[sovits_path] global version, model_version, dict_language, if_lora_v3 version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) # print(sovits_path,version, model_version, if_lora_v3) @@ -273,11 +283,13 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None) f.write(json.dumps(data)) -with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False) as app: - gr.Markdown( - value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") - + "
" - + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") +with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app: + gr.HTML( + top_html.format( + i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") + ), + elem_classes="markdown", ) with gr.Column(): diff --git a/tools/assets.py b/tools/assets.py new file mode 100644 index 0000000..cca0efc --- /dev/null +++ b/tools/assets.py @@ -0,0 +1,112 @@ +js = """ +function createGradioAnimation() { + + const params = new URLSearchParams(window.location.search); + if (params.get('__theme') !== 'light') { + params.set('__theme', 'light'); // 仅当 __theme 不是 'light' 时设置为 'light' + window.location.search = params.toString(); // 更新 URL,触发页面刷新 + } + + var container = document.createElement('div'); + container.id = 'gradio-animation'; + container.style.fontSize = '2em'; + container.style.fontWeight = '500'; + container.style.textAlign = 'center'; + container.style.marginBottom = '20px'; + container.style.fontFamily = '-apple-system, sans-serif, Arial, Calibri'; + + var text = 'Welcome to GPT-SoVITS !'; + for (var i = 0; i < text.length; i++) { + (function(i){ + setTimeout(function(){ + var letter = document.createElement('span'); + letter.style.opacity = '0'; + letter.style.transition = 'opacity 0.5s'; + letter.innerText = text[i]; + + container.appendChild(letter); + + setTimeout(function() { + letter.style.opacity = '1'; + }, 50); + }, i * 250); + })(i); + } + return 'Animation created'; +} +""" + + +css = """ +/* CSSStyleRule */ + +.markdown { + background-color: lightblue; + padding: 10px; +} + +.checkbox_info { + color: var(--block-title-text-color) !important; + font-size: var(--block-title-text-size) !important; + font-weight: var(--block-title-text-weight) !important; + height: 22px; + margin-bottom: 8px !important; +} + +::selection { + background: #ffc078; !important; +} + +#checkbox_train_dpo input[type="checkbox"]{ + margin-top: 6px; +} + +#checkbox_train_dpo span { + margin-top: 6px; +} + +#checkbox_align_train { + padding-top: 18px; + padding-bottom: 18px; +} + +#checkbox_align_infer input[type="checkbox"] { + margin-top: 10px; +} + +#checkbox_align_infer span { + margin-top: 10px; +} + +footer { + height: 50px !important; /* 设置页脚高度 */ + background-color: transparent !important; /* 背景透明 */ + display: flex; + justify-content: center; /* 居中对齐 */ + align-items: center; /* 垂直居中 */ +} + +footer * { + display: none !important; /* 隐藏所有子元素 */ +} + +""" +top_html = """ +
+
{}
+
+ + + + + + + + + + + + +
+
+""" diff --git a/webui.py b/webui.py index 878cce8..26e0fe6 100644 --- a/webui.py +++ b/webui.py @@ -60,6 +60,7 @@ import shutil import subprocess from subprocess import Popen +from tools.assets import css, js, top_html from tools.i18n.i18n import I18nAuto, scan_language_list language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else "Auto" @@ -1299,14 +1300,13 @@ def sync(text): return {"__type__": "update", "value": text} -with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False) as app: - gr.Markdown( - value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") - + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") - + "
" - + i18n("中文教程文档") - + ": " - + "https://www.yuque.com/baicaigongchang1145haoyuangong/ib3g1e" +with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app: + gr.HTML( + top_html.format( + i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.") + ), + elem_classes="markdown", ) with gr.Tabs():