support sovits v2Pro v2ProPlus

support sovits v2Pro v2ProPlus
main
RVC-Boss 2 months ago committed by GitHub
parent 3f46359652
commit 0621259549
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -21,7 +21,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
3) computes spectrograms from audio files. 3) computes spectrograms from audio files.
""" """
def __init__(self, hparams, val=False): def __init__(self, hparams, version=None,val=False):
exp_dir = hparams.exp_dir exp_dir = hparams.exp_dir
self.path2 = "%s/2-name2text.txt" % exp_dir self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir self.path4 = "%s/4-cnhubert" % exp_dir
@ -29,8 +29,14 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
assert os.path.exists(self.path2) assert os.path.exists(self.path2)
assert os.path.exists(self.path4) assert os.path.exists(self.path4)
assert os.path.exists(self.path5) assert os.path.exists(self.path5)
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
if self.is_v2Pro:
self.path7 = "%s/7-sv_cn" % exp_dir
assert os.path.exists(self.path7)
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀 names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
names5 = set(os.listdir(self.path5)) names5 = set(os.listdir(self.path5))
if self.is_v2Pro:
names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀
self.phoneme_data = {} self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f: with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n") lines = f.read().strip("\n").split("\n")
@ -40,8 +46,10 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
if len(tmp) != 4: if len(tmp) != 4:
continue continue
self.phoneme_data[tmp[0]] = [tmp[1]] self.phoneme_data[tmp[0]] = [tmp[1]]
if self.is_v2Pro:
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5) self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6)
else:
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
tmp = self.audiopaths_sid_text tmp = self.audiopaths_sid_text
leng = len(tmp) leng = len(tmp)
min_num = 100 min_num = 100
@ -109,14 +117,21 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
typee = ssl.dtype typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee) ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False ssl.requires_grad = False
if self.is_v2Pro:
sv_emb=torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
except: except:
traceback.print_exc() traceback.print_exc()
spec = torch.zeros(1025, 100) spec = torch.zeros(1025, 100)
wav = torch.zeros(1, 100 * self.hop_length) wav = torch.zeros(1, 100 * self.hop_length)
ssl = torch.zeros(1, 768, 100) ssl = torch.zeros(1, 768, 100)
text = text[-1:] text = text[-1:]
if self.is_v2Pro:
sv_emb=torch.zeros(1,20480)
print("load audio or ssl error!!!!!!", audiopath) print("load audio or ssl error!!!!!!", audiopath)
return (ssl, spec, wav, text) if self.is_v2Pro:
return (ssl, spec, wav, text,sv_emb)
else:
return (ssl, spec, wav, text)
def get_audio(self, filename): def get_audio(self, filename):
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768 audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
@ -177,8 +192,9 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
class TextAudioSpeakerCollate: class TextAudioSpeakerCollate:
"""Zero-pads model inputs and targets""" """Zero-pads model inputs and targets"""
def __init__(self, return_ids=False): def __init__(self, return_ids=False,version=None):
self.return_ids = return_ids self.return_ids = return_ids
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
def __call__(self, batch): def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities """Collate's training batch from normalized text, audio and speaker identities
@ -211,6 +227,9 @@ class TextAudioSpeakerCollate:
ssl_padded.zero_() ssl_padded.zero_()
text_padded.zero_() text_padded.zero_()
if self.is_v2Pro:
sv_embs=torch.FloatTensor(len(batch),20480)
for i in range(len(ids_sorted_decreasing)): for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]] row = batch[ids_sorted_decreasing[i]]
@ -230,7 +249,12 @@ class TextAudioSpeakerCollate:
text_padded[i, : text.size(0)] = text text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0) text_lengths[i] = text.size(0)
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths if self.is_v2Pro:
sv_embs[i]=row[4]
if self.is_v2Pro:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs
else:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset): class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):

@ -586,11 +586,12 @@ class DiscriminatorS(torch.nn.Module):
return x, fmap return x, fmap
v2pro_set={"v2Pro","v2ProPlus"}
class MultiPeriodDiscriminator(torch.nn.Module): class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self, use_spectral_norm=False): def __init__(self, use_spectral_norm=False,version=None):
super(MultiPeriodDiscriminator, self).__init__() super(MultiPeriodDiscriminator, self).__init__()
periods = [2, 3, 5, 7, 11] if version in v2pro_set:periods = [2, 3, 5, 7, 11,17,23]
else:periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
@ -786,7 +787,6 @@ class CodePredictor(nn.Module):
return pred_codes.transpose(0, 1) return pred_codes.transpose(0, 1)
class SynthesizerTrn(nn.Module): class SynthesizerTrn(nn.Module):
""" """
Synthesizer for Training Synthesizer for Training
@ -886,12 +886,23 @@ class SynthesizerTrn(nn.Module):
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024) self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
self.freeze_quantizer = freeze_quantizer self.freeze_quantizer = freeze_quantizer
def forward(self, ssl, y, y_lengths, text, text_lengths): self.is_v2pro=self.version in v2pro_set
if self.is_v2pro:
self.sv_emb = nn.Linear(20480, gin_channels)
self.ge_to512 = nn.Linear(gin_channels, 512)
self.prelu = nn.PReLU(num_parameters=gin_channels)
def forward(self, ssl, y, y_lengths, text, text_lengths,sv_emb=None):
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype) y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
if self.version == "v1": if self.version == "v1":
ge = self.ref_enc(y * y_mask, y_mask) ge = self.ref_enc(y * y_mask, y_mask)
else: else:
ge = self.ref_enc(y[:, :704] * y_mask, y_mask) ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
with autocast(enabled=False): with autocast(enabled=False):
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext() maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
with maybe_no_grad: with maybe_no_grad:
@ -904,7 +915,7 @@ class SynthesizerTrn(nn.Module):
if self.semantic_frame_rate == "25hz": if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge) z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
z_p = self.flow(z, y_mask, g=ge) z_p = self.flow(z, y_mask, g=ge)
@ -941,8 +952,8 @@ class SynthesizerTrn(nn.Module):
return o, y_mask, (z, z_p, m_p, logs_p) return o, y_mask, (z, z_p, m_p, logs_p)
@torch.no_grad() @torch.no_grad()
def decode(self, codes, text, refer, noise_scale=0.5, speed=1): def decode(self, codes, text, refer,noise_scale=0.5, speed=1, sv_emb=None):
def get_ge(refer): def get_ge(refer, sv_emb):
ge = None ge = None
if refer is not None: if refer is not None:
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device) refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
@ -951,16 +962,20 @@ class SynthesizerTrn(nn.Module):
ge = self.ref_enc(refer * refer_mask, refer_mask) ge = self.ref_enc(refer * refer_mask, refer_mask)
else: else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask) ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
return ge return ge
if type(refer) == list: if type(refer) == list:
ges = [] ges = []
for _refer in refer: for idx,_refer in enumerate(refer):
ge = get_ge(_refer) ge = get_ge(_refer, sv_emb[idx]if self.is_v2pro else None)
ges.append(ge) ges.append(ge)
ge = torch.stack(ges, 0).mean(0) ge = torch.stack(ges, 0).mean(0)
else: else:
ge = get_ge(refer) ge = get_ge(refer, sv_emb)
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device) y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device) text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
@ -968,7 +983,7 @@ class SynthesizerTrn(nn.Module):
quantized = self.quantizer.decode(codes) quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz": if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest") quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed) x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1)if self.is_v2pro else ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
z = self.flow(z_p, y_mask, g=ge, reverse=True) z = self.flow(z_p, y_mask, g=ge, reverse=True)

Loading…
Cancel
Save