人物图片生成视频

main
fanpt 1 year ago
commit 6dda9868f0

167
.gitignore vendored

@ -0,0 +1,167 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
examples/results/*
gfpgan/*
checkpoints/
results/*
Dockerfile
start_docker.sh

8
.idea/.gitignore vendored

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

@ -0,0 +1,56 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
<serverData>
<paths name="fanpt@192.168.0.102:22 password">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (2)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (3)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (4)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (5)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (6)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
<paths name="fanpt@192.168.0.102:22 password (7)">
<serverdata>
<mappings>
<mapping local="$PROJECT_DIR$" web="/" />
</mappings>
</serverdata>
</paths>
</serverData>
</component>
</project>

@ -0,0 +1,27 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="13">
<item index="0" class="java.lang.String" itemvalue="protobuf" />
<item index="1" class="java.lang.String" itemvalue="transformers" />
<item index="2" class="java.lang.String" itemvalue="tensorboard" />
<item index="3" class="java.lang.String" itemvalue="icetk" />
<item index="4" class="java.lang.String" itemvalue="cpm_kernels" />
<item index="5" class="java.lang.String" itemvalue="peft" />
<item index="6" class="java.lang.String" itemvalue="accelerate" />
<item index="7" class="java.lang.String" itemvalue="torch" />
<item index="8" class="java.lang.String" itemvalue="datasets" />
<item index="9" class="java.lang.String" itemvalue="bitsandbytes" />
<item index="10" class="java.lang.String" itemvalue="ConcurrentLogHandler" />
<item index="11" class="java.lang.String" itemvalue="uwsgi" />
<item index="12" class="java.lang.String" itemvalue="ultralytics" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

@ -0,0 +1,4 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectRootManager" version="2" project-jdk-name="Sadtalker" project-jdk-type="Python SDK" />
</project>

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/SadTalker.iml" filepath="$PROJECT_DIR$/.idea/SadTalker.iml" />
</modules>
</component>
</project>

@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Tencent AI Lab
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

@ -0,0 +1,56 @@
# SadTalker with GFPGAN 图像语音合成
SadTalker with GFPGAN 是一个图像和语音合成项目,它结合了 SadTalker 模型和 GFPGAN 图像增强技术,使用户能够通过图像和语音生成合成视频。
## 安装步骤
1. 下载 [Anaconda](https://www.anaconda.com/products/distribution) 并安装。
2. 设置 pip 源:
```bash
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
```
3. 进入 SadTalker 目录:
```bash
cd path/to/SadTalker
```
4. 创建并激活虚拟环境:
```bash
conda create -n sadtalker python=3.8
conda activate sadtalker
```
5. 安装 PyTorch 和其他依赖:
```bash
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
conda install ffmpeg
pip install -r requirements.txt
```
## 生成合成视频
执行以下命令生成合成视频:
```bash
# 确保将 path/to/your/audio.wav 替换为您的语音文件路径path/to/your/image.png 替换为您的图像文件路径path/to/output 替换为您的输出目录。
python inference.py --driven_audio path/to/your/audio.wav --source_image path/to/your/image.png --result_dir path/to/output --still --preprocess full --enhancer gfpgan
```
| 名称 | 配置 | 默认值 | 说明 |
|:-------------------|:----------------------|:-------------|:-----------------------------------------|
| 增强模式 | `--enhancer` | None | 使用 `gfpgan``RestoreFormer` 通过面部修复网络增强生成的面部 |
| 背景增强模式 | `--background_enhancer` | None | 使用 `realesrgan` 增强整个视频。 |
| 静态模式 | ` --still` | False | 使用与原始图像相同的姿势参数,减少头部运动。 |
| 表达模式 | `--expression_scale` | 1.0 | 较大的值将增强表情动作。 |
| 保存路径 | `--result_dir` | `./results` | 文件将保存在新的位置。 |
| 预处理模式 | `--preprocess` | `crop` | 在裁剪的输入图像上运行并生成结果。其他选择:`resize`,图像将被调整为特定分辨率。`full`,运行完整图像动画,与 `--still` 一起使用以获得更好的结果。|
| 参考模式 (眼部) | `--ref_eyeblink` | None | 视频路径,我们从该参考视频中借用眨眼动作以提供更自然的眉毛运动。|
| 参考模式 (姿势) | `--ref_pose` | None | 视频路径,我们从该头部参考视频中借用姿势。|
| 3D 模式 | `--face3dvis` | False | 需要额外的安装。有关生成3D人脸的更多详细信息请参见 [这里](docs/face3d.md)。|
| 自由视角模式 | `--input_yaw`,<br> `--input_pitch`,<br> `--input_roll` | None | 从单个图像生成新视角或自由视角的4D对话头。有关更多详细信息请参见 [这里](https://github.com/Winfredy/SadTalker#generating-4d-free-view-talking-examples-from-audio-and-a-single-image)。|

@ -0,0 +1,313 @@
<div align="center">
<img src='https://user-images.githubusercontent.com/4397546/229094115-862c747e-7397-4b54-ba4a-bd368bfe2e0f.png' width='500px'/>
<!--<h2> 😭 SadTalker <span style="font-size:12px">Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation </span> </h2> -->
<a href='https://arxiv.org/abs/2211.12194'><img src='https://img.shields.io/badge/ArXiv-PDF-red'></a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;<a href='https://sadtalker.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/vinthony/SadTalker)
<div>
<a target='_blank'>Wenxuan Zhang <sup>*,1,2</sup> </a>&emsp;
<a href='https://vinthony.github.io/' target='_blank'>Xiaodong Cun <sup>*,2</a>&emsp;
<a href='https://xuanwangvc.github.io/' target='_blank'>Xuan Wang <sup>3</sup></a>&emsp;
<a href='https://yzhang2016.github.io/' target='_blank'>Yong Zhang <sup>2</sup></a>&emsp;
<a href='https://xishen0220.github.io/' target='_blank'>Xi Shen <sup>2</sup></a>&emsp; </br>
<a href='https://yuguo-xjtu.github.io/' target='_blank'>Yu Guo<sup>1</sup> </a>&emsp;
<a href='https://scholar.google.com/citations?hl=zh-CN&user=4oXBp9UAAAAJ' target='_blank'>Ying Shan <sup>2</sup> </a>&emsp;
<a target='_blank'>Fei Wang <sup>1</sup> </a>&emsp;
</div>
<br>
<div>
<sup>1</sup> Xi'an Jiaotong University &emsp; <sup>2</sup> Tencent AI Lab &emsp; <sup>3</sup> Ant Group &emsp;
</div>
<br>
<i><strong><a href='https://arxiv.org/abs/2211.12194' target='_blank'>CVPR 2023</a></strong></i>
<br>
<br>
![sadtalker](https://user-images.githubusercontent.com/4397546/222490039-b1f6156b-bf00-405b-9fda-0c9a9156f991.gif)
<b>TL;DR: &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; single portrait image 🙎‍♂️ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp;+ &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; audio 🎤 &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; = &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; talking head video 🎞.</b>
<br>
</div>
## 🔥 Highlight
- 🔥 The extension of the [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is online. Just install it in `extensions -> install from URL -> https://github.com/Winfredy/SadTalker`, checkout more details [here](#sd-webui-extension).
https://user-images.githubusercontent.com/4397546/222513483-89161f58-83d0-40e4-8e41-96c32b47bd4e.mp4
- 🔥 `full image mode` is online! checkout [here](https://github.com/Winfredy/SadTalker#beta-full-bodyimage-generation) for more details.
| still+enhancer in v0.0.1 | still + enhancer in v0.0.2 | [input image @bagbag1815](https://twitter.com/bagbag1815/status/1642754319094108161) |
|:--------------------: |:--------------------: | :----: |
| <video src="https://user-images.githubusercontent.com/48216707/229484996-5d7be64f-2553-4c9e-a452-c5cf0b8ebafe.mp4" type="video/mp4"> </video> | <video src="https://user-images.githubusercontent.com/4397546/230717873-355b7bf3-d3de-49f9-a439-9220e623fce7.mp4" type="video/mp4"> </video> | <img src='./examples/source_image/full_body_2.png' width='380'>
- 🔥 Several new mode, eg, `still mode`, `reference mode`, `resize mode` are online for better and custom applications.
- 🔥 Happy to see our method is used in various talking or singing avatar, checkout these wonderful demos at [bilibili](https://search.bilibili.com/all?keyword=sadtalker&from_source=webtop_search&spm_id_from=333.1007&search_source=3
) and [twitter #sadtalker](https://twitter.com/search?q=%23sadtalker&src=typed_query).
## 📋 Changelog (Previous changelog can be founded [here](docs/changlelog.md))
- __[2023.04.08]__: In v0.0.2, we add a logo watermark to the generated video to prevent abusing since it is very realistic.
- __[2023.04.08]__: v0.0.2, full image animation, adding baidu driver for download checkpoints. Optimizing the logic about enhancer.
- __[2023.04.06]__: stable-diffiusion webui extension is release.
- __[2023.04.03]__: Enable TTS in huggingface and gradio local demo.
- __[2023.03.30]__: Launch beta version of the full body mode.
- __[2023.03.30]__: Launch new feature: through using reference videos, our algorithm can generate videos with more natural eye blinking and some eyebrow movement.
- __[2023.03.29]__: `resize mode` is online by `python infererence.py --preprocess resize`! Where we can produce a larger crop of the image as discussed in https://github.com/Winfredy/SadTalker/issues/35.
- __[2023.03.29]__: local gradio demo is online! `python app.py` to start the demo. New `requirments.txt` is used to avoid the bugs in `librosa`.
- __[2023.03.28]__: Online demo is launched in [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/vinthony/SadTalker), thanks AK!
## 🎼 Pipeline
![main_of_sadtalker](https://user-images.githubusercontent.com/4397546/222490596-4c8a2115-49a7-42ad-a2c3-3bb3288a5f36.png)
> Our method uses the coefficients of 3DMM as intermediate motion representation. To this end, we first generate
realistic 3D motion coefficients (facial expression β, head pose ρ)
from audio, then these coefficients are used to implicitly modulate
the 3D-aware face render for final video generation.
## 🚧 TODO
<details><summary> Previous TODOs </summary>
- [x] Generating 2D face from a single Image.
- [x] Generating 3D face from Audio.
- [x] Generating 4D free-view talking examples from audio and a single image.
- [x] Gradio/Colab Demo.
- [x] Full body/image Generation.
</details>
- [ ] training code of each componments.
- [ ] Audio-driven Anime Avatar.
- [ ] interpolate ChatGPT for a conversation demo 🤔
- [x] integrade with stable-diffusion-web-ui. (stay tunning!)
## ⚙️ Installation ([中文教程](https://www.bilibili.com/video/BV17N411P7m7/?vd_source=653f1e6e187ffc29a9b677b6ed23169a))
#### Installing Sadtalker on Linux:
```bash
git clone https://github.com/Winfredy/SadTalker.git
cd SadTalker
conda create -n sadtalker python=3.8
conda activate sadtalker
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
conda install ffmpeg
pip install -r requirements.txt
### tts is optional for gradio demo.
### pip install TTS
```
More tips about installnation on Windows and the Docker file can be founded [here](docs/install.md)
#### Sd-Webui-Extension:
<details><summary>CLICK ME</summary>
Installing the lastest version of [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) and install the sadtalker via `extension`.
<img width="726" alt="image" src="https://user-images.githubusercontent.com/4397546/230698519-267d1d1f-6e99-4dd4-81e1-7b889259efbd.png">
Then, retarting the stable-diffusion-webui, set some commandline args. The models will be downloaded automatically in the right place. Alternatively, you can add the path of pre-downloaded sadtalker checkpoints to `SADTALKTER_CHECKPOINTS` in `webui_user.sh`(linux) or `webui_user.bat`(windows) by:
```bash
# windows (webui_user.bat)
set COMMANDLINE_ARGS=--no-gradio-queue --disable-safe-unpickle
set SADTALKER_CHECKPOINTS=D:\SadTalker\checkpoints
# linux (webui_user.sh)
export COMMANDLINE_ARGS=--no-gradio-queue --disable-safe-unpickle
export SADTALKER_CHECKPOINTS=/path/to/SadTalker/checkpoints
```
After installation, the SadTalker can be used in stable-diffusion-webui directly.
<img width="726" alt="image" src="https://user-images.githubusercontent.com/4397546/230698614-58015182-2916-4240-b324-e69022ef75b3.png">
</details>
#### Download Trained Models
<details><summary>CLICK ME</summary>
You can run the following script to put all the models in the right place.
```bash
bash scripts/download_models.sh
```
OR download our pre-trained model from [google drive](https://drive.google.com/drive/folders/1Wd88VDoLhVzYsQ30_qDVluQr_Xm46yHT?usp=sharing) or our [github release page](https://github.com/Winfredy/SadTalker/releases/tag/v0.0.1), and then, put it in ./checkpoints.
OR we provided the downloaded model in [百度云盘](https://pan.baidu.com/s/1nXuVNd0exUl37ISwWqbFGA?pwd=sadt) 提取码: sadt.
| Model | Description
| :--- | :----------
|checkpoints/auido2exp_00300-model.pth | Pre-trained ExpNet in Sadtalker.
|checkpoints/auido2pose_00140-model.pth | Pre-trained PoseVAE in Sadtalker.
|checkpoints/mapping_00229-model.pth.tar | Pre-trained MappingNet in Sadtalker.
|checkpoints/mapping_00109-model.pth.tar | Pre-trained MappingNet in Sadtalker.
|checkpoints/facevid2vid_00189-model.pth.tar | Pre-trained face-vid2vid model from [the reappearance of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis).
|checkpoints/epoch_20.pth | Pre-trained 3DMM extractor in [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction).
|checkpoints/wav2lip.pth | Highly accurate lip-sync model in [Wav2lip](https://github.com/Rudrabha/Wav2Lip).
|checkpoints/shape_predictor_68_face_landmarks.dat | Face landmark model used in [dilb](http://dlib.net/).
|checkpoints/BFM | 3DMM library file.
|checkpoints/hub | Face detection models used in [face alignment](https://github.com/1adrianb/face-alignment).
</details>
## 🔮 Quick Start
#### Generating 2D face from a single Image from default config.
```bash
python inference.py --driven_audio <audio.wav> --source_image <video.mp4 or picture.png>
```
The results will be saved in `results/$SOME_TIMESTAMP/*.mp4`.
Or a local gradio demo similar to our [hugging-face demo](https://huggingface.co/spaces/vinthony/SadTalker) can be run by:
```bash
## you need manually install TTS(https://github.com/coqui-ai/TTS) via `pip install tts` in advanced.
python app.py
```
#### Advanced Configuration
<details><summary> Click Me </summary>
| Name | Configuration | default | Explaination |
|:------------- |:------------- |:----- | :------------- |
| Enhance Mode | `--enhancer` | None | Using `gfpgan` or `RestoreFormer` to enhance the generated face via face restoration network
| Background Enhancer | `--background_enhancer` | None | Using `realesrgan` to enhance the full video.
| Still Mode | ` --still` | False | Using the same pose parameters as the original image, fewer head motion.
| Expressive Mode | `--expression_scale` | 1.0 | a larger value will make the expression motion stronger.
| save path | `--result_dir` |`./results` | The file will be save in the newer location.
| preprocess | `--preprocess` | `crop` | Run and produce the results in the croped input image. Other choices: `resize`, where the images will be resized to the specific resolution. `full` Run the full image animation, use with `--still` to get better results.
| ref Mode (eye) | `--ref_eyeblink` | None | A video path, where we borrow the eyeblink from this reference video to provide more natural eyebrow movement.
| ref Mode (pose) | `--ref_pose` | None | A video path, where we borrow the pose from the head reference video.
| 3D Mode | `--face3dvis` | False | Need additional installation. More details to generate the 3d face can be founded [here](docs/face3d.md).
| free-view Mode | `--input_yaw`,<br> `--input_pitch`,<br> `--input_roll` | None | Genearting novel view or free-view 4D talking head from a single image. More details can be founded [here](https://github.com/Winfredy/SadTalker#generating-4d-free-view-talking-examples-from-audio-and-a-single-image).
</details>
#### Examples
| basic | w/ still mode | w/ exp_scale 1.3 | w/ gfpgan |
|:-------------: |:-------------: |:-------------: |:-------------: |
| <video src="https://user-images.githubusercontent.com/4397546/226097707-bef1dd41-403e-48d3-a6e6-6adf923843af.mp4"></video> | <video src='https://user-images.githubusercontent.com/4397546/226804933-b717229f-1919-4bd5-b6af-bea7ab66cad3.mp4'></video> | <video style='width:256px' src="https://user-images.githubusercontent.com/4397546/226806013-7752c308-8235-4e7a-9465-72d8fc1aa03d.mp4"></video> | <video style='width:256px' src="https://user-images.githubusercontent.com/4397546/226097717-12a1a2a1-ac0f-428d-b2cb-bd6917aff73e.mp4"></video> |
> Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.
| Input, w/ reference video , reference video |
|:-------------: |
| ![free_view](docs/using_ref_video.gif)|
| If the reference video is shorter than the input audio, we will loop the reference video .
<!-- <video src="./docs/art_0##japanese_still.mp4"></video> -->
#### Generating 3D face from Audio
| Input | Animated 3d face |
|:-------------: | :-------------: |
| <img src='examples/source_image/art_0.png' width='200px'> | <video src="https://user-images.githubusercontent.com/4397546/226856847-5a6a0a4d-a5ec-49e2-9b05-3206db65e8e3.mp4"></video> |
> Kindly ensure to activate the audio as the default audio playing is incompatible with GitHub.
#### Generating 4D free-view talking examples from audio and a single image
We use `input_yaw`, `input_pitch`, `input_roll` to control head pose. For example, `--input_yaw -20 30 10` means the input head yaw degree changes from -20 to 30 and then changes from 30 to 10.
```bash
python inference.py --driven_audio <audio.wav> \
--source_image <video.mp4 or picture.png> \
--result_dir <a file to store results> \
--input_yaw -20 30 10
```
| Results, Free-view results, Novel view results |
|:-------------: |
| ![free_view](docs/free_view_result.gif)|
#### [Beta Application] Full body/image Generation
Now, you can use `--still` to generate a natural full body video. You can add `enhancer` or `full_img_enhancer` to improve the quality of the generated video. However, if you add other mode, such as `ref_eyeblinking`, `ref_pose`, the result will be bad. We are still trying to fix this problem.
```bash
python inference.py --driven_audio <audio.wav> \
--source_image <video.mp4 or picture.png> \
--result_dir <a file to store results> \
--still \
--preprocess full \
--enhancer gfpgan
```
## 🛎 Citation
If you find our work useful in your research, please consider citing:
```bibtex
@article{zhang2022sadtalker,
title={SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation},
author={Zhang, Wenxuan and Cun, Xiaodong and Wang, Xuan and Zhang, Yong and Shen, Xi and Guo, Yu and Shan, Ying and Wang, Fei},
journal={arXiv preprint arXiv:2211.12194},
year={2022}
}
```
## 💗 Acknowledgements
Facerender code borrows heavily from [zhanglonghao's reproduction of face-vid2vid](https://github.com/zhanglonghao1992/One-Shot_Free-View_Neural_Talking_Head_Synthesis) and [PIRender](https://github.com/RenYurui/PIRender). We thank the authors for sharing their wonderful code. In training process, We also use the model from [Deep3DFaceReconstruction](https://github.com/microsoft/Deep3DFaceReconstruction) and [Wav2lip](https://github.com/Rudrabha/Wav2Lip). We thank for their wonderful work.
## 🥂 Related Works
- [StyleHEAT: One-Shot High-Resolution Editable Talking Face Generation via Pre-trained StyleGAN (ECCV 2022)](https://github.com/FeiiYin/StyleHEAT)
- [CodeTalker: Speech-Driven 3D Facial Animation with Discrete Motion Prior (CVPR 2023)](https://github.com/Doubiiu/CodeTalker)
- [VideoReTalking: Audio-based Lip Synchronization for Talking Head Video Editing In the Wild (SIGGRAPH Asia 2022)](https://github.com/vinthony/video-retalking)
- [DPE: Disentanglement of Pose and Expression for General Video Portrait Editing (CVPR 2023)](https://github.com/Carlyx/DPE)
- [3D GAN Inversion with Facial Symmetry Prior (CVPR 2023)](https://github.com/FeiiYin/SPI/)
- [T2M-GPT: Generating Human Motion from Textual Descriptions with Discrete Representations (CVPR 2023)](https://github.com/Mael-zys/T2M-GPT)
## 📢 Disclaimer
This is not an official product of Tencent. This repository can only be used for personal/research/non-commercial purposes.
LOGO: color and font suggestion: [ChatGPT](ai.com), logo font[Montserrat Alternates
](https://fonts.google.com/specimen/Montserrat+Alternates?preview.text=SadTalker&preview.text_type=custom&query=mont).
All the copyright demo images are from communities users or the geneartion from stable diffusion. Free free to contact us if you feel uncomfortable.

107
app.py

@ -0,0 +1,107 @@
import os, sys
import tempfile
import gradio as gr
from src.gradio_demo import SadTalker
from src.utils.text2speech import TTSTalker
def get_source_image(image):
return image
def sadtalker_demo():
sad_talker = SadTalker()
tts_talker = TTSTalker()
with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
<a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
<a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a> &nbsp;&nbsp;&nbsp;&nbsp;&nbsp; \
<a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>")
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_source_image"):
with gr.TabItem('Upload image'):
with gr.Row():
source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
with gr.Tabs(elem_id="sadtalker_driven_audio"):
with gr.TabItem('Upload OR TTS'):
with gr.Column(variant='panel'):
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
with gr.Column(variant='panel'):
input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_checkbox"):
with gr.TabItem('Settings'):
with gr.Column(variant='panel'):
is_still_mode = gr.Checkbox(label="w/ Still Mode (fewer hand motion, works on full body)")
enhancer = gr.Checkbox(label="w/ GFPGAN as Face enhancer")
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
with gr.Tabs(elem_id="sadtalker_genearted"):
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
with gr.Row():
examples = [
[
'examples/source_image/full_body_1.png',
'examples/driven_audio/bus_chinese.wav',
True,
False
],
[
'examples/source_image/full_body_2.png',
'examples/driven_audio/itosinger1.wav',
True,
False
],
[
'examples/source_image/art_13.png',
'examples/driven_audio/fayu.wav',
True,
False
],
[
'examples/source_image/art_5.png',
'examples/driven_audio/chinese_news.wav',
True,
False
],
]
gr.Examples(examples=examples,
inputs=[
source_image,
driven_audio,
is_still_mode,
enhancer],
outputs=[gen_video],
fn=sad_talker.test,
cache_examples=os.getenv('SYSTEM') == 'spaces')
submit.click(
fn=sad_talker.test,
inputs=[source_image,
driven_audio,
is_still_mode,
enhancer],
outputs=[gen_video]
)
return sadtalker_interface
if __name__ == "__main__":
demo = sadtalker_demo()
demo.launch()

@ -0,0 +1,14 @@
## changelogs
- __[2023.03.22]__: Launch new feature: generating the 3d face animation from a single image. New applications about it will be updated.
- __[2023.03.22]__: Launch new feature: `still mode`, where only a small head pose will be produced via `python inference.py --still`.
- __[2023.03.18]__: Support `expression intensity`, now you can change the intensity of the generated motion: `python inference.py --expression_scale 1.3 (some value > 1)`.
- __[2023.03.18]__: Reconfig the data folders, now you can download the checkpoint automatically using `bash scripts/download_models.sh`.
- __[2023.03.18]__: We have offically integrate the [GFPGAN](https://github.com/TencentARC/GFPGAN) for face enhancement, using `python inference.py --enhancer gfpgan` for better visualization performance.
- __[2023.03.14]__: Specify the version of package `joblib` to remove the errors in using `librosa`, [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) is online!
- __[2023.03.06]__: Solve some bugs in code and errors in installation
- __[2023.03.03]__: Release the test code for audio-driven single image animation!
- __[2023.02.28]__: SadTalker has been accepted by CVPR 2023!

@ -0,0 +1,48 @@
## 3D Face visualization
We use pytorch3d to visualize the produced 3d face from a single image.
Since it is not easy to install, we produce a new install guidence here:
```bash
git clone https://github.com/Winfredy/SadTalker.git
cd SadTalker
conda create -n sadtalker3d python=3.8
source activate sadtalker3d
conda install ffmpeg
conda install -c fvcore -c iopath -c conda-forge fvcore iopath
conda install libgcc gmp
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
# insintall pytorch3d
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html
pip install -r requirements3d.txt
### install gpfgan for enhancer
pip install git+https://github.com/TencentARC/GFPGAN
### when occurs gcc version problem `from pytorch import _C` from pytorch3d, add the anaconda path to LD_LIBRARY_PATH
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/$YOUR_ANACONDA_PATH/lib/
```
Then, generating the result via:
```bash
python inference.py --driven_audio <audio.wav> \
--source_image <video.mp4 or picture.png> \
--result_dir <a file to store results> \
--face3dvis
```
Then, the results will be given in the folders with the file name of `face3d.mp4`.
More applications about 3d face will be released.

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.4 MiB

@ -0,0 +1,25 @@
### Windows Native
- Make sure you have `ffmpeg` in the `%PATH%` as suggested in [#54](https://github.com/Winfredy/SadTalker/issues/54), following [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) installation to install `ffmpeg`.
### Windows WSL
- Make sure the environment: `export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH`
### Docker installnation
A dockerfile are also provided by [@thegenerativegeneration](https://github.com/thegenerativegeneration) in [docker hub](https://hub.docker.com/repository/docker/wawa9000/sadtalker), which can be used directly as:
```bash
docker run --gpus "all" --rm -v $(pwd):/host_dir wawa9000/sadtalker \
--driven_audio /host_dir/deyu.wav \
--source_image /host_dir/image.jpg \
--expression_scale 1.0 \
--still \
--result_dir /host_dir
```

Binary file not shown.

After

Width:  |  Height:  |  Size: 34 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.7 MiB

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 733 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 478 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 556 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 478 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 704 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 617 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 635 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 657 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 115 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 462 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 812 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 694 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.5 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 98 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 509 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.0 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 617 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 26 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 122 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 134 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 238 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

@ -0,0 +1,159 @@
import torch
from time import strftime
import os, sys, time
from argparse import ArgumentParser
from src.utils.preprocess import CropAndExtract
from src.test_audio2coeff import Audio2Coeff
from src.facerender.animate import AnimateFromCoeff
from src.generate_batch import get_data
from src.generate_facerender_batch import get_facerender_data
def main(args):
#torch.backends.cudnn.enabled = False
pic_path = args.source_image
audio_path = args.driven_audio
save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
os.makedirs(save_dir, exist_ok=True)
pose_style = args.pose_style
device = args.device
batch_size = args.batch_size
input_yaw_list = args.input_yaw
input_pitch_list = args.input_pitch
input_roll_list = args.input_roll
ref_eyeblink = args.ref_eyeblink
ref_pose = args.ref_pose
current_code_path = sys.argv[0]
current_root_path = os.path.split(current_code_path)[0]
os.environ['TORCH_HOME']=os.path.join(current_root_path, args.checkpoint_dir)
path_of_lm_croper = os.path.join(current_root_path, args.checkpoint_dir, 'shape_predictor_68_face_landmarks.dat')
path_of_net_recon_model = os.path.join(current_root_path, args.checkpoint_dir, 'epoch_20.pth')
dir_of_BFM_fitting = os.path.join(current_root_path, args.checkpoint_dir, 'BFM_Fitting')
wav2lip_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'wav2lip.pth')
audio2pose_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2pose_00140-model.pth')
audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
audio2exp_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2exp_00300-model.pth')
audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar')
if args.preprocess == 'full':
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00109-model.pth.tar')
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender_still.yaml')
else:
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar')
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml')
#init model
print(path_of_net_recon_model)
preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)
print(audio2pose_checkpoint)
print(audio2exp_checkpoint)
audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
audio2exp_checkpoint, audio2exp_yaml_path,
wav2lip_checkpoint, device)
print(free_view_checkpoint)
print(mapping_checkpoint)
animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
facerender_yaml_path, device)
#crop image and extract 3dmm from image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
os.makedirs(first_frame_dir, exist_ok=True)
print('3DMM Extraction for source image')
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess)
if first_coeff_path is None:
print("Can't get the coeffs of the input")
return
if ref_eyeblink is not None:
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing eye blinking')
ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir)
else:
ref_eyeblink_coeff_path=None
if ref_pose is not None:
if ref_pose == ref_eyeblink:
ref_pose_coeff_path = ref_eyeblink_coeff_path
else:
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
os.makedirs(ref_pose_frame_dir, exist_ok=True)
print('3DMM Extraction for the reference video providing pose')
ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir)
else:
ref_pose_coeff_path=None
#audio2ceoff
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
# 3dface render
if args.face3dvis:
from src.face3d.visualize import gen_composed_video
gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
#coeff2video
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
batch_size, input_yaw_list, input_pitch_list, input_roll_list,
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess)
animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--driven_audio", default='./examples/driven_audio/eluosi.wav', help="path to driven audio")
parser.add_argument("--source_image", default='./examples/source_image/full3.png', help="path to source image")
parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
parser.add_argument("--result_dir", default='./results', help="path to output")
parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
parser.add_argument("--cpu", dest="cpu", action="store_true")
parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
parser.add_argument("--still", action="store_true", help="can crop back to the orginal videos for the full body aniamtion")
parser.add_argument("--preprocess", default='crop', choices=['crop', 'resize', 'full'], help="how to preprocess the images" )
# net structure and parameters
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
parser.add_argument('--init_path', type=str, default=None, help='Useless')
parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
# default renderer parameters
parser.add_argument('--focal', type=float, default=1015.)
parser.add_argument('--center', type=float, default=112.)
parser.add_argument('--camera_d', type=float, default=10.)
parser.add_argument('--z_near', type=float, default=5.)
parser.add_argument('--z_far', type=float, default=15.)
args = parser.parse_args()
if torch.cuda.is_available() and not args.cpu:
args.device = "cuda"
else:
args.device = "cpu"
main(args)

@ -0,0 +1,208 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "M74Gs_TjYl_B"
},
"source": [
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github"
},
"source": [
"### SadTalkerLearning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation \n",
"\n",
"[arxiv](https://arxiv.org/abs/2211.12194) | [project](https://sadtalker.github.io) | [Github](https://github.com/Winfredy/SadTalker)\n",
"\n",
"Wenxuan Zhang, Xiaodong Cun, Xuan Wang, Yong Zhang, Xi Shen, Yu Guo, Ying Shan, Fei Wang.\n",
"\n",
"Xi'an Jiaotong University, Tencent AI Lab, Ant Group\n",
"\n",
"CVPR 2023\n",
"\n",
"TL;DR: A realistic and stylized talking head video generation method from a single image and audio\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kA89DV-sKS4i"
},
"source": [
"Installation (around 5 mins)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qJ4CplXsYl_E"
},
"outputs": [],
"source": [
"### make sure that CUDA is available in Edit -> Nootbook settings -> GPU\n",
"!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Mdq6j4E5KQAR"
},
"outputs": [],
"source": [
"!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.8 2 \n",
"!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.9 1 \n",
"!python --version \n",
"!apt-get update\n",
"!apt install software-properties-common\n",
"!sudo dpkg --remove --force-remove-reinstreq python3-pip python3-setuptools python3-wheel\n",
"!apt-get install python3-pip\n",
"\n",
"print('Git clone project and install requirements...')\n",
"!git clone https://github.com/Winfredy/SadTalker &> /dev/null\n",
"%cd SadTalker \n",
"!export PYTHONPATH=/content/SadTalker:$PYTHONPATH \n",
"!python3.8 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\n",
"!apt update\n",
"!apt install ffmpeg &> /dev/null \n",
"!python3.8 -m pip install -r requirements.txt"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DddcKB_nKsnk"
},
"source": [
"Download models (1 mins)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "eDw3_UN8K2xa"
},
"outputs": [],
"source": [
"print('Download pre-trained models...')\n",
"!rm -rf checkpoints\n",
"!bash scripts/download_models.sh"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kK7DYeo7Yl_H"
},
"outputs": [],
"source": [
"# borrow from makeittalk\n",
"import ipywidgets as widgets\n",
"import glob\n",
"import matplotlib.pyplot as plt\n",
"print(\"Choose the image name to animate: (saved in folder 'examples/')\")\n",
"img_list = glob.glob1('examples/source_image', '*.png')\n",
"img_list.sort()\n",
"img_list = [item.split('.')[0] for item in img_list]\n",
"default_head_name = widgets.Dropdown(options=img_list, value='full3')\n",
"def on_change(change):\n",
" if change['type'] == 'change' and change['name'] == 'value':\n",
" plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
" plt.axis('off')\n",
" plt.show()\n",
"default_head_name.observe(on_change)\n",
"display(default_head_name)\n",
"plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
"plt.axis('off')\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-khNZcnGK4UK"
},
"source": [
"Animation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ToBlDusjK5sS"
},
"outputs": [],
"source": [
"# selected audio from exmaple/driven_audio\n",
"img = 'examples/source_image/{}.png'.format(default_head_name.value)\n",
"print(img)\n",
"!python3.8 inference.py --driven_audio ./examples/driven_audio/RD_Radio31_000.wav \\\n",
" --source_image {img} \\\n",
" --result_dir ./results --still --preprocess full --enhancer gfpgan"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fAjwGmKKYl_I"
},
"outputs": [],
"source": [
"# visualize code from makeittalk\n",
"from IPython.display import HTML\n",
"from base64 import b64encode\n",
"import os, sys\n",
"\n",
"# get the last from results\n",
"\n",
"results = sorted(os.listdir('./results/'))\n",
"\n",
"mp4_name = glob.glob('./results/'+results[-1]+'/*.mp4')[0]\n",
"\n",
"mp4 = open('{}'.format(mp4_name),'rb').read()\n",
"data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
"\n",
"print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
"display(HTML(\"\"\"\n",
" <video width=256 controls>\n",
" <source src=\"%s\" type=\"video/mp4\">\n",
" </video>\n",
" \"\"\" % data_url))\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "base",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.7"
},
"vscode": {
"interpreter": {
"hash": "db5031b3636a3f037ea48eb287fd3d023feb9033aefc2a9652a92e470fb0851b"
}
},
"accelerator": "GPU",
"gpuClass": "standard"
},
"nbformat": 4,
"nbformat_minor": 0
}

@ -0,0 +1,20 @@
numpy==1.23.4
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.9.2 #
numba
resampy==0.3.1
pydub==0.25.1
scipy==1.5.3
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.2.5
gradio
gfpgan
dlib-bin

@ -0,0 +1,21 @@
numpy==1.23.4
face_alignment==1.3.5
imageio==2.19.3
imageio-ffmpeg==0.4.7
librosa==0.9.2 #
numba
resampy==0.3.1
pydub==0.25.1
scipy==1.5.3
kornia==0.6.8
tqdm
yacs==0.1.8
pyyaml
joblib==1.1.0
scikit-image==0.19.3
basicsr==1.4.2
facexlib==0.2.5
trimesh==3.9.20
dlib-bin
gradio
gfpgan

@ -0,0 +1,14 @@
mkdir ./checkpoints
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
unzip -n ./checkpoints/hub.zip -d ./checkpoints/
unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/

@ -0,0 +1,133 @@
import os, sys
from pathlib import Path
import tempfile
import gradio as gr
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
from modules.shared import opts, OptionInfo
from modules import shared, paths, script_callbacks
import launch
import glob
def get_source_image(image):
return image
def get_img_from_txt2img(x):
talker_path = Path(paths.script_path) / "outputs"
imgs_from_txt_dir = str(talker_path / "txt2img-images/")
imgs = glob.glob(imgs_from_txt_dir+'/*/*.png')
imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_txt_dir, x)))
img_from_txt_path = os.path.join(imgs_from_txt_dir, imgs[-1])
return img_from_txt_path, img_from_txt_path
def get_img_from_img2img(x):
talker_path = Path(paths.script_path) / "outputs"
imgs_from_img_dir = str(talker_path / "img2img-images/")
imgs = glob.glob(imgs_from_img_dir+'/*/*.png')
imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_img_dir, x)))
img_from_img_path = os.path.join(imgs_from_img_dir, imgs[-1])
return img_from_img_path, img_from_img_path
def install():
kv = {
"face-alignment": "face-alignment==1.3.5",
"imageio": "imageio==2.19.3",
"imageio-ffmpeg": "imageio-ffmpeg==0.4.7",
"librosa":"librosa==0.8.0",
"pydub":"pydub==0.25.1",
"scipy":"scipy==1.8.1",
"tqdm": "tqdm",
"yacs":"yacs==0.1.8",
"pyyaml": "pyyaml",
"dlib": "dlib-bin",
"gfpgan": "gfpgan",
}
for k,v in kv.items():
print(k, launch.is_installed(k))
if not launch.is_installed(k):
launch.run_pip("install "+ v, "requirements for SadTalker")
if os.getenv('SADTALKER_CHECKPOINTS'):
print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS'))
else:
### run the scripts to downlod models to correct localtion.
print('download models for SadTalker')
launch.run("cd " + paths.script_path+"/extensions/SadTalker && bash ./scripts/download_models.sh", live=True)
print('SadTalker is successfully installed!')
def on_ui_tabs():
install()
sys.path.extend([paths.script_path+'/extensions/SadTalker'])
repo_dir = paths.script_path+'/extensions/SadTalker/'
result_dir = opts.sadtalker_result_dir
os.makedirs(result_dir, exist_ok=True)
from src.gradio_demo import SadTalker
if os.getenv('SADTALKER_CHECKPOINTS'):
checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS')
else:
checkpoint_path = repo_dir+'checkpoints/'
sad_talker = SadTalker(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', lazy_load=True)
with gr.Blocks(analytics_enabled=False) as audio_to_video:
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_source_image"):
with gr.TabItem('Upload image'):
with gr.Row():
input_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=512,width=512)
with gr.Row():
submit_image2 = gr.Button('load From txt2img', variant='primary')
submit_image2.click(fn=get_img_from_txt2img, inputs=input_image, outputs=[input_image, input_image])
submit_image3 = gr.Button('load from img2img', variant='primary')
submit_image3.click(fn=get_img_from_img2img, inputs=input_image, outputs=[input_image, input_image])
with gr.Tabs(elem_id="sadtalker_driven_audio"):
with gr.TabItem('Upload'):
with gr.Column(variant='panel'):
with gr.Row():
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
with gr.Column(variant='panel'):
with gr.Tabs(elem_id="sadtalker_checkbox"):
with gr.TabItem('Settings'):
with gr.Column(variant='panel'):
is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion)").style(container=True)
is_enhance_mode = gr.Checkbox(label="Enhance Mode (better face quality )").style(container=True)
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
with gr.Tabs(elem_id="sadtalker_genearted"):
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
### gradio gpu call will always return the html,
submit.click(
fn=wrap_queued_call(sad_talker.test),
inputs=[input_image,
driven_audio,
is_still_mode,
is_enhance_mode],
outputs=[gen_video, ]
)
return [(audio_to_video, "SadTalker", "extension")]
def on_ui_settings():
talker_path = Path(paths.script_path) / "outputs"
section = ('extension', "SadTalker")
opts.add_option("sadtalker_result_dir", OptionInfo(str(talker_path / "SadTalker/"), "Path to save results of sadtalker", section=section))
script_callbacks.on_ui_settings(on_ui_settings)
script_callbacks.on_ui_tabs(on_ui_tabs)

@ -0,0 +1,41 @@
from tqdm import tqdm
import torch
from torch import nn
class Audio2Exp(nn.Module):
def __init__(self, netG, cfg, device, prepare_training_loss=False):
super(Audio2Exp, self).__init__()
self.cfg = cfg
self.device = device
self.netG = netG.to(device)
def test(self, batch):
mel_input = batch['indiv_mels'] # bs T 1 80 16
bs = mel_input.shape[0]
T = mel_input.shape[1]
exp_coeff_pred = []
for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
current_mel_input = mel_input[:,i:i+10]
#ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
ref = batch['ref'][:, :, :64][:, i:i+10]
ratio = batch['ratio_gt'][:, i:i+10] #bs T
audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
exp_coeff_pred += [curr_exp_coeff_pred]
# BS x T x 64
results_dict = {
'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
}
return results_dict

@ -0,0 +1,74 @@
import torch
import torch.nn.functional as F
from torch import nn
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
self.use_act = use_act
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
if self.use_act:
return self.act(out)
else:
return out
class SimpleWrapperV2(nn.Module):
def __init__(self) -> None:
super().__init__()
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
)
#### load the pre-trained audio_encoder
#self.audio_encoder = self.audio_encoder.to(device)
'''
wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
state_dict = self.audio_encoder.state_dict()
for k,v in wav2lip_state_dict.items():
if 'audio_encoder' in k:
print('init:', k)
state_dict[k.replace('module.audio_encoder.', '')] = v
self.audio_encoder.load_state_dict(state_dict)
'''
self.mapping1 = nn.Linear(512+64+1, 64)
#self.mapping2 = nn.Linear(30, 64)
#nn.init.constant_(self.mapping1.weight, 0.)
nn.init.constant_(self.mapping1.bias, 0.)
def forward(self, x, ref, ratio):
x = self.audio_encoder(x).view(x.size(0), -1)
ref_reshape = ref.reshape(x.size(0), -1)
ratio = ratio.reshape(x.size(0), -1)
y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
return out

@ -0,0 +1,94 @@
import torch
from torch import nn
from src.audio2pose_models.cvae import CVAE
from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
from src.audio2pose_models.audio_encoder import AudioEncoder
class Audio2Pose(nn.Module):
def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
super().__init__()
self.cfg = cfg
self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
self.device = device
self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
self.audio_encoder.eval()
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.netG = CVAE(cfg)
self.netD_motion = PoseSequenceDiscriminator(cfg)
def forward(self, x):
batch = {}
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6
batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6
batch['class'] = x['class'].squeeze(0).cuda() # bs
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
# forward
audio_emb_list = []
audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
batch['audio_emb'] = audio_emb
batch = self.netG(batch)
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6
pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6
batch['pose_pred'] = pose_pred
batch['pose_gt'] = pose_gt
return batch
def test(self, x):
batch = {}
ref = x['ref'] #bs 1 70
batch['ref'] = x['ref'][:,0,-6:]
batch['class'] = x['class']
bs = ref.shape[0]
indiv_mels= x['indiv_mels'] # bs T 1 80 16
indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
num_frames = x['num_frames']
num_frames = int(num_frames) - 1
#
div = num_frames//self.seq_len
re = num_frames%self.seq_len
audio_emb_list = []
pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
device=batch['ref'].device)]
for i in range(div):
z = torch.randn(bs, self.latent_dim).to(ref.device)
batch['z'] = z
audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
batch['audio_emb'] = audio_emb
batch = self.netG.test(batch)
pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
if re != 0:
z = torch.randn(bs, self.latent_dim).to(ref.device)
batch['z'] = z
audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
if audio_emb.shape[1] != self.seq_len:
pad_dim = self.seq_len-audio_emb.shape[1]
pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
batch['audio_emb'] = audio_emb
batch = self.netG.test(batch)
pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
batch['pose_motion_pred'] = pose_motion_pred
pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
batch['pose_pred'] = pose_pred
return batch

@ -0,0 +1,64 @@
import torch
from torch import nn
from torch.nn import functional as F
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(
nn.Conv2d(cin, cout, kernel_size, stride, padding),
nn.BatchNorm2d(cout)
)
self.act = nn.ReLU()
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act(out)
class AudioEncoder(nn.Module):
def __init__(self, wav2lip_checkpoint, device):
super(AudioEncoder, self).__init__()
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
#### load the pre-trained audio_encoder
wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
state_dict = self.audio_encoder.state_dict()
for k,v in wav2lip_state_dict.items():
if 'audio_encoder' in k:
state_dict[k.replace('module.audio_encoder.', '')] = v
self.audio_encoder.load_state_dict(state_dict)
def forward(self, audio_sequences):
# audio_sequences = (B, T, 1, 80, 16)
B = audio_sequences.size(0)
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
dim = audio_embedding.shape[1]
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512

@ -0,0 +1,149 @@
import torch
import torch.nn.functional as F
from torch import nn
from src.audio2pose_models.res_unet import ResUnet
def class2onehot(idx, class_num):
assert torch.max(idx).item() < class_num
onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
onehot.scatter_(1, idx, 1)
return onehot
class CVAE(nn.Module):
def __init__(self, cfg):
super().__init__()
encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
latent_size = cfg.MODEL.CVAE.LATENT_SIZE
num_classes = cfg.DATASET.NUM_CLASSES
audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
seq_len = cfg.MODEL.CVAE.SEQ_LEN
self.latent_size = latent_size
self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len)
self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, batch):
batch = self.encoder(batch)
mu = batch['mu']
logvar = batch['logvar']
z = self.reparameterize(mu, logvar)
batch['z'] = z
return self.decoder(batch)
def test(self, batch):
'''
class_id = batch['class']
z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
batch['z'] = z
'''
return self.decoder(batch)
class ENCODER(nn.Module):
def __init__(self, layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len):
super().__init__()
self.resunet = ResUnet()
self.num_classes = num_classes
self.seq_len = seq_len
self.MLP = nn.Sequential()
layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
self.MLP.add_module(
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
def forward(self, batch):
class_id = batch['class']
pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
ref = batch['ref'] #bs 6
bs = pose_motion_gt.shape[0]
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
#pose encode
pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
#audio mapping
print(audio_in.shape)
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
audio_out = audio_out.reshape(bs, -1)
class_bias = self.classbias[class_id] #bs latent_size
x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
x_out = self.MLP(x_in)
mu = self.linear_means(x_out)
logvar = self.linear_means(x_out) #bs latent_size
batch.update({'mu':mu, 'logvar':logvar})
return batch
class DECODER(nn.Module):
def __init__(self, layer_sizes, latent_size, num_classes,
audio_emb_in_size, audio_emb_out_size, seq_len):
super().__init__()
self.resunet = ResUnet()
self.num_classes = num_classes
self.seq_len = seq_len
self.MLP = nn.Sequential()
input_size = latent_size + seq_len*audio_emb_out_size + 6
for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
self.MLP.add_module(
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
if i+1 < len(layer_sizes):
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
else:
self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
self.pose_linear = nn.Linear(6, 6)
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
def forward(self, batch):
z = batch['z'] #bs latent_size
bs = z.shape[0]
class_id = batch['class']
ref = batch['ref'] #bs 6
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
#print('audio_in: ', audio_in[:, :, :10])
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
#print('audio_out: ', audio_out[:, :, :10])
audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
class_bias = self.classbias[class_id] #bs latent_size
z = z + class_bias
x_in = torch.cat([ref, z, audio_out], dim=-1)
x_out = self.MLP(x_in) # bs layer_sizes[-1]
x_out = x_out.reshape((bs, self.seq_len, -1))
#print('x_out: ', x_out)
pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
batch.update({'pose_motion_pred':pose_motion_pred})
return batch

@ -0,0 +1,76 @@
import torch
import torch.nn.functional as F
from torch import nn
class ConvNormRelu(nn.Module):
def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
super().__init__()
if kernel_size is None:
if downsample:
kernel_size, stride, padding = 4, 2, 1
else:
kernel_size, stride, padding = 3, 1, 1
if conv_type == '2d':
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
)
if norm == 'BN':
self.norm = nn.BatchNorm2d(out_channels)
elif norm == 'IN':
self.norm = nn.InstanceNorm2d(out_channels)
else:
raise NotImplementedError
elif conv_type == '1d':
self.conv = nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias=False,
)
if norm == 'BN':
self.norm = nn.BatchNorm1d(out_channels)
elif norm == 'IN':
self.norm = nn.InstanceNorm1d(out_channels)
else:
raise NotImplementedError
nn.init.kaiming_normal_(self.conv.weight)
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
if isinstance(self.norm, nn.InstanceNorm1d):
x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
else:
x = self.norm(x)
x = self.act(x)
return x
class PoseSequenceDiscriminator(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
self.seq = nn.Sequential(
ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
)
def forward(self, x):
x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
x = self.seq(x)
x = x.squeeze(1)
return x

@ -0,0 +1,140 @@
import torch.nn as nn
import torch
class ResidualConv(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding):
super(ResidualConv, self).__init__()
self.conv_block = nn.Sequential(
nn.BatchNorm2d(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
),
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
nn.BatchNorm2d(output_dim),
)
def forward(self, x):
return self.conv_block(x) + self.conv_skip(x)
class Upsample(nn.Module):
def __init__(self, input_dim, output_dim, kernel, stride):
super(Upsample, self).__init__()
self.upsample = nn.ConvTranspose2d(
input_dim, output_dim, kernel_size=kernel, stride=stride
)
def forward(self, x):
return self.upsample(x)
class Squeeze_Excite_Block(nn.Module):
def __init__(self, channel, reduction=16):
super(Squeeze_Excite_Block, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid(),
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class ASPP(nn.Module):
def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
super(ASPP, self).__init__()
self.aspp_block1 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.aspp_block2 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.aspp_block3 = nn.Sequential(
nn.Conv2d(
in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
),
nn.ReLU(inplace=True),
nn.BatchNorm2d(out_dims),
)
self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
self._init_weights()
def forward(self, x):
x1 = self.aspp_block1(x)
x2 = self.aspp_block2(x)
x3 = self.aspp_block3(x)
out = torch.cat([x1, x2, x3], dim=1)
return self.output(out)
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
class Upsample_(nn.Module):
def __init__(self, scale=2):
super(Upsample_, self).__init__()
self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
def forward(self, x):
return self.upsample(x)
class AttentionBlock(nn.Module):
def __init__(self, input_encoder, input_decoder, output_dim):
super(AttentionBlock, self).__init__()
self.conv_encoder = nn.Sequential(
nn.BatchNorm2d(input_encoder),
nn.ReLU(),
nn.Conv2d(input_encoder, output_dim, 3, padding=1),
nn.MaxPool2d(2, 2),
)
self.conv_decoder = nn.Sequential(
nn.BatchNorm2d(input_decoder),
nn.ReLU(),
nn.Conv2d(input_decoder, output_dim, 3, padding=1),
)
self.conv_attn = nn.Sequential(
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, 1, 1),
)
def forward(self, x1, x2):
out = self.conv_encoder(x1) + self.conv_decoder(x2)
out = self.conv_attn(out)
return out * x2

@ -0,0 +1,65 @@
import torch
import torch.nn as nn
from src.audio2pose_models.networks import ResidualConv, Upsample
class ResUnet(nn.Module):
def __init__(self, channel=1, filters=[32, 64, 128, 256]):
super(ResUnet, self).__init__()
self.input_layer = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
nn.BatchNorm2d(filters[0]),
nn.ReLU(),
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
)
self.input_skip = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
)
self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
self.output_layer = nn.Sequential(
nn.Conv2d(filters[0], 1, 1, 1),
nn.Sigmoid(),
)
def forward(self, x):
# Encode
x1 = self.input_layer(x) + self.input_skip(x)
x2 = self.residual_conv_1(x1)
x3 = self.residual_conv_2(x2)
# Bridge
x4 = self.bridge(x3)
# Decode
x4 = self.upsample_1(x4)
x5 = torch.cat([x4, x3], dim=1)
x6 = self.up_residual_conv1(x5)
x6 = self.upsample_2(x6)
x7 = torch.cat([x6, x2], dim=1)
x8 = self.up_residual_conv2(x7)
x8 = self.upsample_3(x8)
x9 = torch.cat([x8, x1], dim=1)
x10 = self.up_residual_conv3(x9)
output = self.output_layer(x10)
return output

@ -0,0 +1,58 @@
DATASET:
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
TRAIN_BATCH_SIZE: 32
EVAL_BATCH_SIZE: 32
EXP: True
EXP_DIM: 64
FRAME_LEN: 32
COEFF_LEN: 73
NUM_CLASSES: 46
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
DEBUG: True
NUM_REPEATS: 2
T: 40
MODEL:
FRAMEWORK: V2
AUDIOENCODER:
LEAKY_RELU: True
NORM: 'IN'
DISCRIMINATOR:
LEAKY_RELU: False
INPUT_CHANNELS: 6
CVAE:
AUDIO_EMB_IN_SIZE: 512
AUDIO_EMB_OUT_SIZE: 128
SEQ_LEN: 32
LATENT_SIZE: 256
ENCODER_LAYER_SIZES: [192, 1024]
DECODER_LAYER_SIZES: [1024, 192]
TRAIN:
MAX_EPOCH: 300
GENERATOR:
LR: 2.0e-5
DISCRIMINATOR:
LR: 1.0e-5
LOSS:
W_FEAT: 0
W_COEFF_EXP: 2
W_LM: 1.0e-2
W_LM_MOUTH: 0
W_REG: 0
W_SYNC: 0
W_COLOR: 0
W_EXPRESSION: 0
W_LIPREADING: 0.01
W_LIPREADING_VV: 0
W_EYE_BLINK: 4
TAG:
NAME: small_dataset

@ -0,0 +1,49 @@
DATASET:
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
TRAIN_BATCH_SIZE: 64
EVAL_BATCH_SIZE: 1
EXP: True
EXP_DIM: 64
FRAME_LEN: 32
COEFF_LEN: 73
NUM_CLASSES: 46
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
DEBUG: True
MODEL:
AUDIOENCODER:
LEAKY_RELU: True
NORM: 'IN'
DISCRIMINATOR:
LEAKY_RELU: False
INPUT_CHANNELS: 6
CVAE:
AUDIO_EMB_IN_SIZE: 512
AUDIO_EMB_OUT_SIZE: 6
SEQ_LEN: 32
LATENT_SIZE: 64
ENCODER_LAYER_SIZES: [192, 128]
DECODER_LAYER_SIZES: [128, 192]
TRAIN:
MAX_EPOCH: 150
GENERATOR:
LR: 1.0e-4
DISCRIMINATOR:
LR: 1.0e-4
LOSS:
LAMBDA_REG: 1
LAMBDA_LANDMARKS: 0
LAMBDA_VERTICES: 0
LAMBDA_GAN_MOTION: 0.7
LAMBDA_GAN_COEFF: 0
LAMBDA_KL: 1
TAG:
NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder

@ -0,0 +1,45 @@
model_params:
common_params:
num_kp: 15
image_channel: 3
feature_channel: 32
estimate_jacobian: False # True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25 # 0.25
num_blocks: 5
reshape_channel: 16384 # 16384 = 1024 * 16
reshape_depth: 16
he_estimator_params:
block_expansion: 64
max_features: 2048
num_bins: 66
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
reshape_channel: 32
reshape_depth: 16 # 512 = 32 * 16
num_resblocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
reshape_depth: 16
compress: 4
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
mapping_params:
coeff_nc: 70
descriptor_nc: 1024
layer: 3
num_kp: 15
num_bins: 66

@ -0,0 +1,45 @@
model_params:
common_params:
num_kp: 15
image_channel: 3
feature_channel: 32
estimate_jacobian: False # True
kp_detector_params:
temperature: 0.1
block_expansion: 32
max_features: 1024
scale_factor: 0.25 # 0.25
num_blocks: 5
reshape_channel: 16384 # 16384 = 1024 * 16
reshape_depth: 16
he_estimator_params:
block_expansion: 64
max_features: 2048
num_bins: 66
generator_params:
block_expansion: 64
max_features: 512
num_down_blocks: 2
reshape_channel: 32
reshape_depth: 16 # 512 = 32 * 16
num_resblocks: 6
estimate_occlusion_map: True
dense_motion_params:
block_expansion: 32
max_features: 1024
num_blocks: 5
reshape_depth: 16
compress: 4
discriminator_params:
scales: [1]
block_expansion: 32
max_features: 512
num_blocks: 4
sn: True
mapping_params:
coeff_nc: 73
descriptor_nc: 1024
layer: 3
num_kp: 15
num_bins: 66

@ -0,0 +1,116 @@
"""This package includes all the modules related to data loading and preprocessing
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
You need to implement four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point from data loader.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
See our template dataset class 'template_dataset.py' for more details.
"""
import numpy as np
import importlib
import torch.utils.data
from face3d.data.base_dataset import BaseDataset
def find_dataset_using_name(dataset_name):
"""Import the module "data/[dataset_name]_dataset.py".
In the file, the class called DatasetNameDataset() will
be instantiated. It has to be a subclass of BaseDataset,
and it is case-insensitive.
"""
dataset_filename = "data." + dataset_name + "_dataset"
datasetlib = importlib.import_module(dataset_filename)
dataset = None
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
for name, cls in datasetlib.__dict__.items():
if name.lower() == target_dataset_name.lower() \
and issubclass(cls, BaseDataset):
dataset = cls
if dataset is None:
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
return dataset
def get_option_setter(dataset_name):
"""Return the static method <modify_commandline_options> of the dataset class."""
dataset_class = find_dataset_using_name(dataset_name)
return dataset_class.modify_commandline_options
def create_dataset(opt, rank=0):
"""Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)
"""
data_loader = CustomDatasetDataLoader(opt, rank=rank)
dataset = data_loader.load_data()
return dataset
class CustomDatasetDataLoader():
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
def __init__(self, opt, rank=0):
"""Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.
"""
self.opt = opt
dataset_class = find_dataset_using_name(opt.dataset_mode)
self.dataset = dataset_class(opt)
self.sampler = None
print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
if opt.use_ddp and opt.isTrain:
world_size = opt.world_size
self.sampler = torch.utils.data.distributed.DistributedSampler(
self.dataset,
num_replicas=world_size,
rank=rank,
shuffle=not opt.serial_batches
)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
sampler=self.sampler,
num_workers=int(opt.num_threads / world_size),
batch_size=int(opt.batch_size / world_size),
drop_last=True)
else:
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batch_size,
shuffle=(not opt.serial_batches) and opt.isTrain,
num_workers=int(opt.num_threads),
drop_last=True
)
def set_epoch(self, epoch):
self.dataset.current_epoch = epoch
if self.sampler is not None:
self.sampler.set_epoch(epoch)
def load_data(self):
return self
def __len__(self):
"""Return the number of data in the dataset"""
return min(len(self.dataset), self.opt.max_dataset_size)
def __iter__(self):
"""Return a batch of data"""
for i, data in enumerate(self.dataloader):
if i * self.opt.batch_size >= self.opt.max_dataset_size:
break
yield data

@ -0,0 +1,125 @@
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
"""
import random
import numpy as np
import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
from abc import ABC, abstractmethod
class BaseDataset(data.Dataset, ABC):
"""This class is an abstract base class (ABC) for datasets.
To create a subclass, you need to implement the following four functions:
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
-- <__len__>: return the size of dataset.
-- <__getitem__>: get a data point.
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
"""
def __init__(self, opt):
"""Initialize the class; save the options in the class
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
self.opt = opt
# self.root = opt.dataroot
self.current_epoch = 0
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def __len__(self):
"""Return the total number of images in the dataset."""
return 0
@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass
def get_transform(grayscale=False):
transform_list = []
if grayscale:
transform_list.append(transforms.Grayscale(1))
transform_list += [transforms.ToTensor()]
return transforms.Compose(transform_list)
def get_affine_mat(opt, size):
shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
w, h = size
if 'shift' in opt.preprocess:
shift_pixs = int(opt.shift_pixs)
shift_x = random.randint(-shift_pixs, shift_pixs)
shift_y = random.randint(-shift_pixs, shift_pixs)
if 'scale' in opt.preprocess:
scale = 1 + opt.scale_delta * (2 * random.random() - 1)
if 'rot' in opt.preprocess:
rot_angle = opt.rot_angle * (2 * random.random() - 1)
rot_rad = -rot_angle * np.pi/180
if 'flip' in opt.preprocess:
flip = random.random() > 0.5
shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
affine_inv = np.linalg.inv(affine)
return affine, affine_inv, flip
def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
def apply_lm_affine(landmark, affine, flip, size):
_, h = size
lm = landmark.copy()
lm[:, 1] = h - 1 - lm[:, 1]
lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
lm = lm @ np.transpose(affine)
lm[:, :2] = lm[:, :2] / lm[:, 2:]
lm = lm[:, :2]
lm[:, 1] = h - 1 - lm[:, 1]
if flip:
lm_ = lm.copy()
lm_[:17] = lm[16::-1]
lm_[17:22] = lm[26:21:-1]
lm_[22:27] = lm[21:16:-1]
lm_[31:36] = lm[35:30:-1]
lm_[36:40] = lm[45:41:-1]
lm_[40:42] = lm[47:45:-1]
lm_[42:46] = lm[39:35:-1]
lm_[46:48] = lm[41:39:-1]
lm_[48:55] = lm[54:47:-1]
lm_[55:60] = lm[59:54:-1]
lm_[60:65] = lm[64:59:-1]
lm_[65:68] = lm[67:64:-1]
lm = lm_
return lm

@ -0,0 +1,125 @@
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
"""
import os.path
from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
from data.image_folder import make_dataset
from PIL import Image
import random
import util.util as util
import numpy as np
import json
import torch
from scipy.io import loadmat, savemat
import pickle
from util.preprocess import align_img, estimate_norm
from util.load_mats import load_lm3d
def default_flist_reader(flist):
"""
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
"""
imlist = []
with open(flist, 'r') as rf:
for line in rf.readlines():
impath = line.strip()
imlist.append(impath)
return imlist
def jason_flist_reader(flist):
with open(flist, 'r') as fp:
info = json.load(fp)
return info
def parse_label(label):
return torch.tensor(np.array(label).astype(np.float32))
class FlistDataset(BaseDataset):
"""
It requires one directories to host training images '/path/to/data/train'
You can train the model with the dataset flag '--dataroot /path/to/data'.
"""
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
"""
BaseDataset.__init__(self, opt)
self.lm3d_std = load_lm3d(opt.bfm_folder)
msk_names = default_flist_reader(opt.flist)
self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
self.size = len(self.msk_paths)
self.opt = opt
self.name = 'train' if opt.isTrain else 'val'
if '_' in opt.flist:
self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
img (tensor) -- an image in the input domain
msk (tensor) -- its corresponding attention mask
lm (tensor) -- its corresponding 3d landmarks
im_paths (str) -- image paths
aug_flag (bool) -- a flag used to tell whether its raw or augmented
"""
msk_path = self.msk_paths[index % self.size] # make sure index is within then range
img_path = msk_path.replace('mask/', '')
lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
raw_img = Image.open(img_path).convert('RGB')
raw_msk = Image.open(msk_path).convert('RGB')
raw_lm = np.loadtxt(lm_path).astype(np.float32)
_, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
aug_flag = self.opt.use_aug and self.opt.isTrain
if aug_flag:
img, lm, msk = self._augmentation(img, lm, self.opt, msk)
_, H = img.size
M = estimate_norm(lm, H)
transform = get_transform()
img_tensor = transform(img)
msk_tensor = transform(msk)[:1, ...]
lm_tensor = parse_label(lm)
M_tensor = parse_label(M)
return {'imgs': img_tensor,
'lms': lm_tensor,
'msks': msk_tensor,
'M': M_tensor,
'im_paths': img_path,
'aug_flag': aug_flag,
'dataset': self.name}
def _augmentation(self, img, lm, opt, msk=None):
affine, affine_inv, flip = get_affine_mat(opt, img.size)
img = apply_img_affine(img, affine_inv)
lm = apply_lm_affine(lm, affine, flip, img.size)
if msk is not None:
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
return img, lm, msk
def __len__(self):
"""Return the total number of images in the dataset.
"""
return self.size

@ -0,0 +1,66 @@
"""A modified image folder class
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
so that this class can load images from both current directory and its subdirectories.
"""
import numpy as np
import torch.utils.data as data
from PIL import Image
import os
import os.path
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
'.tif', '.TIF', '.tiff', '.TIFF',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
def default_loader(path):
return Image.open(path).convert('RGB')
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, return_paths=False,
loader=default_loader):
imgs = make_dataset(root)
if len(imgs) == 0:
raise(RuntimeError("Found 0 images in: " + root + "\n"
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
self.root = root
self.imgs = imgs
self.transform = transform
self.return_paths = return_paths
self.loader = loader
def __getitem__(self, index):
path = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
if self.return_paths:
return img, path
else:
return img
def __len__(self):
return len(self.imgs)

@ -0,0 +1,75 @@
"""Dataset class template
This module provides a template for users to implement custom datasets.
You can specify '--dataset_mode template' to use this dataset.
The class name should be consistent with both the filename and its dataset_mode option.
The filename should be <dataset_mode>_dataset.py
The class name should be <Dataset_mode>Dataset.py
You need to implement the following functions:
-- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
-- <__init__>: Initialize this dataset class.
-- <__getitem__>: Return a data point and its metadata information.
-- <__len__>: Return the number of images.
"""
from data.base_dataset import BaseDataset, get_transform
# from data.image_folder import make_dataset
# from PIL import Image
class TemplateDataset(BaseDataset):
"""A template dataset class for you to implement custom datasets."""
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
return parser
def __init__(self, opt):
"""Initialize this dataset class.
Parameters:
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
A few things can be done here.
- save the options (have been done in BaseDataset)
- get image paths and meta information of the dataset.
- define the image transformation.
"""
# save the option and dataset root
BaseDataset.__init__(self, opt)
# get the image paths of your dataset;
self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
self.transform = get_transform(opt)
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index -- a random integer for data indexing
Returns:
a dictionary of data with their names. It usually contains the data itself and its metadata information.
Step 1: get a random image path: e.g., path = self.image_paths[index]
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
Step 4: return a data point as a dictionary.
"""
path = 'temp' # needs to be a string
data_A = None # needs to be a tensor
data_B = None # needs to be a tensor
return {'data_A': data_A, 'data_B': data_B, 'path': path}
def __len__(self):
"""Return the total number of images."""
return len(self.image_paths)

@ -0,0 +1,107 @@
import os
import cv2
import time
import glob
import argparse
import face_alignment
import numpy as np
from PIL import Image
from tqdm import tqdm
from itertools import cycle
from torch.multiprocessing import Pool, Process, set_start_method
class KeypointExtractor():
def __init__(self, device):
self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device)
def extract_keypoint(self, images, name=None, info=True):
if isinstance(images, list):
keypoints = []
if info:
i_range = tqdm(images,desc='landmark Det:')
else:
i_range = images
for image in i_range:
current_kp = self.extract_keypoint(image)
if np.mean(current_kp) == -1 and keypoints:
keypoints.append(keypoints[-1])
else:
keypoints.append(current_kp[None])
keypoints = np.concatenate(keypoints, 0)
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
return keypoints
else:
while True:
try:
keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
break
except RuntimeError as e:
if str(e).startswith('CUDA'):
print("Warning: out of memory, sleep for 1s")
time.sleep(1)
else:
print(e)
break
except TypeError:
print('No face detected in this image')
shape = [68, 2]
keypoints = -1. * np.ones(shape)
break
if name is not None:
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
return keypoints
def read_video(filename):
frames = []
cap = cv2.VideoCapture(filename)
while cap.isOpened():
ret, frame = cap.read()
if ret:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame = Image.fromarray(frame)
frames.append(frame)
else:
break
cap.release()
return frames
def run(data):
filename, opt, device = data
os.environ['CUDA_VISIBLE_DEVICES'] = device
kp_extractor = KeypointExtractor()
images = read_video(filename)
name = filename.split('/')[-2:]
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
kp_extractor.extract_keypoint(
images,
name=os.path.join(opt.output_dir, name[-2], name[-1])
)
if __name__ == '__main__':
set_start_method('spawn')
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
parser.add_argument('--device_ids', type=str, default='0,1')
parser.add_argument('--workers', type=int, default=4)
opt = parser.parse_args()
filenames = list()
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
extensions = VIDEO_EXTENSIONS
for ext in extensions:
os.listdir(f'{opt.input_dir}')
print(f'{opt.input_dir}/*.{ext}')
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
print('Total number of videos:', len(filenames))
pool = Pool(opt.workers)
args_list = cycle([opt])
device_ids = opt.device_ids.split(",")
device_ids = cycle(device_ids)
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
None

@ -0,0 +1,67 @@
"""This package contains modules related to objective functions, optimizations, and network architectures.
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
You need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
In the function <__init__>, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
-- self.visual_names (str list): specify the images that you want to display and save.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
Now you can use the model class by specifying flag '--model dummy'.
See our template model class 'template_model.py' for more details.
"""
import importlib
from src.face3d.models.base_model import BaseModel
def find_model_using_name(model_name):
"""Import the module "models/[model_name]_model.py".
In the file, the class called DatasetNameModel() will
be instantiated. It has to be a subclass of BaseModel,
and it is case-insensitive.
"""
model_filename = "face3d.models." + model_name + "_model"
modellib = importlib.import_module(model_filename)
model = None
target_model_name = model_name.replace('_', '') + 'model'
for name, cls in modellib.__dict__.items():
if name.lower() == target_model_name.lower() \
and issubclass(cls, BaseModel):
model = cls
if model is None:
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
exit(0)
return model
def get_option_setter(model_name):
"""Return the static method <modify_commandline_options> of the model class."""
model_class = find_model_using_name(model_name)
return model_class.modify_commandline_options
def create_model(opt):
"""Create a model given the option.
This function warps the class CustomDatasetDataLoader.
This is the main interface between this package and 'train.py'/'test.py'
Example:
>>> from models import create_model
>>> model = create_model(opt)
"""
model = find_model_using_name(opt.model)
instance = model(opt)
print("model [%s] was created" % type(instance).__name__)
return instance

@ -0,0 +1,164 @@
# Distributed Arcface Training in Pytorch
This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
identity on a single server.
## Requirements
- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
- `pip install -r requirements.txt`.
- Download the dataset
from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
.
## How to Training
To train a model, run `train.py` with the path to the configs:
### 1. Single node, 8 GPUs:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
```
### 2. Multiple nodes, each node 8 GPUs:
Node 0:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
```
Node 1:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
```
### 3.Training resnet2060 with 8 GPUs:
```shell
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
```
## Model Zoo
- The models are available for non-commercial research purposes only.
- All models can be found in here.
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
As the result, we can evaluate the FAIR performance for different algorithms.
For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
There are totally 13,928 positive pairs and 96,983,824 negative pairs.
| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
| :---: | :--- | :--- | :--- |:--- |:--- |
| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
### Performance on IJB-C and Verification Datasets
| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
[comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
## [Speed Benchmark](docs/speed_benchmark.md)
**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
accuracy with several times faster training performance and smaller GPU memory.
Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
training and mixed precision training.
![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
More details see
[speed_benchmark.md](docs/speed_benchmark.md) in docs.
### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
`-` means training failed because of gpu memory limitations.
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
| :--- | :--- | :--- | :--- |
|125000 | 4681 | 4824 | 5004 |
|1400000 | **1672** | 3043 | 4738 |
|5500000 | **-** | **1389** | 3975 |
|8000000 | **-** | **-** | 3565 |
|16000000 | **-** | **-** | 2679 |
|29000000 | **-** | **-** | **1855** |
### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
| :--- | :--- | :--- | :--- |
|125000 | 7358 | 5306 | 4868 |
|1400000 | 32252 | 11178 | 6056 |
|5500000 | **-** | 32188 | 9854 |
|8000000 | **-** | **-** | 12310 |
|16000000 | **-** | **-** | 19950 |
|29000000 | **-** | **-** | 32324 |
## Evaluation ICCV2021-MFR and IJB-C
More details see [eval.md](docs/eval.md) in docs.
## Test
We tested many versions of PyTorch. Please create an issue if you are having trouble.
- [x] torch 1.6.0
- [x] torch 1.7.1
- [x] torch 1.8.0
- [x] torch 1.9.0
## Citation
```
@inproceedings{deng2019arcface,
title={Arcface: Additive angular margin loss for deep face recognition},
author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
pages={4690--4699},
year={2019}
}
@inproceedings{an2020partical_fc,
title={Partial FC: Training 10 Million Identities on a Single Machine},
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
Zhang, Debing and Fu Ying},
booktitle={Arxiv 2010.05222},
year={2020}
}
```

@ -0,0 +1,25 @@
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
from .mobilefacenet import get_mbf
def get_model(name, **kwargs):
# resnet
if name == "r18":
return iresnet18(False, **kwargs)
elif name == "r34":
return iresnet34(False, **kwargs)
elif name == "r50":
return iresnet50(False, **kwargs)
elif name == "r100":
return iresnet100(False, **kwargs)
elif name == "r200":
return iresnet200(False, **kwargs)
elif name == "r2060":
from .iresnet2060 import iresnet2060
return iresnet2060(False, **kwargs)
elif name == "mbf":
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf(fp16=fp16, num_features=num_features)
else:
raise ValueError()

@ -0,0 +1,187 @@
import torch
from torch import nn
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class IBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
groups=1, base_width=64, dilation=1):
super(IBasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
self.conv1 = conv3x3(inplanes, planes)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
self.prelu = nn.PReLU(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
class IResNet(nn.Module):
fc_scale = 7 * 7
def __init__(self,
block, layers, dropout=0, num_features=512, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
super(IResNet, self).__init__()
self.fp16 = fp16
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
self.prelu = nn.PReLU(self.inplanes)
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
self.layer2 = self._make_layer(block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
self.dropout = nn.Dropout(p=dropout, inplace=True)
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
nn.init.constant_(self.features.weight, 1.0)
self.features.weight.requires_grad = False
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0, 0.1)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, IBasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation))
return nn.Sequential(*layers)
def forward(self, x):
with torch.cuda.amp.autocast(self.fp16):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn2(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x.float() if self.fp16 else x)
x = self.features(x)
return x
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
model = IResNet(block, layers, **kwargs)
if pretrained:
raise ValueError()
return model
def iresnet18(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
progress, **kwargs)
def iresnet34(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
progress, **kwargs)
def iresnet50(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
progress, **kwargs)
def iresnet100(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
progress, **kwargs)
def iresnet200(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
progress, **kwargs)

@ -0,0 +1,176 @@
import torch
from torch import nn
assert torch.__version__ >= "1.8.1"
from torch.utils.checkpoint import checkpoint_sequential
__all__ = ['iresnet2060']
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=dilation,
groups=groups,
bias=False,
dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class IBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
groups=1, base_width=64, dilation=1):
super(IBasicBlock, self).__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
self.conv1 = conv3x3(inplanes, planes)
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
self.prelu = nn.PReLU(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
class IResNet(nn.Module):
fc_scale = 7 * 7
def __init__(self,
block, layers, dropout=0, num_features=512, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
super(IResNet, self).__init__()
self.fp16 = fp16
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
self.prelu = nn.PReLU(self.inplanes)
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
self.layer2 = self._make_layer(block,
128,
layers[1],
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block,
256,
layers[2],
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block,
512,
layers[3],
stride=2,
dilate=replace_stride_with_dilation[2])
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
self.dropout = nn.Dropout(p=dropout, inplace=True)
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
nn.init.constant_(self.features.weight, 1.0)
self.features.weight.requires_grad = False
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0, 0.1)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, IBasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
)
layers = []
layers.append(
block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation))
return nn.Sequential(*layers)
def checkpoint(self, func, num_seg, x):
if self.training:
return checkpoint_sequential(func, num_seg, x)
else:
return func(x)
def forward(self, x):
with torch.cuda.amp.autocast(self.fp16):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.layer1(x)
x = self.checkpoint(self.layer2, 20, x)
x = self.checkpoint(self.layer3, 100, x)
x = self.layer4(x)
x = self.bn2(x)
x = torch.flatten(x, 1)
x = self.dropout(x)
x = self.fc(x.float() if self.fp16 else x)
x = self.features(x)
return x
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
model = IResNet(block, layers, **kwargs)
if pretrained:
raise ValueError()
return model
def iresnet2060(pretrained=False, progress=True, **kwargs):
return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)

@ -0,0 +1,130 @@
'''
Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
Original author cavalleria
'''
import torch.nn as nn
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
import torch
class Flatten(Module):
def forward(self, x):
return x.view(x.size(0), -1)
class ConvBlock(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(ConvBlock, self).__init__()
self.layers = nn.Sequential(
Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
BatchNorm2d(num_features=out_c),
PReLU(num_parameters=out_c)
)
def forward(self, x):
return self.layers(x)
class LinearBlock(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(LinearBlock, self).__init__()
self.layers = nn.Sequential(
Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
BatchNorm2d(num_features=out_c)
)
def forward(self, x):
return self.layers(x)
class DepthWise(Module):
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
super(DepthWise, self).__init__()
self.residual = residual
self.layers = nn.Sequential(
ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
)
def forward(self, x):
short_cut = None
if self.residual:
short_cut = x
x = self.layers(x)
if self.residual:
output = short_cut + x
else:
output = x
return output
class Residual(Module):
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
super(Residual, self).__init__()
modules = []
for _ in range(num_block):
modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
self.layers = Sequential(*modules)
def forward(self, x):
return self.layers(x)
class GDC(Module):
def __init__(self, embedding_size):
super(GDC, self).__init__()
self.layers = nn.Sequential(
LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
Flatten(),
Linear(512, embedding_size, bias=False),
BatchNorm1d(embedding_size))
def forward(self, x):
return self.layers(x)
class MobileFaceNet(Module):
def __init__(self, fp16=False, num_features=512):
super(MobileFaceNet, self).__init__()
scale = 2
self.fp16 = fp16
self.layers = nn.Sequential(
ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
)
self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
self.features = GDC(num_features)
self._initialize_weights()
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
with torch.cuda.amp.autocast(self.fp16):
x = self.layers(x)
x = self.conv_sep(x.float() if self.fp16 else x)
x = self.features(x)
return x
def get_mbf(fp16, num_features):
return MobileFaceNet(fp16, num_features)

@ -0,0 +1,23 @@
from easydict import EasyDict as edict
# configs for test speed
config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 1.0
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "synthetic"
config.num_classes = 300 * 10000
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = []

@ -0,0 +1,23 @@
from easydict import EasyDict as edict
# configs for test speed
config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 0.1
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "synthetic"
config.num_classes = 300 * 10000
config.num_epoch = 30
config.warmup_epoch = -1
config.decay_epoch = [10, 16, 22]
config.val_targets = []

@ -0,0 +1,56 @@
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "arcface"
config.network = "r50"
config.resume = False
config.output = "ms1mv3_arcface_r50"
config.dataset = "ms1m-retinaface-t1"
config.embedding_size = 512
config.sample_rate = 1
config.fp16 = False
config.momentum = 0.9
config.weight_decay = 5e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
if config.dataset == "emore":
config.rec = "/train_tmp/faces_emore"
config.num_classes = 85742
config.num_image = 5822653
config.num_epoch = 16
config.warmup_epoch = -1
config.decay_epoch = [8, 14, ]
config.val_targets = ["lfw", ]
elif config.dataset == "ms1m-retinaface-t1":
config.rec = "/train_tmp/ms1m-retinaface-t1"
config.num_classes = 93431
config.num_image = 5179510
config.num_epoch = 25
config.warmup_epoch = -1
config.decay_epoch = [11, 17, 22]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
elif config.dataset == "glint360k":
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
elif config.dataset == "webface":
config.rec = "/train_tmp/faces_webface_112x112"
config.num_classes = 10572
config.num_image = "forget"
config.num_epoch = 34
config.warmup_epoch = -1
config.decay_epoch = [20, 28, 32]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]

@ -0,0 +1,26 @@
from easydict import EasyDict as edict
# make training faster
# our RAM is 256G
# mount -t tmpfs -o size=140G tmpfs /train_tmp
config = edict()
config.loss = "cosface"
config.network = "mbf"
config.resume = False
config.output = None
config.embedding_size = 512
config.sample_rate = 0.1
config.fp16 = True
config.momentum = 0.9
config.weight_decay = 2e-4
config.batch_size = 128
config.lr = 0.1 # batch size is 512
config.rec = "/train_tmp/glint360k"
config.num_classes = 360232
config.num_image = 17091657
config.num_epoch = 20
config.warmup_epoch = -1
config.decay_epoch = [8, 12, 15, 18]
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save