add wav2lip384

main
lipku 6 months ago
parent 4ef68b2995
commit ffa20cd10c

@ -9,6 +9,7 @@ Real time interactive streaming digital human realize audio video synchronous
- 2024.12.8 完善多并发,显存不随并发数增加 - 2024.12.8 完善多并发,显存不随并发数增加
- 2024.12.21 添加wav2lip、musetalk模型预热解决第一次推理卡顿问题。感谢@heimaojinzhangyz - 2024.12.21 添加wav2lip、musetalk模型预热解决第一次推理卡顿问题。感谢@heimaojinzhangyz
- 2024.12.28 添加数字人模型Ultralight-Digital-Human。 感谢@lijihua2017 - 2024.12.28 添加数字人模型Ultralight-Digital-Human。 感谢@lijihua2017
- 2025.1.26 添加wav2lip384模型 感谢@不蠢不蠢
## Features ## Features
1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human 1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human
@ -41,30 +42,21 @@ linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/6749
## 2. Quick Start ## 2. Quick Start
默认采用ernerf模型webrtc推流到srs - 下载模型
### 2.1 运行srs 下载wav2lip运行需要的模型链接:<https://pan.baidu.com/s/1yOsQ06-RIDTJd3HFCw4wtA> 密码: ltua
```bash 将wav2lip384.pth拷到本项目的models下, 重命名为wav2lip.pth;
export CANDIDATE='<服务器外网ip>' #如果srs与浏览器访问在同一层级内网不需要执行这步 将wav2lip384_avatar1.tar.gz解压后整个文件夹拷到本项目的data/avatars下
docker run --rm --env CANDIDATE=$CANDIDATE \ - 运行
-p 1935:1935 -p 8080:8080 -p 1985:1985 -p 8000:8000/udp \ python app.py --transport webrtc --model wav2lip --avatar_id wav2lip384_avatar1
registry.cn-hangzhou.aliyuncs.com/ossrs/srs:5 \ 用浏览器打开http://serverip:8010/webrtcapi.html , 在文本框输入任意文字,提交。数字人播报该段文字
objs/srs -c conf/rtc.conf
```
备注:<font color=red>服务端需要开放端口 tcp:8000,8010,1985; udp:8000</font>
### 2.2 启动数字人: <font color=red>服务端需要开放端口 tcp:8010; udp:1-65536 </font>
```python
python app.py
```
如果访问不了huggingface在运行前 如果访问不了huggingface在运行前
``` ```
export HF_ENDPOINT=https://hf-mirror.com export HF_ENDPOINT=https://hf-mirror.com
``` ```
用浏览器打开http://serverip:8010/rtcpushapi.html, 在文本框输入任意文字,提交。数字人播报该段文字
## 3. More Usage ## 3. More Usage
使用说明: <https://livetalking-doc.readthedocs.io/> 使用说明: <https://livetalking-doc.readthedocs.io/>

@ -478,7 +478,7 @@ if __name__ == '__main__':
print(opt) print(opt)
model = load_model("./models/wav2lip.pth") model = load_model("./models/wav2lip.pth")
avatar = load_avatar(opt.avatar_id) avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,96) warm_up(opt.batch_size,model,384)
# for k in range(opt.max_session): # for k in range(opt.max_session):
# opt.sessionid=k # opt.sessionid=k
# nerfreal = LipReal(opt,model) # nerfreal = LipReal(opt,model)

@ -3,13 +3,15 @@ import torch
import numpy as np import numpy as np
from baseasr import BaseASR from baseasr import BaseASR
# hubert audio feature
class LightASR(BaseASR): class HubertASR(BaseASR):
def __init__(self, opt, parent, audio_processor): #audio_feat_length: select audio feature before and after
def __init__(self, opt, parent, audio_processor,audio_feat_length = [8,8]):
super().__init__(opt, parent) super().__init__(opt, parent)
self.audio_processor = audio_processor self.audio_processor = audio_processor
self.stride_left_size = 32 #self.stride_left_size = 32
self.stride_right_size = 32 #self.stride_right_size = 32
self.audio_feat_length = audio_feat_length
def run_step(self): def run_step(self):
@ -26,7 +28,7 @@ class LightASR(BaseASR):
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]
mel = self.audio_processor.get_hubert_from_16k_speech(inputs) mel = self.audio_processor.get_hubert_from_16k_speech(inputs)
mel_chunks=self.audio_processor.feature2chunks(feature_array=mel,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2) mel_chunks=self.audio_processor.feature2chunks(feature_array=mel,fps=self.fps/2,batch_size=self.batch_size,audio_feat_length = self.audio_feat_length, start=self.stride_left_size/2)
self.feat_queue.put(mel_chunks) self.feat_queue.put(mel_chunks)
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]

@ -33,7 +33,7 @@ from threading import Thread, Event
import torch.multiprocessing as mp import torch.multiprocessing as mp
from lightasr import LightASR from hubertasr import HubertASR
import asyncio import asyncio
from av import AudioFrame, VideoFrame from av import AudioFrame, VideoFrame
from basereal import BaseReal from basereal import BaseReal
@ -241,7 +241,7 @@ class LightReal(BaseReal):
audio_processor = model audio_processor = model
self.model,self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar self.model,self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar
self.asr = LightASR(opt,self,audio_processor) self.asr = HubertASR(opt,self,audio_processor)
self.asr.warm_up() self.asr.warm_up()
#self.__warm_up() #self.__warm_up()

@ -1,2 +1,2 @@
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual from .wav2lip import Wav2Lip #, Wav2Lip_disc_qual
from .syncnet import SyncNet_color #from .syncnet import SyncNet_color

@ -0,0 +1,44 @@
import torch
from torch import nn
from torch.nn import functional as F
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act(out)
class nonorm_Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
)
self.act = nn.LeakyReLU(negative_slope=0.01)
def forward(self, x):
out = self.conv_block(x)
return self.act(out)
class Conv2dTranspose(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
def forward(self, x):
out = self.conv_block(x)
return self.act(out)

@ -1,91 +1,159 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import math
from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d from .conv_384 import Conv2dTranspose, Conv2d, nonorm_Conv2d
class Wav2Lip(nn.Module):
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class SAM(nn.Module):
def __init__(self): def __init__(self):
super(Wav2Lip, self).__init__() super(SAM, self).__init__()
self.sa = SpatialAttention()
def forward(self, sp, se):
sp_att = self.sa(sp)
out = se * sp_att + se
return out
class Wav2Lip(nn.Module):
def __init__(self, audio_encoder=None):
super(Wav2Lip, self).__init__()
self.sam = SAM()
self.face_encoder_blocks = nn.ModuleList([ self.face_encoder_blocks = nn.ModuleList([
nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3),
Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True)), # 192, 192
nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 96, 96
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 48, 48
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 24, 24
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 12, 12
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6, 6
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 nn.Sequential(Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), # 3, 3
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 nn.Sequential(Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0), # 1, 1
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)), ])
nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
self.audio_encoder = nn.Sequential( if audio_encoder is None:
Conv2d(1, 32, kernel_size=3, stride=1, padding=1), self.audio_encoder = nn.Sequential(
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1), Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 1024, kernel_size=3, stride=1, padding=0),
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0))
else:
self.audio_encoder = audio_encoder
for p in self.audio_encoder.parameters():
p.requires_grad = False
self.audio_refine = nn.Sequential(
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0))
self.face_decoder_blocks = nn.ModuleList([ self.face_decoder_blocks = nn.ModuleList([
nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),), nn.Sequential(Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0), ), # + 1024
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), nn.Sequential(Conv2dTranspose(2048, 1024, kernel_size=3, stride=1, padding=0), # 3,3
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), ), # + 1024
nn.Sequential(Conv2dTranspose(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), ), # 6, 6 + 512
nn.Sequential(Conv2dTranspose(1536, 768, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True), ), # 12, 12 + 256
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ), # 24, 24 + 128
nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), nn.Sequential(Conv2dTranspose(640, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24 Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), ), # 48, 48 + 64
nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48 Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), ), # 96, 96 + 32
nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), ), ]) # 192, 192 + 16
self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
nn.Sigmoid()) nn.Sigmoid())
def freeze_audio_encoder(self):
for p in self.audio_encoder.parameters():
p.requires_grad = False
def forward(self, audio_sequences, face_sequences): def forward(self, audio_sequences, face_sequences):
# audio_sequences = (B, T, 1, 80, 16)
B = audio_sequences.size(0) B = audio_sequences.size(0)
input_dim_size = len(face_sequences.size()) input_dim_size = len(face_sequences.size())
@ -93,10 +161,13 @@ class Wav2Lip(nn.Module):
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
feats = [] feats = []
x = face_sequences x = face_sequences
for f in self.face_encoder_blocks: for f in self.face_encoder_blocks:
x = f(x) x = f(x)
feats.append(x) feats.append(x)
@ -105,6 +176,7 @@ class Wav2Lip(nn.Module):
for f in self.face_decoder_blocks: for f in self.face_decoder_blocks:
x = f(x) x = f(x)
try: try:
x = self.sam(feats[-1], x)
x = torch.cat((x, feats[-1]), dim=1) x = torch.cat((x, feats[-1]), dim=1)
except Exception as e: except Exception as e:
print(x.size()) print(x.size())
@ -116,69 +188,11 @@ class Wav2Lip(nn.Module):
x = self.output_block(x) x = self.output_block(x)
if input_dim_size > 4: if input_dim_size > 4:
x = torch.split(x, B, dim=0) # [(B, C, H, W)] x = torch.split(x, B, dim=0) # [(B, C, H, W)]
outputs = torch.stack(x, dim=2) # (B, C, T, H, W) outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
else: else:
outputs = x outputs = x
return outputs return outputs
class Wav2Lip_disc_qual(nn.Module):
def __init__(self):
super(Wav2Lip_disc_qual, self).__init__()
self.face_encoder_blocks = nn.ModuleList([
nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
self.label_noise = .0
def get_lower_half(self, face_sequences):
return face_sequences[:, :, face_sequences.size(2)//2:]
def to_2d(self, face_sequences):
B = face_sequences.size(0)
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
return face_sequences
def perceptual_forward(self, false_face_sequences):
false_face_sequences = self.to_2d(false_face_sequences)
false_face_sequences = self.get_lower_half(false_face_sequences)
false_feats = false_face_sequences
for f in self.face_encoder_blocks:
false_feats = f(false_feats)
false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
torch.ones((len(false_feats), 1)).cuda())
return false_pred_loss
def forward(self, face_sequences):
face_sequences = self.to_2d(face_sequences)
face_sequences = self.get_lower_half(face_sequences)
x = face_sequences
for f in self.face_encoder_blocks:
x = f(x)
return self.binary_pred(x).view(len(x), -1)

@ -102,6 +102,7 @@
fetch('/record', { fetch('/record', {
body: JSON.stringify({ body: JSON.stringify({
type: 'start_record', type: 'start_record',
sessionid:parseInt(document.getElementById('sessionid').value),
}), }),
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
@ -127,6 +128,7 @@
fetch('/record', { fetch('/record', {
body: JSON.stringify({ body: JSON.stringify({
type: 'end_record', type: 'end_record',
sessionid:parseInt(document.getElementById('sessionid').value),
}), }),
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'

Loading…
Cancel
Save