|
|
@ -1,8 +1,7 @@
|
|
|
|
import warnings
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
|
|
|
|
import copy
|
|
|
|
import copy
|
|
|
|
import math
|
|
|
|
import math
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
import pdb
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from torch import nn
|
|
|
|
from torch import nn
|
|
|
@ -984,6 +983,7 @@ class SynthesizerTrn(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
@torch.no_grad()
|
|
|
|
def decode(self, codes, text, refer, noise_scale=0.5,speed=1):
|
|
|
|
def decode(self, codes, text, refer, noise_scale=0.5,speed=1):
|
|
|
|
|
|
|
|
def get_ge(refer):
|
|
|
|
ge = None
|
|
|
|
ge = None
|
|
|
|
if refer is not None:
|
|
|
|
if refer is not None:
|
|
|
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
|
|
|
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
|
|
@ -994,6 +994,15 @@ class SynthesizerTrn(nn.Module):
|
|
|
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
|
|
|
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
|
|
|
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
|
|
|
|
|
|
|
return ge
|
|
|
|
|
|
|
|
if(type(refer)==list):
|
|
|
|
|
|
|
|
ges=[]
|
|
|
|
|
|
|
|
for _refer in refer:
|
|
|
|
|
|
|
|
ge=get_ge(_refer)
|
|
|
|
|
|
|
|
ges.append(ge)
|
|
|
|
|
|
|
|
ge=torch.stack(ges,0).mean(0)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
ge=get_ge(refer)
|
|
|
|
|
|
|
|
|
|
|
|
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
|
|
|
y_lengths = torch.LongTensor([codes.size(2) * 2]).to(codes.device)
|
|
|
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
|
|
|
text_lengths = torch.LongTensor([text.size(-1)]).to(text.device)
|
|
|
|