修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题

修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题
main
RVC-Boss 2 months ago committed by GitHub
parent cd6de7398e
commit ed89a02337
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,6 @@
import math
import pdb
import numpy as np
import torch
from torch import nn
@ -716,11 +718,11 @@ class MelStyleEncoder(nn.Module):
if mask is None:
out = torch.mean(x, dim=1)
else:
len_ = (~mask).sum()
len_ = (~mask).sum(dim=1).unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(-1), 0)
dtype=x.dtype
x = x.float()
x=torch.div(x,len_)
x=torch.div(x,len_.unsqueeze(1))
out=x.sum(dim=1).to(dtype)
return out

Loading…
Cancel
Save