修复ge.sum数值可能爆炸问题

修复ge.sum数值可能爆炸问题
main
RVC-Boss 2 months ago committed by GitHub
parent d6b78c927a
commit 8056efe4ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -716,10 +716,12 @@ class MelStyleEncoder(nn.Module):
if mask is None:
out = torch.mean(x, dim=1)
else:
len_ = (~mask).sum(dim=1).unsqueeze(1)
len_ = (~mask).sum()
x = x.masked_fill(mask.unsqueeze(-1), 0)
x = x.sum(dim=1)
out = torch.div(x, len_)
dtype=x.dtype
x = x.float()
x=torch.div(x,len_)
out=x.sum(dim=1).to(dtype)
return out
def forward(self, x, mask=None):
@ -743,7 +745,6 @@ class MelStyleEncoder(nn.Module):
x = self.fc(x)
# temoral average pooling
w = self.temporal_avg_pool(x, mask=mask)
return w.unsqueeze(-1)

Loading…
Cancel
Save