|
|
|
@ -25,7 +25,7 @@ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
|
|
|
|
|
from BigVGAN.bigvgan import BigVGAN
|
|
|
|
|
from feature_extractor.cnhubert import CNHubert
|
|
|
|
|
from module.mel_processing import mel_spectrogram_torch, spectrogram_torch
|
|
|
|
|
from module.models import SynthesizerTrn, SynthesizerTrnV3
|
|
|
|
|
from module.models import SynthesizerTrn, SynthesizerTrnV3,Generator
|
|
|
|
|
from peft import LoraConfig, get_peft_model
|
|
|
|
|
from process_ckpt import get_sovits_version_from_path_fast, load_sovits_new
|
|
|
|
|
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
|
|
|
@ -66,6 +66,19 @@ mel_fn = lambda x: mel_spectrogram_torch(
|
|
|
|
|
"center": False,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
mel_fn_v4 = lambda x: mel_spectrogram_torch(
|
|
|
|
|
x,
|
|
|
|
|
**{
|
|
|
|
|
"n_fft": 1280,
|
|
|
|
|
"win_size": 1280,
|
|
|
|
|
"hop_size": 320,
|
|
|
|
|
"num_mels": 100,
|
|
|
|
|
"sampling_rate": 32000,
|
|
|
|
|
"fmin": 0,
|
|
|
|
|
"fmax": None,
|
|
|
|
|
"center": False,
|
|
|
|
|
},
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
|
|
|
@ -92,11 +105,12 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
|
|
|
|
|
resample_transform_dict = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resample(audio_tensor, sr0, device):
|
|
|
|
|
def resample(audio_tensor, sr0,sr1, device):
|
|
|
|
|
global resample_transform_dict
|
|
|
|
|
if sr0 not in resample_transform_dict:
|
|
|
|
|
resample_transform_dict[sr0] = torchaudio.transforms.Resample(sr0, 24000).to(device)
|
|
|
|
|
return resample_transform_dict[sr0](audio_tensor)
|
|
|
|
|
key="%s-%s"%(sr0,sr1)
|
|
|
|
|
if key not in resample_transform_dict:
|
|
|
|
|
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
|
|
|
|
|
return resample_transform_dict[key](audio_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DictToAttrRecursive(dict):
|
|
|
|
@ -130,44 +144,6 @@ class DictToAttrRecursive(dict):
|
|
|
|
|
class NO_PROMPT_ERROR(Exception):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# configs/tts_infer.yaml
|
|
|
|
|
"""
|
|
|
|
|
custom:
|
|
|
|
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
|
|
|
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
|
|
|
|
device: cpu
|
|
|
|
|
is_half: false
|
|
|
|
|
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
|
|
|
|
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
|
|
|
|
version: v2
|
|
|
|
|
default:
|
|
|
|
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
|
|
|
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
|
|
|
|
device: cpu
|
|
|
|
|
is_half: false
|
|
|
|
|
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
|
|
|
|
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
|
|
|
|
version: v1
|
|
|
|
|
default_v2:
|
|
|
|
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
|
|
|
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
|
|
|
|
device: cpu
|
|
|
|
|
is_half: false
|
|
|
|
|
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
|
|
|
|
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
|
|
|
|
version: v2
|
|
|
|
|
default_v3:
|
|
|
|
|
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
|
|
|
|
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
|
|
|
|
device: cpu
|
|
|
|
|
is_half: false
|
|
|
|
|
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
|
|
|
|
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
|
|
|
|
|
version: v3
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_seed(seed: int):
|
|
|
|
|
seed = int(seed)
|
|
|
|
|
seed = seed if seed != -1 else random.randint(0, 2**32 - 1)
|
|
|
|
@ -220,6 +196,15 @@ class TTS_Config:
|
|
|
|
|
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
|
|
|
|
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
|
|
|
|
},
|
|
|
|
|
"v4": {
|
|
|
|
|
"device": "cpu",
|
|
|
|
|
"is_half": False,
|
|
|
|
|
"version": "v4",
|
|
|
|
|
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
|
|
|
|
|
"vits_weights_path": "GPT_SoVITS/pretrained_models/s2Gv3.pth",
|
|
|
|
|
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
|
|
|
|
|
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
|
|
|
|
},
|
|
|
|
|
}
|
|
|
|
|
configs: dict = None
|
|
|
|
|
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
|
|
|
|
@ -255,7 +240,7 @@ class TTS_Config:
|
|
|
|
|
|
|
|
|
|
assert isinstance(configs, dict)
|
|
|
|
|
version = configs.get("version", "v2").lower()
|
|
|
|
|
assert version in ["v1", "v2", "v3"]
|
|
|
|
|
assert version in ["v1", "v2", "v3", "v4"]
|
|
|
|
|
self.default_configs[version] = configs.get(version, self.default_configs[version])
|
|
|
|
|
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version]))
|
|
|
|
|
|
|
|
|
@ -356,7 +341,7 @@ class TTS_Config:
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
|
return isinstance(other, TTS_Config) and self.configs_path == other.configs_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from inference_webui import v3v4set
|
|
|
|
|
class TTS:
|
|
|
|
|
def __init__(self, configs: Union[dict, str, TTS_Config]):
|
|
|
|
|
if isinstance(configs, TTS_Config):
|
|
|
|
@ -369,7 +354,7 @@ class TTS:
|
|
|
|
|
self.bert_tokenizer: AutoTokenizer = None
|
|
|
|
|
self.bert_model: AutoModelForMaskedLM = None
|
|
|
|
|
self.cnhuhbert_model: CNHubert = None
|
|
|
|
|
self.bigvgan_model: BigVGAN = None
|
|
|
|
|
self.vocoder_model = None
|
|
|
|
|
self.sr_model: AP_BWE = None
|
|
|
|
|
self.sr_model_not_exist: bool = False
|
|
|
|
|
|
|
|
|
@ -423,10 +408,11 @@ class TTS:
|
|
|
|
|
def init_vits_weights(self, weights_path: str):
|
|
|
|
|
self.configs.vits_weights_path = weights_path
|
|
|
|
|
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
|
|
|
|
|
path_sovits_v3 = self.configs.default_configs["v3"]["vits_weights_path"]
|
|
|
|
|
print(self.configs.default_configs)
|
|
|
|
|
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
|
|
|
|
|
|
|
|
|
|
if if_lora_v3 == True and os.path.exists(path_sovits_v3) == False:
|
|
|
|
|
info = path_sovits_v3 + i18n("SoVITS V3 底模缺失,无法加载相应 LoRA 权重")
|
|
|
|
|
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
|
|
|
|
|
info = path_sovits + i18n("SoVITS %s 底模缺失,无法加载相应 LoRA 权重"%model_version)
|
|
|
|
|
raise FileExistsError(info)
|
|
|
|
|
|
|
|
|
|
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
|
|
|
|
@ -456,7 +442,7 @@ class TTS:
|
|
|
|
|
|
|
|
|
|
# print(f"model_version:{model_version}")
|
|
|
|
|
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
|
|
|
|
|
if model_version != "v3":
|
|
|
|
|
if model_version not in v3v4set:
|
|
|
|
|
vits_model = SynthesizerTrn(
|
|
|
|
|
self.configs.filter_length // 2 + 1,
|
|
|
|
|
self.configs.segment_size // self.configs.hop_length,
|
|
|
|
@ -465,14 +451,14 @@ class TTS:
|
|
|
|
|
)
|
|
|
|
|
self.configs.is_v3_synthesizer = False
|
|
|
|
|
else:
|
|
|
|
|
self.configs.is_v3_synthesizer = kwargs["version"]=model_version
|
|
|
|
|
vits_model = SynthesizerTrnV3(
|
|
|
|
|
self.configs.filter_length // 2 + 1,
|
|
|
|
|
self.configs.segment_size // self.configs.hop_length,
|
|
|
|
|
n_speakers=self.configs.n_speakers,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
self.configs.is_v3_synthesizer = True
|
|
|
|
|
self.init_bigvgan()
|
|
|
|
|
self.init_vocoder()
|
|
|
|
|
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
|
|
|
|
|
del vits_model.enc_q
|
|
|
|
|
|
|
|
|
@ -482,7 +468,7 @@ class TTS:
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
print(
|
|
|
|
|
f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits_v3)['weight'], strict=False)}"
|
|
|
|
|
f"Loading VITS pretrained weights from {weights_path}. {vits_model.load_state_dict(load_sovits_new(path_sovits)['weight'], strict=False)}"
|
|
|
|
|
)
|
|
|
|
|
lora_rank = dict_s2["lora_rank"]
|
|
|
|
|
lora_config = LoraConfig(
|
|
|
|
@ -521,20 +507,36 @@ class TTS:
|
|
|
|
|
if self.configs.is_half and str(self.configs.device) != "cpu":
|
|
|
|
|
self.t2s_model = self.t2s_model.half()
|
|
|
|
|
|
|
|
|
|
def init_bigvgan(self):
|
|
|
|
|
if self.bigvgan_model is not None:
|
|
|
|
|
def init_vocoder(self):
|
|
|
|
|
if self.vocoder_model is not None:
|
|
|
|
|
return
|
|
|
|
|
self.bigvgan_model = BigVGAN.from_pretrained(
|
|
|
|
|
if self.configs.is_v3_synthesizer=="v3":
|
|
|
|
|
self.vocoder_model = BigVGAN.from_pretrained(
|
|
|
|
|
"%s/GPT_SoVITS/pretrained_models/models--nvidia--bigvgan_v2_24khz_100band_256x" % (now_dir,),
|
|
|
|
|
use_cuda_kernel=False,
|
|
|
|
|
) # if True, RuntimeError: Ninja is required to load C++ extensions
|
|
|
|
|
# remove weight norm in the model and set to eval mode
|
|
|
|
|
self.bigvgan_model.remove_weight_norm()
|
|
|
|
|
self.bigvgan_model = self.bigvgan_model.eval()
|
|
|
|
|
self.vocoder_model.remove_weight_norm()
|
|
|
|
|
self.vocoder_model = self.vocoder_model.eval()
|
|
|
|
|
else:
|
|
|
|
|
self.vocoder_model = Generator(
|
|
|
|
|
initial_channel=100,
|
|
|
|
|
resblock="1",
|
|
|
|
|
resblock_kernel_sizes=[3, 7, 11],
|
|
|
|
|
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
|
|
|
|
upsample_rates=[10, 6, 2, 2, 2],
|
|
|
|
|
upsample_initial_channel=512,
|
|
|
|
|
upsample_kernel_sizes=[20, 12, 4, 4, 4],
|
|
|
|
|
gin_channels=0, is_bias=True
|
|
|
|
|
)
|
|
|
|
|
self.vocoder_model.eval()
|
|
|
|
|
self.vocoder_model.remove_weight_norm()
|
|
|
|
|
state_dict_g = torch.load("%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu")
|
|
|
|
|
print("loading v4 vocoder", self.vocoder_model.load_state_dict(state_dict_g))
|
|
|
|
|
if self.configs.is_half == True:
|
|
|
|
|
self.bigvgan_model = self.bigvgan_model.half().to(self.configs.device)
|
|
|
|
|
self.vocoder_model = self.vocoder_model.half().to(self.configs.device)
|
|
|
|
|
else:
|
|
|
|
|
self.bigvgan_model = self.bigvgan_model.to(self.configs.device)
|
|
|
|
|
self.vocoder_model = self.vocoder_model.to(self.configs.device)
|
|
|
|
|
|
|
|
|
|
def init_sr_model(self):
|
|
|
|
|
if self.sr_model is not None:
|
|
|
|
@ -570,8 +572,8 @@ class TTS:
|
|
|
|
|
self.bert_model = self.bert_model.half()
|
|
|
|
|
if self.cnhuhbert_model is not None:
|
|
|
|
|
self.cnhuhbert_model = self.cnhuhbert_model.half()
|
|
|
|
|
if self.bigvgan_model is not None:
|
|
|
|
|
self.bigvgan_model = self.bigvgan_model.half()
|
|
|
|
|
if self.vocoder_model is not None:
|
|
|
|
|
self.vocoder_model = self.vocoder_model.half()
|
|
|
|
|
else:
|
|
|
|
|
if self.t2s_model is not None:
|
|
|
|
|
self.t2s_model = self.t2s_model.float()
|
|
|
|
@ -581,8 +583,8 @@ class TTS:
|
|
|
|
|
self.bert_model = self.bert_model.float()
|
|
|
|
|
if self.cnhuhbert_model is not None:
|
|
|
|
|
self.cnhuhbert_model = self.cnhuhbert_model.float()
|
|
|
|
|
if self.bigvgan_model is not None:
|
|
|
|
|
self.bigvgan_model = self.bigvgan_model.float()
|
|
|
|
|
if self.vocoder_model is not None:
|
|
|
|
|
self.vocoder_model = self.vocoder_model.float()
|
|
|
|
|
|
|
|
|
|
def set_device(self, device: torch.device, save: bool = True):
|
|
|
|
|
"""
|
|
|
|
@ -601,8 +603,8 @@ class TTS:
|
|
|
|
|
self.bert_model = self.bert_model.to(device)
|
|
|
|
|
if self.cnhuhbert_model is not None:
|
|
|
|
|
self.cnhuhbert_model = self.cnhuhbert_model.to(device)
|
|
|
|
|
if self.bigvgan_model is not None:
|
|
|
|
|
self.bigvgan_model = self.bigvgan_model.to(device)
|
|
|
|
|
if self.vocoder_model is not None:
|
|
|
|
|
self.vocoder_model = self.vocoder_model.to(device)
|
|
|
|
|
if self.sr_model is not None:
|
|
|
|
|
self.sr_model = self.sr_model.to(device)
|
|
|
|
|
|
|
|
|
@ -913,13 +915,13 @@ class TTS:
|
|
|
|
|
split_bucket = False
|
|
|
|
|
print(i18n("分段返回模式不支持分桶处理,已自动关闭分桶处理"))
|
|
|
|
|
|
|
|
|
|
if split_bucket and speed_factor == 1.0 and not (self.configs.is_v3_synthesizer and parallel_infer):
|
|
|
|
|
if split_bucket and speed_factor == 1.0 and not (self.configs.is_v3_synthesizer!=False and parallel_infer):
|
|
|
|
|
print(i18n("分桶处理模式已开启"))
|
|
|
|
|
elif speed_factor != 1.0:
|
|
|
|
|
print(i18n("语速调节不支持分桶处理,已自动关闭分桶处理"))
|
|
|
|
|
split_bucket = False
|
|
|
|
|
elif self.configs.is_v3_synthesizer and parallel_infer:
|
|
|
|
|
print(i18n("当开启并行推理模式时,SoVits V3模型不支持分桶处理,已自动关闭分桶处理"))
|
|
|
|
|
elif self.configs.is_v3_synthesizer!=False and parallel_infer:
|
|
|
|
|
print(i18n("当开启并行推理模式时,SoVits V3V4模型不支持分桶处理,已自动关闭分桶处理"))
|
|
|
|
|
split_bucket = False
|
|
|
|
|
else:
|
|
|
|
|
print(i18n("分桶处理模式已关闭"))
|
|
|
|
@ -936,7 +938,7 @@ class TTS:
|
|
|
|
|
if not no_prompt_text:
|
|
|
|
|
assert prompt_lang in self.configs.languages
|
|
|
|
|
|
|
|
|
|
if no_prompt_text and self.configs.is_v3_synthesizer:
|
|
|
|
|
if no_prompt_text and self.configs.is_v3_synthesizer!=False:
|
|
|
|
|
raise NO_PROMPT_ERROR("prompt_text cannot be empty when using SoVITS_V3")
|
|
|
|
|
|
|
|
|
|
if ref_audio_path in [None, ""] and (
|
|
|
|
@ -1044,7 +1046,12 @@ class TTS:
|
|
|
|
|
t_34 = 0.0
|
|
|
|
|
t_45 = 0.0
|
|
|
|
|
audio = []
|
|
|
|
|
output_sr = self.configs.sampling_rate if not self.configs.is_v3_synthesizer else 24000
|
|
|
|
|
if self.configs.is_v3_synthesizer==False:
|
|
|
|
|
output_sr = 32000
|
|
|
|
|
elif self.configs.is_v3_synthesizer == "v3":
|
|
|
|
|
output_sr = 24000
|
|
|
|
|
else:
|
|
|
|
|
output_sr = 48000 # v4
|
|
|
|
|
for item in data:
|
|
|
|
|
t3 = time.perf_counter()
|
|
|
|
|
if return_fragment:
|
|
|
|
@ -1144,7 +1151,7 @@ class TTS:
|
|
|
|
|
if parallel_infer:
|
|
|
|
|
print(f"{i18n('并行合成中')}...")
|
|
|
|
|
audio_fragments = self.v3_synthesis_batched_infer(
|
|
|
|
|
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps
|
|
|
|
|
idx_list, pred_semantic_list, batch_phones, speed=speed_factor, sample_steps=sample_steps,model_version=self.configs.is_v3_synthesizer
|
|
|
|
|
)
|
|
|
|
|
batch_audio_fragment.extend(audio_fragments)
|
|
|
|
|
else:
|
|
|
|
@ -1154,7 +1161,7 @@ class TTS:
|
|
|
|
|
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
|
|
|
|
|
) # .unsqueeze(0)#mq要多unsqueeze一次
|
|
|
|
|
audio_fragment = self.v3_synthesis(
|
|
|
|
|
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps
|
|
|
|
|
_pred_semantic, phones, speed=speed_factor, sample_steps=sample_steps,model_version=self.configs.is_v3_synthesizer
|
|
|
|
|
)
|
|
|
|
|
batch_audio_fragment.append(audio_fragment)
|
|
|
|
|
|
|
|
|
@ -1169,7 +1176,7 @@ class TTS:
|
|
|
|
|
speed_factor,
|
|
|
|
|
False,
|
|
|
|
|
fragment_interval,
|
|
|
|
|
super_sampling if self.configs.is_v3_synthesizer else False,
|
|
|
|
|
super_sampling if self.configs.is_v3_synthesizer=="v3" else False,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
audio.append(batch_audio_fragment)
|
|
|
|
@ -1190,7 +1197,7 @@ class TTS:
|
|
|
|
|
speed_factor,
|
|
|
|
|
split_bucket,
|
|
|
|
|
fragment_interval,
|
|
|
|
|
super_sampling if self.configs.is_v3_synthesizer else False,
|
|
|
|
|
super_sampling if self.configs.is_v3_synthesizer=="v3" else False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
@ -1273,7 +1280,7 @@ class TTS:
|
|
|
|
|
return sr, audio
|
|
|
|
|
|
|
|
|
|
def v3_synthesis(
|
|
|
|
|
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32
|
|
|
|
|
self, semantic_tokens: torch.Tensor, phones: torch.Tensor, speed: float = 1.0, sample_steps: int = 32,model_version="v4"
|
|
|
|
|
):
|
|
|
|
|
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
|
|
|
|
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
|
|
|
@ -1285,19 +1292,22 @@ class TTS:
|
|
|
|
|
ref_audio = ref_audio.to(self.configs.device).float()
|
|
|
|
|
if ref_audio.shape[0] == 2:
|
|
|
|
|
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
|
|
|
|
if ref_sr != 24000:
|
|
|
|
|
ref_audio = resample(ref_audio, ref_sr, self.configs.device)
|
|
|
|
|
tgt_sr = 24000 if model_version == "v3" else 32000
|
|
|
|
|
if ref_sr != tgt_sr:
|
|
|
|
|
ref_audio = resample(ref_audio, ref_sr,tgt_sr, self.configs.device)
|
|
|
|
|
|
|
|
|
|
mel2 = mel_fn(ref_audio)
|
|
|
|
|
mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio)
|
|
|
|
|
mel2 = norm_spec(mel2)
|
|
|
|
|
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
|
|
|
|
mel2 = mel2[:, :, :T_min]
|
|
|
|
|
fea_ref = fea_ref[:, :, :T_min]
|
|
|
|
|
if T_min > 468:
|
|
|
|
|
mel2 = mel2[:, :, -468:]
|
|
|
|
|
fea_ref = fea_ref[:, :, -468:]
|
|
|
|
|
T_min = 468
|
|
|
|
|
chunk_len = 934 - T_min
|
|
|
|
|
Tref = 468 if model_version == "v3" else 500
|
|
|
|
|
Tchunk = 934 if model_version == "v3" else 1000
|
|
|
|
|
if T_min > Tref:
|
|
|
|
|
mel2 = mel2[:, :, -Tref:]
|
|
|
|
|
fea_ref = fea_ref[:, :, -Tref:]
|
|
|
|
|
T_min = Tref
|
|
|
|
|
chunk_len = Tchunk - T_min
|
|
|
|
|
|
|
|
|
|
mel2 = mel2.to(self.precision)
|
|
|
|
|
fea_todo, ge = self.vits_model.decode_encp(semantic_tokens, phones, refer_audio_spec, ge, speed)
|
|
|
|
@ -1324,7 +1334,7 @@ class TTS:
|
|
|
|
|
cfm_res = denorm_spec(cfm_res)
|
|
|
|
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
wav_gen = self.bigvgan_model(cfm_res)
|
|
|
|
|
wav_gen = self.vocoder_model(cfm_res)
|
|
|
|
|
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
|
|
|
|
|
|
|
|
|
return audio
|
|
|
|
@ -1335,7 +1345,7 @@ class TTS:
|
|
|
|
|
semantic_tokens_list: List[torch.Tensor],
|
|
|
|
|
batch_phones: List[torch.Tensor],
|
|
|
|
|
speed: float = 1.0,
|
|
|
|
|
sample_steps: int = 32,
|
|
|
|
|
sample_steps: int = 32,model_version="v4"
|
|
|
|
|
) -> List[torch.Tensor]:
|
|
|
|
|
prompt_semantic_tokens = self.prompt_cache["prompt_semantic"].unsqueeze(0).unsqueeze(0).to(self.configs.device)
|
|
|
|
|
prompt_phones = torch.LongTensor(self.prompt_cache["phones"]).unsqueeze(0).to(self.configs.device)
|
|
|
|
@ -1347,19 +1357,22 @@ class TTS:
|
|
|
|
|
ref_audio = ref_audio.to(self.configs.device).float()
|
|
|
|
|
if ref_audio.shape[0] == 2:
|
|
|
|
|
ref_audio = ref_audio.mean(0).unsqueeze(0)
|
|
|
|
|
if ref_sr != 24000:
|
|
|
|
|
ref_audio = resample(ref_audio, ref_sr, self.configs.device)
|
|
|
|
|
tgt_sr = 24000 if model_version == "v3" else 32000
|
|
|
|
|
if ref_sr != tgt_sr:
|
|
|
|
|
ref_audio = resample(ref_audio, ref_sr,tgt_sr, self.configs.device)
|
|
|
|
|
|
|
|
|
|
mel2 = mel_fn(ref_audio)
|
|
|
|
|
mel2 = mel_fn(ref_audio) if model_version == "v3" else mel_fn_v4(ref_audio)
|
|
|
|
|
mel2 = norm_spec(mel2)
|
|
|
|
|
T_min = min(mel2.shape[2], fea_ref.shape[2])
|
|
|
|
|
mel2 = mel2[:, :, :T_min]
|
|
|
|
|
fea_ref = fea_ref[:, :, :T_min]
|
|
|
|
|
if T_min > 468:
|
|
|
|
|
mel2 = mel2[:, :, -468:]
|
|
|
|
|
fea_ref = fea_ref[:, :, -468:]
|
|
|
|
|
T_min = 468
|
|
|
|
|
chunk_len = 934 - T_min
|
|
|
|
|
Tref = 468 if model_version == "v3" else 500
|
|
|
|
|
Tchunk = 934 if model_version == "v3" else 1000
|
|
|
|
|
if T_min > Tref:
|
|
|
|
|
mel2 = mel2[:, :, -Tref:]
|
|
|
|
|
fea_ref = fea_ref[:, :, -Tref:]
|
|
|
|
|
T_min = Tref
|
|
|
|
|
chunk_len = Tchunk - T_min
|
|
|
|
|
|
|
|
|
|
mel2 = mel2.to(self.precision)
|
|
|
|
|
|
|
|
|
@ -1413,7 +1426,7 @@ class TTS:
|
|
|
|
|
pred_spec = denorm_spec(pred_spec)
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
wav_gen = self.bigvgan_model(pred_spec)
|
|
|
|
|
wav_gen = self.vocoder_model(pred_spec)
|
|
|
|
|
audio = wav_gen[0][0] # .cpu().detach().numpy()
|
|
|
|
|
|
|
|
|
|
audio_fragments = []
|
|
|
|
|