diff --git a/GPT_SoVITS/module/modules.py b/GPT_SoVITS/module/modules.py index 0969900..9a94898 100644 --- a/GPT_SoVITS/module/modules.py +++ b/GPT_SoVITS/module/modules.py @@ -1,4 +1,6 @@ import math +import pdb + import numpy as np import torch from torch import nn @@ -716,11 +718,11 @@ class MelStyleEncoder(nn.Module): if mask is None: out = torch.mean(x, dim=1) else: - len_ = (~mask).sum() + len_ = (~mask).sum(dim=1).unsqueeze(1) x = x.masked_fill(mask.unsqueeze(-1), 0) dtype=x.dtype x = x.float() - x=torch.div(x,len_) + x=torch.div(x,len_.unsqueeze(1)) out=x.sum(dim=1).to(dtype) return out