|
|
@ -1,4 +1,6 @@
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
|
|
|
|
import pdb
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from torch import nn
|
|
|
@ -716,11 +718,11 @@ class MelStyleEncoder(nn.Module):
|
|
|
|
if mask is None:
|
|
|
|
if mask is None:
|
|
|
|
out = torch.mean(x, dim=1)
|
|
|
|
out = torch.mean(x, dim=1)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
len_ = (~mask).sum()
|
|
|
|
len_ = (~mask).sum(dim=1).unsqueeze(1)
|
|
|
|
x = x.masked_fill(mask.unsqueeze(-1), 0)
|
|
|
|
x = x.masked_fill(mask.unsqueeze(-1), 0)
|
|
|
|
dtype=x.dtype
|
|
|
|
dtype=x.dtype
|
|
|
|
x = x.float()
|
|
|
|
x = x.float()
|
|
|
|
x=torch.div(x,len_)
|
|
|
|
x=torch.div(x,len_.unsqueeze(1))
|
|
|
|
out=x.sum(dim=1).to(dtype)
|
|
|
|
out=x.sum(dim=1).to(dtype)
|
|
|
|
return out
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|