|
|
@ -4,12 +4,30 @@ logging.getLogger("urllib3").setLevel(logging.ERROR)
|
|
|
|
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
|
|
|
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
|
|
|
logging.getLogger("httpx").setLevel(logging.ERROR)
|
|
|
|
logging.getLogger("httpx").setLevel(logging.ERROR)
|
|
|
|
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
|
|
|
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
|
|
|
|
|
|
|
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
|
|
|
import pdb
|
|
|
|
import pdb
|
|
|
|
|
|
|
|
|
|
|
|
gpt_path = os.environ.get(
|
|
|
|
if os.path.exists("./gweight.txt"):
|
|
|
|
"gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
|
|
|
with open("./gweight.txt", 'r',encoding="utf-8") as file:
|
|
|
|
)
|
|
|
|
gweight_data = file.read()
|
|
|
|
sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
|
|
|
gpt_path = os.environ.get(
|
|
|
|
|
|
|
|
"gpt_path", gweight_data)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
gpt_path = os.environ.get(
|
|
|
|
|
|
|
|
"gpt_path", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if os.path.exists("./sweight.txt"):
|
|
|
|
|
|
|
|
with open("./sweight.txt", 'r',encoding="utf-8") as file:
|
|
|
|
|
|
|
|
sweight_data = file.read()
|
|
|
|
|
|
|
|
sovits_path = os.environ.get("sovits_path", sweight_data)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
sovits_path = os.environ.get("sovits_path", "GPT_SoVITS/pretrained_models/s2G488k.pth")
|
|
|
|
|
|
|
|
# gpt_path = os.environ.get(
|
|
|
|
|
|
|
|
# "gpt_path", "pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"
|
|
|
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
# sovits_path = os.environ.get("sovits_path", "pretrained_models/s2G488k.pth")
|
|
|
|
cnhubert_base_path = os.environ.get(
|
|
|
|
cnhubert_base_path = os.environ.get(
|
|
|
|
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
|
|
|
"cnhubert_base_path", "pretrained_models/chinese-hubert-base"
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -60,7 +78,7 @@ def get_bert_feature(text, word2ph):
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|
inputs = tokenizer(text, return_tensors="pt")
|
|
|
|
inputs = tokenizer(text, return_tensors="pt")
|
|
|
|
for i in inputs:
|
|
|
|
for i in inputs:
|
|
|
|
inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
|
|
|
|
inputs[i] = inputs[i].to(device)
|
|
|
|
res = bert_model(**inputs, output_hidden_states=True)
|
|
|
|
res = bert_model(**inputs, output_hidden_states=True)
|
|
|
|
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
|
|
|
res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
|
|
|
|
assert len(word2ph) == len(text)
|
|
|
|
assert len(word2ph) == len(text)
|
|
|
@ -124,6 +142,7 @@ def change_sovits_weights(sovits_path):
|
|
|
|
vq_model = vq_model.to(device)
|
|
|
|
vq_model = vq_model.to(device)
|
|
|
|
vq_model.eval()
|
|
|
|
vq_model.eval()
|
|
|
|
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
|
|
|
print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
|
|
|
|
|
|
|
|
with open("./sweight.txt","w",encoding="utf-8")as f:f.write(sovits_path)
|
|
|
|
change_sovits_weights(sovits_path)
|
|
|
|
change_sovits_weights(sovits_path)
|
|
|
|
|
|
|
|
|
|
|
|
def change_gpt_weights(gpt_path):
|
|
|
|
def change_gpt_weights(gpt_path):
|
|
|
@ -140,6 +159,7 @@ def change_gpt_weights(gpt_path):
|
|
|
|
t2s_model.eval()
|
|
|
|
t2s_model.eval()
|
|
|
|
total = sum([param.nelement() for param in t2s_model.parameters()])
|
|
|
|
total = sum([param.nelement() for param in t2s_model.parameters()])
|
|
|
|
print("Number of parameter: %.2fM" % (total / 1e6))
|
|
|
|
print("Number of parameter: %.2fM" % (total / 1e6))
|
|
|
|
|
|
|
|
with open("./gweight.txt","w",encoding="utf-8")as f:f.write(gpt_path)
|
|
|
|
change_gpt_weights(gpt_path)
|
|
|
|
change_gpt_weights(gpt_path)
|
|
|
|
|
|
|
|
|
|
|
|
def get_spepc(hps, filename):
|
|
|
|
def get_spepc(hps, filename):
|
|
|
@ -188,19 +208,19 @@ def splite_en_inf(sentence, language):
|
|
|
|
def clean_text_inf(text, language):
|
|
|
|
def clean_text_inf(text, language):
|
|
|
|
phones, word2ph, norm_text = clean_text(text, language)
|
|
|
|
phones, word2ph, norm_text = clean_text(text, language)
|
|
|
|
phones = cleaned_text_to_sequence(phones)
|
|
|
|
phones = cleaned_text_to_sequence(phones)
|
|
|
|
|
|
|
|
|
|
|
|
return phones, word2ph, norm_text
|
|
|
|
return phones, word2ph, norm_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_bert_inf(phones, word2ph, norm_text, language):
|
|
|
|
def get_bert_inf(phones, word2ph, norm_text, language):
|
|
|
|
if language == "zh":
|
|
|
|
if language == "zh":
|
|
|
|
bert = get_bert_feature(norm_text, word2ph).to(device)
|
|
|
|
bert = get_bert_feature(norm_text, word2ph).to(device)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
bert = torch.zeros(
|
|
|
|
bert = torch.zeros(
|
|
|
|
(1024, len(phones)),
|
|
|
|
(1024, len(phones)),
|
|
|
|
dtype=torch.float16 if is_half == True else torch.float32,
|
|
|
|
dtype=torch.float16 if is_half == True else torch.float32,
|
|
|
|
).to(device)
|
|
|
|
).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
return bert
|
|
|
|
return bert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -213,7 +233,7 @@ def nonen_clean_text_inf(text, language):
|
|
|
|
lang = langlist[i]
|
|
|
|
lang = langlist[i]
|
|
|
|
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
|
|
|
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
|
|
|
phones_list.append(phones)
|
|
|
|
phones_list.append(phones)
|
|
|
|
if lang=="en" or "ja":
|
|
|
|
if lang == "en" or "ja":
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
word2ph_list.append(word2ph)
|
|
|
|
word2ph_list.append(word2ph)
|
|
|
@ -222,7 +242,7 @@ def nonen_clean_text_inf(text, language):
|
|
|
|
phones = sum(phones_list, [])
|
|
|
|
phones = sum(phones_list, [])
|
|
|
|
word2ph = sum(word2ph_list, [])
|
|
|
|
word2ph = sum(word2ph_list, [])
|
|
|
|
norm_text = ' '.join(norm_text_list)
|
|
|
|
norm_text = ' '.join(norm_text_list)
|
|
|
|
|
|
|
|
|
|
|
|
return phones, word2ph, norm_text
|
|
|
|
return phones, word2ph, norm_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -238,7 +258,7 @@ def nonen_get_bert_inf(text, language):
|
|
|
|
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
|
|
|
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
|
|
|
bert_list.append(bert)
|
|
|
|
bert_list.append(bert)
|
|
|
|
bert = torch.cat(bert_list, dim=1)
|
|
|
|
bert = torch.cat(bert_list, dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
return bert
|
|
|
|
return bert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -271,6 +291,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|
|
|
t1 = ttime()
|
|
|
|
t1 = ttime()
|
|
|
|
prompt_language = dict_language[prompt_language]
|
|
|
|
prompt_language = dict_language[prompt_language]
|
|
|
|
text_language = dict_language[text_language]
|
|
|
|
text_language = dict_language[text_language]
|
|
|
|
|
|
|
|
|
|
|
|
if prompt_language == "en":
|
|
|
|
if prompt_language == "en":
|
|
|
|
phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language)
|
|
|
|
phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -281,21 +302,21 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language)
|
|
|
|
bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language)
|
|
|
|
bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
bert1 = nonen_get_bert_inf(prompt_text, prompt_language)
|
|
|
|
bert1 = nonen_get_bert_inf(prompt_text, prompt_language)
|
|
|
|
|
|
|
|
|
|
|
|
for text in texts:
|
|
|
|
for text in texts:
|
|
|
|
# 解决输入目标文本的空行导致报错的问题
|
|
|
|
# 解决输入目标文本的空行导致报错的问题
|
|
|
|
if (len(text.strip()) == 0):
|
|
|
|
if (len(text.strip()) == 0):
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
if text_language == "en":
|
|
|
|
if text_language == "en":
|
|
|
|
phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language)
|
|
|
|
phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language)
|
|
|
|
phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language)
|
|
|
|
|
|
|
|
|
|
|
|
if text_language == "en":
|
|
|
|
if text_language == "en":
|
|
|
|
bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language)
|
|
|
|
bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
bert2 = nonen_get_bert_inf(text, text_language)
|
|
|
|
bert2 = nonen_get_bert_inf(text, text_language)
|
|
|
|
|
|
|
|
|
|
|
|
bert = torch.cat([bert1, bert2], 1)
|
|
|
|
bert = torch.cat([bert1, bert2], 1)
|
|
|
|
|
|
|
|
|
|
|
|
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
|
|
|
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
|
|
|