From 8056efe4ab7bbc3610c72ae356a6f37518441f7d Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Mon, 9 Jun 2025 23:53:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dge.sum=E6=95=B0=E5=80=BC?= =?UTF-8?q?=E5=8F=AF=E8=83=BD=E7=88=86=E7=82=B8=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复ge.sum数值可能爆炸问题 --- GPT_SoVITS/module/modules.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/GPT_SoVITS/module/modules.py b/GPT_SoVITS/module/modules.py index 7493f0b..0969900 100644 --- a/GPT_SoVITS/module/modules.py +++ b/GPT_SoVITS/module/modules.py @@ -716,10 +716,12 @@ class MelStyleEncoder(nn.Module): if mask is None: out = torch.mean(x, dim=1) else: - len_ = (~mask).sum(dim=1).unsqueeze(1) + len_ = (~mask).sum() x = x.masked_fill(mask.unsqueeze(-1), 0) - x = x.sum(dim=1) - out = torch.div(x, len_) + dtype=x.dtype + x = x.float() + x=torch.div(x,len_) + out=x.sum(dim=1).to(dtype) return out def forward(self, x, mask=None): @@ -743,7 +745,6 @@ class MelStyleEncoder(nn.Module): x = self.fc(x) # temoral average pooling w = self.temporal_avg_pool(x, mask=mask) - return w.unsqueeze(-1)