From 06212595494c3ab0f132e13ec194061d1ff260b1 Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Wed, 4 Jun 2025 15:18:55 +0800 Subject: [PATCH] support sovits v2Pro v2ProPlus support sovits v2Pro v2ProPlus --- GPT_SoVITS/module/data_utils.py | 36 +++++++++++++++++++++++++----- GPT_SoVITS/module/models.py | 39 +++++++++++++++++++++++---------- 2 files changed, 57 insertions(+), 18 deletions(-) diff --git a/GPT_SoVITS/module/data_utils.py b/GPT_SoVITS/module/data_utils.py index 11f6b09..8182968 100644 --- a/GPT_SoVITS/module/data_utils.py +++ b/GPT_SoVITS/module/data_utils.py @@ -21,7 +21,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): 3) computes spectrograms from audio files. """ - def __init__(self, hparams, val=False): + def __init__(self, hparams, version=None,val=False): exp_dir = hparams.exp_dir self.path2 = "%s/2-name2text.txt" % exp_dir self.path4 = "%s/4-cnhubert" % exp_dir @@ -29,8 +29,14 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): assert os.path.exists(self.path2) assert os.path.exists(self.path4) assert os.path.exists(self.path5) + self.is_v2Pro=version in {"v2Pro","v2ProPlus"} + if self.is_v2Pro: + self.path7 = "%s/7-sv_cn" % exp_dir + assert os.path.exists(self.path7) names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀 names5 = set(os.listdir(self.path5)) + if self.is_v2Pro: + names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀 self.phoneme_data = {} with open(self.path2, "r", encoding="utf8") as f: lines = f.read().strip("\n").split("\n") @@ -40,8 +46,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): if len(tmp) != 4: continue self.phoneme_data[tmp[0]] = [tmp[1]] - - self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5) + if self.is_v2Pro: + self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6) + else: + self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5) tmp = self.audiopaths_sid_text leng = len(tmp) min_num = 100 @@ -109,14 +117,21 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): typee = ssl.dtype ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) ssl.requires_grad = False + if self.is_v2Pro: + sv_emb=torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu") except: traceback.print_exc() spec = torch.zeros(1025, 100) wav = torch.zeros(1, 100 * self.hop_length) ssl = torch.zeros(1, 768, 100) text = text[-1:] + if self.is_v2Pro: + sv_emb=torch.zeros(1,20480) print("load audio or ssl error!!!!!!", audiopath) - return (ssl, spec, wav, text) + if self.is_v2Pro: + return (ssl, spec, wav, text,sv_emb) + else: + return (ssl, spec, wav, text) def get_audio(self, filename): audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768 @@ -177,8 +192,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset): class TextAudioSpeakerCollate: """Zero-pads model inputs and targets""" - def __init__(self, return_ids=False): + def __init__(self, return_ids=False,version=None): self.return_ids = return_ids + self.is_v2Pro=version in {"v2Pro","v2ProPlus"} def __call__(self, batch): """Collate's training batch from normalized text, audio and speaker identities @@ -211,6 +227,9 @@ class TextAudioSpeakerCollate: ssl_padded.zero_() text_padded.zero_() + if self.is_v2Pro: + sv_embs=torch.FloatTensor(len(batch),20480) + for i in range(len(ids_sorted_decreasing)): row = batch[ids_sorted_decreasing[i]] @@ -230,7 +249,12 @@ class TextAudioSpeakerCollate: text_padded[i, : text.size(0)] = text text_lengths[i] = text.size(0) - return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths + if self.is_v2Pro: + sv_embs[i]=row[4] + if self.is_v2Pro: + return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs + else: + return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset): diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index b73612f..4fbec59 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -586,11 +586,12 @@ class DiscriminatorS(torch.nn.Module): return x, fmap - +v2pro_set={"v2Pro","v2ProPlus"} class MultiPeriodDiscriminator(torch.nn.Module): - def __init__(self, use_spectral_norm=False): + def __init__(self, use_spectral_norm=False,version=None): super(MultiPeriodDiscriminator, self).__init__() - periods = [2, 3, 5, 7, 11] + if version in v2pro_set:periods = [2, 3, 5, 7, 11,17,23] + else:periods = [2, 3, 5, 7, 11] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] @@ -786,7 +787,6 @@ class CodePredictor(nn.Module): return pred_codes.transpose(0, 1) - class SynthesizerTrn(nn.Module): """ Synthesizer for Training @@ -886,12 +886,23 @@ class SynthesizerTrn(nn.Module): self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) self.freeze_quantizer = freeze_quantizer - def forward(self, ssl, y, y_lengths, text, text_lengths): + self.is_v2pro=self.version in v2pro_set + if self.is_v2pro: + self.sv_emb = nn.Linear(20480, gin_channels) + self.ge_to512 = nn.Linear(gin_channels, 512) + self.prelu = nn.PReLU(num_parameters=gin_channels) + + def forward(self, ssl, y, y_lengths, text, text_lengths,sv_emb=None): y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) if self.version == "v1": ge = self.ref_enc(y * y_mask, y_mask) else: ge = self.ref_enc(y[:, :704] * y_mask, y_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) # B*20480->B*512 + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) + ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1) with autocast(enabled=False): maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() with maybe_no_grad: @@ -904,7 +915,7 @@ class SynthesizerTrn(nn.Module): if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z_p = self.flow(z, y_mask, g=ge) @@ -941,8 +952,8 @@ class SynthesizerTrn(nn.Module): return o, y_mask, (z, z_p, m_p, logs_p) @torch.no_grad() - def decode(self, codes, text, refer, noise_scale=0.5, speed=1): - def get_ge(refer): + def decode(self, codes, text, refer,noise_scale=0.5, speed=1, sv_emb=None): + def get_ge(refer, sv_emb): ge = None if refer is not None: refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) @@ -951,16 +962,20 @@ class SynthesizerTrn(nn.Module): ge = self.ref_enc(refer * refer_mask, refer_mask) else: ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) + if self.is_v2pro: + sv_emb = self.sv_emb(sv_emb) # B*20480->B*512 + ge += sv_emb.unsqueeze(-1) + ge = self.prelu(ge) return ge if type(refer) == list: ges = [] - for _refer in refer: - ge = get_ge(_refer) + for idx,_refer in enumerate(refer): + ge = get_ge(_refer, sv_emb[idx]if self.is_v2pro else None) ges.append(ge) ge = torch.stack(ges, 0).mean(0) else: - ge = get_ge(refer) + ge = get_ge(refer, sv_emb) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) @@ -968,7 +983,7 @@ class SynthesizerTrn(nn.Module): quantized = self.quantizer.decode(codes) if self.semantic_frame_rate == "25hz": quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") - x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) + x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1)if self.is_v2pro else ge, speed) z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z = self.flow(z_p, y_mask, g=ge, reverse=True)