diff --git a/README.md b/README.md
index 06cd923..00582c3 100644
--- a/README.md
+++ b/README.md
@@ -9,6 +9,7 @@ Real time interactive streaming digital human, realize audio video synchronous
- 2024.12.8 完善多并发,显存不随并发数增加
- 2024.12.21 添加wav2lip、musetalk模型预热,解决第一次推理卡顿问题。感谢@heimaojinzhangyz
- 2024.12.28 添加数字人模型Ultralight-Digital-Human。 感谢@lijihua2017
+- 2025.1.26 添加wav2lip384模型 感谢@不蠢不蠢
## Features
1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human
@@ -41,29 +42,20 @@ linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/6749
## 2. Quick Start
-默认采用ernerf模型,webrtc推流到srs
-### 2.1 运行srs
-```bash
-export CANDIDATE='<服务器外网ip>' #如果srs与浏览器访问在同一层级内网,不需要执行这步
-docker run --rm --env CANDIDATE=$CANDIDATE \
- -p 1935:1935 -p 8080:8080 -p 1985:1985 -p 8000:8000/udp \
- registry.cn-hangzhou.aliyuncs.com/ossrs/srs:5 \
- objs/srs -c conf/rtc.conf
-```
-备注:服务端需要开放端口 tcp:8000,8010,1985; udp:8000
+- 下载模型
+下载wav2lip运行需要的模型,链接: 密码: ltua
+将wav2lip384.pth拷到本项目的models下, 重命名为wav2lip.pth;
+将wav2lip384_avatar1.tar.gz解压后整个文件夹拷到本项目的data/avatars下
+- 运行
+python app.py --transport webrtc --model wav2lip --avatar_id wav2lip384_avatar1
+用浏览器打开http://serverip:8010/webrtcapi.html , 在文本框输入任意文字,提交。数字人播报该段文字
-### 2.2 启动数字人:
-
-```python
-python app.py
-```
+服务端需要开放端口 tcp:8010; udp:1-65536
如果访问不了huggingface,在运行前
```
export HF_ENDPOINT=https://hf-mirror.com
-```
-
-用浏览器打开http://serverip:8010/rtcpushapi.html, 在文本框输入任意文字,提交。数字人播报该段文字
+```
## 3. More Usage
diff --git a/app.py b/app.py
index a87cd6d..07b0a4d 100644
--- a/app.py
+++ b/app.py
@@ -478,7 +478,7 @@ if __name__ == '__main__':
print(opt)
model = load_model("./models/wav2lip.pth")
avatar = load_avatar(opt.avatar_id)
- warm_up(opt.batch_size,model,96)
+ warm_up(opt.batch_size,model,384)
# for k in range(opt.max_session):
# opt.sessionid=k
# nerfreal = LipReal(opt,model)
diff --git a/data/avatars/.gitkeep b/data/avatars/.gitkeep
new file mode 100644
index 0000000..e69de29
diff --git a/lightasr.py b/hubertasr.py
similarity index 66%
rename from lightasr.py
rename to hubertasr.py
index b3df50f..6e37e37 100644
--- a/lightasr.py
+++ b/hubertasr.py
@@ -3,13 +3,15 @@ import torch
import numpy as np
from baseasr import BaseASR
-
-class LightASR(BaseASR):
- def __init__(self, opt, parent, audio_processor):
+# hubert audio feature
+class HubertASR(BaseASR):
+ #audio_feat_length: select audio feature before and after
+ def __init__(self, opt, parent, audio_processor,audio_feat_length = [8,8]):
super().__init__(opt, parent)
self.audio_processor = audio_processor
- self.stride_left_size = 32
- self.stride_right_size = 32
+ #self.stride_left_size = 32
+ #self.stride_right_size = 32
+ self.audio_feat_length = audio_feat_length
def run_step(self):
@@ -26,7 +28,7 @@ class LightASR(BaseASR):
inputs = np.concatenate(self.frames) # [N * chunk]
mel = self.audio_processor.get_hubert_from_16k_speech(inputs)
- mel_chunks=self.audio_processor.feature2chunks(feature_array=mel,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2)
+ mel_chunks=self.audio_processor.feature2chunks(feature_array=mel,fps=self.fps/2,batch_size=self.batch_size,audio_feat_length = self.audio_feat_length, start=self.stride_left_size/2)
self.feat_queue.put(mel_chunks)
self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):]
diff --git a/lightreal.py b/lightreal.py
index 443424f..9a3ec03 100644
--- a/lightreal.py
+++ b/lightreal.py
@@ -33,7 +33,7 @@ from threading import Thread, Event
import torch.multiprocessing as mp
-from lightasr import LightASR
+from hubertasr import HubertASR
import asyncio
from av import AudioFrame, VideoFrame
from basereal import BaseReal
@@ -241,7 +241,7 @@ class LightReal(BaseReal):
audio_processor = model
self.model,self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar
- self.asr = LightASR(opt,self,audio_processor)
+ self.asr = HubertASR(opt,self,audio_processor)
self.asr.warm_up()
#self.__warm_up()
diff --git a/wav2lip/models/__init__.py b/wav2lip/models/__init__.py
index 4374370..b0eb492 100644
--- a/wav2lip/models/__init__.py
+++ b/wav2lip/models/__init__.py
@@ -1,2 +1,2 @@
-from .wav2lip import Wav2Lip, Wav2Lip_disc_qual
-from .syncnet import SyncNet_color
\ No newline at end of file
+from .wav2lip import Wav2Lip #, Wav2Lip_disc_qual
+#from .syncnet import SyncNet_color
\ No newline at end of file
diff --git a/wav2lip/models/conv_384.py b/wav2lip/models/conv_384.py
new file mode 100644
index 0000000..6fcda7a
--- /dev/null
+++ b/wav2lip/models/conv_384.py
@@ -0,0 +1,44 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+class Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+ self.residual = residual
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ if self.residual:
+ out += x
+ return self.act(out)
+
+class nonorm_Conv2d(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.Conv2d(cin, cout, kernel_size, stride, padding),
+ )
+ self.act = nn.LeakyReLU(negative_slope=0.01)
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ return self.act(out)
+
+class Conv2dTranspose(nn.Module):
+ def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.conv_block = nn.Sequential(
+ nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
+ nn.BatchNorm2d(cout)
+ )
+ self.act = nn.ReLU()
+
+ def forward(self, x):
+ out = self.conv_block(x)
+ return self.act(out)
diff --git a/wav2lip/models/wav2lip.py b/wav2lip/models/wav2lip.py
index ae5d691..f6511dd 100644
--- a/wav2lip/models/wav2lip.py
+++ b/wav2lip/models/wav2lip.py
@@ -1,184 +1,198 @@
-import torch
-from torch import nn
-from torch.nn import functional as F
-import math
-
-from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
-
-class Wav2Lip(nn.Module):
- def __init__(self):
- super(Wav2Lip, self).__init__()
-
- self.face_encoder_blocks = nn.ModuleList([
- nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96
-
- nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
-
- nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
-
- nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
-
- nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
-
- nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3
- Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
-
- nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
- Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
-
- self.audio_encoder = nn.Sequential(
- Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
-
- Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
-
- Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
-
- Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
-
- Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
- Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
-
- self.face_decoder_blocks = nn.ModuleList([
- nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),),
-
- nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3
- Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),),
-
- nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
- Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6
-
- nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
- Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12
-
- nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24
-
- nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48
-
- nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
- Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96
-
- self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
- nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
- nn.Sigmoid())
-
- def forward(self, audio_sequences, face_sequences):
- # audio_sequences = (B, T, 1, 80, 16)
- B = audio_sequences.size(0)
-
- input_dim_size = len(face_sequences.size())
- if input_dim_size > 4:
- audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
- face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
-
- audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
-
- feats = []
- x = face_sequences
- for f in self.face_encoder_blocks:
- x = f(x)
- feats.append(x)
-
- x = audio_embedding
- for f in self.face_decoder_blocks:
- x = f(x)
- try:
- x = torch.cat((x, feats[-1]), dim=1)
- except Exception as e:
- print(x.size())
- print(feats[-1].size())
- raise e
-
- feats.pop()
-
- x = self.output_block(x)
-
- if input_dim_size > 4:
- x = torch.split(x, B, dim=0) # [(B, C, H, W)]
- outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
-
- else:
- outputs = x
-
- return outputs
-
-class Wav2Lip_disc_qual(nn.Module):
- def __init__(self):
- super(Wav2Lip_disc_qual, self).__init__()
-
- self.face_encoder_blocks = nn.ModuleList([
- nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
-
- nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
- nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
-
- nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
- nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
-
- nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
- nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
-
- nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
- nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
-
- nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
- nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
-
- nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
- nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
-
- self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
- self.label_noise = .0
-
- def get_lower_half(self, face_sequences):
- return face_sequences[:, :, face_sequences.size(2)//2:]
-
- def to_2d(self, face_sequences):
- B = face_sequences.size(0)
- face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
- return face_sequences
-
- def perceptual_forward(self, false_face_sequences):
- false_face_sequences = self.to_2d(false_face_sequences)
- false_face_sequences = self.get_lower_half(false_face_sequences)
-
- false_feats = false_face_sequences
- for f in self.face_encoder_blocks:
- false_feats = f(false_feats)
-
- false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
- torch.ones((len(false_feats), 1)).cuda())
-
- return false_pred_loss
-
- def forward(self, face_sequences):
- face_sequences = self.to_2d(face_sequences)
- face_sequences = self.get_lower_half(face_sequences)
-
- x = face_sequences
- for f in self.face_encoder_blocks:
- x = f(x)
-
- return self.binary_pred(x).view(len(x), -1)
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from .conv_384 import Conv2dTranspose, Conv2d, nonorm_Conv2d
+
+
+class SpatialAttention(nn.Module):
+ def __init__(self, kernel_size=7):
+ super(SpatialAttention, self).__init__()
+
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ avg_out = torch.mean(x, dim=1, keepdim=True)
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
+ x = torch.cat([avg_out, max_out], dim=1)
+ x = self.conv1(x)
+ return self.sigmoid(x)
+
+
+class SAM(nn.Module):
+ def __init__(self):
+ super(SAM, self).__init__()
+ self.sa = SpatialAttention()
+
+ def forward(self, sp, se):
+ sp_att = self.sa(sp)
+ out = se * sp_att + se
+ return out
+
+
+class Wav2Lip(nn.Module):
+ def __init__(self, audio_encoder=None):
+ super(Wav2Lip, self).__init__()
+ self.sam = SAM()
+ self.face_encoder_blocks = nn.ModuleList([
+ nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3),
+ Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True)), # 192, 192
+
+ nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 96, 96
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 48, 48
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 24, 24
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 12, 12
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6, 6
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), # 3, 3
+ Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0), # 1, 1
+ Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
+ Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)), ])
+
+
+ if audio_encoder is None:
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(512, 1024, kernel_size=3, stride=1, padding=0),
+ Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0))
+
+ else:
+ self.audio_encoder = audio_encoder
+
+ for p in self.audio_encoder.parameters():
+ p.requires_grad = False
+
+ self.audio_refine = nn.Sequential(
+ Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0),
+ Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0))
+
+ self.face_decoder_blocks = nn.ModuleList([
+ nn.Sequential(Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0), ), # + 1024
+
+
+ nn.Sequential(Conv2dTranspose(2048, 1024, kernel_size=3, stride=1, padding=0), # 3,3
+ Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), ), # + 1024
+
+ nn.Sequential(Conv2dTranspose(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), ), # 6, 6 + 512
+
+ nn.Sequential(Conv2dTranspose(1536, 768, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True), ), # 12, 12 + 256
+
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ), # 24, 24 + 128
+
+
+ nn.Sequential(Conv2dTranspose(640, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), ), # 48, 48 + 64
+
+ nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), ), # 96, 96 + 32
+
+ nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), ), ]) # 192, 192 + 16
+
+ self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
+ nn.Sigmoid())
+
+ def freeze_audio_encoder(self):
+ for p in self.audio_encoder.parameters():
+ p.requires_grad = False
+
+ def forward(self, audio_sequences, face_sequences):
+
+ B = audio_sequences.size(0)
+
+ input_dim_size = len(face_sequences.size())
+ if input_dim_size > 4:
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
+
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
+
+
+
+ feats = []
+ x = face_sequences
+
+ for f in self.face_encoder_blocks:
+ x = f(x)
+ feats.append(x)
+
+ x = audio_embedding
+ for f in self.face_decoder_blocks:
+ x = f(x)
+ try:
+ x = self.sam(feats[-1], x)
+ x = torch.cat((x, feats[-1]), dim=1)
+ except Exception as e:
+ print(x.size())
+ print(feats[-1].size())
+ raise e
+
+ feats.pop()
+
+ x = self.output_block(x)
+
+ if input_dim_size > 4:
+ x = torch.split(x, B, dim=0) # [(B, C, H, W)]
+ outputs = torch.stack(x, dim=2) # (B, C, T, H, W)
+
+ else:
+ outputs = x
+
+ return outputs
+
diff --git a/web/webrtcapi.html b/web/webrtcapi.html
index 042a269..b64ba22 100644
--- a/web/webrtcapi.html
+++ b/web/webrtcapi.html
@@ -102,6 +102,7 @@
fetch('/record', {
body: JSON.stringify({
type: 'start_record',
+ sessionid:parseInt(document.getElementById('sessionid').value),
}),
headers: {
'Content-Type': 'application/json'
@@ -127,6 +128,7 @@
fetch('/record', {
body: JSON.stringify({
type: 'end_record',
+ sessionid:parseInt(document.getElementById('sessionid').value),
}),
headers: {
'Content-Type': 'application/json'