support gpt-sovits v4

support gpt-sovits v4
main
RVC-Boss 4 months ago committed by GitHub
parent e0c452f007
commit c6cb6b45f3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -39,21 +39,25 @@ except:
... ...
version = model_version = os.environ.get("version", "v2") version = model_version = os.environ.get("version", "v2")
path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth" path_sovits_v3 = "GPT_SoVITS/pretrained_models/s2Gv3.pth"
path_sovits_v4 = "GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth"
is_exist_s2gv3 = os.path.exists(path_sovits_v3) is_exist_s2gv3 = os.path.exists(path_sovits_v3)
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
pretrained_sovits_name = [ pretrained_sovits_name = [
"GPT_SoVITS/pretrained_models/s2G488k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth",
path_sovits_v3, "GPT_SoVITS/pretrained_models/s2Gv3.pth",
"GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth",
] ]
pretrained_gpt_name = [ pretrained_gpt_name = [
"GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt",
"GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt",
"GPT_SoVITS/pretrained_models/s1v3.ckpt", "GPT_SoVITS/pretrained_models/s1v3.ckpt",
"GPT_SoVITS/pretrained_models/s1v3.ckpt",
] ]
_ = [[], []] _ = [[], []]
for i in range(3): for i in range(4):
if os.path.exists(pretrained_gpt_name[i]): if os.path.exists(pretrained_gpt_name[i]):
_[0].append(pretrained_gpt_name[i]) _[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]): if os.path.exists(pretrained_sovits_name[i]):
@ -102,7 +106,7 @@ cnhubert.cnhubert_base_path = cnhubert_base_path
import random import random
from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3 from GPT_SoVITS.module.models import SynthesizerTrn, SynthesizerTrnV3,Generator
def set_seed(seed): def set_seed(seed):
@ -222,23 +226,25 @@ else:
resample_transform_dict = {} resample_transform_dict = {}
def resample(audio_tensor, sr0): def resample(audio_tensor, sr0,sr1):
global resample_transform_dict global resample_transform_dict
if sr0 not in resample_transform_dict: key="%s-%s"%(sr0,sr1)
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device) if key not in resample_transform_dict:
return resample_transform_dict[sr0](audio_tensor) resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
return resample_transform_dict[key](audio_tensor)
###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt ###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
# symbol_version-model_version-if_lora_v3 # symbol_version-model_version-if_lora_v3
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
v3v4set={"v3","v4"}
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None): def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
global vq_model, hps, version, model_version, dict_language, if_lora_v3 global vq_model, hps, version, model_version, dict_language, if_lora_v3
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
# print(sovits_path,version, model_version, if_lora_v3) print(sovits_path,version, model_version, if_lora_v3)
if if_lora_v3 == True and is_exist_s2gv3 == False: is_exist=is_exist_s2gv3 if model_version=="v3"else is_exist_s2gv4
if if_lora_v3 == True and is_exist == False:
info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重") info = "GPT_SoVITS/pretrained_models/s2Gv3.pth" + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
gr.Warning(info) gr.Warning(info)
raise FileExistsError(info) raise FileExistsError(info)
@ -257,7 +263,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
else: else:
text_update = {"__type__": "update", "value": ""} text_update = {"__type__": "update", "value": ""}
text_language_update = {"__type__": "update", "value": i18n("中文")} text_language_update = {"__type__": "update", "value": i18n("中文")}
if model_version == "v3": if model_version in v3v4set:
visible_sample_steps = True visible_sample_steps = True
visible_inp_refs = False visible_inp_refs = False
else: else:
@ -270,10 +276,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
prompt_language_update, prompt_language_update,
text_update, text_update,
text_language_update, text_language_update,
{"__type__": "update", "visible": visible_sample_steps, "value": 32}, {"__type__": "update", "visible": visible_sample_steps, "value": 32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
{"__type__": "update", "visible": visible_inp_refs}, {"__type__": "update", "visible": visible_inp_refs},
{"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False}, {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
{"__type__": "update", "visible": True if model_version == "v3" else False}, {"__type__": "update", "visible": True if model_version =="v3" else False},
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False}, {"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
) )
@ -289,7 +295,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
hps.model.version = "v2" hps.model.version = "v2"
version = hps.model.version version = hps.model.version
# print("sovits版本:",hps.model.version) # print("sovits版本:",hps.model.version)
if model_version != "v3": if model_version not in v3v4set:
vq_model = SynthesizerTrn( vq_model = SynthesizerTrn(
hps.data.filter_length // 2 + 1, hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length, hps.train.segment_size // hps.data.hop_length,
@ -317,9 +323,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
if if_lora_v3 == False: if if_lora_v3 == False:
print("loading sovits_%s" % model_version, vq_model.load_state_dict(dict_s2["weight"], strict=False)) print("loading sovits_%s" % model_version, vq_model.load_state_dict(dict_s2["weight"], strict=False))
else: else:
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
print( print(
"loading sovits_v3pretrained_G", "loading sovits_%spretrained_G"%model_version,
vq_model.load_state_dict(load_sovits_new(path_sovits_v3)["weight"], strict=False), vq_model.load_state_dict(load_sovits_new(path_sovits)["weight"], strict=False),
) )
lora_rank = dict_s2["lora_rank"] lora_rank = dict_s2["lora_rank"]
lora_config = LoraConfig( lora_config = LoraConfig(
@ -329,7 +336,7 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
init_lora_weights=True, init_lora_weights=True,
) )
vq_model.cfm = get_peft_model(vq_model.cfm, lora_config) vq_model.cfm = get_peft_model(vq_model.cfm, lora_config)
print("loading sovits_v3_lora%s" % (lora_rank)) print("loading sovits_%s_lora%s" % (model_version,lora_rank))
vq_model.load_state_dict(dict_s2["weight"], strict=False) vq_model.load_state_dict(dict_s2["weight"], strict=False)
vq_model.cfm = vq_model.cfm.merge_and_unload() vq_model.cfm = vq_model.cfm.merge_and_unload()
# torch.save(vq_model.state_dict(),"merge_win.pth") # torch.save(vq_model.state_dict(),"merge_win.pth")
@ -342,10 +349,10 @@ def change_sovits_weights(sovits_path, prompt_language=None, text_language=None)
prompt_language_update, prompt_language_update,
text_update, text_update,
text_language_update, text_language_update,
{"__type__": "update", "visible": visible_sample_steps, "value": 32}, {"__type__": "update", "visible": visible_sample_steps, "value":32 if model_version=="v3"else 8,"choices":[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32]},
{"__type__": "update", "visible": visible_inp_refs}, {"__type__": "update", "visible": visible_inp_refs},
{"__type__": "update", "value": False, "interactive": True if model_version != "v3" else False}, {"__type__": "update", "value": False, "interactive": True if model_version not in v3v4set else False},
{"__type__": "update", "visible": True if model_version == "v3" else False}, {"__type__": "update", "visible": True if model_version =="v3" else False},
{"__type__": "update", "value": i18n("合成语音"), "interactive": True}, {"__type__": "update", "value": i18n("合成语音"), "interactive": True},
) )
with open("./weight.json") as f: with open("./weight.json") as f:
@ -392,7 +399,7 @@ now_dir = os.getcwd()
def init_bigvgan(): def init_bigvgan():
global bigvgan_model global bigvgan_model,hifigan_model
from BigVGAN import bigvgan from BigVGAN import bigvgan
bigvgan_model = bigvgan.BigVGAN.from_pretrained( bigvgan_model = bigvgan.BigVGAN.from_pretrained(
@ -402,16 +409,47 @@ def init_bigvgan():
# remove weight norm in the model and set to eval mode # remove weight norm in the model and set to eval mode
bigvgan_model.remove_weight_norm() bigvgan_model.remove_weight_norm()
bigvgan_model = bigvgan_model.eval() bigvgan_model = bigvgan_model.eval()
if hifigan_model:
hifigan_model=hifigan_model.cpu()
hifigan_model=None
try:torch.cuda.empty_cache()
except:pass
if is_half == True: if is_half == True:
bigvgan_model = bigvgan_model.half().to(device) bigvgan_model = bigvgan_model.half().to(device)
else: else:
bigvgan_model = bigvgan_model.to(device) bigvgan_model = bigvgan_model.to(device)
def init_hifigan():
global hifigan_model,bigvgan_model
hifigan_model = Generator(
initial_channel=100,
resblock="1",
resblock_kernel_sizes=[3, 7, 11],
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
upsample_rates=[10, 6, 2, 2, 2],
upsample_initial_channel=512,
upsample_kernel_sizes=[20, 12, 4, 4, 4],
gin_channels=0, is_bias=True
)
hifigan_model.eval()
hifigan_model.remove_weight_norm()
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
print("loading vocoder",hifigan_model.load_state_dict(state_dict_g))
if bigvgan_model:
bigvgan_model=bigvgan_model.cpu()
bigvgan_model=None
try:torch.cuda.empty_cache()
except:pass
if is_half == True:
hifigan_model = hifigan_model.half().to(device)
else:
hifigan_model = hifigan_model.to(device)
if model_version != "v3": bigvgan_model=hifigan_model=None
bigvgan_model = None if model_version=="v3":
else:
init_bigvgan() init_bigvgan()
if model_version=="v4":
init_hifigan()
def get_spepc(hps, filename): def get_spepc(hps, filename):
@ -576,6 +614,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
"center": False, "center": False,
}, },
) )
mel_fn_v4 = lambda x: mel_spectrogram_torch(
x,
**{
"n_fft": 1280,
"win_size": 1280,
"hop_size": 320,
"num_mels": 100,
"sampling_rate": 32000,
"fmin": 0,
"fmax": None,
"center": False,
},
)
def merge_short_text_in_array(texts, threshold): def merge_short_text_in_array(texts, threshold):
@ -647,7 +698,7 @@ def get_tts_wav(
t = [] t = []
if prompt_text is None or len(prompt_text) == 0: if prompt_text is None or len(prompt_text) == 0:
ref_free = True ref_free = True
if model_version == "v3": if model_version in v3v4set:
ref_free = False # s2v3暂不支持ref_free ref_free = False # s2v3暂不支持ref_free
else: else:
if_sr = False if_sr = False
@ -755,7 +806,7 @@ def get_tts_wav(
cache[i_text] = pred_semantic cache[i_text] = pred_semantic
t3 = ttime() t3 = ttime()
###v3不存在以下逻辑和inp_refs ###v3不存在以下逻辑和inp_refs
if model_version != "v3": if model_version not in v3v4set:
refers = [] refers = []
if inp_refs: if inp_refs:
for path in inp_refs: for path in inp_refs:
@ -779,25 +830,24 @@ def get_tts_wav(
ref_audio = ref_audio.to(device).float() ref_audio = ref_audio.to(device).float()
if ref_audio.shape[0] == 2: if ref_audio.shape[0] == 2:
ref_audio = ref_audio.mean(0).unsqueeze(0) ref_audio = ref_audio.mean(0).unsqueeze(0)
if sr != 24000: tgt_sr=24000 if model_version=="v3"else 32000
ref_audio = resample(ref_audio, sr) if sr != tgt_sr:
ref_audio = resample(ref_audio, sr,tgt_sr)
# print("ref_audio",ref_audio.abs().mean()) # print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn(ref_audio) mel2 = mel_fn(ref_audio)if model_version=="v3"else mel_fn_v4(ref_audio)
mel2 = norm_spec(mel2) mel2 = norm_spec(mel2)
T_min = min(mel2.shape[2], fea_ref.shape[2]) T_min = min(mel2.shape[2], fea_ref.shape[2])
mel2 = mel2[:, :, :T_min] mel2 = mel2[:, :, :T_min]
fea_ref = fea_ref[:, :, :T_min] fea_ref = fea_ref[:, :, :T_min]
if T_min > 468: Tref=468 if model_version=="v3"else 500
mel2 = mel2[:, :, -468:] Tchunk=934 if model_version=="v3"else 1000
fea_ref = fea_ref[:, :, -468:] if T_min > Tref:
T_min = 468 mel2 = mel2[:, :, -Tref:]
chunk_len = 934 - T_min fea_ref = fea_ref[:, :, -Tref:]
# print("fea_ref",fea_ref,fea_ref.shape) T_min = Tref
# print("mel2",mel2) chunk_len = Tchunk - T_min
mel2 = mel2.to(dtype) mel2 = mel2.to(dtype)
fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed) fea_todo, ge = vq_model.decode_encp(pred_semantic, phoneme_ids1, refer, ge, speed)
# print("fea_todo",fea_todo)
# print("ge",ge.abs().mean())
cfm_resss = [] cfm_resss = []
idx = 0 idx = 0
while 1: while 1:
@ -806,22 +856,24 @@ def get_tts_wav(
break break
idx += chunk_len idx += chunk_len
fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1) fea = torch.cat([fea_ref, fea_todo_chunk], 2).transpose(2, 1)
# set_seed(123)
cfm_res = vq_model.cfm.inference( cfm_res = vq_model.cfm.inference(
fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0 fea, torch.LongTensor([fea.size(1)]).to(fea.device), mel2, sample_steps, inference_cfg_rate=0
) )
cfm_res = cfm_res[:, :, mel2.shape[2] :] cfm_res = cfm_res[:, :, mel2.shape[2] :]
mel2 = cfm_res[:, :, -T_min:] mel2 = cfm_res[:, :, -T_min:]
# print("fea", fea)
# print("mel2in", mel2)
fea_ref = fea_todo_chunk[:, :, -T_min:] fea_ref = fea_todo_chunk[:, :, -T_min:]
cfm_resss.append(cfm_res) cfm_resss.append(cfm_res)
cmf_res = torch.cat(cfm_resss, 2) cfm_res = torch.cat(cfm_resss, 2)
cmf_res = denorm_spec(cmf_res) cfm_res = denorm_spec(cfm_res)
if bigvgan_model == None: if model_version=="v3":
init_bigvgan() if bigvgan_model == None:
init_bigvgan()
else:#v4
if hifigan_model == None:
init_hifigan()
vocoder_model=bigvgan_model if model_version=="v3"else hifigan_model
with torch.inference_mode(): with torch.inference_mode():
wav_gen = bigvgan_model(cmf_res) wav_gen = vocoder_model(cfm_res)
audio = wav_gen[0][0] # .cpu().detach().numpy() audio = wav_gen[0][0] # .cpu().detach().numpy()
max_audio = torch.abs(audio).max() # 简单防止16bit爆音 max_audio = torch.abs(audio).max() # 简单防止16bit爆音
if max_audio > 1: if max_audio > 1:
@ -833,16 +885,18 @@ def get_tts_wav(
t1 = ttime() t1 = ttime()
print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3]))) print("%.3f\t%.3f\t%.3f\t%.3f" % (t[0], sum(t[1::3]), sum(t[2::3]), sum(t[3::3])))
audio_opt = torch.cat(audio_opt, 0) # np.concatenate audio_opt = torch.cat(audio_opt, 0) # np.concatenate
sr = hps.data.sampling_rate if model_version != "v3" else 24000 if model_version in {"v1","v2"}:opt_sr=32000
if if_sr == True and sr == 24000: elif model_version=="v3":opt_sr=24000
else:opt_sr=48000#v4
if if_sr == True and opt_sr == 24000:
print(i18n("音频超分中")) print(i18n("音频超分中"))
audio_opt, sr = audio_sr(audio_opt.unsqueeze(0), sr) audio_opt, opt_sr = audio_sr(audio_opt.unsqueeze(0), opt_sr)
max_audio = np.abs(audio_opt).max() max_audio = np.abs(audio_opt).max()
if max_audio > 1: if max_audio > 1:
audio_opt /= max_audio audio_opt /= max_audio
else: else:
audio_opt = audio_opt.cpu().detach().numpy() audio_opt = audio_opt.cpu().detach().numpy()
yield sr, (audio_opt * 32767).astype(np.int16) yield opt_sr, (audio_opt * 32767).astype(np.int16)
def split(todo_text): def split(todo_text):
@ -971,8 +1025,8 @@ def change_choices():
} }
SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3"] SoVITS_weight_root = ["SoVITS_weights", "SoVITS_weights_v2", "SoVITS_weights_v3", "SoVITS_weights_v4"]
GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3"] GPT_weight_root = ["GPT_weights", "GPT_weights_v2", "GPT_weights_v3", "GPT_weights_v4"]
for path in SoVITS_weight_root + GPT_weight_root: for path in SoVITS_weight_root + GPT_weight_root:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
@ -1039,7 +1093,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。") label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。")
+ i18n("v3暂不支持该模式使用了会报错。"), + i18n("v3暂不支持该模式使用了会报错。"),
value=False, value=False,
interactive=True if model_version != "v3" else False, interactive=True if model_version not in v3v4set else False,
show_label=True, show_label=True,
scale=1, scale=1,
) )
@ -1064,7 +1118,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
), ),
file_count="multiple", file_count="multiple",
) )
if model_version != "v3" if model_version not in v3v4set
else gr.File( else gr.File(
label=i18n( label=i18n(
"可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。" "可选项:通过拖拽多个文件上传多个参考音频(建议同性),平均融合他们的音色。如不填写此项,音色由左侧单个参考音频控制。如是微调模型,建议参考音频全部在微调训练集音色内,底模不用管。"
@ -1076,16 +1130,16 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
sample_steps = ( sample_steps = (
gr.Radio( gr.Radio(
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"), label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
value=32, value=32 if model_version=="v3"else 8,
choices=[4, 8, 16, 32], choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32],
visible=True, visible=True,
) )
if model_version == "v3" if model_version in v3v4set
else gr.Radio( else gr.Radio(
label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"), label=i18n("采样步数,如果觉得电,提高试试,如果觉得慢,降低试试"),
choices=[4, 8, 16, 32], choices=[4, 8, 16, 32,64,128]if model_version=="v3"else [4, 8, 16, 32],
visible=False, visible=False,
value=32, value=32 if model_version=="v3"else 8,
) )
) )
if_sr_Checkbox = gr.Checkbox( if_sr_Checkbox = gr.Checkbox(
@ -1093,7 +1147,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
value=False, value=False,
interactive=True, interactive=True,
show_label=True, show_label=True,
visible=False if model_version != "v3" else True, visible=False if model_version !="v3" else True,
) )
gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3")) gr.Markdown(html_center(i18n("*请填写需要合成的目标文本和语种模式"), "h3"))
with gr.Row(): with gr.Row():

@ -22,23 +22,24 @@ def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
01:v2 01:v2
02:v3 02:v3
03:v3lora 03:v3lora
04:v4lora
""" """
from io import BytesIO from io import BytesIO
def my_save2(fea, path): def my_save2(fea, path,cfm_version):
bio = BytesIO() bio = BytesIO()
torch.save(fea, bio) torch.save(fea, bio)
bio.seek(0) bio.seek(0)
data = bio.getvalue() data = bio.getvalue()
data = b"03" + data[2:] ###temp for v3lora only, todo byte=b"03" if cfm_version=="v3"else b"04"
data = byte + data[2:]
with open(path, "wb") as f: with open(path, "wb") as f:
f.write(data) f.write(data)
def savee(ckpt, name, epoch, steps, hps, lora_rank=None): def savee(ckpt, name, epoch, steps, hps, cfm_version=None,lora_rank=None):
try: try:
opt = OrderedDict() opt = OrderedDict()
opt["weight"] = {} opt["weight"] = {}
@ -50,7 +51,7 @@ def savee(ckpt, name, epoch, steps, hps, lora_rank=None):
opt["info"] = "%sepoch_%siteration" % (epoch, steps) opt["info"] = "%sepoch_%siteration" % (epoch, steps)
if lora_rank: if lora_rank:
opt["lora_rank"] = lora_rank opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name),cfm_version)
else: else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name)) my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success." return "Success."
@ -63,11 +64,13 @@ head2version = {
b"01": ["v2", "v2", False], b"01": ["v2", "v2", False],
b"02": ["v2", "v3", False], b"02": ["v2", "v3", False],
b"03": ["v2", "v3", True], b"03": ["v2", "v3", True],
b"04": ["v2", "v4", True],
} }
hash_pretrained_dict = { hash_pretrained_dict = {
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained "dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained "43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained "6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
"4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
} }
import hashlib import hashlib
@ -85,7 +88,7 @@ def get_sovits_version_from_path_fast(sovits_path):
hash = get_hash_from_file(sovits_path) hash = get_hash_from_file(sovits_path)
if hash in hash_pretrained_dict: if hash in hash_pretrained_dict:
return hash_pretrained_dict[hash] return hash_pretrained_dict[hash]
###2-new weights or old weights, by head ###2-new weights, by head
with open(sovits_path, "rb") as f: with open(sovits_path, "rb") as f:
version = f.read(2) version = f.read(2)
if version != b"PK": if version != b"PK":

@ -27,12 +27,11 @@ from random import randint
from module import commons from module import commons
from module.data_utils import ( from module.data_utils import (
DistributedBucketSampler, DistributedBucketSampler,
) TextAudioSpeakerCollateV3,
from module.data_utils import ( TextAudioSpeakerLoaderV3,
TextAudioSpeakerCollateV3 as TextAudioSpeakerCollate, TextAudioSpeakerCollateV4,
) TextAudioSpeakerLoaderV4,
from module.data_utils import (
TextAudioSpeakerLoaderV3 as TextAudioSpeakerLoader,
) )
from module.models import ( from module.models import (
SynthesizerTrnV3 as SynthesizerTrn, SynthesizerTrnV3 as SynthesizerTrn,
@ -89,6 +88,8 @@ def run(rank, n_gpus, hps):
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
TextAudioSpeakerLoader=TextAudioSpeakerLoaderV3 if hps.model.version=="v3"else TextAudioSpeakerLoaderV4
TextAudioSpeakerCollate=TextAudioSpeakerCollateV3 if hps.model.version=="v3"else TextAudioSpeakerCollateV4
train_dataset = TextAudioSpeakerLoader(hps.data) ######## train_dataset = TextAudioSpeakerLoader(hps.data) ########
train_sampler = DistributedBucketSampler( train_sampler = DistributedBucketSampler(
train_dataset, train_dataset,
@ -364,7 +365,7 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade
hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank), hps.name + "_e%s_s%s_l%s" % (epoch, global_step, lora_rank),
epoch, epoch,
global_step, global_step,
hps, hps,cfm_version=hps.model.version,
lora_rank=lora_rank, lora_rank=lora_rank,
), ),
) )

Loading…
Cancel
Save