From ed89a023378dabba9d4b6580235bb9742245816d Mon Sep 17 00:00:00 2001 From: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com> Date: Wed, 11 Jun 2025 23:14:52 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E2=80=9C=E4=BF=AE=E5=A4=8Dge?= =?UTF-8?q?.sum=E6=95=B0=E5=80=BC=E5=8F=AF=E8=83=BD=E7=88=86=E7=82=B8?= =?UTF-8?q?=E7=9A=84=E2=80=9D=E5=8F=AF=E8=83=BD=E5=AF=BC=E8=87=B4=E7=9A=84?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E7=88=86=E7=82=B8=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题 --- GPT_SoVITS/module/modules.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/GPT_SoVITS/module/modules.py b/GPT_SoVITS/module/modules.py index 0969900..9a94898 100644 --- a/GPT_SoVITS/module/modules.py +++ b/GPT_SoVITS/module/modules.py @@ -1,4 +1,6 @@ import math +import pdb + import numpy as np import torch from torch import nn @@ -716,11 +718,11 @@ class MelStyleEncoder(nn.Module): if mask is None: out = torch.mean(x, dim=1) else: - len_ = (~mask).sum() + len_ = (~mask).sum(dim=1).unsqueeze(1) x = x.masked_fill(mask.unsqueeze(-1), 0) dtype=x.dtype x = x.float() - x=torch.div(x,len_) + x=torch.div(x,len_.unsqueeze(1)) out=x.sum(dim=1).to(dtype) return out