|
|
@ -586,11 +586,12 @@ class DiscriminatorS(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
return x, fmap
|
|
|
|
return x, fmap
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
v2pro_set={"v2Pro","v2ProPlus"}
|
|
|
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
|
|
|
class MultiPeriodDiscriminator(torch.nn.Module):
|
|
|
|
def __init__(self, use_spectral_norm=False):
|
|
|
|
def __init__(self, use_spectral_norm=False,version=None):
|
|
|
|
super(MultiPeriodDiscriminator, self).__init__()
|
|
|
|
super(MultiPeriodDiscriminator, self).__init__()
|
|
|
|
periods = [2, 3, 5, 7, 11]
|
|
|
|
if version in v2pro_set:periods = [2, 3, 5, 7, 11,17,23]
|
|
|
|
|
|
|
|
else:periods = [2, 3, 5, 7, 11]
|
|
|
|
|
|
|
|
|
|
|
|
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
|
|
|
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
|
|
|
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
|
|
|
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
|
|
@ -786,7 +787,6 @@ class CodePredictor(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
return pred_codes.transpose(0, 1)
|
|
|
|
return pred_codes.transpose(0, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SynthesizerTrn(nn.Module):
|
|
|
|
class SynthesizerTrn(nn.Module):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Synthesizer for Training
|
|
|
|
Synthesizer for Training
|
|
|
@ -886,12 +886,23 @@ class SynthesizerTrn(nn.Module):
|
|
|
|
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
|
|
|
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
|
|
|
self.freeze_quantizer = freeze_quantizer
|
|
|
|
self.freeze_quantizer = freeze_quantizer
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, ssl, y, y_lengths, text, text_lengths):
|
|
|
|
self.is_v2pro=self.version in v2pro_set
|
|
|
|
|
|
|
|
if self.is_v2pro:
|
|
|
|
|
|
|
|
self.sv_emb = nn.Linear(20480, gin_channels)
|
|
|
|
|
|
|
|
self.ge_to512 = nn.Linear(gin_channels, 512)
|
|
|
|
|
|
|
|
self.prelu = nn.PReLU(num_parameters=gin_channels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, ssl, y, y_lengths, text, text_lengths,sv_emb=None):
|
|
|
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
|
|
|
y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y.size(2)), 1).to(y.dtype)
|
|
|
|
if self.version == "v1":
|
|
|
|
if self.version == "v1":
|
|
|
|
ge = self.ref_enc(y * y_mask, y_mask)
|
|
|
|
ge = self.ref_enc(y * y_mask, y_mask)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
|
|
|
ge = self.ref_enc(y[:, :704] * y_mask, y_mask)
|
|
|
|
|
|
|
|
if self.is_v2pro:
|
|
|
|
|
|
|
|
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
|
|
|
|
|
|
|
ge += sv_emb.unsqueeze(-1)
|
|
|
|
|
|
|
|
ge = self.prelu(ge)
|
|
|
|
|
|
|
|
ge512 = self.ge_to512(ge.transpose(2, 1)).transpose(2, 1)
|
|
|
|
with autocast(enabled=False):
|
|
|
|
with autocast(enabled=False):
|
|
|
|
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
|
|
|
maybe_no_grad = torch.no_grad() if self.freeze_quantizer else contextlib.nullcontext()
|
|
|
|
with maybe_no_grad:
|
|
|
|
with maybe_no_grad:
|
|
|
@ -904,7 +915,7 @@ class SynthesizerTrn(nn.Module):
|
|
|
|
if self.semantic_frame_rate == "25hz":
|
|
|
|
if self.semantic_frame_rate == "25hz":
|
|
|
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
|
|
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
|
|
|
|
|
|
|
|
|
|
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge)
|
|
|
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge512 if self.is_v2pro else ge)
|
|
|
|
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
|
|
|
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=ge)
|
|
|
|
z_p = self.flow(z, y_mask, g=ge)
|
|
|
|
z_p = self.flow(z, y_mask, g=ge)
|
|
|
|
|
|
|
|
|
|
|
@ -941,8 +952,8 @@ class SynthesizerTrn(nn.Module):
|
|
|
|
return o, y_mask, (z, z_p, m_p, logs_p)
|
|
|
|
return o, y_mask, (z, z_p, m_p, logs_p)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@torch.no_grad()
|
|
|
|
def decode(self, codes, text, refer, noise_scale=0.5, speed=1):
|
|
|
|
def decode(self, codes, text, refer,noise_scale=0.5, speed=1, sv_emb=None):
|
|
|
|
def get_ge(refer):
|
|
|
|
def get_ge(refer, sv_emb):
|
|
|
|
ge = None
|
|
|
|
ge = None
|
|
|
|
if refer is not None:
|
|
|
|
if refer is not None:
|
|
|
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
|
|
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
|
|
@ -951,16 +962,20 @@ class SynthesizerTrn(nn.Module):
|
|
|
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
|
|
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
|
|
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
|
|
|
|
|
|
|
if self.is_v2pro:
|
|
|
|
|
|
|
|
sv_emb = self.sv_emb(sv_emb) # B*20480->B*512
|
|
|
|
|
|
|
|
ge += sv_emb.unsqueeze(-1)
|
|
|
|
|
|
|
|
ge = self.prelu(ge)
|
|
|
|
return ge
|
|
|
|
return ge
|
|
|
|
|
|
|
|
|
|
|
|
if type(refer) == list:
|
|
|
|
if type(refer) == list:
|
|
|
|
ges = []
|
|
|
|
ges = []
|
|
|
|
for _refer in refer:
|
|
|
|
for idx,_refer in enumerate(refer):
|
|
|
|
ge = get_ge(_refer)
|
|
|
|
ge = get_ge(_refer, sv_emb[idx]if self.is_v2pro else None)
|
|
|
|
ges.append(ge)
|
|
|
|
ges.append(ge)
|
|
|
|
ge = torch.stack(ges, 0).mean(0)
|
|
|
|
ge = torch.stack(ges, 0).mean(0)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
ge = get_ge(refer)
|
|
|
|
ge = get_ge(refer, sv_emb)
|
|
|
|
|
|
|
|
|
|
|
|
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
|
|
|
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
|
|
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
|
|
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
|
|
@ -968,7 +983,7 @@ class SynthesizerTrn(nn.Module):
|
|
|
|
quantized = self.quantizer.decode(codes)
|
|
|
|
quantized = self.quantizer.decode(codes)
|
|
|
|
if self.semantic_frame_rate == "25hz":
|
|
|
|
if self.semantic_frame_rate == "25hz":
|
|
|
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
|
|
|
quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")
|
|
|
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, ge, speed)
|
|
|
|
x, m_p, logs_p, y_mask = self.enc_p(quantized, y_lengths, text, text_lengths, self.ge_to512(ge.transpose(2,1)).transpose(2,1)if self.is_v2pro else ge, speed)
|
|
|
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
|
|
|
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
|
|
|
|
|
|
|
|
|
|
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
|
|
|
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
|
|
|