@ -150,9 +150,9 @@ sys.path.append(now_dir)
sys . path . append ( " %s /GPT_SoVITS " % ( now_dir ) )
sys . path . append ( " %s /GPT_SoVITS " % ( now_dir ) )
import signal
import signal
import LangSegment
from text . LangSegmenter import LangSegmenter
from time import time as ttime
from time import time as ttime
import torch
import torch , torchaudio
import librosa
import librosa
import soundfile as sf
import soundfile as sf
from fastapi import FastAPI , Request , Query , HTTPException
from fastapi import FastAPI , Request , Query , HTTPException
@ -162,7 +162,8 @@ from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
import numpy as np
from feature_extractor import cnhubert
from feature_extractor import cnhubert
from io import BytesIO
from io import BytesIO
from module . models import SynthesizerTrn
from module . models import SynthesizerTrn , SynthesizerTrnV3
from peft import LoraConfig , PeftModel , get_peft_model
from AR . models . t2s_lightning_module import Text2SemanticLightningModule
from AR . models . t2s_lightning_module import Text2SemanticLightningModule
from text import cleaned_text_to_sequence
from text import cleaned_text_to_sequence
from text . cleaner import clean_text
from text . cleaner import clean_text
@ -197,6 +198,61 @@ def is_full(*items): # 任意一项为空返回False
return True
return True
def init_bigvgan ( ) :
global bigvgan_model
from BigVGAN import bigvgan
bigvgan_model = bigvgan . BigVGAN . from_pretrained ( " %s /GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x " % ( now_dir , ) , use_cuda_kernel = False ) # if True, RuntimeError: Ninja is required to load C++ extensions
# remove weight norm in the model and set to eval mode
bigvgan_model . remove_weight_norm ( )
bigvgan_model = bigvgan_model . eval ( )
if is_half == True :
bigvgan_model = bigvgan_model . half ( ) . to ( device )
else :
bigvgan_model = bigvgan_model . to ( device )
resample_transform_dict = { }
def resample ( audio_tensor , sr0 ) :
global resample_transform_dict
if sr0 not in resample_transform_dict :
resample_transform_dict [ sr0 ] = torchaudio . transforms . Resample (
sr0 , 24000
) . to ( device )
return resample_transform_dict [ sr0 ] ( audio_tensor )
from module . mel_processing import spectrogram_torch , mel_spectrogram_torch
spec_min = - 12
spec_max = 2
def norm_spec ( x ) :
return ( x - spec_min ) / ( spec_max - spec_min ) * 2 - 1
def denorm_spec ( x ) :
return ( x + 1 ) / 2 * ( spec_max - spec_min ) + spec_min
mel_fn = lambda x : mel_spectrogram_torch ( x , * * {
" n_fft " : 1024 ,
" win_size " : 1024 ,
" hop_size " : 256 ,
" num_mels " : 100 ,
" sampling_rate " : 24000 ,
" fmin " : 0 ,
" fmax " : None ,
" center " : False
} )
sr_model = None
def audio_sr ( audio , sr ) :
global sr_model
if sr_model == None :
from tools . audio_sr import AP_BWE
try :
sr_model = AP_BWE ( device , DictToAttrRecursive )
except FileNotFoundError :
logger . info ( " 你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载 " )
return audio . cpu ( ) . detach ( ) . numpy ( ) , sr
return sr_model ( audio , sr )
class Speaker :
class Speaker :
def __init__ ( self , name , gpt , sovits , phones = None , bert = None , prompt = None ) :
def __init__ ( self , name , gpt , sovits , phones = None , bert = None , prompt = None ) :
self . name = name
self . name = name
@ -214,31 +270,72 @@ class Sovits:
self . vq_model = vq_model
self . vq_model = vq_model
self . hps = hps
self . hps = hps
from process_ckpt import get_sovits_version_from_path_fast , load_sovits_new
def get_sovits_weights ( sovits_path ) :
def get_sovits_weights ( sovits_path ) :
dict_s2 = torch . load ( sovits_path , map_location = " cpu " )
path_sovits_v3 = " GPT_SoVITS/pretrained_models/s2Gv3.pth "
is_exist_s2gv3 = os . path . exists ( path_sovits_v3 )
version , model_version , if_lora_v3 = get_sovits_version_from_path_fast ( sovits_path )
if if_lora_v3 == True and is_exist_s2gv3 == False :
logger . info ( " SoVITS V3 底模缺失,无法加载相应 LoRA 权重 " )
dict_s2 = load_sovits_new ( sovits_path )
hps = dict_s2 [ " config " ]
hps = dict_s2 [ " config " ]
hps = DictToAttrRecursive ( hps )
hps = DictToAttrRecursive ( hps )
hps . model . semantic_frame_rate = " 25hz "
hps . model . semantic_frame_rate = " 25hz "
if dict_s2 [ ' weight ' ] [ ' enc_p.text_embedding.weight ' ] . shape [ 0 ] == 322 :
if ' enc_p.text_embedding.weight ' not in dict_s2 [ ' weight ' ] :
hps . model . version = " v2 " #v3model,v2sybomls
elif dict_s2 [ ' weight ' ] [ ' enc_p.text_embedding.weight ' ] . shape [ 0 ] == 322 :
hps . model . version = " v1 "
hps . model . version = " v1 "
else :
else :
hps . model . version = " v2 "
hps . model . version = " v2 "
logger . info ( f " 模型版本: { hps . model . version } " )
if model_version == " v3 " :
hps . model . version = " v3 "
model_params_dict = vars ( hps . model )
model_params_dict = vars ( hps . model )
vq_model = SynthesizerTrn (
if model_version != " v3 " :
hps . data . filter_length / / 2 + 1 ,
vq_model = SynthesizerTrn (
hps . train . segment_size / / hps . data . hop_length ,
hps . data . filter_length / / 2 + 1 ,
n_speakers = hps . data . n_speakers ,
hps . train . segment_size / / hps . data . hop_length ,
* * model_params_dict
n_speakers = hps . data . n_speakers ,
)
* * model_params_dict
)
else :
vq_model = SynthesizerTrnV3 (
hps . data . filter_length / / 2 + 1 ,
hps . train . segment_size / / hps . data . hop_length ,
n_speakers = hps . data . n_speakers ,
* * model_params_dict
)
init_bigvgan ( )
model_version = hps . model . version
logger . info ( f " 模型版本: { model_version } " )
if ( " pretrained " not in sovits_path ) :
if ( " pretrained " not in sovits_path ) :
del vq_model . enc_q
try :
del vq_model . enc_q
except : pass
if is_half == True :
if is_half == True :
vq_model = vq_model . half ( ) . to ( device )
vq_model = vq_model . half ( ) . to ( device )
else :
else :
vq_model = vq_model . to ( device )
vq_model = vq_model . to ( device )
vq_model . eval ( )
vq_model . eval ( )
vq_model . load_state_dict ( dict_s2 [ " weight " ] , strict = False )
if if_lora_v3 == False :
vq_model . load_state_dict ( dict_s2 [ " weight " ] , strict = False )
else :
vq_model . load_state_dict ( load_sovits_new ( path_sovits_v3 ) [ " weight " ] , strict = False )
lora_rank = dict_s2 [ " lora_rank " ]
lora_config = LoraConfig (
target_modules = [ " to_k " , " to_q " , " to_v " , " to_out.0 " ] ,
r = lora_rank ,
lora_alpha = lora_rank ,
init_lora_weights = True ,
)
vq_model . cfm = get_peft_model ( vq_model . cfm , lora_config )
vq_model . load_state_dict ( dict_s2 [ " weight " ] , strict = False )
vq_model . cfm = vq_model . cfm . merge_and_unload ( )
# torch.save(vq_model.state_dict(),"merge_win.pth")
vq_model . eval ( )
sovits = Sovits ( vq_model , hps )
sovits = Sovits ( vq_model , hps )
return sovits
return sovits
@ -260,8 +357,8 @@ def get_gpt_weights(gpt_path):
t2s_model = t2s_model . half ( )
t2s_model = t2s_model . half ( )
t2s_model = t2s_model . to ( device )
t2s_model = t2s_model . to ( device )
t2s_model . eval ( )
t2s_model . eval ( )
total = sum ( [ param . nelement ( ) for param in t2s_model . parameters ( ) ] )
# total = sum([param.nelement() for param in t2s_model.parameters()] )
logger . info ( " Number of parameter: %.2f M " % ( total / 1e6 ) )
# logger.info("Number of parameter: %.2fM" % (total / 1e6) )
gpt = Gpt ( max_sec , t2s_model )
gpt = Gpt ( max_sec , t2s_model )
return gpt
return gpt
@ -295,6 +392,7 @@ def get_bert_feature(text, word2ph):
def clean_text_inf ( text , language , version ) :
def clean_text_inf ( text , language , version ) :
language = language . replace ( " all_ " , " " )
phones , word2ph , norm_text = clean_text ( text , language , version )
phones , word2ph , norm_text = clean_text ( text , language , version )
phones = cleaned_text_to_sequence ( phones , version )
phones = cleaned_text_to_sequence ( phones , version )
return phones , word2ph , norm_text
return phones , word2ph , norm_text
@ -315,16 +413,10 @@ def get_bert_inf(phones, word2ph, norm_text, language):
from text import chinese
from text import chinese
def get_phones_and_bert ( text , language , version , final = False ) :
def get_phones_and_bert ( text , language , version , final = False ) :
if language in { " en " , " all_zh " , " all_ja " , " all_ko " , " all_yue " } :
if language in { " en " , " all_zh " , " all_ja " , " all_ko " , " all_yue " } :
language = language . replace ( " all_ " , " " )
formattext = text
if language == " en " :
LangSegment . setfilters ( [ " en " ] )
formattext = " " . join ( tmp [ " text " ] for tmp in LangSegment . getTexts ( text ) )
else :
# 因无法区别中日韩文汉字,以用户输入为准
formattext = text
while " " in formattext :
while " " in formattext :
formattext = formattext . replace ( " " , " " )
formattext = formattext . replace ( " " , " " )
if language == " zh" :
if language == " all_zh " :
if re . search ( r ' [A-Za-z] ' , formattext ) :
if re . search ( r ' [A-Za-z] ' , formattext ) :
formattext = re . sub ( r ' [a-z] ' , lambda x : x . group ( 0 ) . upper ( ) , formattext )
formattext = re . sub ( r ' [a-z] ' , lambda x : x . group ( 0 ) . upper ( ) , formattext )
formattext = chinese . mix_text_normalize ( formattext )
formattext = chinese . mix_text_normalize ( formattext )
@ -332,7 +424,7 @@ def get_phones_and_bert(text,language,version,final=False):
else :
else :
phones , word2ph , norm_text = clean_text_inf ( formattext , language , version )
phones , word2ph , norm_text = clean_text_inf ( formattext , language , version )
bert = get_bert_feature ( norm_text , word2ph ) . to ( device )
bert = get_bert_feature ( norm_text , word2ph ) . to ( device )
elif language == " yue" and re . search ( r ' [A-Za-z] ' , formattext ) :
elif language == " all_ yue" and re . search ( r ' [A-Za-z] ' , formattext ) :
formattext = re . sub ( r ' [a-z] ' , lambda x : x . group ( 0 ) . upper ( ) , formattext )
formattext = re . sub ( r ' [a-z] ' , lambda x : x . group ( 0 ) . upper ( ) , formattext )
formattext = chinese . mix_text_normalize ( formattext )
formattext = chinese . mix_text_normalize ( formattext )
return get_phones_and_bert ( formattext , " yue " , version )
return get_phones_and_bert ( formattext , " yue " , version )
@ -345,19 +437,18 @@ def get_phones_and_bert(text,language,version,final=False):
elif language in { " zh " , " ja " , " ko " , " yue " , " auto " , " auto_yue " } :
elif language in { " zh " , " ja " , " ko " , " yue " , " auto " , " auto_yue " } :
textlist = [ ]
textlist = [ ]
langlist = [ ]
langlist = [ ]
LangSegment . setfilters ( [ " zh " , " ja " , " en " , " ko " ] )
if language == " auto " :
if language == " auto " :
for tmp in LangSegment . getTexts ( text ) :
for tmp in LangSegment er . getTexts ( text ) :
langlist . append ( tmp [ " lang " ] )
langlist . append ( tmp [ " lang " ] )
textlist . append ( tmp [ " text " ] )
textlist . append ( tmp [ " text " ] )
elif language == " auto_yue " :
elif language == " auto_yue " :
for tmp in LangSegment . getTexts ( text ) :
for tmp in LangSegment er . getTexts ( text ) :
if tmp [ " lang " ] == " zh " :
if tmp [ " lang " ] == " zh " :
tmp [ " lang " ] = " yue "
tmp [ " lang " ] = " yue "
langlist . append ( tmp [ " lang " ] )
langlist . append ( tmp [ " lang " ] )
textlist . append ( tmp [ " text " ] )
textlist . append ( tmp [ " text " ] )
else :
else :
for tmp in LangSegment . getTexts ( text ) :
for tmp in LangSegment er . getTexts ( text ) :
if tmp [ " lang " ] == " en " :
if tmp [ " lang " ] == " en " :
langlist . append ( tmp [ " lang " ] )
langlist . append ( tmp [ " lang " ] )
else :
else :
@ -556,10 +647,11 @@ def only_punc(text):
splits = { " , " , " 。 " , " ? " , " ! " , " , " , " . " , " ? " , " ! " , " ~ " , " : " , " : " , " — " , " … " , }
splits = { " , " , " 。 " , " ? " , " ! " , " , " , " . " , " ? " , " ! " , " ~ " , " : " , " : " , " — " , " … " , }
def get_tts_wav ( ref_wav_path , prompt_text , prompt_language , text , text_language , top_k = 15 , top_p = 0.6 , temperature = 0.6 , speed = 1 , inp_refs = None , s pk = " default " ) :
def get_tts_wav ( ref_wav_path , prompt_text , prompt_language , text , text_language , top_k = 15 , top_p = 0.6 , temperature = 0.6 , speed = 1 , inp_refs = None , s ample_steps = 32 , if_sr = False , s pk = " default " ) :
infer_sovits = speaker_list [ spk ] . sovits
infer_sovits = speaker_list [ spk ] . sovits
vq_model = infer_sovits . vq_model
vq_model = infer_sovits . vq_model
hps = infer_sovits . hps
hps = infer_sovits . hps
version = vq_model . version
infer_gpt = speaker_list [ spk ] . gpt
infer_gpt = speaker_list [ spk ] . gpt
t2s_model = infer_gpt . t2s_model
t2s_model = infer_gpt . t2s_model
@ -587,20 +679,22 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
prompt_semantic = codes [ 0 , 0 ]
prompt_semantic = codes [ 0 , 0 ]
prompt = prompt_semantic . unsqueeze ( 0 ) . to ( device )
prompt = prompt_semantic . unsqueeze ( 0 ) . to ( device )
refers = [ ]
if version != " v3 " :
if ( inp_refs ) :
refers = [ ]
for path in inp_refs :
if ( inp_refs ) :
try :
for path in inp_refs :
refer = get_spepc ( hps , path ) . to ( dtype ) . to ( device )
try :
refers . append ( refer )
refer = get_spepc ( hps , path ) . to ( dtype ) . to ( device )
except Exception as e :
refers . append ( refer )
logger . error ( e )
except Exception as e :
if ( len ( refers ) == 0 ) :
logger . error ( e )
refers = [ get_spepc ( hps , ref_wav_path ) . to ( dtype ) . to ( device ) ]
if ( len ( refers ) == 0 ) :
refers = [ get_spepc ( hps , ref_wav_path ) . to ( dtype ) . to ( device ) ]
else :
refer = get_spepc ( hps , ref_wav_path ) . to ( device ) . to ( dtype )
t1 = ttime ( )
t1 = ttime ( )
version = vq_model . version
# os.environ['version'] = version
os . environ [ ' version ' ] = version
prompt_language = dict_language [ prompt_language . lower ( ) ]
prompt_language = dict_language [ prompt_language . lower ( ) ]
text_language = dict_language [ text_language . lower ( ) ]
text_language = dict_language [ text_language . lower ( ) ]
phones1 , bert1 , norm_text1 = get_phones_and_bert ( prompt_text , prompt_language , version )
phones1 , bert1 , norm_text1 = get_phones_and_bert ( prompt_text , prompt_language , version )
@ -634,20 +728,82 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
early_stop_num = hz * max_sec )
early_stop_num = hz * max_sec )
pred_semantic = pred_semantic [ : , - idx : ] . unsqueeze ( 0 )
pred_semantic = pred_semantic [ : , - idx : ] . unsqueeze ( 0 )
t3 = ttime ( )
t3 = ttime ( )
audio = \
vq_model . decode ( pred_semantic , torch . LongTensor ( phones2 ) . to ( device ) . unsqueeze ( 0 ) ,
if version != " v3 " :
refers , speed = speed ) . detach ( ) . cpu ( ) . numpy ( ) [
audio = \
0 , 0 ] ###试试重建不带上prompt部分
vq_model . decode ( pred_semantic , torch . LongTensor ( phones2 ) . to ( device ) . unsqueeze ( 0 ) ,
refers , speed = speed ) . detach ( ) . cpu ( ) . numpy ( ) [
0 , 0 ] ###试试重建不带上prompt部分
else :
phoneme_ids0 = torch . LongTensor ( phones1 ) . to ( device ) . unsqueeze ( 0 )
phoneme_ids1 = torch . LongTensor ( phones2 ) . to ( device ) . unsqueeze ( 0 )
# print(11111111, phoneme_ids0, phoneme_ids1)
fea_ref , ge = vq_model . decode_encp ( prompt . unsqueeze ( 0 ) , phoneme_ids0 , refer )
ref_audio , sr = torchaudio . load ( ref_wav_path )
ref_audio = ref_audio . to ( device ) . float ( )
if ( ref_audio . shape [ 0 ] == 2 ) :
ref_audio = ref_audio . mean ( 0 ) . unsqueeze ( 0 )
if sr != 24000 :
ref_audio = resample ( ref_audio , sr )
# print("ref_audio",ref_audio.abs().mean())
mel2 = mel_fn ( ref_audio )
mel2 = norm_spec ( mel2 )
T_min = min ( mel2 . shape [ 2 ] , fea_ref . shape [ 2 ] )
mel2 = mel2 [ : , : , : T_min ]
fea_ref = fea_ref [ : , : , : T_min ]
if ( T_min > 468 ) :
mel2 = mel2 [ : , : , - 468 : ]
fea_ref = fea_ref [ : , : , - 468 : ]
T_min = 468
chunk_len = 934 - T_min
# print("fea_ref",fea_ref,fea_ref.shape)
# print("mel2",mel2)
mel2 = mel2 . to ( dtype )
fea_todo , ge = vq_model . decode_encp ( pred_semantic , phoneme_ids1 , refer , ge , speed )
# print("fea_todo",fea_todo)
# print("ge",ge.abs().mean())
cfm_resss = [ ]
idx = 0
while ( 1 ) :
fea_todo_chunk = fea_todo [ : , : , idx : idx + chunk_len ]
if ( fea_todo_chunk . shape [ - 1 ] == 0 ) : break
idx + = chunk_len
fea = torch . cat ( [ fea_ref , fea_todo_chunk ] , 2 ) . transpose ( 2 , 1 )
# set_seed(123)
cfm_res = vq_model . cfm . inference ( fea , torch . LongTensor ( [ fea . size ( 1 ) ] ) . to ( fea . device ) , mel2 , sample_steps , inference_cfg_rate = 0 )
cfm_res = cfm_res [ : , : , mel2 . shape [ 2 ] : ]
mel2 = cfm_res [ : , : , - T_min : ]
# print("fea", fea)
# print("mel2in", mel2)
fea_ref = fea_todo_chunk [ : , : , - T_min : ]
cfm_resss . append ( cfm_res )
cmf_res = torch . cat ( cfm_resss , 2 )
cmf_res = denorm_spec ( cmf_res )
if bigvgan_model == None : init_bigvgan ( )
with torch . inference_mode ( ) :
wav_gen = bigvgan_model ( cmf_res )
audio = wav_gen [ 0 ] [ 0 ] . cpu ( ) . detach ( ) . numpy ( )
max_audio = np . abs ( audio ) . max ( )
max_audio = np . abs ( audio ) . max ( )
if max_audio > 1 :
if max_audio > 1 :
audio / = max_audio
audio / = max_audio
audio_opt . append ( audio )
audio_opt . append ( audio )
audio_opt . append ( zero_wav )
audio_opt . append ( zero_wav )
audio_opt = np . concatenate ( audio_opt , 0 )
t4 = ttime ( )
t4 = ttime ( )
sr = hps . data . sampling_rate if version != " v3 " else 24000
if if_sr and sr == 24000 :
audio_opt = torch . from_numpy ( audio_opt ) . float ( ) . to ( device )
audio_opt , sr = audio_sr ( audio_opt . unsqueeze ( 0 ) , sr )
max_audio = np . abs ( audio_opt ) . max ( )
if max_audio > 1 : audio_opt / = max_audio
sr = 48000
if is_int32 :
if is_int32 :
audio_bytes = pack_audio ( audio_bytes , ( np . concatenate ( audio_opt , 0 ) * 2147483647 ) . astype ( np . int32 ) , hps . data . sampling_rate )
audio_bytes = pack_audio ( audio_bytes , ( audio_opt * 2147483647 ) . astype ( np . int32 ) , sr)
else :
else :
audio_bytes = pack_audio ( audio_bytes , ( np . concatenate ( audio_opt , 0 ) * 32768 ) . astype ( np . int16 ) , hps . data . sampling_rate )
audio_bytes = pack_audio ( audio_bytes , ( audio_opt * 32768 ) . astype ( np . int16 ) , sr)
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
# logger.info("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
if stream_mode == " normal " :
if stream_mode == " normal " :
audio_bytes , audio_chunk = read_clean_buffer ( audio_bytes )
audio_bytes , audio_chunk = read_clean_buffer ( audio_bytes )
@ -655,7 +811,9 @@ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language,
if not stream_mode == " normal " :
if not stream_mode == " normal " :
if media_type == " wav " :
if media_type == " wav " :
audio_bytes = pack_wav ( audio_bytes , hps . data . sampling_rate )
sr = 48000 if if_sr else 24000
sr = hps . data . sampling_rate if version != " v3 " else sr
audio_bytes = pack_wav ( audio_bytes , sr )
yield audio_bytes . getvalue ( )
yield audio_bytes . getvalue ( )
@ -688,7 +846,7 @@ def handle_change(path, text, language):
return JSONResponse ( { " code " : 0 , " message " : " Success " } , status_code = 200 )
return JSONResponse ( { " code " : 0 , " message " : " Success " } , status_code = 200 )
def handle ( refer_wav_path , prompt_text , prompt_language , text , text_language , cut_punc , top_k , top_p , temperature , speed , inp_refs ):
def handle ( refer_wav_path , prompt_text , prompt_language , text , text_language , cut_punc , top_k , top_p , temperature , speed , inp_refs , sample_steps , if_sr ):
if (
if (
refer_wav_path == " " or refer_wav_path is None
refer_wav_path == " " or refer_wav_path is None
or prompt_text == " " or prompt_text is None
or prompt_text == " " or prompt_text is None
@ -702,12 +860,15 @@ def handle(refer_wav_path, prompt_text, prompt_language, text, text_language, cu
if not default_refer . is_ready ( ) :
if not default_refer . is_ready ( ) :
return JSONResponse ( { " code " : 400 , " message " : " 未指定参考音频且接口无预设 " } , status_code = 400 )
return JSONResponse ( { " code " : 400 , " message " : " 未指定参考音频且接口无预设 " } , status_code = 400 )
if not sample_steps in [ 4 , 8 , 16 , 32 ] :
sample_steps = 32
if cut_punc == None :
if cut_punc == None :
text = cut_text ( text , default_cut_punc )
text = cut_text ( text , default_cut_punc )
else :
else :
text = cut_text ( text , cut_punc )
text = cut_text ( text , cut_punc )
return StreamingResponse ( get_tts_wav ( refer_wav_path , prompt_text , prompt_language , text , text_language , top_k , top_p , temperature , speed , inp_refs ), media_type = " audio/ " + media_type )
return StreamingResponse ( get_tts_wav ( refer_wav_path , prompt_text , prompt_language , text , text_language , top_k , top_p , temperature , speed , inp_refs , sample_steps , if_sr ), media_type = " audio/ " + media_type )
@ -915,7 +1076,9 @@ async def tts_endpoint(request: Request):
json_post_raw . get ( " top_p " , 1.0 ) ,
json_post_raw . get ( " top_p " , 1.0 ) ,
json_post_raw . get ( " temperature " , 1.0 ) ,
json_post_raw . get ( " temperature " , 1.0 ) ,
json_post_raw . get ( " speed " , 1.0 ) ,
json_post_raw . get ( " speed " , 1.0 ) ,
json_post_raw . get ( " inp_refs " , [ ] )
json_post_raw . get ( " inp_refs " , [ ] ) ,
json_post_raw . get ( " sample_steps " , 32 ) ,
json_post_raw . get ( " if_sr " , False )
)
)
@ -931,9 +1094,11 @@ async def tts_endpoint(
top_p : float = 1.0 ,
top_p : float = 1.0 ,
temperature : float = 1.0 ,
temperature : float = 1.0 ,
speed : float = 1.0 ,
speed : float = 1.0 ,
inp_refs : list = Query ( default = [ ] )
inp_refs : list = Query ( default = [ ] ) ,
sample_steps : int = 32 ,
if_sr : bool = False
) :
) :
return handle ( refer_wav_path , prompt_text , prompt_language , text , text_language , cut_punc , top_k , top_p , temperature , speed , inp_refs )
return handle ( refer_wav_path , prompt_text , prompt_language , text , text_language , cut_punc , top_k , top_p , temperature , speed , inp_refs , sample_steps , if_sr )
if __name__ == " __main__ " :
if __name__ == " __main__ " :