|
|
@ -202,16 +202,6 @@ if is_half == True:
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
ssl_model = ssl_model.to(device)
|
|
|
|
ssl_model = ssl_model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
resample_transform_dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resample(audio_tensor, sr0, sr1):
|
|
|
|
|
|
|
|
global resample_transform_dict
|
|
|
|
|
|
|
|
key = "%s-%s" % (sr0, sr1)
|
|
|
|
|
|
|
|
if key not in resample_transform_dict:
|
|
|
|
|
|
|
|
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
|
|
|
|
|
|
|
return resample_transform_dict[key](audio_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
|
|
|
|
###todo:put them to process_ckpt and modify my_save func (save sovits weights), gpt save weights use my_save in process_ckpt
|
|
|
|
# symbol_version-model_version-if_lora_v3
|
|
|
|
# symbol_version-model_version-if_lora_v3
|
|
|
@ -899,7 +889,7 @@ def get_tts_wav(
|
|
|
|
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
|
|
|
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
|
|
|
tgt_sr = 24000 if model_version == "v3" else 32000
|
|
|
|
tgt_sr = 24000 if model_version == "v3" else 32000
|
|
|
|
if sr != tgt_sr:
|
|
|
|
if sr != tgt_sr:
|
|
|
|
ref_audio = resample(ref_audio, sr, tgt_sr)
|
|
|
|
ref_audio = resample(ref_audio, sr, tgt_sr,device)
|
|
|
|
# print("ref_audio",ref_audio.abs().mean())
|
|
|
|
# print("ref_audio",ref_audio.abs().mean())
|
|
|
|
mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio)
|
|
|
|
mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio)
|
|
|
|
mel2 = norm_spec(mel2)
|
|
|
|
mel2 = norm_spec(mel2)
|
|
|
|