diff --git a/wav2lip/models384/__init__.py b/wav2lip/models384/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/wav2lip/models384/conv.py b/wav2lip/models384/conv.py new file mode 100644 index 0000000..fec03c5 --- /dev/null +++ b/wav2lip/models384/conv.py @@ -0,0 +1,44 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class nonorm_Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + ) + self.act = nn.LeakyReLU(negative_slope=0.01) + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) + +class Conv2dTranspose(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) diff --git a/wav2lip/models384/syncnet.py b/wav2lip/models384/syncnet.py new file mode 100644 index 0000000..211bfd0 --- /dev/null +++ b/wav2lip/models384/syncnet.py @@ -0,0 +1,92 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .conv import Conv2d + +class SyncNet_color(nn.Module): + def __init__(self): + super(SyncNet_color, self).__init__() + + self.face_encoder = nn.Sequential( + Conv2d(15, 16, kernel_size=(7, 7), stride=1, padding=3, act="leaky"), # 192, 384 + + Conv2d(16, 32, kernel_size=5, stride=(1, 2), padding=1, act="leaky"), # 192, 192 + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + + Conv2d(32, 64, kernel_size=3, stride=2, padding=1, act="leaky"), # 96, 96 + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + + Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act="leaky"), # 48, 48 + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + + Conv2d(128, 256, kernel_size=3, stride=2, padding=1, act="leaky"), # 24, 24 + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + + ################### + # Modified blocks + ################## + Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act="leaky"), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), # 12, 12 + + Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act="leaky"), + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), # 6, 6 + + Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act="leaky"), # 3, 3 + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act="leaky"), + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act="relu")) # 1, 1 + ################## + + # print(summary(self.face_encoder, (15, 96, 192)), act="relu") + + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act="leaky"), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act="leaky"), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act="leaky"), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act="leaky"), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + + ################### + # Modified blocks + ################## + Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act="leaky"), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act="leaky"), + + Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act="relu"), + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act="relu")) + ################## + + # print(summary(self.audio_encoder, (1, 80, 16))) + + def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T) + face_embedding = self.face_encoder(face_sequences) + audio_embedding = self.audio_encoder(audio_sequences) + + audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) + face_embedding = face_embedding.view(face_embedding.size(0), -1) + + audio_embedding = F.normalize(audio_embedding, p=2, dim=1) + face_embedding = F.normalize(face_embedding, p=2, dim=1) + + + return audio_embedding, face_embedding + + def audio_forward(self, audio_sequences): + return self.audio_encoder(audio_sequences) diff --git a/wav2lip/models384/wav2lip.py b/wav2lip/models384/wav2lip.py new file mode 100644 index 0000000..7606748 --- /dev/null +++ b/wav2lip/models384/wav2lip.py @@ -0,0 +1,198 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) + + +class SAM(nn.Module): + def __init__(self): + super(SAM, self).__init__() + self.sa = SpatialAttention() + + def forward(self, sp, se): + sp_att = self.sa(sp) + out = se * sp_att + se + return out + + +class Wav2Lip(nn.Module): + def __init__(self, audio_encoder=None): + super(Wav2Lip, self).__init__() + self.sam = SAM() + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3), + Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True)), # 192, 192 + + nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 96, 96 + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 48, 48 + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 24, 24 + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 12, 12 + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6, 6 + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), # 3, 3 + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0), # 1, 1 + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0), + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)), ]) + + + if audio_encoder is None: + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + + Conv2d(256, 512, kernel_size=3, stride=1, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(512, 1024, kernel_size=3, stride=1, padding=0), + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)) + + else: + self.audio_encoder = audio_encoder + + for p in self.audio_encoder.parameters(): + p.requires_grad = False + + self.audio_refine = nn.Sequential( + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0), + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)) + + self.face_decoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0), ), # + 1024 + + + nn.Sequential(Conv2dTranspose(2048, 1024, kernel_size=3, stride=1, padding=0), # 3,3 + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), ), # + 1024 + + nn.Sequential(Conv2dTranspose(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), ), # 6, 6 + 512 + + nn.Sequential(Conv2dTranspose(1536, 768, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True), ), # 12, 12 + 256 + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ), # 24, 24 + 128 + + + nn.Sequential(Conv2dTranspose(640, 256, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), ), # 48, 48 + 64 + + nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), ), # 96, 96 + 32 + + nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), ), ]) # 192, 192 + 16 + + self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), + nn.Sigmoid()) + + def freeze_audio_encoder(self): + for p in self.audio_encoder.parameters(): + p.requires_grad = False + + def forward(self, audio_sequences, face_sequences): + + B = audio_sequences.size(0) + + input_dim_size = len(face_sequences.size()) + if input_dim_size > 4: + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) + + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + + + + feats = [] + x = face_sequences + + for f in self.face_encoder_blocks: + x = f(x) + feats.append(x) + + x = audio_embedding + for f in self.face_decoder_blocks: + x = f(x) + try: + x = self.sam(feats[-1], x) + x = torch.cat((x, feats[-1]), dim=1) + except Exception as e: + print(x.size()) + print(feats[-1].size()) + raise e + + feats.pop() + + x = self.output_block(x) + + if input_dim_size > 4: + x = torch.split(x, B, dim=0) # [(B, C, H, W)] + outputs = torch.stack(x, dim=2) # (B, C, T, H, W) + + else: + outputs = x + + return outputs +