@ -1,5 +1,6 @@
import os , sys
import threading
from tqdm import tqdm
now_dir = os . getcwd ( )
@ -54,6 +55,7 @@ class TextPreprocessor:
self . bert_model = bert_model
self . tokenizer = tokenizer
self . device = device
self . bert_lock = threading . RLock ( )
def preprocess ( self , text : str , lang : str , text_split_method : str , version : str = " v2 " ) - > List [ Dict ] :
print ( f ' ############ { i18n ( " 切分文本 " ) } ############ ' )
@ -117,70 +119,71 @@ class TextPreprocessor:
return self . get_phones_and_bert ( text , language , version )
def get_phones_and_bert ( self , text : str , language : str , version : str , final : bool = False ) :
if language in { " en " , " all_zh " , " all_ja " , " all_ko " , " all_yue " } :
# language = language.replace("all_","")
formattext = text
while " " in formattext :
formattext = formattext . replace ( " " , " " )
if language == " all_zh " :
if re . search ( r ' [A-Za-z] ' , formattext ) :
formattext = re . sub ( r ' [a-z] ' , lambda x : x . group ( 0 ) . upper ( ) , formattext )
formattext = chinese . mix_text_normalize ( formattext )
return self . get_phones_and_bert ( formattext , " zh " , version )
else :
phones , word2ph , norm_text = self . clean_text_inf ( formattext , language , version )
bert = self . get_bert_feature ( norm_text , word2ph ) . to ( self . device )
elif language == " all_yue " and re . search ( r ' [A-Za-z] ' , formattext ) :
formattext = re . sub ( r ' [a-z] ' , lambda x : x . group ( 0 ) . upper ( ) , formattext )
formattext = chinese . mix_text_normalize ( formattext )
return self . get_phones_and_bert ( formattext , " yue " , version )
else :
phones , word2ph , norm_text = self . clean_text_inf ( formattext , language , version )
bert = torch . zeros (
( 1024 , len ( phones ) ) ,
dtype = torch . float32 ,
) . to ( self . device )
elif language in { " zh " , " ja " , " ko " , " yue " , " auto " , " auto_yue " } :
textlist = [ ]
langlist = [ ]
if language == " auto " :
for tmp in LangSegmenter . getTexts ( text ) :
langlist . append ( tmp [ " lang " ] )
textlist . append ( tmp [ " text " ] )
elif language == " auto_yue " :
for tmp in LangSegmenter . getTexts ( text ) :
if tmp [ " lang " ] == " zh " :
tmp [ " lang " ] = " yue "
langlist . append ( tmp [ " lang " ] )
textlist . append ( tmp [ " text " ] )
else :
for tmp in LangSegmenter . getTexts ( text ) :
if tmp [ " lang " ] == " en " :
langlist . append ( tmp [ " lang " ] )
else :
# 因无法区别中日韩文汉字,以用户输入为准
langlist . append ( language )
textlist . append ( tmp [ " text " ] )
# print(textlist)
# print(langlist)
phones_list = [ ]
bert_list = [ ]
norm_text_list = [ ]
for i in range ( len ( textlist ) ) :
lang = langlist [ i ]
phones , word2ph , norm_text = self . clean_text_inf ( textlist [ i ] , lang , version )
bert = self . get_bert_inf ( phones , word2ph , norm_text , lang )
phones_list . append ( phones )
norm_text_list . append ( norm_text )
bert_list . append ( bert )
bert = torch . cat ( bert_list , dim = 1 )
phones = sum ( phones_list , [ ] )
norm_text = ' ' . join ( norm_text_list )
if not final and len ( phones ) < 6 :
return self . get_phones_and_bert ( " . " + text , language , version , final = True )
return phones , bert , norm_text
with self . bert_lock :
if language in { " en " , " all_zh " , " all_ja " , " all_ko " , " all_yue " } :
# language = language.replace("all_","")
formattext = text
while " " in formattext :
formattext = formattext . replace ( " " , " " )
if language == " all_zh " :
if re . search ( r ' [A-Za-z] ' , formattext ) :
formattext = re . sub ( r ' [a-z] ' , lambda x : x . group ( 0 ) . upper ( ) , formattext )
formattext = chinese . mix_text_normalize ( formattext )
return self . get_phones_and_bert ( formattext , " zh " , version )
else :
phones , word2ph , norm_text = self . clean_text_inf ( formattext , language , version )
bert = self . get_bert_feature ( norm_text , word2ph ) . to ( self . device )
elif language == " all_yue " and re . search ( r ' [A-Za-z] ' , formattext ) :
formattext = re . sub ( r ' [a-z] ' , lambda x : x . group ( 0 ) . upper ( ) , formattext )
formattext = chinese . mix_text_normalize ( formattext )
return self . get_phones_and_bert ( formattext , " yue " , version )
else :
phones , word2ph , norm_text = self . clean_text_inf ( formattext , language , version )
bert = torch . zeros (
( 1024 , len ( phones ) ) ,
dtype = torch . float32 ,
) . to ( self . device )
elif language in { " zh " , " ja " , " ko " , " yue " , " auto " , " auto_yue " } :
textlist = [ ]
langlist = [ ]
if language == " auto " :
for tmp in LangSegmenter . getTexts ( text ) :
langlist . append ( tmp [ " lang " ] )
textlist . append ( tmp [ " text " ] )
elif language == " auto_yue " :
for tmp in LangSegmenter . getTexts ( text ) :
if tmp [ " lang " ] == " zh " :
tmp [ " lang " ] = " yue "
langlist . append ( tmp [ " lang " ] )
textlist . append ( tmp [ " text " ] )
else :
for tmp in LangSegmenter . getTexts ( text ) :
if tmp [ " lang " ] == " en " :
langlist . append ( tmp [ " lang " ] )
else :
# 因无法区别中日韩文汉字,以用户输入为准
langlist . append ( language )
textlist . append ( tmp [ " text " ] )
# print(textlist)
# print(langlist)
phones_list = [ ]
bert_list = [ ]
norm_text_list = [ ]
for i in range ( len ( textlist ) ) :
lang = langlist [ i ]
phones , word2ph , norm_text = self . clean_text_inf ( textlist [ i ] , lang , version )
bert = self . get_bert_inf ( phones , word2ph , norm_text , lang )
phones_list . append ( phones )
norm_text_list . append ( norm_text )
bert_list . append ( bert )
bert = torch . cat ( bert_list , dim = 1 )
phones = sum ( phones_list , [ ] )
norm_text = ' ' . join ( norm_text_list )
if not final and len ( phones ) < 6 :
return self . get_phones_and_bert ( " . " + text , language , version , final = True )
return phones , bert , norm_text
def get_bert_feature ( self , text : str , word2ph : list ) - > torch . Tensor :