support gpt-sovits v4

support gpt-sovits v4
main
RVC-Boss 4 months ago committed by GitHub
parent e0c452f007
commit c6cb6b45f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -39,21 +39,25 @@ except:
...
version = model_version = os.environ.get("version", "v2")
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
pretrained_sovits_name = [
"GPT_SoVITS/pretrained_models/s2G488k.pth",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
path_sovits_v3,
"GPT_SoVITS/pretrained_models/s2Gv3.pth",
"GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
]
pretrained_gpt_name = [
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
]
_ = [[], []]
for i in range(3):
for i in range(4):
if os.path.exists(pretrained_gpt_name[i]):
_[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]):
@ -102,7 +106,7 @@ cnhubert.cnhubert_base_path = cnhubert_base_path
import random
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3,Generator
def set_seed(seed):
@ -222,23 +226,25 @@ else:
resample_transform_dict = {}
def resample(audio_tensor, sr0):
def resample(audio_tensor, sr0,sr1):
global resample_transform_dict
if sr0 not in resample_transform_dict:
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device)
return resample_transform_dict[sr0](audio_tensor)
key="%s-%s"%(sr0,sr1)
if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
return resample_transform_dict[key](audio_tensor)
###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
# symbol_version-model_version-if_lora_v3
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
v3v4set={"v3","v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
global vq_model, hps, version, model_version, dict_language, if_lora_v3
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
# print(sovits_path,version, model_version, if_lora_v3)
if if_lora_v3 == True and is_exist_s2gv3 == False:
print(sovits_path,version, model_version, if_lora_v3)
is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4
if if_lora_v3 == True and is_exist == False:
info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
gr.Warning(info)
raise FileExistsError(info)
@ -257,7 +263,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
else:
text_update = {"__type__": "update", "value": ""}
text_language_update = {"__type__": "update", "value": i18n("中文")}
if model_version == "v3":
if model_version in v3v4set:
visible_sample_steps = True
visible_inp_refs = False
else:
@ -270,10 +276,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
prompt_language_update,
text_update,
text_language_update,
{"__type__": "update", "visible": visible_sample_steps, "value": 32},
{"__type__": "update", "visible": visible_sample_steps, "value": 32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
{"__type__": "update", "visible": visible_inp_refs},
{"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False},
{"__type__": "update", "visible": True if model_version == "v3" else False},
{"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
{"__type__": "update", "visible": True if model_version =="v3" else False},
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
)
@ -289,7 +295,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
hps.model.version = "v2"
version = hps.model.version
# print("sovits版本:",hps.model.version)
if model_version != "v3":
if model_version not in v3v4set:
vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
@ -317,9 +323,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
if if_lora_v3 == False:
print("loading sovits_%s" % model_version, vq_model.load_state_dict(dict_s2["weight"], strict=False))
else:
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
print(
"loading sovits_v3pretrained_G",
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False),
"loading sovits_%spretrained_G"%model_version,
vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False),
)
lora_rank = dict_s2["lora_rank"]
lora_config = LoraConfig(
@ -329,7 +336,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
init_lora_weights=True,
)
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
print("loading sovits_v3_lora%s" % (lora_rank))
print("loading sovits_%s_lora%s" % (model_version,lora_rank))
vq_model.load_state_dict(dict_s2["weight"], strict=False)
vq_model.cfm = vq_model.cfm.merge_and_unload()
# torch.save(vq_model.state_dict(),"merge_win.pth")
@ -342,10 +349,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
prompt_language_update,
text_update,
text_language_update,
{"__type__": "update", "visible": visible_sample_steps, "value": 32},
{"__type__": "update", "visible": visible_sample_steps, "value":32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
{"__type__": "update", "visible": visible_inp_refs},
{"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False},
{"__type__": "update", "visible": True if model_version == "v3" else False},
{"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
{"__type__": "update", "visible": True if model_version =="v3" else False},
{"__type__": "update", "value": i18n("合成语音"), "interactive": True},
)
with open("./weight.json") as f:
@ -392,7 +399,7 @@ now_dir = os.getcwd()
def init_bigvgan():
global bigvgan_model
global bigvgan_model,hifigan_model
from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained(
@ -402,16 +409,47 @@ def init_bigvgan():
# remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval()
if hifigan_model:
hifigan_model=hifigan_model.cpu()
hifigan_model=None
try:torch.cuda.empty_cache()
except:pass
if is_half == True:
bigvgan_model = bigvgan_model.half().to(device)
else:
bigvgan_model = bigvgan_model.to(device)
def init_hifigan():
global hifigan_model,bigvgan_model
hifigan_model = Generator(
initial_channel=100,
resblock="1",
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[10, 6, 2, 2, 2],
upsample_initial_channel=512,
upsample_kernel_sizes=[20, 12, 4, 4, 4],
gin_channels=0, is_bias=True
)
hifigan_model.eval()
hifigan_model.remove_weight_norm()
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
print("loading vocoder",hifigan_model.load_state_dict(state_dict_g))
if bigvgan_model:
bigvgan_model=bigvgan_model.cpu()
bigvgan_model=None
try:torch.cuda.empty_cache()
except:pass
if is_half == True:
hifigan_model = hifigan_model.half().to(device)
else:
hifigan_model = hifigan_model.to(device)
if model_version != "v3":
bigvgan_model = None
else:
bigvgan_model=hifigan_model=None
if model_version=="v3":
init_bigvgan()
if model_version=="v4":
init_hifigan()
def get_spepc(hps, filename):
@ -576,6 +614,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
"center": False,
},
)
mel_fn_v4 = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1280,
"win_size": 1280,
"hop_size": 320,
"num_mels": 100,
"sampling_rate": 32000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
def merge_short_text_in_array(texts, threshold):
@ -647,7 +698,7 @@ def get_tts_wav(
t = []
if prompt_text is None or len(prompt_text) == 0:
ref_free = True
if model_version == "v3":
if model_version in v3v4set:
ref_free = False # s2v3暂不支持ref_free
else:
if_sr = False
@ -755,7 +806,7 @@ def get_tts_wav(
cache[i_text] = pred_semantic
t3 = ttime()
###v3不存在以下逻辑和inp_refs
if model_version != "v3":
if model_version not in v3v4set:
refers = []
if inp_refs:
for path in inp_refs:
@ -779,25 +830,24 @@ def get_tts_wav(
ref_audio = ref_audio.to(device).float()
if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr != 24000:
ref_audio = resample(ref_audio, sr)
tgt_sr=24000 if model_version=="v3"else 32000
if sr != tgt_sr:
ref_audio = resample(ref_audio, sr,tgt_sr)
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio)
mel2 = mel_fn(ref_audio)if model_version=="v3"else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min]
if T_min > 468:
mel2 = mel2[:, :, -468:]
fea_ref = fea_ref[:, :, -468:]
T_min = 468
chunk_len = 934 - T_min
# print("fea_ref",fea_ref,fea_ref.shape)
# print("mel2",mel2)
Tref=468 if model_version=="v3"else 500
Tchunk=934 if model_version=="v3"else 1000
if T_min > Tref:
mel2 = mel2[:, :, -Tref:]
fea_ref = fea_ref[:, :, -Tref:]
T_min = Tref
chunk_len = Tchunk - T_min
mel2 = mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
# print("fea_todo",fea_todo)
# print("ge",ge.abs().mean())
cfm_resss = []
idx = 0
while 1:
@ -806,22 +856,24 @@ def get_tts_wav(
break
idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
# set_seed(123)
cfm_res = vq_model.cfm.inference(
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
)
cfm_res = cfm_res[:, :, mel2.shape[2] :]
mel2 = cfm_res[:, :, -T_min:]
# print("fea", fea)
# print("mel2in", mel2)
fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res)
if bigvgan_model == None:
init_bigvgan()
cfm_res = torch.cat(cfm_resss, 2)
cfm_res = denorm_spec(cfm_res)
if model_version=="v3":
if bigvgan_model == None:
init_bigvgan()
else:#v4
if hifigan_model == None:
init_hifigan()
vocoder_model=bigvgan_model if model_version=="v3"else hifigan_model
with torch.inference_mode():
wav_gen = bigvgan_model(cmf_res)
wav_gen = vocoder_model(cfm_res)
audio = wav_gen[0][0] # .cpu().detach().numpy()
max_audio = torch.abs(audio).max() # 简单防止16bit爆音
if max_audio > 1:
@ -833,16 +885,18 @@ def get_tts_wav(
t1 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
audio_opt = torch.cat(audio_opt, 0) # np.concatenate
sr = hps.data.sampling_rate if model_version != "v3" else 24000
if if_sr == True and sr == 24000:
if model_version in {"v1","v2"}:opt_sr=32000
elif model_version=="v3":opt_sr=24000
else:opt_sr=48000#v4
if if_sr == True and opt_sr == 24000:
print(i18n("音频超分中"))
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr)
audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr)
max_audio = np.abs(audio_opt).max()
if max_audio > 1:
audio_opt /= max_audio
else:
audio_opt = audio_opt.cpu().detach().numpy()
yield sr, (audio_opt * 32767).astype(np.int16)
yield opt_sr, (audio_opt * 32767).astype(np.int16)
def split(todo_text):
@ -971,8 +1025,8 @@ def change_choices():
}
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"]
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"]
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
for path in SoVITS_weight_root + GPT_weight_root:
os.makedirs(path, exist_ok=True)
@ -1039,7 +1093,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。")
+ i18n("v3暂不支持该模式使用了会报错。"),
value=False,
interactive=True if model_version != "v3" else False,
interactive=True if model_version not in v3v4set else False,
show_label=True,
scale=1,
)
@ -1064,7 +1118,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
),
file_count="multiple",
)
if model_version != "v3"
if model_version not in v3v4set
else gr.File(
label=i18n(
"可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
@ -1076,16 +1130,16 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
sample_steps = (
gr.Radio(
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
value=32,
choices=[4, 8, 16, 32],
value=32 if model_version=="v3"else 8,
choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32],
visible=True,
)
if model_version == "v3"
if model_version in v3v4set
else gr.Radio(
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
choices=[4, 8, 16, 32],
choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32],
visible=False,
value=32,
value=32 if model_version=="v3"else 8,
)
)
if_sr_Checkbox = gr.Checkbox(
@ -1093,7 +1147,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
value=False,
interactive=True,
show_label=True,
visible=False if model_version != "v3" else True,
visible=False if model_version !="v3" else True,
)
gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
with gr.Row():

@ -22,23 +22,24 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
01:v2
02:v3
03:v3lora
04:v4lora
"""
from io import BytesIO
def my_save2(fea, path):
def my_save2(fea, path,cfm_version):
bio = BytesIO()
torch.save(fea, bio)
bio.seek(0)
data = bio.getvalue()
data = b"03" + data[2:] ###temp for v3lora only, todo
byte=b"03" if cfm_version=="v3"else b"04"
data = byte + data[2:]
with open(path, "wb") as f:
f.write(data)
def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None):
try:
opt = OrderedDict()
opt["weight"] = {}
@ -50,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
if lora_rank:
opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version)
else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success."
@ -63,11 +64,13 @@ head2version = {
b"01": ["v2", "v2", False],
b"02": ["v2", "v3", False],
b"03": ["v2", "v3", True],
b"04": ["v2", "v4", True],
}
hash_pretrained_dict = {
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
"4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
}
import hashlib
@ -85,7 +88,7 @@ def get_sovits_version_from_path_fast(sovits_path):
hash = get_hash_from_file(sovits_path)
if hash in hash_pretrained_dict:
return hash_pretrained_dict[hash]
###2-new weights or old weights, by head
###2-new weights, by head
with open(sovits_path, "rb") as f:
version = f.read(2)
if version != b"PK":

@ -27,12 +27,11 @@ from random import randint
from module import commons
from module.data_utils import (
DistributedBucketSampler,
)
from module.data_utils import (
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate,
)
from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
TextAudioSpeakerCollateV3,
TextAudioSpeakerLoaderV3,
TextAudioSpeakerCollateV4,
TextAudioSpeakerLoaderV4,
)
from module.models import (
SynthesizerTrnV3 as SynthesizerTrn,
@ -89,6 +88,8 @@ def run(rank, n_gpus, hps):
if torch.cuda.is_available():
torch.cuda.set_device(rank)
TextAudioSpeakerLoader=TextAudioSpeakerLoaderV3 if hps.model.version=="v3"else TextAudioSpeakerLoaderV4
TextAudioSpeakerCollate=TextAudioSpeakerCollateV3 if hps.model.version=="v3"else TextAudioSpeakerCollateV4
train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler(
train_dataset,
@ -364,7 +365,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
epoch,
global_step,
hps,
hps,cfm_version=hps.model.version,
lora_rank=lora_rank,
),
)

Loading…
Cancel
Save