|
|
@ -981,7 +981,6 @@ class SynthesizerTrn(nn.Module):
|
|
|
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
|
|
|
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
|
|
|
return codes.transpose(0, 1)
|
|
|
|
return codes.transpose(0, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CFM(torch.nn.Module):
|
|
|
|
class CFM(torch.nn.Module):
|
|
|
|
def __init__(self, in_channels, dit):
|
|
|
|
def __init__(self, in_channels, dit):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
@ -993,6 +992,8 @@ class CFM(torch.nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
self.criterion = torch.nn.MSELoss()
|
|
|
|
self.criterion = torch.nn.MSELoss()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.use_conditioner_cache = True
|
|
|
|
|
|
|
|
|
|
|
|
@torch.inference_mode()
|
|
|
|
@torch.inference_mode()
|
|
|
|
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
|
|
|
|
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
|
|
|
|
"""Forward diffusion"""
|
|
|
|
"""Forward diffusion"""
|
|
|
@ -1005,25 +1006,38 @@ class CFM(torch.nn.Module):
|
|
|
|
mu = mu.transpose(2, 1)
|
|
|
|
mu = mu.transpose(2, 1)
|
|
|
|
t = 0
|
|
|
|
t = 0
|
|
|
|
d = 1 / n_timesteps
|
|
|
|
d = 1 / n_timesteps
|
|
|
|
|
|
|
|
text_cache = None
|
|
|
|
|
|
|
|
text_cfg_cache = None
|
|
|
|
|
|
|
|
dt_cache = None
|
|
|
|
|
|
|
|
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
|
|
|
for j in range(n_timesteps):
|
|
|
|
for j in range(n_timesteps):
|
|
|
|
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
|
|
|
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
|
|
|
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
|
|
|
|
|
|
|
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
|
|
|
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
|
|
|
v_pred = self.estimator(
|
|
|
|
v_pred, text_emb, dt = self.estimator(
|
|
|
|
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False
|
|
|
|
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False, infer=True, text_cache=text_cache, dt_cache=dt_cache
|
|
|
|
).transpose(2, 1)
|
|
|
|
)
|
|
|
|
|
|
|
|
v_pred = v_pred.transpose(2, 1)
|
|
|
|
|
|
|
|
if self.use_conditioner_cache:
|
|
|
|
|
|
|
|
text_cache = text_emb
|
|
|
|
|
|
|
|
dt_cache = dt
|
|
|
|
if inference_cfg_rate > 1e-5:
|
|
|
|
if inference_cfg_rate > 1e-5:
|
|
|
|
neg = self.estimator(
|
|
|
|
neg, text_cfg_emb, _ = self.estimator(
|
|
|
|
x,
|
|
|
|
x,
|
|
|
|
prompt_x,
|
|
|
|
prompt_x,
|
|
|
|
x_lens,
|
|
|
|
x_lens,
|
|
|
|
t_tensor,
|
|
|
|
t_tensor,
|
|
|
|
d_tensor,
|
|
|
|
d_tensor,
|
|
|
|
mu,
|
|
|
|
mu,
|
|
|
|
use_grad_ckpt=False,
|
|
|
|
use_grad_ckpt=False,
|
|
|
|
drop_audio_cond=True,
|
|
|
|
drop_audio_cond=True,
|
|
|
|
drop_text=True,
|
|
|
|
drop_text=True,
|
|
|
|
).transpose(2, 1)
|
|
|
|
infer=True,
|
|
|
|
|
|
|
|
text_cache=text_cfg_cache,
|
|
|
|
|
|
|
|
dt_cache=dt_cache
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
neg = neg.transpose(2, 1)
|
|
|
|
|
|
|
|
if self.use_conditioner_cache:
|
|
|
|
|
|
|
|
text_cfg_cache = text_cfg_emb
|
|
|
|
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
|
|
|
|
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
|
|
|
|
x = x + d * v_pred
|
|
|
|
x = x + d * v_pred
|
|
|
|
t = t + d
|
|
|
|
t = t + d
|
|
|
|