diff --git a/README.md b/README.md index 06cd923..00582c3 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ Real time interactive streaming digital human, realize audio video synchronous - 2024.12.8 完善多并发,显存不随并发数增加 - 2024.12.21 添加wav2lip、musetalk模型预热,解决第一次推理卡顿问题。感谢@heimaojinzhangyz - 2024.12.28 添加数字人模型Ultralight-Digital-Human。 感谢@lijihua2017 +- 2025.1.26 添加wav2lip384模型 感谢@不蠢不蠢 ## Features 1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human @@ -41,29 +42,20 @@ linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/6749 ## 2. Quick Start -默认采用ernerf模型,webrtc推流到srs -### 2.1 运行srs -```bash -export CANDIDATE='<服务器外网ip>' #如果srs与浏览器访问在同一层级内网,不需要执行这步 -docker run --rm --env CANDIDATE=$CANDIDATE \ - -p 1935:1935 -p 8080:8080 -p 1985:1985 -p 8000:8000/udp \ - registry.cn-hangzhou.aliyuncs.com/ossrs/srs:5 \ - objs/srs -c conf/rtc.conf -``` -备注:服务端需要开放端口 tcp:8000,8010,1985; udp:8000 +- 下载模型 +下载wav2lip运行需要的模型,链接: 密码: ltua +将wav2lip384.pth拷到本项目的models下, 重命名为wav2lip.pth; +将wav2lip384_avatar1.tar.gz解压后整个文件夹拷到本项目的data/avatars下 +- 运行 +python app.py --transport webrtc --model wav2lip --avatar_id wav2lip384_avatar1 +用浏览器打开http://serverip:8010/webrtcapi.html , 在文本框输入任意文字,提交。数字人播报该段文字 -### 2.2 启动数字人: - -```python -python app.py -``` +服务端需要开放端口 tcp:8010; udp:1-65536 如果访问不了huggingface,在运行前 ``` export HF_ENDPOINT=https://hf-mirror.com -``` - -用浏览器打开http://serverip:8010/rtcpushapi.html, 在文本框输入任意文字,提交。数字人播报该段文字 +``` ## 3. More Usage diff --git a/app.py b/app.py index a87cd6d..07b0a4d 100644 --- a/app.py +++ b/app.py @@ -478,7 +478,7 @@ if __name__ == '__main__': print(opt) model = load_model("./models/wav2lip.pth") avatar = load_avatar(opt.avatar_id) - warm_up(opt.batch_size,model,96) + warm_up(opt.batch_size,model,384) # for k in range(opt.max_session): # opt.sessionid=k # nerfreal = LipReal(opt,model) diff --git a/data/avatars/.gitkeep b/data/avatars/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/lightasr.py b/hubertasr.py similarity index 66% rename from lightasr.py rename to hubertasr.py index b3df50f..6e37e37 100644 --- a/lightasr.py +++ b/hubertasr.py @@ -3,13 +3,15 @@ import torch import numpy as np from baseasr import BaseASR - -class LightASR(BaseASR): - def __init__(self, opt, parent, audio_processor): +# hubert audio feature +class HubertASR(BaseASR): + #audio_feat_length: select audio feature before and after + def __init__(self, opt, parent, audio_processor,audio_feat_length = [8,8]): super().__init__(opt, parent) self.audio_processor = audio_processor - self.stride_left_size = 32 - self.stride_right_size = 32 + #self.stride_left_size = 32 + #self.stride_right_size = 32 + self.audio_feat_length = audio_feat_length def run_step(self): @@ -26,7 +28,7 @@ class LightASR(BaseASR): inputs = np.concatenate(self.frames) # [N * chunk] mel = self.audio_processor.get_hubert_from_16k_speech(inputs) - mel_chunks=self.audio_processor.feature2chunks(feature_array=mel,fps=self.fps/2,batch_size=self.batch_size,start=self.stride_left_size/2) + mel_chunks=self.audio_processor.feature2chunks(feature_array=mel,fps=self.fps/2,batch_size=self.batch_size,audio_feat_length = self.audio_feat_length, start=self.stride_left_size/2) self.feat_queue.put(mel_chunks) self.frames = self.frames[-(self.stride_left_size + self.stride_right_size):] diff --git a/lightreal.py b/lightreal.py index 443424f..9a3ec03 100644 --- a/lightreal.py +++ b/lightreal.py @@ -33,7 +33,7 @@ from threading import Thread, Event import torch.multiprocessing as mp -from lightasr import LightASR +from hubertasr import HubertASR import asyncio from av import AudioFrame, VideoFrame from basereal import BaseReal @@ -241,7 +241,7 @@ class LightReal(BaseReal): audio_processor = model self.model,self.frame_list_cycle,self.face_list_cycle,self.coord_list_cycle = avatar - self.asr = LightASR(opt,self,audio_processor) + self.asr = HubertASR(opt,self,audio_processor) self.asr.warm_up() #self.__warm_up() diff --git a/wav2lip/models/__init__.py b/wav2lip/models/__init__.py index 4374370..b0eb492 100644 --- a/wav2lip/models/__init__.py +++ b/wav2lip/models/__init__.py @@ -1,2 +1,2 @@ -from .wav2lip import Wav2Lip, Wav2Lip_disc_qual -from .syncnet import SyncNet_color \ No newline at end of file +from .wav2lip import Wav2Lip #, Wav2Lip_disc_qual +#from .syncnet import SyncNet_color \ No newline at end of file diff --git a/wav2lip/models/conv_384.py b/wav2lip/models/conv_384.py new file mode 100644 index 0000000..6fcda7a --- /dev/null +++ b/wav2lip/models/conv_384.py @@ -0,0 +1,44 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + +class nonorm_Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + ) + self.act = nn.LeakyReLU(negative_slope=0.01) + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) + +class Conv2dTranspose(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + + def forward(self, x): + out = self.conv_block(x) + return self.act(out) diff --git a/wav2lip/models/wav2lip.py b/wav2lip/models/wav2lip.py index ae5d691..f6511dd 100644 --- a/wav2lip/models/wav2lip.py +++ b/wav2lip/models/wav2lip.py @@ -1,184 +1,198 @@ -import torch -from torch import nn -from torch.nn import functional as F -import math - -from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d - -class Wav2Lip(nn.Module): - def __init__(self): - super(Wav2Lip, self).__init__() - - self.face_encoder_blocks = nn.ModuleList([ - nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 - - nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), - - nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), - - nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), - - nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 - Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), - - nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 - Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), - - nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 - Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) - - self.audio_encoder = nn.Sequential( - Conv2d(1, 32, kernel_size=3, stride=1, padding=1), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(64, 128, kernel_size=3, stride=3, padding=1), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), - Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), - - Conv2d(256, 512, kernel_size=3, stride=1, padding=0), - Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) - - self.face_decoder_blocks = nn.ModuleList([ - nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),), - - nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3 - Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), - - nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), - Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 - - nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), - Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12 - - nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), - Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24 - - nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48 - - nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), - Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 - - self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), - nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), - nn.Sigmoid()) - - def forward(self, audio_sequences, face_sequences): - # audio_sequences = (B, T, 1, 80, 16) - B = audio_sequences.size(0) - - input_dim_size = len(face_sequences.size()) - if input_dim_size > 4: - audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) - face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) - - audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 - - feats = [] - x = face_sequences - for f in self.face_encoder_blocks: - x = f(x) - feats.append(x) - - x = audio_embedding - for f in self.face_decoder_blocks: - x = f(x) - try: - x = torch.cat((x, feats[-1]), dim=1) - except Exception as e: - print(x.size()) - print(feats[-1].size()) - raise e - - feats.pop() - - x = self.output_block(x) - - if input_dim_size > 4: - x = torch.split(x, B, dim=0) # [(B, C, H, W)] - outputs = torch.stack(x, dim=2) # (B, C, T, H, W) - - else: - outputs = x - - return outputs - -class Wav2Lip_disc_qual(nn.Module): - def __init__(self): - super(Wav2Lip_disc_qual, self).__init__() - - self.face_encoder_blocks = nn.ModuleList([ - nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96 - - nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48 - nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)), - - nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24 - nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)), - - nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12 - nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)), - - nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6 - nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)), - - nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3 - nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),), - - nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 - nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) - - self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) - self.label_noise = .0 - - def get_lower_half(self, face_sequences): - return face_sequences[:, :, face_sequences.size(2)//2:] - - def to_2d(self, face_sequences): - B = face_sequences.size(0) - face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) - return face_sequences - - def perceptual_forward(self, false_face_sequences): - false_face_sequences = self.to_2d(false_face_sequences) - false_face_sequences = self.get_lower_half(false_face_sequences) - - false_feats = false_face_sequences - for f in self.face_encoder_blocks: - false_feats = f(false_feats) - - false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1), - torch.ones((len(false_feats), 1)).cuda()) - - return false_pred_loss - - def forward(self, face_sequences): - face_sequences = self.to_2d(face_sequences) - face_sequences = self.get_lower_half(face_sequences) - - x = face_sequences - for f in self.face_encoder_blocks: - x = f(x) - - return self.binary_pred(x).view(len(x), -1) +import torch +from torch import nn +from torch.nn import functional as F + +from .conv_384 import Conv2dTranspose, Conv2d, nonorm_Conv2d + + +class SpatialAttention(nn.Module): + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + x = torch.cat([avg_out, max_out], dim=1) + x = self.conv1(x) + return self.sigmoid(x) + + +class SAM(nn.Module): + def __init__(self): + super(SAM, self).__init__() + self.sa = SpatialAttention() + + def forward(self, sp, se): + sp_att = self.sa(sp) + out = se * sp_att + se + return out + + +class Wav2Lip(nn.Module): + def __init__(self, audio_encoder=None): + super(Wav2Lip, self).__init__() + self.sam = SAM() + self.face_encoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3), + Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True)), # 192, 192 + + nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 96, 96 + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 48, 48 + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 24, 24 + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 12, 12 + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6, 6 + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(512, 1024, kernel_size=3, stride=2, padding=1), # 3, 3 + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True)), + + nn.Sequential(Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0), # 1, 1 + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0), + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)), ]) + + + if audio_encoder is None: + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + + Conv2d(256, 512, kernel_size=3, stride=1, padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(512, 1024, kernel_size=3, stride=1, padding=0), + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)) + + else: + self.audio_encoder = audio_encoder + + for p in self.audio_encoder.parameters(): + p.requires_grad = False + + self.audio_refine = nn.Sequential( + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0), + Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)) + + self.face_decoder_blocks = nn.ModuleList([ + nn.Sequential(Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0), ), # + 1024 + + + nn.Sequential(Conv2dTranspose(2048, 1024, kernel_size=3, stride=1, padding=0), # 3,3 + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), ), # + 1024 + + nn.Sequential(Conv2dTranspose(2048, 1024, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True), ), # 6, 6 + 512 + + nn.Sequential(Conv2dTranspose(1536, 768, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(768, 768, kernel_size=3, stride=1, padding=1, residual=True), ), # 12, 12 + 256 + + nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ), # 24, 24 + 128 + + + nn.Sequential(Conv2dTranspose(640, 256, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), ), # 48, 48 + 64 + + nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), ), # 96, 96 + 32 + + nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), ), ]) # 192, 192 + 16 + + self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), + nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), + nn.Sigmoid()) + + def freeze_audio_encoder(self): + for p in self.audio_encoder.parameters(): + p.requires_grad = False + + def forward(self, audio_sequences, face_sequences): + + B = audio_sequences.size(0) + + input_dim_size = len(face_sequences.size()) + if input_dim_size > 4: + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) + + audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 + + + + feats = [] + x = face_sequences + + for f in self.face_encoder_blocks: + x = f(x) + feats.append(x) + + x = audio_embedding + for f in self.face_decoder_blocks: + x = f(x) + try: + x = self.sam(feats[-1], x) + x = torch.cat((x, feats[-1]), dim=1) + except Exception as e: + print(x.size()) + print(feats[-1].size()) + raise e + + feats.pop() + + x = self.output_block(x) + + if input_dim_size > 4: + x = torch.split(x, B, dim=0) # [(B, C, H, W)] + outputs = torch.stack(x, dim=2) # (B, C, T, H, W) + + else: + outputs = x + + return outputs + diff --git a/web/webrtcapi.html b/web/webrtcapi.html index 042a269..b64ba22 100644 --- a/web/webrtcapi.html +++ b/web/webrtcapi.html @@ -102,6 +102,7 @@ fetch('/record', { body: JSON.stringify({ type: 'start_record', + sessionid:parseInt(document.getElementById('sessionid').value), }), headers: { 'Content-Type': 'application/json' @@ -127,6 +128,7 @@ fetch('/record', { body: JSON.stringify({ type: 'end_record', + sessionid:parseInt(document.getElementById('sessionid').value), }), headers: { 'Content-Type': 'application/json'