support sovits v2Pro v2ProPlus

support sovits v2Pro v2ProPlus
main
RVC-Boss 2 months ago committed by GitHub
parent 0621259549
commit 92819d0b31
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -35,7 +35,16 @@ from tools.i18n.i18n import I18nAuto, scan_language_list
from tools.my_utils import load_audio
from TTS_infer_pack.text_segmentation_method import splits
from TTS_infer_pack.TextPreprocessor import TextPreprocessor
from sv import SV
resample_transform_dict={}
def resample(audio_tensor, sr0,sr1,device):
global resample_transform_dict
key="%s-%s-%s"%(sr0,sr1,str(device))
if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample(
sr0, sr1
).to(device)
return resample_transform_dict[key](audio_tensor)
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
@ -102,18 +111,6 @@ def speed_change(input_audio: np.ndarray, speed: float, sr: int):
return processed_audio
resample_transform_dict = {}
def resample(audio_tensor, sr0, sr1, device):
global resample_transform_dict
key = "%s-%s" % (sr0, sr1)
if key not in resample_transform_dict:
resample_transform_dict[key] = torchaudio.transforms.Resample(sr0, sr1).to(device)
return resample_transform_dict[key](audio_tensor)
class DictToAttrRecursive(dict):
def __init__(self, input_dict):
super().__init__(input_dict)
@ -252,6 +249,24 @@ class TTS_Config:
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
"v2Pro": {
"device": "cpu",
"is_half": False,
"version": "v2Pro",
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2Pro_pre1.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
"v2ProPlus": {
"device": "cpu",
"is_half": False,
"version": "v2ProPlus",
"t2s_weights_path": "GPT_SoVITS/pretrained_models/s1v3.ckpt",
"vits_weights_path": "GPT_SoVITS/pretrained_models/v2Pro/s2Gv2ProPlus_pre1.pth",
"cnhuhbert_base_path": "GPT_SoVITS/pretrained_models/chinese-hubert-base",
"bert_base_path": "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
},
}
configs: dict = None
v1_languages: list = ["auto", "en", "zh", "ja", "all_zh", "all_ja"]
@ -287,7 +302,7 @@ class TTS_Config:
assert isinstance(configs, dict)
version = configs.get("version", "v2").lower()
assert version in ["v1", "v2", "v3", "v4"]
assert version in ["v1", "v2", "v3", "v4", "v2Pro", "v2ProPlus"]
self.default_configs[version] = configs.get(version, self.default_configs[version])
self.configs: dict = configs.get("custom", deepcopy(self.default_configs[version]))
@ -403,6 +418,7 @@ class TTS:
self.cnhuhbert_model: CNHubert = None
self.vocoder = None
self.sr_model: AP_BWE = None
self.sv_model = None
self.sr_model_not_exist: bool = False
self.vocoder_configs: dict = {
@ -463,6 +479,8 @@ class TTS:
def init_vits_weights(self, weights_path: str):
self.configs.vits_weights_path = weights_path
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(weights_path)
if "Pro"in model_version:
self.init_sv_model()
path_sovits = self.configs.default_configs[model_version]["vits_weights_path"]
if if_lora_v3 == True and os.path.exists(path_sovits) == False:
@ -472,7 +490,6 @@ class TTS:
# dict_s2 = torch.load(weights_path, map_location=self.configs.device,weights_only=False)
dict_s2 = load_sovits_new(weights_path)
hps = dict_s2["config"]
hps["model"]["semantic_frame_rate"] = "25hz"
if "enc_p.text_embedding.weight" not in dict_s2["weight"]:
hps["model"]["version"] = "v2" # v3model,v2sybomls
@ -480,7 +497,15 @@ class TTS:
hps["model"]["version"] = "v1"
else:
hps["model"]["version"] = "v2"
# version = hps["model"]["version"]
version = hps["model"]["version"]
v3v4set={"v3", "v4"}
if model_version not in v3v4set:
if "Pro"not in model_version:
model_version = version
else:
hps["model"]["version"] = model_version
else:
hps["model"]["version"] = model_version
self.configs.filter_length = hps["data"]["filter_length"]
self.configs.segment_size = hps["train"]["segment_size"]
@ -496,7 +521,7 @@ class TTS:
# print(f"model_version:{model_version}")
# print(f'hps["model"]["version"]:{hps["model"]["version"]}')
if model_version not in {"v3", "v4"}:
if model_version not in v3v4set:
vits_model = SynthesizerTrn(
self.configs.filter_length // 2 + 1,
self.configs.segment_size // self.configs.hop_length,
@ -517,6 +542,8 @@ class TTS:
if "pretrained" not in weights_path and hasattr(vits_model, "enc_q"):
del vits_model.enc_q
self.is_v2pro=model_version in {"v2Pro","v2ProPlus"}
if if_lora_v3 == False:
print(
f"Loading VITS weights from {weights_path}. {vits_model.load_state_dict(dict_s2['weight'], strict=False)}"
@ -551,7 +578,7 @@ class TTS:
self.configs.t2s_weights_path = weights_path
self.configs.save_configs()
self.configs.hz = 50
dict_s1 = torch.load(weights_path, map_location=self.configs.device)
dict_s1 = torch.load(weights_path, map_location=self.configs.device, weights_only=False)
config = dict_s1["config"]
self.configs.max_sec = config["data"]["max_sec"]
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
@ -605,7 +632,7 @@ class TTS:
)
self.vocoder.remove_weight_norm()
state_dict_g = torch.load(
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu"
"%s/GPT_SoVITS/pretrained_models/gsv-v4-pretrained/vocoder.pth" % (now_dir,), map_location="cpu", weights_only=False
)
print("loading vocoder", self.vocoder.load_state_dict(state_dict_g))
@ -631,6 +658,11 @@ class TTS:
print(i18n("你没有下载超分模型的参数,因此不进行超分。如想超分请先参照教程把文件下载好"))
self.sr_model_not_exist = True
def init_sv_model(self):
if self.sv_model is not None:
return
self.sv_model = SV(self.configs.device, self.configs.is_half)
def enable_half_precision(self, enable: bool = True, save: bool = True):
"""
To enable half precision for the TTS model.
@ -706,11 +738,11 @@ class TTS:
self.prompt_cache["ref_audio_path"] = ref_audio_path
def _set_ref_spec(self, ref_audio_path):
spec = self._get_ref_spec(ref_audio_path)
spec_audio = self._get_ref_spec(ref_audio_path)
if self.prompt_cache["refer_spec"] in [[], None]:
self.prompt_cache["refer_spec"] = [spec]
self.prompt_cache["refer_spec"] = [spec_audio]
else:
self.prompt_cache["refer_spec"][0] = spec
self.prompt_cache["refer_spec"][0] = spec_audio
def _get_ref_spec(self, ref_audio_path):
raw_audio, raw_sr = torchaudio.load(ref_audio_path)
@ -718,25 +750,33 @@ class TTS:
self.prompt_cache["raw_audio"] = raw_audio
self.prompt_cache["raw_sr"] = raw_sr
audio = load_audio(ref_audio_path, int(self.configs.sampling_rate))
audio = torch.FloatTensor(audio)
if raw_sr != self.configs.sampling_rate:
audio = raw_audio.to(self.configs.device)
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
audio = resample(audio, raw_sr, self.configs.sampling_rate, self.configs.device)
else:
audio = raw_audio.to(self.configs.device)
if (audio.shape[0] == 2): audio = audio.mean(0).unsqueeze(0)
maxx = audio.abs().max()
if maxx > 1:
audio /= min(2, maxx)
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm,
audio,
self.configs.filter_length,
self.configs.sampling_rate,
self.configs.hop_length,
self.configs.win_length,
center=False,
)
spec = spec.to(self.configs.device)
if self.configs.is_half:
spec = spec.half()
return spec
if self.is_v2pro == True:
audio = resample(audio, self.configs.sampling_rate, 16000, self.configs.device)
if self.configs.is_half:
audio = audio.half()
else:audio=None
return spec,audio
def _set_prompt_semantic(self, ref_wav_path: str):
zero_wav = np.zeros(
@ -1171,10 +1211,13 @@ class TTS:
t4 = time.perf_counter()
t_34 += t4 - t3
refer_audio_spec: torch.Tensor = [
item.to(dtype=self.precision, device=self.configs.device)
for item in self.prompt_cache["refer_spec"]
]
refer_audio_spec = []
if self.is_v2pro:sv_emb=[]
for spec,audio_tensor in self.prompt_cache["refer_spec"]:
spec=spec.to(dtype=self.precision, device=self.configs.device)
refer_audio_spec.append(spec)
if self.is_v2pro:
sv_emb.append(self.sv_model.compute_embedding3(audio_tensor))
batch_audio_fragment = []
@ -1206,9 +1249,10 @@ class TTS:
torch.cat(pred_semantic_list).unsqueeze(0).unsqueeze(0).to(self.configs.device)
)
_batch_phones = torch.cat(batch_phones).unsqueeze(0).to(self.configs.device)
_batch_audio_fragment = self.vits_model.decode(
all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
if self.is_v2pro!=True:
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
else:
_batch_audio_fragment = self.vits_model.decode(all_pred_semantic, _batch_phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
audio_frag_end_idx.insert(0, 0)
batch_audio_fragment = [
_batch_audio_fragment[audio_frag_end_idx[i - 1] : audio_frag_end_idx[i]]
@ -1221,9 +1265,10 @@ class TTS:
_pred_semantic = (
pred_semantic_list[i][-idx:].unsqueeze(0).unsqueeze(0)
) # .unsqueeze(0)#mq要多unsqueeze一次
audio_fragment = self.vits_model.decode(
_pred_semantic, phones, refer_audio_spec, speed=speed_factor
).detach()[0, 0, :]
if self.is_v2pro != True:
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor).detach()[0, 0, :]
else:
audio_fragment = self.vits_model.decode(_pred_semantic, phones, refer_audio_spec, speed=speed_factor,sv_emb=sv_emb).detach()[0, 0, :]
batch_audio_fragment.append(audio_fragment) ###试试重建不带上prompt部分
else:
if parallel_infer:

Loading…
Cancel
Save