support sovits v2Pro v2ProPlus

support sovits v2Pro v2ProPlus
main
RVC-Boss 2 months ago committed by GitHub
parent 3f46359652
commit 0621259549
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -21,7 +21,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
3) computes spectrograms from audio files.
"""
def __init__(self, hparams, val=False):
def __init__(self, hparams, version=None,val=False):
exp_dir = hparams.exp_dir
self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir
@ -29,8 +29,14 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
assert os.path.exists(self.path2)
assert os.path.exists(self.path4)
assert os.path.exists(self.path5)
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
if self.is_v2Pro:
self.path7 = "%s/7-sv_cn" % exp_dir
assert os.path.exists(self.path7)
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
names5 = set(os.listdir(self.path5))
if self.is_v2Pro:
names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
@ -40,7 +46,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
if self.is_v2Pro:
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6)
else:
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
tmp = self.audiopaths_sid_text
leng = len(tmp)
@ -109,13 +117,20 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
if self.is_v2Pro:
sv_emb=torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
except:
traceback.print_exc()
spec = torch.zeros(1025, 100)
wav = torch.zeros(1, 100 * self.hop_length)
ssl = torch.zeros(1, 768, 100)
text = text[-1:]
if self.is_v2Pro:
sv_emb=torch.zeros(1,20480)
print("load audio or ssl error!!!!!!", audiopath)
if self.is_v2Pro:
return (ssl, spec, wav, text,sv_emb)
else:
return (ssl, spec, wav, text)
def get_audio(self, filename):
@ -177,8 +192,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
class TextAudioSpeakerCollate:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
def __init__(self, return_ids=False,version=None):
self.return_ids = return_ids
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities
@ -211,6 +227,9 @@ class TextAudioSpeakerCollate:
ssl_padded.zero_()
text_padded.zero_()
if self.is_v2Pro:
sv_embs=torch.FloatTensor(len(batch),20480)
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
@ -230,6 +249,11 @@ class TextAudioSpeakerCollate:
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
if self.is_v2Pro:
sv_embs[i]=row[4]
if self.is_v2Pro:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs
else:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths

@ -586,11 +586,12 @@ class DiscriminatorS(torch.nn.Module):
return x, fmap
v2pro_set={"v2Pro","v2ProPlus"}
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
def __init__(self, use_spectral_norm=False,version=None):
super(MultiPeriodDiscriminator, self).__init__()
periods = [2, 3, 5, 7, 11]
if version in v2pro_set:periods = [2, 3, 5, 7, 11,17,23]
else:periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
@ -786,7 +787,6 @@ class CodePredictor(nn.Module):
return pred_codes.transpose(0, 1)
class SynthesizerTrn(nn.Module):
"""
Synthesizer for Training
@ -886,12 +886,23 @@ class SynthesizerTrn(nn.Module):
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
self.freeze_quantizer = freeze_quantizer
def forward(self, ssl, y, y_lengths, text, text_lengths):
self.is_v2pro=self.version in v2pro_set
if self.is_v2pro:
self.sv_emb = nn.Linear(20480, gin_channels)
self.ge_to512 = nn.Linear(gin_channels, 512)
self.prelu = nn.PReLU(num_parameters=gin_channels)
def forward(self, ssl, y, y_lengths, text, text_lengths,sv_emb=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
if self.version == "v1":
ge = self.ref_enc(y * y_mask, y_mask)
else:
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad:
@ -904,7 +915,7 @@ class SynthesizerTrn(nn.Module):
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge)
@ -941,8 +952,8 @@ class SynthesizerTrn(nn.Module):
return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
def get_ge(refer):
def decode(self, codes, text, refer,noise_scale=0.5, speed=1, sv_emb=None):
def get_ge(refer, sv_emb):
ge = None
if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
@ -951,16 +962,20 @@ class SynthesizerTrn(nn.Module):
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
return ge
if type(refer) == list:
ges = []
for _refer in refer:
ge = get_ge(_refer)
for idx,_refer in enumerate(refer):
ge = get_ge(_refer, sv_emb[idx]if self.is_v2pro else None)
ges.append(ge)
ge = torch.stack(ges, 0).mean(0)
else:
ge = get_ge(refer)
ge = get_ge(refer, sv_emb)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
@ -968,7 +983,7 @@ class SynthesizerTrn(nn.Module):
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1)if self.is_v2pro else ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True)

Loading…
Cancel
Save