|
|
@ -331,27 +331,29 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|
|
|
int(hps.data.sampling_rate * 0.3),
|
|
|
|
int(hps.data.sampling_rate * 0.3),
|
|
|
|
dtype=np.float16 if is_half == True else np.float32,
|
|
|
|
dtype=np.float16 if is_half == True else np.float32,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
with torch.no_grad():
|
|
|
|
if not ref_free:
|
|
|
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
|
|
|
with torch.no_grad():
|
|
|
|
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
|
|
|
|
wav16k, sr = librosa.load(ref_wav_path, sr=16000)
|
|
|
|
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
|
|
|
if (wav16k.shape[0] > 160000 or wav16k.shape[0] < 48000):
|
|
|
|
wav16k = torch.from_numpy(wav16k)
|
|
|
|
raise OSError(i18n("参考音频在3~10秒范围外,请更换!"))
|
|
|
|
zero_wav_torch = torch.from_numpy(zero_wav)
|
|
|
|
wav16k = torch.from_numpy(wav16k)
|
|
|
|
if is_half == True:
|
|
|
|
zero_wav_torch = torch.from_numpy(zero_wav)
|
|
|
|
wav16k = wav16k.half().to(device)
|
|
|
|
if is_half == True:
|
|
|
|
zero_wav_torch = zero_wav_torch.half().to(device)
|
|
|
|
wav16k = wav16k.half().to(device)
|
|
|
|
else:
|
|
|
|
zero_wav_torch = zero_wav_torch.half().to(device)
|
|
|
|
wav16k = wav16k.to(device)
|
|
|
|
else:
|
|
|
|
zero_wav_torch = zero_wav_torch.to(device)
|
|
|
|
wav16k = wav16k.to(device)
|
|
|
|
wav16k = torch.cat([wav16k, zero_wav_torch])
|
|
|
|
zero_wav_torch = zero_wav_torch.to(device)
|
|
|
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
|
|
|
wav16k = torch.cat([wav16k, zero_wav_torch])
|
|
|
|
"last_hidden_state"
|
|
|
|
ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
|
|
|
|
].transpose(
|
|
|
|
"last_hidden_state"
|
|
|
|
1, 2
|
|
|
|
].transpose(
|
|
|
|
) # .float()
|
|
|
|
1, 2
|
|
|
|
codes = vq_model.extract_latent(ssl_content)
|
|
|
|
) # .float()
|
|
|
|
|
|
|
|
codes = vq_model.extract_latent(ssl_content)
|
|
|
|
prompt_semantic = codes[0, 0]
|
|
|
|
prompt_semantic = codes[0, 0]
|
|
|
|
|
|
|
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
t1 = ttime()
|
|
|
|
t1 = ttime()
|
|
|
|
|
|
|
|
|
|
|
|
if (how_to_cut == i18n("凑四句一切")):
|
|
|
|
if (how_to_cut == i18n("凑四句一切")):
|
|
|
@ -391,7 +393,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|
|
|
|
|
|
|
|
|
|
|
bert = bert.to(device).unsqueeze(0)
|
|
|
|
bert = bert.to(device).unsqueeze(0)
|
|
|
|
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
|
|
|
all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
|
|
|
|
prompt = prompt_semantic.unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
t2 = ttime()
|
|
|
|
t2 = ttime()
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|
# pred_semantic = t2s_model.model.infer(
|
|
|
|
# pred_semantic = t2s_model.model.infer(
|
|
|
|