|
|
@ -11,7 +11,7 @@
|
|
|
|
调用请求缺少参考音频时使用
|
|
|
|
调用请求缺少参考音频时使用
|
|
|
|
`-dr` - `默认参考音频路径`
|
|
|
|
`-dr` - `默认参考音频路径`
|
|
|
|
`-dt` - `默认参考音频文本`
|
|
|
|
`-dt` - `默认参考音频文本`
|
|
|
|
`-dl` - `默认参考音频语种, "中文","英文","日文","zh","en","ja"`
|
|
|
|
`-dl` - `默认参考音频语种, "中文","英文","日文","韩文","粤语,"zh","en","ja","ko","yue"`
|
|
|
|
|
|
|
|
|
|
|
|
`-d` - `推理设备, "cuda","cpu"`
|
|
|
|
`-d` - `推理设备, "cuda","cpu"`
|
|
|
|
`-a` - `绑定地址, 默认"127.0.0.1"`
|
|
|
|
`-a` - `绑定地址, 默认"127.0.0.1"`
|
|
|
@ -201,6 +201,11 @@ def change_sovits_weights(sovits_path):
|
|
|
|
hps = dict_s2["config"]
|
|
|
|
hps = dict_s2["config"]
|
|
|
|
hps = DictToAttrRecursive(hps)
|
|
|
|
hps = DictToAttrRecursive(hps)
|
|
|
|
hps.model.semantic_frame_rate = "25hz"
|
|
|
|
hps.model.semantic_frame_rate = "25hz"
|
|
|
|
|
|
|
|
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
|
|
|
|
|
|
|
hps.model.version = "v1"
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
hps.model.version = "v2"
|
|
|
|
|
|
|
|
print("sovits版本:",hps.model.version)
|
|
|
|
model_params_dict = vars(hps.model)
|
|
|
|
model_params_dict = vars(hps.model)
|
|
|
|
vq_model = SynthesizerTrn(
|
|
|
|
vq_model = SynthesizerTrn(
|
|
|
|
hps.data.filter_length // 2 + 1,
|
|
|
|
hps.data.filter_length // 2 + 1,
|
|
|
@ -251,9 +256,9 @@ def get_bert_feature(text, word2ph):
|
|
|
|
return phone_level_feature.T
|
|
|
|
return phone_level_feature.T
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clean_text_inf(text, language):
|
|
|
|
def clean_text_inf(text, language, version):
|
|
|
|
phones, word2ph, norm_text = clean_text(text, language)
|
|
|
|
phones, word2ph, norm_text = clean_text(text, language, version)
|
|
|
|
phones = cleaned_text_to_sequence(phones)
|
|
|
|
phones = cleaned_text_to_sequence(phones, version)
|
|
|
|
return phones, word2ph, norm_text
|
|
|
|
return phones, word2ph, norm_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -269,36 +274,48 @@ def get_bert_inf(phones, word2ph, norm_text, language):
|
|
|
|
|
|
|
|
|
|
|
|
return bert
|
|
|
|
return bert
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from text import chinese
|
|
|
|
def get_phones_and_bert(text,language):
|
|
|
|
def get_phones_and_bert(text,language,version):
|
|
|
|
if language in {"en","all_zh","all_ja"}:
|
|
|
|
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
|
|
|
language = language.replace("all_","")
|
|
|
|
language = language.replace("all_","")
|
|
|
|
if language == "en":
|
|
|
|
if language == "en":
|
|
|
|
LangSegment.setfilters(["en"])
|
|
|
|
LangSegment.setfilters(["en"])
|
|
|
|
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
|
|
|
formattext = " ".join(tmp["text"] for tmp in LangSegment.getTexts(text))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# 因无法区别中日文汉字,以用户输入为准
|
|
|
|
# 因无法区别中日韩文汉字,以用户输入为准
|
|
|
|
formattext = text
|
|
|
|
formattext = text
|
|
|
|
while " " in formattext:
|
|
|
|
while " " in formattext:
|
|
|
|
formattext = formattext.replace(" ", " ")
|
|
|
|
formattext = formattext.replace(" ", " ")
|
|
|
|
phones, word2ph, norm_text = clean_text_inf(formattext, language)
|
|
|
|
|
|
|
|
if language == "zh":
|
|
|
|
if language == "zh":
|
|
|
|
|
|
|
|
if re.search(r'[A-Za-z]', formattext):
|
|
|
|
|
|
|
|
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
|
|
|
|
|
|
|
formattext = chinese.text_normalize(formattext)
|
|
|
|
|
|
|
|
return get_phones_and_bert(formattext,"zh",version)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
|
|
|
bert = get_bert_feature(norm_text, word2ph).to(device)
|
|
|
|
bert = get_bert_feature(norm_text, word2ph).to(device)
|
|
|
|
|
|
|
|
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
|
|
|
|
|
|
|
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
|
|
|
|
|
|
|
formattext = chinese.text_normalize(formattext)
|
|
|
|
|
|
|
|
return get_phones_and_bert(formattext,"yue",version)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
phones, word2ph, norm_text = clean_text_inf(formattext, language, version)
|
|
|
|
bert = torch.zeros(
|
|
|
|
bert = torch.zeros(
|
|
|
|
(1024, len(phones)),
|
|
|
|
(1024, len(phones)),
|
|
|
|
dtype=torch.float16 if is_half == True else torch.float32,
|
|
|
|
dtype=torch.float16 if is_half == True else torch.float32,
|
|
|
|
).to(device)
|
|
|
|
).to(device)
|
|
|
|
elif language in {"zh", "ja","auto"}:
|
|
|
|
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
|
|
|
textlist=[]
|
|
|
|
textlist=[]
|
|
|
|
langlist=[]
|
|
|
|
langlist=[]
|
|
|
|
LangSegment.setfilters(["zh","ja","en","ko"])
|
|
|
|
LangSegment.setfilters(["zh","ja","en","ko"])
|
|
|
|
if language == "auto":
|
|
|
|
if language == "auto":
|
|
|
|
for tmp in LangSegment.getTexts(text):
|
|
|
|
for tmp in LangSegment.getTexts(text):
|
|
|
|
if tmp["lang"] == "ko":
|
|
|
|
langlist.append(tmp["lang"])
|
|
|
|
langlist.append("zh")
|
|
|
|
|
|
|
|
textlist.append(tmp["text"])
|
|
|
|
textlist.append(tmp["text"])
|
|
|
|
else:
|
|
|
|
elif language == "auto_yue":
|
|
|
|
|
|
|
|
for tmp in LangSegment.getTexts(text):
|
|
|
|
|
|
|
|
if tmp["lang"] == "zh":
|
|
|
|
|
|
|
|
tmp["lang"] = "yue"
|
|
|
|
langlist.append(tmp["lang"])
|
|
|
|
langlist.append(tmp["lang"])
|
|
|
|
textlist.append(tmp["text"])
|
|
|
|
textlist.append(tmp["text"])
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -306,17 +323,15 @@ def get_phones_and_bert(text,language):
|
|
|
|
if tmp["lang"] == "en":
|
|
|
|
if tmp["lang"] == "en":
|
|
|
|
langlist.append(tmp["lang"])
|
|
|
|
langlist.append(tmp["lang"])
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# 因无法区别中日文汉字,以用户输入为准
|
|
|
|
# 因无法区别中日韩文汉字,以用户输入为准
|
|
|
|
langlist.append(language)
|
|
|
|
langlist.append(language)
|
|
|
|
textlist.append(tmp["text"])
|
|
|
|
textlist.append(tmp["text"])
|
|
|
|
# logger.info(textlist)
|
|
|
|
|
|
|
|
# logger.info(langlist)
|
|
|
|
|
|
|
|
phones_list = []
|
|
|
|
phones_list = []
|
|
|
|
bert_list = []
|
|
|
|
bert_list = []
|
|
|
|
norm_text_list = []
|
|
|
|
norm_text_list = []
|
|
|
|
for i in range(len(textlist)):
|
|
|
|
for i in range(len(textlist)):
|
|
|
|
lang = langlist[i]
|
|
|
|
lang = langlist[i]
|
|
|
|
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
|
|
|
|
phones, word2ph, norm_text = clean_text_inf(textlist[i], lang, version)
|
|
|
|
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
|
|
|
bert = get_bert_inf(phones, word2ph, norm_text, lang)
|
|
|
|
phones_list.append(phones)
|
|
|
|
phones_list.append(phones)
|
|
|
|
norm_text_list.append(norm_text)
|
|
|
|
norm_text_list.append(norm_text)
|
|
|
@ -328,15 +343,33 @@ def get_phones_and_bert(text,language):
|
|
|
|
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
|
|
|
|
return phones,bert.to(torch.float16 if is_half == True else torch.float32),norm_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DictToAttrRecursive:
|
|
|
|
class DictToAttrRecursive(dict):
|
|
|
|
def __init__(self, input_dict):
|
|
|
|
def __init__(self, input_dict):
|
|
|
|
|
|
|
|
super().__init__(input_dict)
|
|
|
|
for key, value in input_dict.items():
|
|
|
|
for key, value in input_dict.items():
|
|
|
|
if isinstance(value, dict):
|
|
|
|
if isinstance(value, dict):
|
|
|
|
# 如果值是字典,递归调用构造函数
|
|
|
|
value = DictToAttrRecursive(value)
|
|
|
|
setattr(self, key, DictToAttrRecursive(value))
|
|
|
|
self[key] = value
|
|
|
|
else:
|
|
|
|
|
|
|
|
setattr(self, key, value)
|
|
|
|
setattr(self, key, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __getattr__(self, item):
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
return self[item]
|
|
|
|
|
|
|
|
except KeyError:
|
|
|
|
|
|
|
|
raise AttributeError(f"Attribute {item} not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __setattr__(self, key, value):
|
|
|
|
|
|
|
|
if isinstance(value, dict):
|
|
|
|
|
|
|
|
value = DictToAttrRecursive(value)
|
|
|
|
|
|
|
|
super(DictToAttrRecursive, self).__setitem__(key, value)
|
|
|
|
|
|
|
|
super().__setattr__(key, value)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __delattr__(self, item):
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
del self[item]
|
|
|
|
|
|
|
|
except KeyError:
|
|
|
|
|
|
|
|
raise AttributeError(f"Attribute {item} not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_spepc(hps, filename):
|
|
|
|
def get_spepc(hps, filename):
|
|
|
|
audio = load_audio(filename, int(hps.data.sampling_rate))
|
|
|
|
audio = load_audio(filename, int(hps.data.sampling_rate))
|
|
|
@ -488,9 +521,10 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|
|
|
codes = vq_model.extract_latent(ssl_content)
|
|
|
|
codes = vq_model.extract_latent(ssl_content)
|
|
|
|
prompt_semantic = codes[0, 0]
|
|
|
|
prompt_semantic = codes[0, 0]
|
|
|
|
t1 = ttime()
|
|
|
|
t1 = ttime()
|
|
|
|
|
|
|
|
version = vq_model.version
|
|
|
|
prompt_language = dict_language[prompt_language.lower()]
|
|
|
|
prompt_language = dict_language[prompt_language.lower()]
|
|
|
|
text_language = dict_language[text_language.lower()]
|
|
|
|
text_language = dict_language[text_language.lower()]
|
|
|
|
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language)
|
|
|
|
phones1, bert1, norm_text1 = get_phones_and_bert(prompt_text, prompt_language, version)
|
|
|
|
texts = text.split("\n")
|
|
|
|
texts = text.split("\n")
|
|
|
|
audio_bytes = BytesIO()
|
|
|
|
audio_bytes = BytesIO()
|
|
|
|
|
|
|
|
|
|
|
@ -500,7 +534,7 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
audio_opt = []
|
|
|
|
audio_opt = []
|
|
|
|
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language)
|
|
|
|
phones2, bert2, norm_text2 = get_phones_and_bert(text, text_language, version)
|
|
|
|
bert = torch.cat([bert1, bert2], 1)
|
|
|
|
bert = torch.cat([bert1, bert2], 1)
|
|
|
|
|
|
|
|
|
|
|
|
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
|
|
|
all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
|
|
|
@ -606,17 +640,27 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
|
|
|
|
# --------------------------------
|
|
|
|
# --------------------------------
|
|
|
|
dict_language = {
|
|
|
|
dict_language = {
|
|
|
|
"中文": "all_zh",
|
|
|
|
"中文": "all_zh",
|
|
|
|
|
|
|
|
"粤语": "all_yue",
|
|
|
|
"英文": "en",
|
|
|
|
"英文": "en",
|
|
|
|
"日文": "all_ja",
|
|
|
|
"日文": "all_ja",
|
|
|
|
|
|
|
|
"韩文": "all_ko",
|
|
|
|
"中英混合": "zh",
|
|
|
|
"中英混合": "zh",
|
|
|
|
|
|
|
|
"粤英混合": "yue",
|
|
|
|
"日英混合": "ja",
|
|
|
|
"日英混合": "ja",
|
|
|
|
|
|
|
|
"韩英混合": "ko",
|
|
|
|
"多语种混合": "auto", #多语种启动切分识别语种
|
|
|
|
"多语种混合": "auto", #多语种启动切分识别语种
|
|
|
|
|
|
|
|
"多语种混合(粤语)": "auto_yue",
|
|
|
|
"all_zh": "all_zh",
|
|
|
|
"all_zh": "all_zh",
|
|
|
|
|
|
|
|
"all_yue": "all_yue",
|
|
|
|
"en": "en",
|
|
|
|
"en": "en",
|
|
|
|
"all_ja": "all_ja",
|
|
|
|
"all_ja": "all_ja",
|
|
|
|
|
|
|
|
"all_ko": "all_ko",
|
|
|
|
"zh": "zh",
|
|
|
|
"zh": "zh",
|
|
|
|
|
|
|
|
"yue": "yue",
|
|
|
|
"ja": "ja",
|
|
|
|
"ja": "ja",
|
|
|
|
|
|
|
|
"ko": "ko",
|
|
|
|
"auto": "auto",
|
|
|
|
"auto": "auto",
|
|
|
|
|
|
|
|
"auto_yue": "auto_yue",
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# logger
|
|
|
|
# logger
|
|
|
|