condition cache (#2377)

main
Kakaru 2 months ago committed by GitHub
parent e0e6d333b5
commit 5169d52b1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -143,6 +143,9 @@ class DiT(nn.Module):
drop_audio_cond=False, # cfg for cond audio
drop_text=False, # cfg for text
# mask: bool["b n"] | None = None, # noqa: F722
infer=False, # bool
text_cache=None, # torch tensor as text_embed
dt_cache=None, # torch tensor as dt
):
x = x0.transpose(2, 1)
cond = cond0.transpose(2, 1)
@ -155,9 +158,17 @@ class DiT(nn.Module):
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
dt = self.d_embed(dt_base_bootstrap)
if infer and dt_cache is not None:
dt = dt_cache
else:
dt = self.d_embed(dt_base_bootstrap)
t += dt
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
if infer and text_cache is not None:
text_embed = text_cache
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
@ -177,4 +188,7 @@ class DiT(nn.Module):
x = self.norm_out(x, t)
output = self.proj_out(x)
return output
if infer:
return output, text_embed, dt
else:
return output

@ -981,7 +981,6 @@ class SynthesizerTrn(nn.Module):
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1)
class CFM(torch.nn.Module):
def __init__(self, in_channels, dit):
super().__init__()
@ -993,6 +992,8 @@ class CFM(torch.nn.Module):
self.criterion = torch.nn.MSELoss()
self.use_conditioner_cache = True
@torch.inference_mode()
def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0):
"""Forward diffusion"""
@ -1005,25 +1006,38 @@ class CFM(torch.nn.Module):
mu = mu.transpose(2, 1)
t = 0
d = 1 / n_timesteps
text_cache = None
text_cfg_cache = None
dt_cache = None
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
for j in range(n_timesteps):
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False
).transpose(2, 1)
v_pred, text_emb, dt = self.estimator(
x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False, infer=True, text_cache=text_cache, dt_cache=dt_cache
)
v_pred = v_pred.transpose(2, 1)
if self.use_conditioner_cache:
text_cache = text_emb
dt_cache = dt
if inference_cfg_rate > 1e-5:
neg = self.estimator(
x,
prompt_x,
x_lens,
t_tensor,
d_tensor,
mu,
use_grad_ckpt=False,
drop_audio_cond=True,
drop_text=True,
).transpose(2, 1)
neg, text_cfg_emb, _ = self.estimator(
x,
prompt_x,
x_lens,
t_tensor,
d_tensor,
mu,
use_grad_ckpt=False,
drop_audio_cond=True,
drop_text=True,
infer=True,
text_cache=text_cfg_cache,
dt_cache=dt_cache
)
neg = neg.transpose(2, 1)
if self.use_conditioner_cache:
text_cfg_cache = text_cfg_emb
v_pred = v_pred + (v_pred - neg) * inference_cfg_rate
x = x + d * v_pred
t = t + d

Loading…
Cancel
Save