add wav2lip384

main
lipku 6 months ago
parent 4ef68b2995
commit ffa20cd10c

@ -9,6 +9,7 @@ Real time interactive streaming digital human realize audio video synchronous
- 2024.12.8 完善多并发,显存不随并发数增加 - 2024.12.8 完善多并发,显存不随并发数增加
- 2024.12.21 添加wav2lip、musetalk模型预热解决第一次推理卡顿问题。感谢@heimaojinzhangyz - 2024.12.21 添加wav2lip、musetalk模型预热解决第一次推理卡顿问题。感谢@heimaojinzhangyz
- 2024.12.28 添加数字人模型Ultralight-Digital-Human。 感谢@lijihua2017 - 2024.12.28 添加数字人模型Ultralight-Digital-Human。 感谢@lijihua2017
- 2025.1.26 添加wav2lip384模型 感谢@不蠢不蠢
## Features ## Features
1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human 1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human
@ -41,29 +42,20 @@ linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/6749
## 2. Quick Start ## 2. Quick Start
默认采用ernerf模型webrtc推流到srs - 下载模型
### 2.1 运行srs 下载wav2lip运行需要的模型链接:<https://pan.baidu.com/s/1yOsQ06-RIDTJd3HFCw4wtA> 密码: ltua
```bash 将wav2lip384.pth拷到本项目的models下, 重命名为wav2lip.pth;
export CANDIDATE='<服务器外网ip>' #如果srs与浏览器访问在同一层级内网不需要执行这步 将wav2lip384_avatar1.tar.gz解压后整个文件夹拷到本项目的data/avatars下
docker run --rm --env CANDIDATE=$CANDIDATE \ - 运行
-p 1935:1935 -p 8080:8080 -p 1985:1985 -p 8000:8000/udp \ python app.py --transport webrtc --model wav2lip --avatar_id wav2lip384_avatar1
registry.cn-hangzhou.aliyuncs.com/ossrs/srs:5 \ 用浏览器打开http://serverip:8010/webrtcapi.html , 在文本框输入任意文字,提交。数字人播报该段文字
objs/srs -c conf/rtc.conf
```
备注:<font color=red>服务端需要开放端口 tcp:8000,8010,1985; udp:8000</font>
### 2.2 启动数字人: <font color=red>服务端需要开放端口 tcp:8010; udp:1-65536 </font>
```python
python app.py
```
如果访问不了huggingface在运行前 如果访问不了huggingface在运行前
``` ```
export HF_ENDPOINT=https://hf-mirror.com export HF_ENDPOINT=https://hf-mirror.com
``` ```
用浏览器打开http://serverip:8010/rtcpushapi.html, 在文本框输入任意文字,提交。数字人播报该段文字
## 3. More Usage ## 3. More Usage

@ -478,7 +478,7 @@ if __name__ == '__main__':
print(opt) print(opt)
model = load_model("./models/wav2lip.pth") model = load_model("./models/wav2lip.pth")
avatar = load_avatar(opt.avatar_id) avatar = load_avatar(opt.avatar_id)
warm_up(opt.batch_size,model,96) warm_up(opt.batch_size,model,384)
# for k in range(opt.max_session): # for k in range(opt.max_session):
# opt.sessionid=k # opt.sessionid=k
# nerfreal = LipReal(opt,model) # nerfreal = LipReal(opt,model)

@ -3,13 +3,15 @@ import torch
import numpy as np import numpy as np
from baseasr import BaseASR from baseasr import BaseASR
# hubert audio feature
class LightASR(BaseASR): class HubertASR(BaseASR):
def __init__(self, opt, parent, audio_processor): #audio_feat_length: select audio feature before and after
def __init__(self, opt, parent, audio_processor,audio_feat_length = [8,8]):
super().__init__(opt, parent) super().__init__(opt, parent)
self.audio_processor = audio_processor self.audio_processor = audio_processor
self.stride_left_size = 32 #self.stride_left_size = 32
self.stride_right_size = 32 #self.stride_right_size = 32
self.audio_feat_length = audio_feat_length
def run_step(self): def run_step(self):
@ -26,7 +28,7 @@ class LightASR(BaseASR):
inputs = np.concatenate(self.frames) # [N * chunk] inputs = np.concatenate(self.frames) # [N * chunk]
mel = self.audio_processor.get_hubert_from_16k_speech(inputs) mel = self.audio_processor.get_hubert_from_16k_speech(inputs)
mel_chunks=self.audio_processor.feature2chunks(feature_array=mel,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2) mel_chunks=self.audio_processor.feature2chunks(feature_array=mel,fps=self.fps/2,batch_size=self.batch_size,audio_feat_length = self.audio_feat_length, start=self.stride_left_size/2)
self.feat_queue.put(mel_chunks) self.feat_queue.put(mel_chunks)
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]

@ -33,7 +33,7 @@ from threading import Thread, Event
import torch.multiprocessing as mp import torch.multiprocessing as mp
from lightasr import LightASR from hubertasr import HubertASR
import asyncio import asyncio
from av import AudioFrame, VideoFrame from av import AudioFrame, VideoFrame
from basereal import BaseReal from basereal import BaseReal
@ -241,7 +241,7 @@ class LightReal(BaseReal):
audio_processor = model audio_processor = model
self.model,self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar self.model,self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar
self.asr = LightASR(opt,self,audio_processor) self.asr = HubertASR(opt,self,audio_processor)
self.asr.warm_up() self.asr.warm_up()
#self.__warm_up() #self.__warm_up()

@ -1,2 +1,2 @@
from .wav2lip import Wav2Lip, Wav2Lip_disc_qual from .wav2lip import Wav2Lip #, Wav2Lip_disc_qual
from .syncnet import SyncNet_color #from .syncnet import SyncNet_color

@ -0,0 +1,44 @@
import torch
from torch import nn
from torch.nn import functional as F
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act(out)
class nonorm_Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
)
self.act = nn.LeakyReLU(negative_slope=0.01)
def forward(self, x):
out = self.conv_block(x)
return self.act(out)
class Conv2dTranspose(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
def forward(self, x):
out = self.conv_block(x)
return self.act(out)

@ -1,184 +1,198 @@
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import math
from .conv_384 import Conv2dTranspose, Conv2d, nonorm_Conv2d
from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
class Wav2Lip(nn.Module): class SpatialAttention(nn.Module):
def __init__(self): def __init__(self, kernel_size=7):
super(Wav2Lip, self).__init__() super(SpatialAttention, self).__init__()
self.face_encoder_blocks = nn.ModuleList([ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 self.sigmoid = nn.Sigmoid()
nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 def forward(self, x):
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), avg_out = torch.mean(x, dim=1, keepdim=True)
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 x = self.conv1(x)
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), return self.sigmoid(x)
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
class SAM(nn.Module):
nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 def __init__(self):
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), super(SAM, self).__init__()
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), self.sa = SpatialAttention()
nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 def forward(self, sp, se):
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), sp_att = self.sa(sp)
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), out = se * sp_att + se
return out
nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
class Wav2Lip(nn.Module):
nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 def __init__(self, audio_encoder=None):
Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) super(Wav2Lip, self).__init__()
self.sam = SAM()
self.audio_encoder = nn.Sequential( self.face_encoder_blocks = nn.ModuleList([
Conv2d(1, 32, kernel_size=3, stride=1, padding=1), nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True)), # 192, 192
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 96, 96
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1), Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 48, 48
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0), nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 24, 24
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
self.face_decoder_blocks = nn.ModuleList([ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 12, 12
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3 Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6, 6
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True)),
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12 nn.Sequential(Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), # 3, 3
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True)),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
nn.Sequential(Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0), # 1, 1
nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)), ])
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), if audio_encoder is None:
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), self.audio_encoder = nn.Sequential(
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
nn.Sigmoid()) Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
def forward(self, audio_sequences, face_sequences): Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
# audio_sequences = (B, T, 1, 80, 16)
B = audio_sequences.size(0) Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
input_dim_size = len(face_sequences.size()) Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
if input_dim_size > 4:
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
feats = [] Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
x = face_sequences Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
for f in self.face_encoder_blocks: Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
x = f(x)
feats.append(x) Conv2d(512, 1024, kernel_size=3, stride=1, padding=0),
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0))
x = audio_embedding
for f in self.face_decoder_blocks: else:
x = f(x) self.audio_encoder = audio_encoder
try:
x = torch.cat((x, feats[-1]), dim=1) for p in self.audio_encoder.parameters():
except Exception as e: p.requires_grad = False
print(x.size())
print(feats[-1].size()) self.audio_refine = nn.Sequential(
raise e Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0))
feats.pop()
self.face_decoder_blocks = nn.ModuleList([
x = self.output_block(x) nn.Sequential(Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0), ), # + 1024
if input_dim_size > 4:
x = torch.split(x, B, dim=0) # [(B, C, H, W)] nn.Sequential(Conv2dTranspose(2048, 1024, kernel_size=3, stride=1, padding=0), # 3,3
outputs = torch.stack(x, dim=2) # (B, C, T, H, W) Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), ), # + 1024
else: nn.Sequential(Conv2dTranspose(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=1),
outputs = x Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), ), # 6, 6 + 512
return outputs
nn.Sequential(Conv2dTranspose(1536, 768, kernel_size=3, stride=2, padding=1, output_padding=1),
class Wav2Lip_disc_qual(nn.Module): Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True),
def __init__(self): Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True), ), # 12, 12 + 256
super(Wav2Lip_disc_qual, self).__init__()
nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
self.face_encoder_blocks = nn.ModuleList([ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96 Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ), # 24, 24 + 128
nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)), nn.Sequential(Conv2dTranspose(640, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24 Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), ), # 48, 48 + 64
nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12 Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)), Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), ), # 96, 96 + 32
nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6 nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)), Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), ), ]) # 192, 192 + 16
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),), self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 nn.Sigmoid())
nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
def freeze_audio_encoder(self):
self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) for p in self.audio_encoder.parameters():
self.label_noise = .0 p.requires_grad = False
def get_lower_half(self, face_sequences): def forward(self, audio_sequences, face_sequences):
return face_sequences[:, :, face_sequences.size(2)//2:]
B = audio_sequences.size(0)
def to_2d(self, face_sequences):
B = face_sequences.size(0) input_dim_size = len(face_sequences.size())
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) if input_dim_size > 4:
return face_sequences audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
def perceptual_forward(self, false_face_sequences):
false_face_sequences = self.to_2d(false_face_sequences) audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
false_face_sequences = self.get_lower_half(false_face_sequences)
false_feats = false_face_sequences
for f in self.face_encoder_blocks: feats = []
false_feats = f(false_feats) x = face_sequences
false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1), for f in self.face_encoder_blocks:
torch.ones((len(false_feats), 1)).cuda()) x = f(x)
feats.append(x)
return false_pred_loss
x = audio_embedding
def forward(self, face_sequences): for f in self.face_decoder_blocks:
face_sequences = self.to_2d(face_sequences) x = f(x)
face_sequences = self.get_lower_half(face_sequences) try:
x = self.sam(feats[-1], x)
x = face_sequences x = torch.cat((x, feats[-1]), dim=1)
for f in self.face_encoder_blocks: except Exception as e:
x = f(x) print(x.size())
print(feats[-1].size())
return self.binary_pred(x).view(len(x), -1) raise e
feats.pop()
x = self.output_block(x)
if input_dim_size > 4:
x = torch.split(x, B, dim=0) # [(B, C, H, W)]
outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
else:
outputs = x
return outputs

@ -102,6 +102,7 @@
fetch('/record', { fetch('/record', {
body: JSON.stringify({ body: JSON.stringify({
type: 'start_record', type: 'start_record',
sessionid:parseInt(document.getElementById('sessionid').value),
}), }),
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
@ -127,6 +128,7 @@
fetch('/record', { fetch('/record', {
body: JSON.stringify({ body: JSON.stringify({
type: 'end_record', type: 'end_record',
sessionid:parseInt(document.getElementById('sessionid').value),
}), }),
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'

Loading…
Cancel
Save