diff --git a/README.md b/README.md
index ba6c7a6..dfc84f1 100644
--- a/README.md
+++ b/README.md
@@ -9,8 +9,8 @@ Real time interactive streaming digital human, realize audio video synchronous
- 2024.12.8 完善多并发,显存不随并发数增加
- 2024.12.21 添加wav2lip、musetalk模型预热,解决第一次推理卡顿问题。感谢@heimaojinzhangyz
- 2024.12.28 添加数字人模型Ultralight-Digital-Human。 感谢@lijihua2017
-- 2025.1.26 添加wav2lip384开源模型 感谢@不蠢不蠢
- 2025.2.7 添加fish-speech tts
+- 2025.2.21 添加wav2lip256开源模型 感谢@不蠢不蠢
## Features
1. 支持多种数字人模型: ernerf、musetalk、wav2lip、Ultralight-Digital-Human
@@ -33,10 +33,10 @@ conda activate nerfstream
#如果cuda版本不为11.3(运行nvidia-smi确认版本),根据安装对应版本的pytorch
conda install pytorch==1.12.1 torchvision==0.13.1 cudatoolkit=11.3 -c pytorch
pip install -r requirements.txt
-#如果不训练ernerf模型,不需要安装下面的库
-pip install "git+https://github.com/facebookresearch/pytorch3d.git"
-pip install tensorflow-gpu==2.8.0
-pip install --upgrade "protobuf<=3.20.1"
+#如果需要训练ernerf模型,安装下面的库
+# pip install "git+https://github.com/facebookresearch/pytorch3d.git"
+# pip install tensorflow-gpu==2.8.0
+# pip install --upgrade "protobuf<=3.20.1"
```
安装常见问题[FAQ](https://livetalking-doc.readthedocs.io/en/latest/faq.html)
linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/674972886
@@ -46,11 +46,12 @@ linux cuda环境搭建可以参考这篇文章 https://zhuanlan.zhihu.com/p/6749
- 下载模型
百度云盘 密码: ltua
GoogleDriver
-将wav2lip384.pth拷到本项目的models下, 重命名为wav2lip.pth;
-将wav2lip384_avatar1.tar.gz解压后整个文件夹拷到本项目的data/avatars下
+将wav2lip256.pth拷到本项目的models下, 重命名为wav2lip.pth;
+将wav2lip256_avatar1.tar.gz解压后整个文件夹拷到本项目的data/avatars下
- 运行
-python app.py --transport webrtc --model wav2lip --avatar_id wav2lip384_avatar1
-用浏览器打开http://serverip:8010/webrtcapi.html , 在文本框输入任意文字,提交。数字人播报该段文字
+python app.py --transport webrtc --model wav2lip --avatar_id wav2lip256_avatar1
+用浏览器打开http://serverip:8010/webrtcapi.html , 先点‘start',播放数字人视频;然后在文本框输入任意文字,提交。数字人播报该段文字
+如果需要商用高清wav2lip模型,可以与我联系购买
服务端需要开放端口 tcp:8010; udp:1-65536
diff --git a/app.py b/app.py
index 07b0a4d..2adac66 100644
--- a/app.py
+++ b/app.py
@@ -478,7 +478,7 @@ if __name__ == '__main__':
print(opt)
model = load_model("./models/wav2lip.pth")
avatar = load_avatar(opt.avatar_id)
- warm_up(opt.batch_size,model,384)
+ warm_up(opt.batch_size,model,256)
# for k in range(opt.max_session):
# opt.sessionid=k
# nerfreal = LipReal(opt,model)
diff --git a/wav2lip/models/__init__.py b/wav2lip/models/__init__.py
index b0eb492..9f185e6 100644
--- a/wav2lip/models/__init__.py
+++ b/wav2lip/models/__init__.py
@@ -1,2 +1,2 @@
-from .wav2lip import Wav2Lip #, Wav2Lip_disc_qual
-#from .syncnet import SyncNet_color
\ No newline at end of file
+from .wav2lip_v2 import Wav2Lip, Wav2Lip_disc_qual
+from .syncnet import SyncNet_color
\ No newline at end of file
diff --git a/wav2lip/models/conv_384.py b/wav2lip/models/conv_384.py
deleted file mode 100644
index 6fcda7a..0000000
--- a/wav2lip/models/conv_384.py
+++ /dev/null
@@ -1,44 +0,0 @@
-import torch
-from torch import nn
-from torch.nn import functional as F
-
-class Conv2d(nn.Module):
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.conv_block = nn.Sequential(
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
- nn.BatchNorm2d(cout)
- )
- self.act = nn.ReLU()
- self.residual = residual
-
- def forward(self, x):
- out = self.conv_block(x)
- if self.residual:
- out += x
- return self.act(out)
-
-class nonorm_Conv2d(nn.Module):
- def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.conv_block = nn.Sequential(
- nn.Conv2d(cin, cout, kernel_size, stride, padding),
- )
- self.act = nn.LeakyReLU(negative_slope=0.01)
-
- def forward(self, x):
- out = self.conv_block(x)
- return self.act(out)
-
-class Conv2dTranspose(nn.Module):
- def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.conv_block = nn.Sequential(
- nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
- nn.BatchNorm2d(cout)
- )
- self.act = nn.ReLU()
-
- def forward(self, x):
- out = self.conv_block(x)
- return self.act(out)
diff --git a/wav2lip/models/syncnet.py b/wav2lip/models/syncnet.py
index e773cdc..15ac672 100644
--- a/wav2lip/models/syncnet.py
+++ b/wav2lip/models/syncnet.py
@@ -1,7 +1,7 @@
import torch
from torch import nn
from torch.nn import functional as F
-
+import pdb
from .conv import Conv2d
class SyncNet_color(nn.Module):
@@ -56,10 +56,10 @@ class SyncNet_color(nn.Module):
face_embedding = self.face_encoder(face_sequences)
audio_embedding = self.audio_encoder(audio_sequences)
- audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
- face_embedding = face_embedding.view(face_embedding.size(0), -1)
+ audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)#[4, 512]
+ face_embedding = face_embedding.view(face_embedding.size(0), -1) #[4, 512]
- audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
+ audio_embedding = F.normalize(audio_embedding, p=2, dim=1) #按照宽度方向进行l2归一化
face_embedding = F.normalize(face_embedding, p=2, dim=1)
diff --git a/wav2lip/models/wav2lip_v2.py b/wav2lip/models/wav2lip_v2.py
new file mode 100644
index 0000000..297a183
--- /dev/null
+++ b/wav2lip/models/wav2lip_v2.py
@@ -0,0 +1,223 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+import pdb
+from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d
+
+
+class Wav2Lip(nn.Module):
+ def __init__(self):
+ super(Wav2Lip, self).__init__()
+
+ self.face_encoder_blocks = nn.ModuleList([
+ nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)),
+
+ nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)),
+
+ nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
+
+ nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
+
+ nn.Sequential(Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0)), ])
+
+ self.audio_encoder = nn.Sequential(
+ Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+ Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=1, stride=1, padding=0), )
+
+ self.face_decoder_blocks = nn.ModuleList([
+ nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0), ),
+
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=4, stride=1, padding=0),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
+
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
+
+ nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), ),
+
+ nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), ),
+
+ nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), ),
+
+ nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), ),
+
+ nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+ Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), ), ])
+
+ self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1),
+ nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0),
+ nn.Sigmoid())
+
+ def audio_forward(self, audio_sequences, a_alpha=1.):
+ audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
+ if a_alpha != 1.:
+ audio_embedding *= a_alpha
+ return audio_embedding
+
+ def inference(self, audio_embedding, face_sequences):
+ feats = []
+ x = face_sequences
+ for f in self.face_encoder_blocks:
+ x = f(x)
+ feats.append(x)
+
+ x = audio_embedding
+ for f in self.face_decoder_blocks:
+ x = f(x)
+ try:
+ x = torch.cat((x, feats[-1]), dim=1)
+ except Exception as e:
+ print(x.size())
+ print(feats[-1].size())
+ raise e
+
+ feats.pop()
+
+ x = self.output_block(x)
+ outputs = x
+
+ return outputs
+
+ def forward(self, audio_sequences, face_sequences, a_alpha=1.):
+ # audio_sequences = (B, T, 1, 80, 16)
+ B = audio_sequences.size(0)
+
+ input_dim_size = len(face_sequences.size())
+ if input_dim_size > 4:
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)#[bz, 5, 1, 80, 16]->[bz*5, 1, 80, 16]
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)#[bz, 6, 5, 256, 256]->[bz*5, 6, 256, 256]
+
+ audio_embedding = self.audio_encoder(audio_sequences) # [bz*5, 1, 80, 16]->[bz*5, 512, 1, 1]
+ if a_alpha != 1.:
+ audio_embedding *= a_alpha #放大音频强度
+
+ feats = []
+ x = face_sequences
+ for f in self.face_encoder_blocks:
+ x = f(x)
+ feats.append(x)
+
+ x = audio_embedding
+ for f in self.face_decoder_blocks:
+ x = f(x)
+ try:
+ x = torch.cat((x, feats[-1]), dim=1)
+ except Exception as e:
+ print(x.size())
+ print(feats[-1].size())
+ raise e
+
+ feats.pop()
+
+ x = self.output_block(x) #[bz*5, 80, 256, 256]->[bz*5, 3, 256, 256]
+
+ if input_dim_size > 4: #[bz*5, 3, 256, 256]->[B, 3, 5, 256, 256]
+ x = torch.split(x, B, dim=0)
+ outputs = torch.stack(x, dim=2)
+
+ else:
+ outputs = x
+
+ return outputs
+
+
+class Wav2Lip_disc_qual(nn.Module):
+ def __init__(self):
+ super(Wav2Lip_disc_qual, self).__init__()
+
+ self.face_encoder_blocks = nn.ModuleList([
+ nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)),
+
+ nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2),
+ nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
+
+ nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2),
+ nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
+
+ nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2),
+ nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
+
+ nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
+
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1), ),
+
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
+ nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1), ),
+
+ nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=4, stride=1, padding=0),
+ nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)), ])
+
+ self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
+ self.label_noise = .0
+
+ def get_lower_half(self, face_sequences): #取得输入图片的下半部分。
+ return face_sequences[:, :, face_sequences.size(2) // 2:]
+
+ def to_2d(self, face_sequences): #将输入的图片序列连接起来,形成一个二维的tensor。
+ B = face_sequences.size(0)
+ face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
+ return face_sequences
+
+ def perceptual_forward(self, false_face_sequences): #前传生成图像
+ false_face_sequences = self.to_2d(false_face_sequences) #[bz, 3, 5, 256, 256]->[bz*5, 3, 256, 256]
+ false_face_sequences = self.get_lower_half(false_face_sequences)#[bz*5, 3, 256, 256]->[bz*5, 3, 128, 256]
+
+ false_feats = false_face_sequences
+ for f in self.face_encoder_blocks: #[bz*5, 3, 128, 256]->[bz*5, 512, 1, 1]
+ false_feats = f(false_feats)
+
+ return self.binary_pred(false_feats).view(len(false_feats), -1) #[bz*5, 512, 1, 1]->[bz*5, 1, 1]
+
+ def forward(self, face_sequences): #前传真值图像
+ face_sequences = self.to_2d(face_sequences)
+ face_sequences = self.get_lower_half(face_sequences)
+
+ x = face_sequences
+ for f in self.face_encoder_blocks:
+ x = f(x)
+
+ return self.binary_pred(x).view(len(x), -1)