|
|
@ -331,6 +331,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|
|
|
int(hps.data.sampling_rate * 0.3),
|
|
|
|
int(hps.data.sampling_rate * 0.3),
|
|
|
|
dtype=np.float16 if is_half == True else np.float32,
|
|
|
|
dtype=np.float16 if is_half == True else np.float32,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
if not ref_free:
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
|
|
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
|
|
|
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
|
|
|
|
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
|
|
|
@ -350,8 +351,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|
|
|
1, 2
|
|
|
|
1, 2
|
|
|
|
) # .float()
|
|
|
|
) # .float()
|
|
|
|
codes = vq_model.extract_latent(ssl_content)
|
|
|
|
codes = vq_model.extract_latent(ssl_content)
|
|
|
|
|
|
|
|
|
|
|
|
prompt_semantic = codes[0, 0]
|
|
|
|
prompt_semantic = codes[0, 0]
|
|
|
|
|
|
|
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
t1 = ttime()
|
|
|
|
t1 = ttime()
|
|
|
|
|
|
|
|
|
|
|
|
if (how_to_cut == i18n("凑四句一切")):
|
|
|
|
if (how_to_cut == i18n("凑四句一切")):
|
|
|
@ -391,7 +393,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|
|
|
|
|
|
|
|
|
|
|
bert = bert.to(device).unsqueeze(0)
|
|
|
|
bert = bert.to(device).unsqueeze(0)
|
|
|
|
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
|
|
|
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
|
|
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
t2 = ttime()
|
|
|
|
t2 = ttime()
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|
# pred_semantic = t2s_model.model.infer(
|
|
|
|
# pred_semantic = t2s_model.model.infer(
|
|
|
|