diff --git a/GPT_SoVITS/module/modules.py b/GPT_SoVITS/module/modules.py index 7493f0b..0969900 100644 --- a/GPT_SoVITS/module/modules.py +++ b/GPT_SoVITS/module/modules.py @@ -716,10 +716,12 @@ class MelStyleEncoder(nn.Module): if mask is None: out = torch.mean(x, dim=1) else: - len_ = (~mask).sum(dim=1).unsqueeze(1) + len_ = (~mask).sum() x = x.masked_fill(mask.unsqueeze(-1), 0) - x = x.sum(dim=1) - out = torch.div(x, len_) + dtype=x.dtype + x = x.float() + x=torch.div(x,len_) + out=x.sum(dim=1).to(dtype) return out def forward(self, x, mask=None): @@ -743,7 +745,6 @@ class MelStyleEncoder(nn.Module): x = self.fc(x) # temoral average pooling w = self.temporal_avg_pool(x, mask=mask) - return w.unsqueeze(-1)