From 5169d52b1b586a410ff4d92a837520b629ba567d Mon Sep 17 00:00:00 2001 From: Kakaru <97896816+KakaruHayate@users.noreply.github.com> Date: Mon, 26 May 2025 11:27:36 +0800 Subject: [PATCH] condition cache (#2377) --- GPT_SoVITS/f5_tts/model/backbones/dit.py | 20 +++++++++-- GPT_SoVITS/module/models.py | 46 +++++++++++++++--------- 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/GPT_SoVITS/f5_tts/model/backbones/dit.py b/GPT_SoVITS/f5_tts/model/backbones/dit.py index 7d98a85..f64a3c3 100644 --- a/GPT_SoVITS/f5_tts/model/backbones/dit.py +++ b/GPT_SoVITS/f5_tts/model/backbones/dit.py @@ -143,6 +143,9 @@ class DiT(nn.Module): drop_audio_cond=False, # cfg for cond audio drop_text=False, # cfg for text # mask: bool["b n"] | None = None, # noqa: F722 + infer=False, # bool + text_cache=None, # torch tensor as text_embed + dt_cache=None, # torch tensor as dt ): x = x0.transpose(2, 1) cond = cond0.transpose(2, 1) @@ -155,9 +158,17 @@ class DiT(nn.Module): # t: conditioning time, c: context (text + masked cond audio), x: noised input audio t = self.time_embed(time) - dt = self.d_embed(dt_base_bootstrap) + if infer and dt_cache is not None: + dt = dt_cache + else: + dt = self.d_embed(dt_base_bootstrap) t += dt - text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change + + if infer and text_cache is not None: + text_embed = text_cache + else: + text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change + x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond) rope = self.rotary_embed.forward_from_seq_len(seq_len) @@ -177,4 +188,7 @@ class DiT(nn.Module): x = self.norm_out(x, t) output = self.proj_out(x) - return output + if infer: + return output, text_embed, dt + else: + return output \ No newline at end of file diff --git a/GPT_SoVITS/module/models.py b/GPT_SoVITS/module/models.py index 3e37f0f..b73612f 100644 --- a/GPT_SoVITS/module/models.py +++ b/GPT_SoVITS/module/models.py @@ -981,7 +981,6 @@ class SynthesizerTrn(nn.Module): quantized, codes, commit_loss, quantized_list = self.quantizer(ssl) return codes.transpose(0, 1) - class CFM(torch.nn.Module): def __init__(self, in_channels, dit): super().__init__() @@ -993,6 +992,8 @@ class CFM(torch.nn.Module): self.criterion = torch.nn.MSELoss() + self.use_conditioner_cache = True + @torch.inference_mode() def inference(self, mu, x_lens, prompt, n_timesteps, temperature=1.0, inference_cfg_rate=0): """Forward diffusion""" @@ -1005,25 +1006,38 @@ class CFM(torch.nn.Module): mu = mu.transpose(2, 1) t = 0 d = 1 / n_timesteps + text_cache = None + text_cfg_cache = None + dt_cache = None + d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d for j in range(n_timesteps): t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t - d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d # v_pred = model(x, t_tensor, d_tensor, **extra_args) - v_pred = self.estimator( - x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False - ).transpose(2, 1) + v_pred, text_emb, dt = self.estimator( + x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=False, drop_text=False, infer=True, text_cache=text_cache, dt_cache=dt_cache + ) + v_pred = v_pred.transpose(2, 1) + if self.use_conditioner_cache: + text_cache = text_emb + dt_cache = dt if inference_cfg_rate > 1e-5: - neg = self.estimator( - x, - prompt_x, - x_lens, - t_tensor, - d_tensor, - mu, - use_grad_ckpt=False, - drop_audio_cond=True, - drop_text=True, - ).transpose(2, 1) + neg, text_cfg_emb, _ = self.estimator( + x, + prompt_x, + x_lens, + t_tensor, + d_tensor, + mu, + use_grad_ckpt=False, + drop_audio_cond=True, + drop_text=True, + infer=True, + text_cache=text_cfg_cache, + dt_cache=dt_cache + ) + neg = neg.transpose(2, 1) + if self.use_conditioner_cache: + text_cfg_cache = text_cfg_emb v_pred = v_pred + (v_pred - neg) * inference_cfg_rate x = x + d * v_pred t = t + d