support sovits v2Pro v2ProPlus

support sovits v2Pro v2ProPlus
main
RVC-Boss 2 months ago committed by GitHub
parent 0621259549
commit 92819d0b31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -35,7 +35,16 @@ from tools.i18n.i18n import I18nAuto, scan_language_list
from tools.my_utils import load_audio from tools.my_utils import load_audio
from TTS_infer_pack.text_segmentation_method import splits from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from sv import SV
resample_transform_dict={}
def resample(audio_tensor, sr0,sr1,device):
global resample_transform_dict
key="%s-%s-%s"%(sr0,sr1,str(device))
if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample(
sr0, sr1
).to(device)
return resample_transform_dict[key](audio_tensor)
language = os.environ.get("language", "Auto") language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language) i18n = I18nAuto(language=language)
@ -102,18 +111,6 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
return processed_audio return processed_audio
resample_transform_dict = {}
def resample(audio_tensor, sr0, sr1, device):
global resample_transform_dict
key = "%s-%s" % (sr0, sr1)
if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
return resample_transform_dict[key](audio_tensor)
class DictToAttrRecursive(dict): class DictToAttrRecursive(dict):
def __init__(self, input_dict): def __init__(self, input_dict):
super().__init__(input_dict) super().__init__(input_dict)
@ -252,6 +249,24 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base", "cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large", "bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
}, },
"v2Pro": {
"device": "cpu",
"is_half": False,
"version": "v2Pro",
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro_pre1.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
"v2ProPlus": {
"device": "cpu",
"is_half": False,
"version": "v2ProPlus",
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus_pre1.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
} }
configs: dict = None configs: dict = None
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"] v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
@ -287,7 +302,7 @@ class TTS_Config:
assert isinstance(configs, dict) assert isinstance(configs, dict)
version = configs.get("version", "v2").lower() version = configs.get("version", "v2").lower()
assert version in ["v1", "v2", "v3", "v4"] assert version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]
self.default_configs[version] = configs.get(version, self.default_configs[version]) self.default_configs[version] = configs.get(version, self.default_configs[version])
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version])) self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version]))
@ -403,6 +418,7 @@ class TTS:
self.cnhuhbert_model: CNHubert = None self.cnhuhbert_model: CNHubert = None
self.vocoder = None self.vocoder = None
self.sr_model: AP_BWE = None self.sr_model: AP_BWE = None
self.sv_model = None
self.sr_model_not_exist: bool = False self.sr_model_not_exist: bool = False
self.vocoder_configs: dict = { self.vocoder_configs: dict = {
@ -463,6 +479,8 @@ class TTS:
def init_vits_weights(self, weights_path: str): def init_vits_weights(self, weights_path: str):
self.configs.vits_weights_path = weights_path self.configs.vits_weights_path = weights_path
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path) version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
if "Pro"in model_version:
self.init_sv_model()
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"] path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
if if_lora_v3 == True and os.path.exists(path_sovits) == False: if if_lora_v3 == True and os.path.exists(path_sovits) == False:
@ -472,7 +490,6 @@ class TTS:
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False) # dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
dict_s2 = load_sovits_new(weights_path) dict_s2 = load_sovits_new(weights_path)
hps = dict_s2["config"] hps = dict_s2["config"]
hps["model"]["semantic_frame_rate"] = "25hz" hps["model"]["semantic_frame_rate"] = "25hz"
if "enc_p.text_embedding.weight" not in dict_s2["weight"]: if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
hps["model"]["version"] = "v2" # v3model,v2sybomls hps["model"]["version"] = "v2" # v3model,v2sybomls
@ -480,7 +497,15 @@ class TTS:
hps["model"]["version"] = "v1" hps["model"]["version"] = "v1"
else: else:
hps["model"]["version"] = "v2" hps["model"]["version"] = "v2"
# version = hps["model"]["version"] version = hps["model"]["version"]
v3v4set={"v3", "v4"}
if model_version not in v3v4set:
if "Pro"not in model_version:
model_version = version
else:
hps["model"]["version"] = model_version
else:
hps["model"]["version"] = model_version
self.configs.filter_length = hps["data"]["filter_length"] self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"] self.configs.segment_size = hps["train"]["segment_size"]
@ -496,7 +521,7 @@ class TTS:
# print(f"model_version:{model_version}") # print(f"model_version:{model_version}")
# print(f'hps["model"]["version"]:{hps["model"]["version"]}') # print(f'hps["model"]["version"]:{hps["model"]["version"]}')
if model_version not in {"v3", "v4"}: if model_version not in v3v4set:
vits_model = SynthesizerTrn( vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1, self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length, self.configs.segment_size // self.configs.hop_length,
@ -517,6 +542,8 @@ class TTS:
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"): if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
del vits_model.enc_q del vits_model.enc_q
self.is_v2pro=model_version in {"v2Pro","v2ProPlus"}
if if_lora_v3 == False: if if_lora_v3 == False:
print( print(
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}" f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
@ -551,7 +578,7 @@ class TTS:
self.configs.t2s_weights_path = weights_path self.configs.t2s_weights_path = weights_path
self.configs.save_configs() self.configs.save_configs()
self.configs.hz = 50 self.configs.hz = 50
dict_s1 = torch.load(weights_path, map_location=self.configs.device) dict_s1 = torch.load(weights_path, map_location=self.configs.device, weights_only=False)
config = dict_s1["config"] config = dict_s1["config"]
self.configs.max_sec = config["data"]["max_sec"] self.configs.max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False) t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
@ -605,7 +632,7 @@ class TTS:
) )
self.vocoder.remove_weight_norm() self.vocoder.remove_weight_norm()
state_dict_g = torch.load( state_dict_g = torch.load(
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu" "%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
) )
print("loading vocoder", self.vocoder.load_state_dict(state_dict_g)) print("loading vocoder", self.vocoder.load_state_dict(state_dict_g))
@ -631,6 +658,11 @@ class TTS:
print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好")) print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
self.sr_model_not_exist = True self.sr_model_not_exist = True
def init_sv_model(self):
if self.sv_model is not None:
return
self.sv_model = SV(self.configs.device, self.configs.is_half)
def enable_half_precision(self, enable: bool = True, save: bool = True): def enable_half_precision(self, enable: bool = True, save: bool = True):
""" """
To enable half precision for the TTS model. To enable half precision for the TTS model.
@ -706,11 +738,11 @@ class TTS:
self.prompt_cache["ref_audio_path"] = ref_audio_path self.prompt_cache["ref_audio_path"] = ref_audio_path
def _set_ref_spec(self, ref_audio_path): def _set_ref_spec(self, ref_audio_path):
spec = self._get_ref_spec(ref_audio_path) spec_audio = self._get_ref_spec(ref_audio_path)
if self.prompt_cache["refer_spec"] in [[], None]: if self.prompt_cache["refer_spec"] in [[], None]:
self.prompt_cache["refer_spec"] = [spec] self.prompt_cache["refer_spec"] = [spec_audio]
else: else:
self.prompt_cache["refer_spec"][0] = spec self.prompt_cache["refer_spec"][0] = spec_audio
def _get_ref_spec(self, ref_audio_path): def _get_ref_spec(self, ref_audio_path):
raw_audio, raw_sr = torchaudio.load(ref_audio_path) raw_audio, raw_sr = torchaudio.load(ref_audio_path)
@ -718,25 +750,33 @@ class TTS:
self.prompt_cache["raw_audio"] = raw_audio self.prompt_cache["raw_audio"] = raw_audio
self.prompt_cache["raw_sr"] = raw_sr self.prompt_cache["raw_sr"] = raw_sr
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate)) if raw_sr != self.configs.sampling_rate:
audio = torch.FloatTensor(audio) audio = raw_audio.to(self.configs.device)
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
else:
audio = raw_audio.to(self.configs.device)
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
maxx = audio.abs().max() maxx = audio.abs().max()
if maxx > 1: if maxx > 1:
audio /= min(2, maxx) audio /= min(2, maxx)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch( spec = spectrogram_torch(
audio_norm, audio,
self.configs.filter_length, self.configs.filter_length,
self.configs.sampling_rate, self.configs.sampling_rate,
self.configs.hop_length, self.configs.hop_length,
self.configs.win_length, self.configs.win_length,
center=False, center=False,
) )
spec = spec.to(self.configs.device)
if self.configs.is_half: if self.configs.is_half:
spec = spec.half() spec = spec.half()
return spec if self.is_v2pro == True:
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
if self.configs.is_half:
audio = audio.half()
else:audio=None
return spec,audio
def _set_prompt_semantic(self, ref_wav_path: str): def _set_prompt_semantic(self, ref_wav_path: str):
zero_wav = np.zeros( zero_wav = np.zeros(
@ -1171,10 +1211,13 @@ class TTS:
t4 = time.perf_counter() t4 = time.perf_counter()
t_34 += t4 - t3 t_34 += t4 - t3
refer_audio_spec: torch.Tensor = [ refer_audio_spec = []
item.to(dtype=self.precision, device=self.configs.device) if self.is_v2pro:sv_emb=[]
for item in self.prompt_cache["refer_spec"] for spec,audio_tensor in self.prompt_cache["refer_spec"]:
] spec=spec.to(dtype=self.precision, device=self.configs.device)
refer_audio_spec.append(spec)
if self.is_v2pro:
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
batch_audio_fragment = [] batch_audio_fragment = []
@ -1206,9 +1249,10 @@ class TTS:
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device) torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
) )
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device) _batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = self.vits_model.decode( if self.is_v2pro!=True:
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor _batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
).detach()[0, 0, :] else:
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
audio_frag_end_idx.insert(0, 0) audio_frag_end_idx.insert(0, 0)
batch_audio_fragment = [ batch_audio_fragment = [
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]] _batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
@ -1221,9 +1265,10 @@ class TTS:
_pred_semantic = ( _pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0) pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次 ) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.vits_model.decode( if self.is_v2pro != True:
_pred_semantic, phones, refer_audio_spec, speed=speed_factor audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
).detach()[0, 0, :] else:
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分 batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
else: else:
if parallel_infer: if parallel_infer:

Loading…
Cancel
Save