@ -4,6 +4,7 @@ import os, sys, gc
import random
import traceback
import torchaudio
from tqdm import tqdm
now_dir = os . getcwd ( )
sys . path . append ( now_dir )
@ -15,10 +16,11 @@ import torch
import torch . nn . functional as F
import yaml
from transformers import AutoModelForMaskedLM , AutoTokenizer
from tools . audio_sr import AP_BWE
from AR . models . t2s_lightning_module import Text2SemanticLightningModule
from feature_extractor . cnhubert import CNHubert
from module . models import SynthesizerTrn
from module . models import SynthesizerTrn , SynthesizerTrnV3
from peft import LoraConfig , get_peft_model
import librosa
from time import time as ttime
from tools . i18n . i18n import I18nAuto , scan_language_list
@ -26,10 +28,98 @@ from tools.my_utils import load_audio
from module . mel_processing import spectrogram_torch
from TTS_infer_pack . text_segmentation_method import splits
from TTS_infer_pack . TextPreprocessor import TextPreprocessor
from BigVGAN . bigvgan import BigVGAN
from module . mel_processing import spectrogram_torch , mel_spectrogram_torch
from process_ckpt import get_sovits_version_from_path_fast , load_sovits_new
language = os . environ . get ( " language " , " Auto " )
language = sys . argv [ - 1 ] if sys . argv [ - 1 ] in scan_language_list ( ) else language
i18n = I18nAuto ( language = language )
spec_min = - 12
spec_max = 2
def norm_spec ( x ) :
return ( x - spec_min ) / ( spec_max - spec_min ) * 2 - 1
def denorm_spec ( x ) :
return ( x + 1 ) / 2 * ( spec_max - spec_min ) + spec_min
mel_fn = lambda x : mel_spectrogram_torch ( x , * * {
" n_fft " : 1024 ,
" win_size " : 1024 ,
" hop_size " : 256 ,
" num_mels " : 100 ,
" sampling_rate " : 24000 ,
" fmin " : 0 ,
" fmax " : None ,
" center " : False
} )
def speed_change ( input_audio : np . ndarray , speed : float , sr : int ) :
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio . astype ( np . int16 ) . tobytes ( )
# 设置 ffmpeg 输入流
input_stream = ffmpeg . input ( ' pipe: ' , format = ' s16le ' , acodec = ' pcm_s16le ' , ar = str ( sr ) , ac = 1 )
# 变速处理
output_stream = input_stream . filter ( ' atempo ' , speed )
# 输出流到管道
out , _ = (
output_stream . output ( ' pipe: ' , format = ' s16le ' , acodec = ' pcm_s16le ' )
. run ( input = raw_audio , capture_stdout = True , capture_stderr = True )
)
# 将管道输出解码为 NumPy 数组
processed_audio = np . frombuffer ( out , np . int16 )
return processed_audio
resample_transform_dict = { }
def resample ( audio_tensor , sr0 , device ) :
global resample_transform_dict
if sr0 not in resample_transform_dict :
resample_transform_dict [ sr0 ] = torchaudio . transforms . Resample (
sr0 , 24000
) . to ( device )
return resample_transform_dict [ sr0 ] ( audio_tensor )
class DictToAttrRecursive ( dict ) :
def __init__ ( self , input_dict ) :
super ( ) . __init__ ( input_dict )
for key , value in input_dict . items ( ) :
if isinstance ( value , dict ) :
value = DictToAttrRecursive ( value )
self [ key ] = value
setattr ( self , key , value )
def __getattr__ ( self , item ) :
try :
return self [ item ]
except KeyError :
raise AttributeError ( f " Attribute { item } not found " )
def __setattr__ ( self , key , value ) :
if isinstance ( value , dict ) :
value = DictToAttrRecursive ( value )
super ( DictToAttrRecursive , self ) . __setitem__ ( key , value )
super ( ) . __setattr__ ( key , value )
def __delattr__ ( self , item ) :
try :
del self [ item ]
except KeyError :
raise AttributeError ( f " Attribute { item } not found " )
class NO_PROMPT_ERROR ( Exception ) :
pass
# configs/tts_infer.yaml
"""
custom :
@ -56,11 +146,19 @@ default_v2:
t2s_weights_path : GPT_SoVITS / pretrained_models / gsv - v2final - pretrained / s1bert25hz - 5 kh - longer - epoch = 12 - step = 369668. ckpt
vits_weights_path : GPT_SoVITS / pretrained_models / gsv - v2final - pretrained / s2G2333k . pth
version : v2
default_v3 :
bert_base_path : GPT_SoVITS / pretrained_models / chinese - roberta - wwm - ext - large
cnhuhbert_base_path : GPT_SoVITS / pretrained_models / chinese - hubert - base
device : cpu
is_half : false
t2s_weights_path : GPT_SoVITS / pretrained_models / s1v3 . ckpt
vits_weights_path : GPT_SoVITS / pretrained_models / s2Gv3 . pth
version : v3
"""
def set_seed ( seed : int ) :
seed = int ( seed )
seed = seed if seed != - 1 else random . randrange ( 1 << 32 )
seed = seed if seed != - 1 else random . rand int( 0 , 2 * * 32 - 1 )
print ( f " Set seed to { seed } " )
os . environ [ ' PYTHONHASHSEED ' ] = str ( seed )
random . seed ( seed )
@ -82,7 +180,7 @@ def set_seed(seed:int):
class TTS_Config :
default_configs = {
" default " : {
" v1 " : {
" device " : " cpu " ,
" is_half " : False ,
" version " : " v1 " ,
@ -91,7 +189,7 @@ class TTS_Config:
" cnhuhbert_base_path " : " GPT_SoVITS/pretrained_models/chinese-hubert-base " ,
" bert_base_path " : " GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large " ,
} ,
" default_ v2" : {
" v2" : {
" device " : " cpu " ,
" is_half " : False ,
" version " : " v2 " ,
@ -100,6 +198,15 @@ class TTS_Config:
" cnhuhbert_base_path " : " GPT_SoVITS/pretrained_models/chinese-hubert-base " ,
" bert_base_path " : " GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large " ,
} ,
" v3 " : {
" device " : " cpu " ,
" is_half " : False ,
" version " : " v3 " ,
" t2s_weights_path " : " GPT_SoVITS/pretrained_models/s1v3.ckpt " ,
" vits_weights_path " : " GPT_SoVITS/pretrained_models/s2Gv3.pth " ,
" cnhuhbert_base_path " : " GPT_SoVITS/pretrained_models/chinese-hubert-base " ,
" bert_base_path " : " GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large " ,
} ,
}
configs : dict = None
v1_languages : list = [ " auto " , " en " , " zh " , " ja " , " all_zh " , " all_ja " ]
@ -136,12 +243,9 @@ class TTS_Config:
assert isinstance ( configs , dict )
version = configs . get ( " version " , " v2 " ) . lower ( )
assert version in [ " v1 " , " v2 " ]
self . default_configs [ " default " ] = configs . get ( " default " , self . default_configs [ " default " ] )
self . default_configs [ " default_v2 " ] = configs . get ( " default_v2 " , self . default_configs [ " default_v2 " ] )
default_config_key = " default " if version == " v1 " else " default_v2 "
self . configs : dict = configs . get ( " custom " , deepcopy ( self . default_configs [ default_config_key ] ) )
assert version in [ " v1 " , " v2 " , " v3 " ]
self . default_configs [ version ] = configs . get ( version , self . default_configs [ version ] )
self . configs : dict = configs . get ( " custom " , deepcopy ( self . default_configs [ version ] ) )
self . device = self . configs . get ( " device " , torch . device ( " cpu " ) )
@ -159,20 +263,22 @@ class TTS_Config:
self . vits_weights_path = self . configs . get ( " vits_weights_path " , None )
self . bert_base_path = self . configs . get ( " bert_base_path " , None )
self . cnhuhbert_base_path = self . configs . get ( " cnhuhbert_base_path " , None )
self . languages = self . v2_languages if self . version == " v2 " else self . v1_languages
self . languages = self . v1_languages if self . version == " v1 " else self . v2_languages
self . is_v3_synthesizer : bool = False
if ( self . t2s_weights_path in [ None , " " ] ) or ( not os . path . exists ( self . t2s_weights_path ) ) :
self . t2s_weights_path = self . default_configs [ default_config_key ] [ ' t2s_weights_path ' ]
self . t2s_weights_path = self . default_configs [ version ] [ ' t2s_weights_path ' ]
print ( f " fall back to default t2s_weights_path: { self . t2s_weights_path } " )
if ( self . vits_weights_path in [ None , " " ] ) or ( not os . path . exists ( self . vits_weights_path ) ) :
self . vits_weights_path = self . default_configs [ default_config_key ] [ ' vits_weights_path ' ]
self . vits_weights_path = self . default_configs [ version ] [ ' vits_weights_path ' ]
print ( f " fall back to default vits_weights_path: { self . vits_weights_path } " )
if ( self . bert_base_path in [ None , " " ] ) or ( not os . path . exists ( self . bert_base_path ) ) :
self . bert_base_path = self . default_configs [ default_config_key ] [ ' bert_base_path ' ]
self . bert_base_path = self . default_configs [ version ] [ ' bert_base_path ' ]
print ( f " fall back to default bert_base_path: { self . bert_base_path } " )
if ( self . cnhuhbert_base_path in [ None , " " ] ) or ( not os . path . exists ( self . cnhuhbert_base_path ) ) :
self . cnhuhbert_base_path = self . default_configs [ default_config_key ] [ ' cnhuhbert_base_path ' ]
self . cnhuhbert_base_path = self . default_configs [ version ] [ ' cnhuhbert_base_path ' ]
print ( f " fall back to default cnhuhbert_base_path: { self . cnhuhbert_base_path } " )
self . update_configs ( )
@ -195,7 +301,7 @@ class TTS_Config:
else :
print ( i18n ( " 路径不存在,使用默认配置 " ) )
self . save_configs ( configs_path )
with open ( configs_path , ' r ' ) as f :
with open ( configs_path , ' r ' , encoding = ' utf-8 ' ) as f :
configs = yaml . load ( f , Loader = yaml . FullLoader )
return configs
@ -224,7 +330,7 @@ class TTS_Config:
def update_version ( self , version : str ) - > None :
self . version = version
self . languages = self . v 2_languages if self . version == " v2 " else self . v1 _languages
self . languages = self . v 1_languages if self . version == " v1 " else self . v2 _languages
def __str__ ( self ) :
self . configs = self . update_configs ( )
@ -252,10 +358,13 @@ class TTS:
self . configs : TTS_Config = TTS_Config ( configs )
self . t2s_model : Text2SemanticLightningModule = None
self . vits_model : SynthesizerTrn = None
self . vits_model : Union[ SynthesizerTrn, SynthesizerTrnV3 ] = None
self . bert_tokenizer : AutoTokenizer = None
self . bert_model : AutoModelForMaskedLM = None
self . cnhuhbert_model : CNHubert = None
self . bigvgan_model : BigVGAN = None
self . sr_model : AP_BWE = None
self . sr_model_not_exist : bool = False
self . _init_models ( )
@ -310,38 +419,83 @@ class TTS:
self . bert_model = self . bert_model . half ( )
def init_vits_weights ( self , weights_path : str ) :
print ( f " Loading VITS weights from { weights_path } " )
self . configs . vits_weights_path = weights_path
dict_s2 = torch . load ( weights_path , map_location = self . configs . device , weights_only = False )
version , model_version , if_lora_v3 = get_sovits_version_from_path_fast ( weights_path )
path_sovits_v3 = self . configs . default_configs [ " v3 " ] [ " vits_weights_path " ]
if if_lora_v3 == True and os . path . exists ( path_sovits_v3 ) == False :
info = path_sovits_v3 + i18n ( " SoVITS V3 底模缺失,无法加载相应 LoRA 权重 " )
raise FileExistsError ( info )
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
dict_s2 = load_sovits_new ( weights_path )
hps = dict_s2 [ " config " ]
if dict_s2 [ ' weight ' ] [ ' enc_p.text_embedding.weight ' ] . shape [ 0 ] == 322 :
self . configs . update_version ( " v1 " )
hps [ " model " ] [ " semantic_frame_rate " ] = " 25hz "
if ' enc_p.text_embedding.weight ' not in dict_s2 [ ' weight ' ] :
hps [ " model " ] [ " version " ] = " v2 " #v3model,v2sybomls
elif dict_s2 [ ' weight ' ] [ ' enc_p.text_embedding.weight ' ] . shape [ 0 ] == 322 :
hps [ " model " ] [ " version " ] = " v1 "
else :
self . configs . update_version ( " v2 " )
self . configs . save_configs ( )
hps [ " model " ] [ " version " ] = " v2 "
# version = hps["model"]["version"]
hps [ " model " ] [ " version " ] = self . configs . version
self . configs . filter_length = hps [ " data " ] [ " filter_length " ]
self . configs . segment_size = hps [ " train " ] [ " segment_size " ]
self . configs . sampling_rate = hps [ " data " ] [ " sampling_rate " ]
self . configs . hop_length = hps [ " data " ] [ " hop_length " ]
self . configs . win_length = hps [ " data " ] [ " win_length " ]
self . configs . n_speakers = hps [ " data " ] [ " n_speakers " ]
self . configs . semantic_frame_rate = " 25hz "
self . configs . semantic_frame_rate = hps [ " model " ] [ " semantic_frame_rate " ]
kwargs = hps [ " model " ]
vits_model = SynthesizerTrn (
self . configs . filter_length / / 2 + 1 ,
self . configs . segment_size / / self . configs . hop_length ,
n_speakers = self . configs . n_speakers ,
* * kwargs
)
# print(f"self.configs.sampling_rate:{self.configs.sampling_rate}")
self . configs . update_version ( model_version )
# print(f"model_version:{model_version}")
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
if model_version != " v3 " :
vits_model = SynthesizerTrn (
self . configs . filter_length / / 2 + 1 ,
self . configs . segment_size / / self . configs . hop_length ,
n_speakers = self . configs . n_speakers ,
* * kwargs
)
if hasattr ( vits_model , " enc_q " ) :
del vits_model . enc_q
self . configs . is_v3_synthesizer = False
else :
vits_model = SynthesizerTrnV3 (
self . configs . filter_length / / 2 + 1 ,
self . configs . segment_size / / self . configs . hop_length ,
n_speakers = self . configs . n_speakers ,
* * kwargs
)
self . configs . is_v3_synthesizer = True
self . init_bigvgan ( )
if if_lora_v3 == False :
print ( f " Loading VITS weights from { weights_path } . { vits_model . load_state_dict ( dict_s2 [ ' weight ' ] , strict = False ) } " )
else :
print ( f " Loading VITS pretrained weights from { weights_path } . { vits_model . load_state_dict ( load_sovits_new ( path_sovits_v3 ) [ ' weight ' ] , strict = False ) } " )
lora_rank = dict_s2 [ " lora_rank " ]
lora_config = LoraConfig (
target_modules = [ " to_k " , " to_q " , " to_v " , " to_out.0 " ] ,
r = lora_rank ,
lora_alpha = lora_rank ,
init_lora_weights = True ,
)
vits_model . cfm = get_peft_model ( vits_model . cfm , lora_config )
print ( f " Loading LoRA weights from { weights_path } . { vits_model . load_state_dict ( dict_s2 [ ' weight ' ] , strict = False ) } " )
vits_model . cfm = vits_model . cfm . merge_and_unload ( )
if hasattr ( vits_model , " enc_q " ) :
del vits_model . enc_q
vits_model = vits_model . to ( self . configs . device )
vits_model = vits_model . eval ( )
vits_model . load_state_dict ( dict_s2 [ " weight " ] , strict = False )
self . vits_model = vits_model
if self . configs . is_half and str ( self . configs . device ) != " cpu " :
self . vits_model = self . vits_model . half ( )
@ -363,6 +517,30 @@ class TTS:
if self . configs . is_half and str ( self . configs . device ) != " cpu " :
self . t2s_model = self . t2s_model . half ( )
def init_bigvgan ( self ) :
if self . bigvgan_model is not None :
return
self . bigvgan_model = BigVGAN . from_pretrained ( " %s /GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x " % ( now_dir , ) , use_cuda_kernel = False ) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
self . bigvgan_model . remove_weight_norm ( )
self . bigvgan_model = self . bigvgan_model . eval ( )
if self . configs . is_half == True :
self . bigvgan_model = self . bigvgan_model . half ( ) . to ( self . configs . device )
else :
self . bigvgan_model = self . bigvgan_model . to ( self . configs . device )
def init_sr_model ( self ) :
if self . sr_model is not None :
return
try :
self . sr_model : AP_BWE = AP_BWE ( self . configs . device , DictToAttrRecursive )
self . sr_model_not_exist = False
except FileNotFoundError :
print ( i18n ( " 你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好 " ) )
self . sr_model_not_exist = True
def enable_half_precision ( self , enable : bool = True , save : bool = True ) :
'''
To enable half precision for the TTS model .
@ -387,6 +565,8 @@ class TTS:
self . bert_model = self . bert_model . half ( )
if self . cnhuhbert_model is not None :
self . cnhuhbert_model = self . cnhuhbert_model . half ( )
if self . bigvgan_model is not None :
self . bigvgan_model = self . bigvgan_model . half ( )
else :
if self . t2s_model is not None :
self . t2s_model = self . t2s_model . float ( )
@ -396,6 +576,8 @@ class TTS:
self . bert_model = self . bert_model . float ( )
if self . cnhuhbert_model is not None :
self . cnhuhbert_model = self . cnhuhbert_model . float ( )
if self . bigvgan_model is not None :
self . bigvgan_model = self . bigvgan_model . float ( )
def set_device ( self , device : torch . device , save : bool = True ) :
'''
@ -414,6 +596,11 @@ class TTS:
self . bert_model = self . bert_model . to ( device )
if self . cnhuhbert_model is not None :
self . cnhuhbert_model = self . cnhuhbert_model . to ( device )
if self . bigvgan_model is not None :
self . bigvgan_model = self . bigvgan_model . to ( device )
if self . sr_model is not None :
self . sr_model = self . sr_model . to ( device )
def set_ref_audio ( self , ref_audio_path : str ) :
'''
@ -437,6 +624,11 @@ class TTS:
self . prompt_cache [ " refer_spec " ] [ 0 ] = spec
def _get_ref_spec ( self , ref_audio_path ) :
raw_audio , raw_sr = torchaudio . load ( ref_audio_path )
raw_audio = raw_audio . to ( self . configs . device ) . float ( )
self . prompt_cache [ " raw_audio " ] = raw_audio
self . prompt_cache [ " raw_sr " ] = raw_sr
audio = load_audio ( ref_audio_path , int ( self . configs . sampling_rate ) )
audio = torch . FloatTensor ( audio )
maxx = audio . abs ( ) . max ( )
@ -625,11 +817,11 @@ class TTS:
Recovery the order of the audio according to the batch_index_list .
Args :
data ( List [ list ( np. ndarray ) ] ) : the out of order audio .
data ( List [ list ( torch. Tensor ) ] ) : the out of order audio .
batch_index_list ( List [ list [ int ] ] ) : the batch index list .
Returns :
list ( List [ np. ndarray ] ) : the data in the original order .
list ( List [ torch. Tensor ] ) : the data in the original order .
'''
length = len ( sum ( batch_index_list , [ ] ) )
_data = [ None ] * length
@ -671,6 +863,8 @@ class TTS:
" seed " : - 1 , # int. random seed for reproducibility.
" parallel_infer " : True , # bool. whether to use parallel inference.
" repetition_penalty " : 1.35 # float. repetition penalty for T2S model.
" sample_steps " : 32 , # int. number of sampling steps for VITS model V3.
" super_sampling " : False , # bool. whether to use super-sampling for audio when using VITS model V3.
}
returns :
Tuple [ int , np . ndarray ] : sampling rate and audio data .
@ -698,6 +892,8 @@ class TTS:
actual_seed = set_seed ( seed )
parallel_infer = inputs . get ( " parallel_infer " , True )
repetition_penalty = inputs . get ( " repetition_penalty " , 1.35 )
sample_steps = inputs . get ( " sample_steps " , 32 )
super_sampling = inputs . get ( " super_sampling " , False )
if parallel_infer :
print ( i18n ( " 并行推理模式已开启 " ) )
@ -732,6 +928,9 @@ class TTS:
if not no_prompt_text :
assert prompt_lang in self . configs . languages
if no_prompt_text and self . configs . is_v3_synthesizer :
raise NO_PROMPT_ERROR ( " prompt_text cannot be empty when using SoVITS_V3 " )
if ref_audio_path in [ None , " " ] and \
( ( self . prompt_cache [ " prompt_semantic " ] is None ) or ( self . prompt_cache [ " refer_spec " ] in [ None , [ ] ] ) ) :
raise ValueError ( " ref_audio_path cannot be empty, when the reference audio is not set using set_ref_audio() " )
@ -761,13 +960,13 @@ class TTS:
if ( prompt_text [ - 1 ] not in splits ) : prompt_text + = " 。 " if prompt_lang != " en " else " . "
print ( i18n ( " 实际输入的参考文本: " ) , prompt_text )
if self . prompt_cache [ " prompt_text " ] != prompt_text :
self . prompt_cache [ " prompt_text " ] = prompt_text
self . prompt_cache [ " prompt_lang " ] = prompt_lang
phones , bert_features , norm_text = \
self . text_preprocessor . segment_and_extract_feature_for_text (
prompt_text ,
prompt_lang ,
self . configs . version )
self . prompt_cache [ " prompt_text " ] = prompt_text
self . prompt_cache [ " prompt_lang " ] = prompt_lang
self . prompt_cache [ " phones " ] = phones
self . prompt_cache [ " bert_features " ] = bert_features
self . prompt_cache [ " norm_text " ] = norm_text
@ -781,8 +980,7 @@ class TTS:
if not return_fragment :
data = self . text_preprocessor . preprocess ( text , text_lang , text_split_method , self . configs . version )
if len ( data ) == 0 :
yield self . configs . sampling_rate , np . zeros ( int ( self . configs . sampling_rate ) ,
dtype = np . int16 )
yield 16000 , np . zeros ( int ( 16000 ) , dtype = np . int16 )
return
batch_index_list : list = None
@ -836,6 +1034,7 @@ class TTS:
t_34 = 0.0
t_45 = 0.0
audio = [ ]
output_sr = self . configs . sampling_rate if not self . configs . is_v3_synthesizer else 24000
for item in data :
t3 = ttime ( )
if return_fragment :
@ -858,7 +1057,7 @@ class TTS:
else :
prompt = self . prompt_cache [ " prompt_semantic " ] . expand ( len ( all_phoneme_ids ) , - 1 ) . to ( self . configs . device )
print ( f " ############ { i18n ( ' 预测语义Token ' ) } ############ " )
pred_semantic_list , idx_list = self . t2s_model . model . infer_panel (
all_phoneme_ids ,
all_phoneme_lens ,
@ -892,70 +1091,80 @@ class TTS:
# batch_audio_fragment = (self.vits_model.batched_decode(
# pred_semantic, pred_semantic_len, batch_phones, batch_phones_len,refer_audio_spec
# ))
if speed_factor == 1.0 :
# ## vits并行推理 method 2
pred_semantic_list = [ item [ - idx : ] for item , idx in zip ( pred_semantic_list , idx_list ) ]
upsample_rate = math . prod ( self . vits_model . upsample_rates )
audio_frag_idx = [ pred_semantic_list [ i ] . shape [ 0 ] * 2 * upsample_rate for i in range ( 0 , len ( pred_semantic_list ) ) ]
audio_frag_end_idx = [ sum ( audio_frag_idx [ : i + 1 ] ) for i in range ( 0 , len ( audio_frag_idx ) ) ]
all_pred_semantic = torch . cat ( pred_semantic_list ) . unsqueeze ( 0 ) . unsqueeze ( 0 ) . to ( self . configs . device )
_batch_phones = torch . cat ( batch_phones ) . unsqueeze ( 0 ) . to ( self . configs . device )
_batch_audio_fragment = ( self . vits_model . decode (
all_pred_semantic , _batch_phones , refer_audio_spec , speed = speed_factor
) . detach ( ) [ 0 , 0 , : ] )
audio_frag_end_idx . insert ( 0 , 0 )
batch_audio_fragment = [ _batch_audio_fragment [ audio_frag_end_idx [ i - 1 ] : audio_frag_end_idx [ i ] ] for i in range ( 1 , len ( audio_frag_end_idx ) ) ]
print ( f " ############ { i18n ( ' 合成音频 ' ) } ############ " )
if not self . configs . is_v3_synthesizer :
if speed_factor == 1.0 :
# ## vits并行推理 method 2
pred_semantic_list = [ item [ - idx : ] for item , idx in zip ( pred_semantic_list , idx_list ) ]
upsample_rate = math . prod ( self . vits_model . upsample_rates )
audio_frag_idx = [ pred_semantic_list [ i ] . shape [ 0 ] * 2 * upsample_rate for i in range ( 0 , len ( pred_semantic_list ) ) ]
audio_frag_end_idx = [ sum ( audio_frag_idx [ : i + 1 ] ) for i in range ( 0 , len ( audio_frag_idx ) ) ]
all_pred_semantic = torch . cat ( pred_semantic_list ) . unsqueeze ( 0 ) . unsqueeze ( 0 ) . to ( self . configs . device )
_batch_phones = torch . cat ( batch_phones ) . unsqueeze ( 0 ) . to ( self . configs . device )
_batch_audio_fragment = ( self . vits_model . decode (
all_pred_semantic , _batch_phones , refer_audio_spec , speed = speed_factor
) . detach ( ) [ 0 , 0 , : ] )
audio_frag_end_idx . insert ( 0 , 0 )
batch_audio_fragment = [ _batch_audio_fragment [ audio_frag_end_idx [ i - 1 ] : audio_frag_end_idx [ i ] ] for i in range ( 1 , len ( audio_frag_end_idx ) ) ]
else :
# ## vits串行推理
for i , idx in enumerate ( tqdm ( idx_list ) ) :
phones = batch_phones [ i ] . unsqueeze ( 0 ) . to ( self . configs . device )
_pred_semantic = ( pred_semantic_list [ i ] [ - idx : ] . unsqueeze ( 0 ) . unsqueeze ( 0 ) ) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = ( self . vits_model . decode (
_pred_semantic , phones , refer_audio_spec , speed = speed_factor
) . detach ( ) [ 0 , 0 , : ] )
batch_audio_fragment . append (
audio_fragment
) ###试试重建不带上prompt部分
else :
# ## vits串行推理
for i , idx in enumerate ( idx_list ) :
for i , idx in enumerate ( tqdm ( idx_list ) ) :
phones = batch_phones [ i ] . unsqueeze ( 0 ) . to ( self . configs . device )
_pred_semantic = ( pred_semantic_list [ i ] [ - idx : ] . unsqueeze ( 0 ) . unsqueeze ( 0 ) ) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = ( self . vits_model . decode (
_pred_semantic , phones , refer_audio_spec , speed = speed_factor
) . detach ( ) [ 0 , 0 , : ] )
audio_fragment = self . v3_synthesis (
_pred_semantic , phones , speed= speed_factor , sample_steps = sample_steps
)
batch_audio_fragment . append (
audio_fragment
) ###试试重建不带上prompt部分
)
t5 = ttime ( )
t_45 + = t5 - t4
if return_fragment :
print ( " %.3f \t %.3f \t %.3f \t %.3f " % ( t1 - t0 , t2 - t1 , t4 - t3 , t5 - t4 ) )
yield self . audio_postprocess ( [ batch_audio_fragment ] ,
self . configs . sampling_rate ,
output_sr ,
None ,
speed_factor ,
False ,
fragment_interval
fragment_interval ,
super_sampling if self . configs . is_v3_synthesizer else False
)
else :
audio . append ( batch_audio_fragment )
if self . stop_flag :
yield self . configs . sampling_rate , np . zeros ( int ( self . configs . sampling_rate ) ,
dtype = np . int16 )
yield 16000 , np . zeros ( int ( 16000 ) , dtype = np . int16 )
return
if not return_fragment :
print ( " %.3f \t %.3f \t %.3f \t %.3f " % ( t1 - t0 , t2 - t1 , t_34 , t_45 ) )
if len ( audio ) == 0 :
yield self . configs . sampling_rate , np . zeros ( int ( self . configs . sampling_rate ) ,
dtype = np . int16 )
yield 16000 , np . zeros ( int ( 16000 ) , dtype = np . int16 )
return
yield self . audio_postprocess ( audio ,
self . configs . sampling_rate ,
output_sr ,
batch_index_list ,
speed_factor ,
split_bucket ,
fragment_interval
fragment_interval ,
super_sampling if self . configs . is_v3_synthesizer else False
)
except Exception as e :
traceback . print_exc ( )
# 必须返回一个空音频, 否则会导致显存不释放。
yield self . configs . sampling_rate , np . zeros ( int ( self . configs . sampling_rate ) ,
dtype = np . int16 )
yield 16000 , np . zeros ( int ( 16000 ) , dtype = np . int16 )
# 重置模型, 否则会导致显存释放不完全。
del self . t2s_model
del self . vits_model
@ -983,7 +1192,8 @@ class TTS:
batch_index_list : list = None ,
speed_factor : float = 1.0 ,
split_bucket : bool = True ,
fragment_interval : float = 0.3
fragment_interval : float = 0.3 ,
super_sampling : bool = False ,
) - > Tuple [ int , np . ndarray ] :
zero_wav = torch . zeros (
int ( self . configs . sampling_rate * fragment_interval ) ,
@ -996,7 +1206,7 @@ class TTS:
max_audio = torch . abs ( audio_fragment ) . max ( ) #简单防止16bit爆音
if max_audio > 1 : audio_fragment / = max_audio
audio_fragment : torch . Tensor = torch . cat ( [ audio_fragment , zero_wav ] , dim = 0 )
audio [ i ] [ j ] = audio_fragment . cpu ( ) . numpy ( )
audio [ i ] [ j ] = audio_fragment
if split_bucket :
@ -1005,8 +1215,21 @@ class TTS:
# audio = [item for batch in audio for item in batch]
audio = sum ( audio , [ ] )
audio = torch . cat ( audio , dim = 0 )
if super_sampling :
print ( f " ############ { i18n ( ' 音频超采样 ' ) } ############ " )
t1 = ttime ( )
self . init_sr_model ( )
if not self . sr_model_not_exist :
audio , sr = self . sr_model ( audio . unsqueeze ( 0 ) , sr )
max_audio = np . abs ( audio ) . max ( )
if max_audio > 1 : audio / = max_audio
t2 = ttime ( )
print ( f " 超采样用时: { t2 - t1 : .3f } s " )
else :
audio = audio . cpu ( ) . numpy ( )
audio = np . concatenate ( audio , 0 )
audio = ( audio * 32768 ) . astype ( np . int16 )
# try:
@ -1018,25 +1241,59 @@ class TTS:
return sr , audio
def speed_change ( input_audio : np . ndarray , speed : float , sr : int ) :
# 将 NumPy 数组转换为原始 PCM 流
raw_audio = input_audio . astype ( np . int16 ) . tobytes ( )
# 设置 ffmpeg 输入流
input_stream = ffmpeg . input ( ' pipe: ' , format = ' s16le ' , acodec = ' pcm_s16le ' , ar = str ( sr ) , ac = 1 )
# 变速处理
output_stream = input_stream . filter ( ' atempo ' , speed )
# 输出流到管道
out , _ = (
output_stream . output ( ' pipe: ' , format = ' s16le ' , acodec = ' pcm_s16le ' )
. run ( input = raw_audio , capture_stdout = True , capture_stderr = True )
)
# 将管道输出解码为 NumPy 数组
processed_audio = np . frombuffer ( out , np . int16 )
return processed_audio
def v3_synthesis ( self ,
semantic_tokens : torch . Tensor ,
phones : torch . Tensor ,
speed : float = 1.0 ,
sample_steps : int = 32
) :
prompt_semantic_tokens = self . prompt_cache [ " prompt_semantic " ] . unsqueeze ( 0 ) . unsqueeze ( 0 ) . to ( self . configs . device )
prompt_phones = torch . LongTensor ( self . prompt_cache [ " phones " ] ) . unsqueeze ( 0 ) . to ( self . configs . device )
refer_audio_spec = self . prompt_cache [ " refer_spec " ] [ 0 ] . to ( dtype = self . precision , device = self . configs . device )
fea_ref , ge = self . vits_model . decode_encp ( prompt_semantic_tokens , prompt_phones , refer_audio_spec )
ref_audio : torch . Tensor = self . prompt_cache [ " raw_audio " ]
ref_sr = self . prompt_cache [ " raw_sr " ]
ref_audio = ref_audio . to ( self . configs . device ) . float ( )
if ( ref_audio . shape [ 0 ] == 2 ) :
ref_audio = ref_audio . mean ( 0 ) . unsqueeze ( 0 )
if ref_sr != 24000 :
ref_audio = resample ( ref_audio , ref_sr , self . configs . device )
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn ( ref_audio )
mel2 = norm_spec ( mel2 )
T_min = min ( mel2 . shape [ 2 ] , fea_ref . shape [ 2 ] )
mel2 = mel2 [ : , : , : T_min ]
fea_ref = fea_ref [ : , : , : T_min ]
if ( T_min > 468 ) :
mel2 = mel2 [ : , : , - 468 : ]
fea_ref = fea_ref [ : , : , - 468 : ]
T_min = 468
chunk_len = 934 - T_min
mel2 = mel2 . to ( self . precision )
fea_todo , ge = self . vits_model . decode_encp ( semantic_tokens , phones , refer_audio_spec , ge , speed )
cfm_resss = [ ]
idx = 0
while ( 1 ) :
fea_todo_chunk = fea_todo [ : , : , idx : idx + chunk_len ]
if ( fea_todo_chunk . shape [ - 1 ] == 0 ) : break
idx + = chunk_len
fea = torch . cat ( [ fea_ref , fea_todo_chunk ] , 2 ) . transpose ( 2 , 1 )
cfm_res = self . vits_model . cfm . inference ( fea , torch . LongTensor ( [ fea . size ( 1 ) ] ) . to ( fea . device ) , mel2 , sample_steps , inference_cfg_rate = 0 )
cfm_res = cfm_res [ : , : , mel2 . shape [ 2 ] : ]
mel2 = cfm_res [ : , : , - T_min : ]
fea_ref = fea_todo_chunk [ : , : , - T_min : ]
cfm_resss . append ( cfm_res )
cmf_res = torch . cat ( cfm_resss , 2 )
cmf_res = denorm_spec ( cmf_res )
with torch . inference_mode ( ) :
wav_gen = self . bigvgan_model ( cmf_res )
audio = wav_gen [ 0 ] [ 0 ] #.cpu().detach().numpy()
return audio