人物图片生成视频
@ -0,0 +1,167 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
examples/results/*
|
||||
gfpgan/*
|
||||
checkpoints/
|
||||
results/*
|
||||
Dockerfile
|
||||
start_docker.sh
|
@ -0,0 +1,8 @@
|
||||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
@ -0,0 +1,56 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="PublishConfigData" remoteFilesAllowedToDisappearOnAutoupload="false">
|
||||
<serverData>
|
||||
<paths name="fanpt@192.168.0.102:22 password">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="fanpt@192.168.0.102:22 password (2)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="fanpt@192.168.0.102:22 password (3)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="fanpt@192.168.0.102:22 password (4)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="fanpt@192.168.0.102:22 password (5)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="fanpt@192.168.0.102:22 password (6)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
<paths name="fanpt@192.168.0.102:22 password (7)">
|
||||
<serverdata>
|
||||
<mappings>
|
||||
<mapping local="$PROJECT_DIR$" web="/" />
|
||||
</mappings>
|
||||
</serverdata>
|
||||
</paths>
|
||||
</serverData>
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,27 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="13">
|
||||
<item index="0" class="java.lang.String" itemvalue="protobuf" />
|
||||
<item index="1" class="java.lang.String" itemvalue="transformers" />
|
||||
<item index="2" class="java.lang.String" itemvalue="tensorboard" />
|
||||
<item index="3" class="java.lang.String" itemvalue="icetk" />
|
||||
<item index="4" class="java.lang.String" itemvalue="cpm_kernels" />
|
||||
<item index="5" class="java.lang.String" itemvalue="peft" />
|
||||
<item index="6" class="java.lang.String" itemvalue="accelerate" />
|
||||
<item index="7" class="java.lang.String" itemvalue="torch" />
|
||||
<item index="8" class="java.lang.String" itemvalue="datasets" />
|
||||
<item index="9" class="java.lang.String" itemvalue="bitsandbytes" />
|
||||
<item index="10" class="java.lang.String" itemvalue="ConcurrentLogHandler" />
|
||||
<item index="11" class="java.lang.String" itemvalue="uwsgi" />
|
||||
<item index="12" class="java.lang.String" itemvalue="ultralytics" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
@ -0,0 +1,4 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Sadtalker" project-jdk-type="Python SDK" />
|
||||
</project>
|
@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/SadTalker.iml" filepath="$PROJECT_DIR$/.idea/SadTalker.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Tencent AI Lab
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -0,0 +1,107 @@
|
||||
import os, sys
|
||||
import tempfile
|
||||
import gradio as gr
|
||||
from src.gradio_demo import SadTalker
|
||||
from src.utils.text2speech import TTSTalker
|
||||
|
||||
def get_source_image(image):
|
||||
return image
|
||||
|
||||
|
||||
|
||||
def sadtalker_demo():
|
||||
|
||||
sad_talker = SadTalker()
|
||||
tts_talker = TTSTalker()
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
|
||||
gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
|
||||
<a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> \
|
||||
<a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a> \
|
||||
<a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>")
|
||||
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Tabs(elem_id="sadtalker_source_image"):
|
||||
with gr.TabItem('Upload image'):
|
||||
with gr.Row():
|
||||
source_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
|
||||
|
||||
with gr.Tabs(elem_id="sadtalker_driven_audio"):
|
||||
with gr.TabItem('Upload OR TTS'):
|
||||
with gr.Column(variant='panel'):
|
||||
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
|
||||
tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
|
||||
tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])
|
||||
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Tabs(elem_id="sadtalker_checkbox"):
|
||||
with gr.TabItem('Settings'):
|
||||
with gr.Column(variant='panel'):
|
||||
is_still_mode = gr.Checkbox(label="w/ Still Mode (fewer hand motion, works on full body)")
|
||||
enhancer = gr.Checkbox(label="w/ GFPGAN as Face enhancer")
|
||||
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
|
||||
|
||||
with gr.Tabs(elem_id="sadtalker_genearted"):
|
||||
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
|
||||
|
||||
|
||||
with gr.Row():
|
||||
examples = [
|
||||
[
|
||||
'examples/source_image/full_body_1.png',
|
||||
'examples/driven_audio/bus_chinese.wav',
|
||||
True,
|
||||
False
|
||||
],
|
||||
[
|
||||
'examples/source_image/full_body_2.png',
|
||||
'examples/driven_audio/itosinger1.wav',
|
||||
True,
|
||||
False
|
||||
],
|
||||
[
|
||||
'examples/source_image/art_13.png',
|
||||
'examples/driven_audio/fayu.wav',
|
||||
True,
|
||||
False
|
||||
],
|
||||
[
|
||||
'examples/source_image/art_5.png',
|
||||
'examples/driven_audio/chinese_news.wav',
|
||||
True,
|
||||
False
|
||||
],
|
||||
]
|
||||
gr.Examples(examples=examples,
|
||||
inputs=[
|
||||
source_image,
|
||||
driven_audio,
|
||||
is_still_mode,
|
||||
enhancer],
|
||||
outputs=[gen_video],
|
||||
fn=sad_talker.test,
|
||||
cache_examples=os.getenv('SYSTEM') == 'spaces')
|
||||
|
||||
submit.click(
|
||||
fn=sad_talker.test,
|
||||
inputs=[source_image,
|
||||
driven_audio,
|
||||
is_still_mode,
|
||||
enhancer],
|
||||
outputs=[gen_video]
|
||||
)
|
||||
|
||||
return sadtalker_interface
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
demo = sadtalker_demo()
|
||||
demo.launch()
|
||||
|
||||
|
@ -0,0 +1,14 @@
|
||||
## changelogs
|
||||
|
||||
- __[2023.03.22]__: Launch new feature: generating the 3d face animation from a single image. New applications about it will be updated.
|
||||
|
||||
- __[2023.03.22]__: Launch new feature: `still mode`, where only a small head pose will be produced via `python inference.py --still`.
|
||||
|
||||
- __[2023.03.18]__: Support `expression intensity`, now you can change the intensity of the generated motion: `python inference.py --expression_scale 1.3 (some value > 1)`.
|
||||
|
||||
- __[2023.03.18]__: Reconfig the data folders, now you can download the checkpoint automatically using `bash scripts/download_models.sh`.
|
||||
- __[2023.03.18]__: We have offically integrate the [GFPGAN](https://github.com/TencentARC/GFPGAN) for face enhancement, using `python inference.py --enhancer gfpgan` for better visualization performance.
|
||||
- __[2023.03.14]__: Specify the version of package `joblib` to remove the errors in using `librosa`, [](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb) is online!
|
||||
- __[2023.03.06]__: Solve some bugs in code and errors in installation
|
||||
- __[2023.03.03]__: Release the test code for audio-driven single image animation!
|
||||
- __[2023.02.28]__: SadTalker has been accepted by CVPR 2023!
|
@ -0,0 +1,48 @@
|
||||
## 3D Face visualization
|
||||
|
||||
We use pytorch3d to visualize the produced 3d face from a single image.
|
||||
|
||||
Since it is not easy to install, we produce a new install guidence here:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Winfredy/SadTalker.git
|
||||
cd SadTalker
|
||||
conda create -n sadtalker3d python=3.8
|
||||
source activate sadtalker3d
|
||||
|
||||
conda install ffmpeg
|
||||
conda install -c fvcore -c iopath -c conda-forge fvcore iopath
|
||||
conda install libgcc gmp
|
||||
|
||||
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
|
||||
|
||||
# insintall pytorch3d
|
||||
pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu113_pyt1110/download.html
|
||||
|
||||
pip install -r requirements3d.txt
|
||||
|
||||
### install gpfgan for enhancer
|
||||
pip install git+https://github.com/TencentARC/GFPGAN
|
||||
|
||||
|
||||
### when occurs gcc version problem `from pytorch import _C` from pytorch3d, add the anaconda path to LD_LIBRARY_PATH
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/$YOUR_ANACONDA_PATH/lib/
|
||||
|
||||
```
|
||||
|
||||
|
||||
Then, generating the result via:
|
||||
|
||||
```bash
|
||||
|
||||
|
||||
python inference.py --driven_audio <audio.wav> \
|
||||
--source_image <video.mp4 or picture.png> \
|
||||
--result_dir <a file to store results> \
|
||||
--face3dvis
|
||||
|
||||
```
|
||||
|
||||
Then, the results will be given in the folders with the file name of `face3d.mp4`.
|
||||
|
||||
More applications about 3d face will be released.
|
After Width: | Height: | Size: 5.4 MiB |
@ -0,0 +1,25 @@
|
||||
|
||||
|
||||
|
||||
### Windows Native
|
||||
|
||||
- Make sure you have `ffmpeg` in the `%PATH%` as suggested in [#54](https://github.com/Winfredy/SadTalker/issues/54), following [this](https://www.geeksforgeeks.org/how-to-install-ffmpeg-on-windows/) installation to install `ffmpeg`.
|
||||
|
||||
|
||||
### Windows WSL
|
||||
- Make sure the environment: `export LD_LIBRARY_PATH=/usr/lib/wsl/lib:$LD_LIBRARY_PATH`
|
||||
|
||||
|
||||
### Docker installnation
|
||||
|
||||
A dockerfile are also provided by [@thegenerativegeneration](https://github.com/thegenerativegeneration) in [docker hub](https://hub.docker.com/repository/docker/wawa9000/sadtalker), which can be used directly as:
|
||||
|
||||
```bash
|
||||
docker run --gpus "all" --rm -v $(pwd):/host_dir wawa9000/sadtalker \
|
||||
--driven_audio /host_dir/deyu.wav \
|
||||
--source_image /host_dir/image.jpg \
|
||||
--expression_scale 1.0 \
|
||||
--still \
|
||||
--result_dir /host_dir
|
||||
```
|
||||
|
After Width: | Height: | Size: 34 KiB |
After Width: | Height: | Size: 7.7 MiB |
After Width: | Height: | Size: 733 KiB |
After Width: | Height: | Size: 478 KiB |
After Width: | Height: | Size: 556 KiB |
After Width: | Height: | Size: 478 KiB |
After Width: | Height: | Size: 704 KiB |
After Width: | Height: | Size: 617 KiB |
After Width: | Height: | Size: 635 KiB |
After Width: | Height: | Size: 657 KiB |
After Width: | Height: | Size: 1.4 MiB |
After Width: | Height: | Size: 2.0 MiB |
After Width: | Height: | Size: 115 KiB |
After Width: | Height: | Size: 462 KiB |
After Width: | Height: | Size: 812 KiB |
After Width: | Height: | Size: 694 KiB |
After Width: | Height: | Size: 1.3 MiB |
After Width: | Height: | Size: 3.5 MiB |
After Width: | Height: | Size: 1.2 MiB |
After Width: | Height: | Size: 98 KiB |
After Width: | Height: | Size: 509 KiB |
After Width: | Height: | Size: 3.0 MiB |
After Width: | Height: | Size: 1.2 MiB |
After Width: | Height: | Size: 617 KiB |
After Width: | Height: | Size: 26 KiB |
After Width: | Height: | Size: 122 KiB |
After Width: | Height: | Size: 134 KiB |
After Width: | Height: | Size: 108 KiB |
After Width: | Height: | Size: 43 KiB |
After Width: | Height: | Size: 238 KiB |
After Width: | Height: | Size: 108 KiB |
After Width: | Height: | Size: 44 KiB |
@ -0,0 +1,159 @@
|
||||
import torch
|
||||
from time import strftime
|
||||
import os, sys, time
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from src.utils.preprocess import CropAndExtract
|
||||
from src.test_audio2coeff import Audio2Coeff
|
||||
from src.facerender.animate import AnimateFromCoeff
|
||||
from src.generate_batch import get_data
|
||||
from src.generate_facerender_batch import get_facerender_data
|
||||
|
||||
def main(args):
|
||||
#torch.backends.cudnn.enabled = False
|
||||
|
||||
pic_path = args.source_image
|
||||
audio_path = args.driven_audio
|
||||
save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
pose_style = args.pose_style
|
||||
device = args.device
|
||||
batch_size = args.batch_size
|
||||
input_yaw_list = args.input_yaw
|
||||
input_pitch_list = args.input_pitch
|
||||
input_roll_list = args.input_roll
|
||||
ref_eyeblink = args.ref_eyeblink
|
||||
ref_pose = args.ref_pose
|
||||
|
||||
current_code_path = sys.argv[0]
|
||||
current_root_path = os.path.split(current_code_path)[0]
|
||||
|
||||
os.environ['TORCH_HOME']=os.path.join(current_root_path, args.checkpoint_dir)
|
||||
|
||||
path_of_lm_croper = os.path.join(current_root_path, args.checkpoint_dir, 'shape_predictor_68_face_landmarks.dat')
|
||||
path_of_net_recon_model = os.path.join(current_root_path, args.checkpoint_dir, 'epoch_20.pth')
|
||||
dir_of_BFM_fitting = os.path.join(current_root_path, args.checkpoint_dir, 'BFM_Fitting')
|
||||
wav2lip_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'wav2lip.pth')
|
||||
|
||||
audio2pose_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2pose_00140-model.pth')
|
||||
audio2pose_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2pose.yaml')
|
||||
|
||||
audio2exp_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'auido2exp_00300-model.pth')
|
||||
audio2exp_yaml_path = os.path.join(current_root_path, 'src', 'config', 'auido2exp.yaml')
|
||||
|
||||
free_view_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'facevid2vid_00189-model.pth.tar')
|
||||
|
||||
if args.preprocess == 'full':
|
||||
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00109-model.pth.tar')
|
||||
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender_still.yaml')
|
||||
else:
|
||||
mapping_checkpoint = os.path.join(current_root_path, args.checkpoint_dir, 'mapping_00229-model.pth.tar')
|
||||
facerender_yaml_path = os.path.join(current_root_path, 'src', 'config', 'facerender.yaml')
|
||||
|
||||
#init model
|
||||
print(path_of_net_recon_model)
|
||||
preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)
|
||||
|
||||
print(audio2pose_checkpoint)
|
||||
print(audio2exp_checkpoint)
|
||||
audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
|
||||
audio2exp_checkpoint, audio2exp_yaml_path,
|
||||
wav2lip_checkpoint, device)
|
||||
|
||||
print(free_view_checkpoint)
|
||||
print(mapping_checkpoint)
|
||||
animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
|
||||
facerender_yaml_path, device)
|
||||
|
||||
#crop image and extract 3dmm from image
|
||||
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
|
||||
os.makedirs(first_frame_dir, exist_ok=True)
|
||||
print('3DMM Extraction for source image')
|
||||
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess)
|
||||
if first_coeff_path is None:
|
||||
print("Can't get the coeffs of the input")
|
||||
return
|
||||
|
||||
if ref_eyeblink is not None:
|
||||
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
|
||||
ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
|
||||
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
|
||||
print('3DMM Extraction for the reference video providing eye blinking')
|
||||
ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir)
|
||||
else:
|
||||
ref_eyeblink_coeff_path=None
|
||||
|
||||
if ref_pose is not None:
|
||||
if ref_pose == ref_eyeblink:
|
||||
ref_pose_coeff_path = ref_eyeblink_coeff_path
|
||||
else:
|
||||
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
|
||||
ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
|
||||
os.makedirs(ref_pose_frame_dir, exist_ok=True)
|
||||
print('3DMM Extraction for the reference video providing pose')
|
||||
ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir)
|
||||
else:
|
||||
ref_pose_coeff_path=None
|
||||
|
||||
#audio2ceoff
|
||||
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
|
||||
coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
|
||||
|
||||
# 3dface render
|
||||
if args.face3dvis:
|
||||
from src.face3d.visualize import gen_composed_video
|
||||
gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
|
||||
|
||||
#coeff2video
|
||||
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
|
||||
batch_size, input_yaw_list, input_pitch_list, input_roll_list,
|
||||
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess)
|
||||
|
||||
animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
|
||||
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--driven_audio", default='./examples/driven_audio/eluosi.wav', help="path to driven audio")
|
||||
parser.add_argument("--source_image", default='./examples/source_image/full3.png', help="path to source image")
|
||||
parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
|
||||
parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
|
||||
parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
|
||||
parser.add_argument("--result_dir", default='./results', help="path to output")
|
||||
parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
|
||||
parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
|
||||
parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
|
||||
parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
|
||||
parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
|
||||
parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
|
||||
parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
|
||||
parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
|
||||
parser.add_argument("--cpu", dest="cpu", action="store_true")
|
||||
parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
|
||||
parser.add_argument("--still", action="store_true", help="can crop back to the orginal videos for the full body aniamtion")
|
||||
parser.add_argument("--preprocess", default='crop', choices=['crop', 'resize', 'full'], help="how to preprocess the images" )
|
||||
|
||||
# net structure and parameters
|
||||
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
|
||||
parser.add_argument('--init_path', type=str, default=None, help='Useless')
|
||||
parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
|
||||
parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
|
||||
parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
|
||||
|
||||
# default renderer parameters
|
||||
parser.add_argument('--focal', type=float, default=1015.)
|
||||
parser.add_argument('--center', type=float, default=112.)
|
||||
parser.add_argument('--camera_d', type=float, default=10.)
|
||||
parser.add_argument('--z_near', type=float, default=5.)
|
||||
parser.add_argument('--z_far', type=float, default=15.)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if torch.cuda.is_available() and not args.cpu:
|
||||
args.device = "cuda"
|
||||
else:
|
||||
args.device = "cpu"
|
||||
|
||||
main(args)
|
||||
|
@ -0,0 +1,20 @@
|
||||
numpy==1.23.4
|
||||
face_alignment==1.3.5
|
||||
imageio==2.19.3
|
||||
imageio-ffmpeg==0.4.7
|
||||
librosa==0.9.2 #
|
||||
numba
|
||||
resampy==0.3.1
|
||||
pydub==0.25.1
|
||||
scipy==1.5.3
|
||||
kornia==0.6.8
|
||||
tqdm
|
||||
yacs==0.1.8
|
||||
pyyaml
|
||||
joblib==1.1.0
|
||||
scikit-image==0.19.3
|
||||
basicsr==1.4.2
|
||||
facexlib==0.2.5
|
||||
gradio
|
||||
gfpgan
|
||||
dlib-bin
|
@ -0,0 +1,21 @@
|
||||
numpy==1.23.4
|
||||
face_alignment==1.3.5
|
||||
imageio==2.19.3
|
||||
imageio-ffmpeg==0.4.7
|
||||
librosa==0.9.2 #
|
||||
numba
|
||||
resampy==0.3.1
|
||||
pydub==0.25.1
|
||||
scipy==1.5.3
|
||||
kornia==0.6.8
|
||||
tqdm
|
||||
yacs==0.1.8
|
||||
pyyaml
|
||||
joblib==1.1.0
|
||||
scikit-image==0.19.3
|
||||
basicsr==1.4.2
|
||||
facexlib==0.2.5
|
||||
trimesh==3.9.20
|
||||
dlib-bin
|
||||
gradio
|
||||
gfpgan
|
@ -0,0 +1,14 @@
|
||||
mkdir ./checkpoints
|
||||
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
|
||||
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
|
||||
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
|
||||
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
|
||||
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
|
||||
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
|
||||
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
|
||||
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
|
||||
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
|
||||
wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
|
||||
|
||||
unzip -n ./checkpoints/hub.zip -d ./checkpoints/
|
||||
unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/
|
@ -0,0 +1,133 @@
|
||||
import os, sys
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import gradio as gr
|
||||
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
|
||||
from modules.shared import opts, OptionInfo
|
||||
from modules import shared, paths, script_callbacks
|
||||
import launch
|
||||
import glob
|
||||
|
||||
def get_source_image(image):
|
||||
return image
|
||||
|
||||
def get_img_from_txt2img(x):
|
||||
talker_path = Path(paths.script_path) / "outputs"
|
||||
imgs_from_txt_dir = str(talker_path / "txt2img-images/")
|
||||
imgs = glob.glob(imgs_from_txt_dir+'/*/*.png')
|
||||
imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_txt_dir, x)))
|
||||
img_from_txt_path = os.path.join(imgs_from_txt_dir, imgs[-1])
|
||||
return img_from_txt_path, img_from_txt_path
|
||||
|
||||
def get_img_from_img2img(x):
|
||||
talker_path = Path(paths.script_path) / "outputs"
|
||||
imgs_from_img_dir = str(talker_path / "img2img-images/")
|
||||
imgs = glob.glob(imgs_from_img_dir+'/*/*.png')
|
||||
imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_img_dir, x)))
|
||||
img_from_img_path = os.path.join(imgs_from_img_dir, imgs[-1])
|
||||
return img_from_img_path, img_from_img_path
|
||||
|
||||
def install():
|
||||
|
||||
kv = {
|
||||
"face-alignment": "face-alignment==1.3.5",
|
||||
"imageio": "imageio==2.19.3",
|
||||
"imageio-ffmpeg": "imageio-ffmpeg==0.4.7",
|
||||
"librosa":"librosa==0.8.0",
|
||||
"pydub":"pydub==0.25.1",
|
||||
"scipy":"scipy==1.8.1",
|
||||
"tqdm": "tqdm",
|
||||
"yacs":"yacs==0.1.8",
|
||||
"pyyaml": "pyyaml",
|
||||
"dlib": "dlib-bin",
|
||||
"gfpgan": "gfpgan",
|
||||
}
|
||||
|
||||
for k,v in kv.items():
|
||||
print(k, launch.is_installed(k))
|
||||
if not launch.is_installed(k):
|
||||
launch.run_pip("install "+ v, "requirements for SadTalker")
|
||||
|
||||
|
||||
if os.getenv('SADTALKER_CHECKPOINTS'):
|
||||
print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS'))
|
||||
else:
|
||||
### run the scripts to downlod models to correct localtion.
|
||||
print('download models for SadTalker')
|
||||
launch.run("cd " + paths.script_path+"/extensions/SadTalker && bash ./scripts/download_models.sh", live=True)
|
||||
print('SadTalker is successfully installed!')
|
||||
|
||||
|
||||
def on_ui_tabs():
|
||||
install()
|
||||
|
||||
sys.path.extend([paths.script_path+'/extensions/SadTalker'])
|
||||
|
||||
repo_dir = paths.script_path+'/extensions/SadTalker/'
|
||||
|
||||
result_dir = opts.sadtalker_result_dir
|
||||
os.makedirs(result_dir, exist_ok=True)
|
||||
|
||||
from src.gradio_demo import SadTalker
|
||||
|
||||
if os.getenv('SADTALKER_CHECKPOINTS'):
|
||||
checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS')
|
||||
else:
|
||||
checkpoint_path = repo_dir+'checkpoints/'
|
||||
|
||||
sad_talker = SadTalker(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', lazy_load=True)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False) as audio_to_video:
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Tabs(elem_id="sadtalker_source_image"):
|
||||
with gr.TabItem('Upload image'):
|
||||
with gr.Row():
|
||||
input_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=512,width=512)
|
||||
|
||||
with gr.Row():
|
||||
submit_image2 = gr.Button('load From txt2img', variant='primary')
|
||||
submit_image2.click(fn=get_img_from_txt2img, inputs=input_image, outputs=[input_image, input_image])
|
||||
|
||||
submit_image3 = gr.Button('load from img2img', variant='primary')
|
||||
submit_image3.click(fn=get_img_from_img2img, inputs=input_image, outputs=[input_image, input_image])
|
||||
|
||||
with gr.Tabs(elem_id="sadtalker_driven_audio"):
|
||||
with gr.TabItem('Upload'):
|
||||
with gr.Column(variant='panel'):
|
||||
|
||||
with gr.Row():
|
||||
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
|
||||
|
||||
|
||||
with gr.Column(variant='panel'):
|
||||
with gr.Tabs(elem_id="sadtalker_checkbox"):
|
||||
with gr.TabItem('Settings'):
|
||||
with gr.Column(variant='panel'):
|
||||
is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion)").style(container=True)
|
||||
is_enhance_mode = gr.Checkbox(label="Enhance Mode (better face quality )").style(container=True)
|
||||
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
|
||||
|
||||
with gr.Tabs(elem_id="sadtalker_genearted"):
|
||||
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
|
||||
|
||||
|
||||
### gradio gpu call will always return the html,
|
||||
submit.click(
|
||||
fn=wrap_queued_call(sad_talker.test),
|
||||
inputs=[input_image,
|
||||
driven_audio,
|
||||
is_still_mode,
|
||||
is_enhance_mode],
|
||||
outputs=[gen_video, ]
|
||||
)
|
||||
|
||||
return [(audio_to_video, "SadTalker", "extension")]
|
||||
|
||||
def on_ui_settings():
|
||||
talker_path = Path(paths.script_path) / "outputs"
|
||||
section = ('extension', "SadTalker")
|
||||
opts.add_option("sadtalker_result_dir", OptionInfo(str(talker_path / "SadTalker/"), "Path to save results of sadtalker", section=section))
|
||||
|
||||
script_callbacks.on_ui_settings(on_ui_settings)
|
||||
script_callbacks.on_ui_tabs(on_ui_tabs)
|
@ -0,0 +1,41 @@
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Audio2Exp(nn.Module):
|
||||
def __init__(self, netG, cfg, device, prepare_training_loss=False):
|
||||
super(Audio2Exp, self).__init__()
|
||||
self.cfg = cfg
|
||||
self.device = device
|
||||
self.netG = netG.to(device)
|
||||
|
||||
def test(self, batch):
|
||||
|
||||
mel_input = batch['indiv_mels'] # bs T 1 80 16
|
||||
bs = mel_input.shape[0]
|
||||
T = mel_input.shape[1]
|
||||
|
||||
exp_coeff_pred = []
|
||||
|
||||
for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
|
||||
|
||||
current_mel_input = mel_input[:,i:i+10]
|
||||
|
||||
#ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
|
||||
ref = batch['ref'][:, :, :64][:, i:i+10]
|
||||
ratio = batch['ratio_gt'][:, i:i+10] #bs T
|
||||
|
||||
audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
|
||||
|
||||
curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
|
||||
|
||||
exp_coeff_pred += [curr_exp_coeff_pred]
|
||||
|
||||
# BS x T x 64
|
||||
results_dict = {
|
||||
'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
|
||||
}
|
||||
return results_dict
|
||||
|
||||
|
@ -0,0 +1,74 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||
nn.BatchNorm2d(cout)
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
self.residual = residual
|
||||
self.use_act = use_act
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
if self.residual:
|
||||
out += x
|
||||
|
||||
if self.use_act:
|
||||
return self.act(out)
|
||||
else:
|
||||
return out
|
||||
|
||||
class SimpleWrapperV2(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.audio_encoder = nn.Sequential(
|
||||
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
||||
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
|
||||
)
|
||||
|
||||
#### load the pre-trained audio_encoder
|
||||
#self.audio_encoder = self.audio_encoder.to(device)
|
||||
'''
|
||||
wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
|
||||
state_dict = self.audio_encoder.state_dict()
|
||||
|
||||
for k,v in wav2lip_state_dict.items():
|
||||
if 'audio_encoder' in k:
|
||||
print('init:', k)
|
||||
state_dict[k.replace('module.audio_encoder.', '')] = v
|
||||
self.audio_encoder.load_state_dict(state_dict)
|
||||
'''
|
||||
|
||||
self.mapping1 = nn.Linear(512+64+1, 64)
|
||||
#self.mapping2 = nn.Linear(30, 64)
|
||||
#nn.init.constant_(self.mapping1.weight, 0.)
|
||||
nn.init.constant_(self.mapping1.bias, 0.)
|
||||
|
||||
def forward(self, x, ref, ratio):
|
||||
x = self.audio_encoder(x).view(x.size(0), -1)
|
||||
ref_reshape = ref.reshape(x.size(0), -1)
|
||||
ratio = ratio.reshape(x.size(0), -1)
|
||||
|
||||
y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
|
||||
out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
|
||||
return out
|
@ -0,0 +1,94 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from src.audio2pose_models.cvae import CVAE
|
||||
from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
|
||||
from src.audio2pose_models.audio_encoder import AudioEncoder
|
||||
|
||||
class Audio2Pose(nn.Module):
|
||||
def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
|
||||
self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
|
||||
self.device = device
|
||||
|
||||
self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
|
||||
self.audio_encoder.eval()
|
||||
for param in self.audio_encoder.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.netG = CVAE(cfg)
|
||||
self.netD_motion = PoseSequenceDiscriminator(cfg)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
batch = {}
|
||||
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
|
||||
batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6
|
||||
batch['ref'] = coeff_gt[:, 0, -9:-3] #bs 6
|
||||
batch['class'] = x['class'].squeeze(0).cuda() # bs
|
||||
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
|
||||
|
||||
# forward
|
||||
audio_emb_list = []
|
||||
audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
|
||||
batch['audio_emb'] = audio_emb
|
||||
batch = self.netG(batch)
|
||||
|
||||
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
|
||||
pose_gt = coeff_gt[:, 1:, -9:-3].clone() # bs frame_len 6
|
||||
pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred # bs frame_len 6
|
||||
|
||||
batch['pose_pred'] = pose_pred
|
||||
batch['pose_gt'] = pose_gt
|
||||
|
||||
return batch
|
||||
|
||||
def test(self, x):
|
||||
|
||||
batch = {}
|
||||
ref = x['ref'] #bs 1 70
|
||||
batch['ref'] = x['ref'][:,0,-6:]
|
||||
batch['class'] = x['class']
|
||||
bs = ref.shape[0]
|
||||
|
||||
indiv_mels= x['indiv_mels'] # bs T 1 80 16
|
||||
indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
|
||||
num_frames = x['num_frames']
|
||||
num_frames = int(num_frames) - 1
|
||||
|
||||
#
|
||||
div = num_frames//self.seq_len
|
||||
re = num_frames%self.seq_len
|
||||
audio_emb_list = []
|
||||
pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
|
||||
device=batch['ref'].device)]
|
||||
|
||||
for i in range(div):
|
||||
z = torch.randn(bs, self.latent_dim).to(ref.device)
|
||||
batch['z'] = z
|
||||
audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
|
||||
batch['audio_emb'] = audio_emb
|
||||
batch = self.netG.test(batch)
|
||||
pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
|
||||
|
||||
if re != 0:
|
||||
z = torch.randn(bs, self.latent_dim).to(ref.device)
|
||||
batch['z'] = z
|
||||
audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
|
||||
if audio_emb.shape[1] != self.seq_len:
|
||||
pad_dim = self.seq_len-audio_emb.shape[1]
|
||||
pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
|
||||
audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
|
||||
batch['audio_emb'] = audio_emb
|
||||
batch = self.netG.test(batch)
|
||||
pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
|
||||
|
||||
pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
|
||||
batch['pose_motion_pred'] = pose_motion_pred
|
||||
|
||||
pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
|
||||
|
||||
batch['pose_pred'] = pose_pred
|
||||
return batch
|
@ -0,0 +1,64 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||
nn.BatchNorm2d(cout)
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
if self.residual:
|
||||
out += x
|
||||
return self.act(out)
|
||||
|
||||
class AudioEncoder(nn.Module):
|
||||
def __init__(self, wav2lip_checkpoint, device):
|
||||
super(AudioEncoder, self).__init__()
|
||||
|
||||
self.audio_encoder = nn.Sequential(
|
||||
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
||||
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
||||
|
||||
#### load the pre-trained audio_encoder
|
||||
wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
|
||||
state_dict = self.audio_encoder.state_dict()
|
||||
|
||||
for k,v in wav2lip_state_dict.items():
|
||||
if 'audio_encoder' in k:
|
||||
state_dict[k.replace('module.audio_encoder.', '')] = v
|
||||
self.audio_encoder.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def forward(self, audio_sequences):
|
||||
# audio_sequences = (B, T, 1, 80, 16)
|
||||
B = audio_sequences.size(0)
|
||||
|
||||
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
||||
|
||||
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
|
||||
dim = audio_embedding.shape[1]
|
||||
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
|
||||
|
||||
return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
|
@ -0,0 +1,149 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from src.audio2pose_models.res_unet import ResUnet
|
||||
|
||||
def class2onehot(idx, class_num):
|
||||
|
||||
assert torch.max(idx).item() < class_num
|
||||
onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
|
||||
onehot.scatter_(1, idx, 1)
|
||||
return onehot
|
||||
|
||||
class CVAE(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
|
||||
decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
|
||||
latent_size = cfg.MODEL.CVAE.LATENT_SIZE
|
||||
num_classes = cfg.DATASET.NUM_CLASSES
|
||||
audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
|
||||
audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
|
||||
seq_len = cfg.MODEL.CVAE.SEQ_LEN
|
||||
|
||||
self.latent_size = latent_size
|
||||
|
||||
self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
|
||||
audio_emb_in_size, audio_emb_out_size, seq_len)
|
||||
self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
|
||||
audio_emb_in_size, audio_emb_out_size, seq_len)
|
||||
def reparameterize(self, mu, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
eps = torch.randn_like(std)
|
||||
return mu + eps * std
|
||||
|
||||
def forward(self, batch):
|
||||
batch = self.encoder(batch)
|
||||
mu = batch['mu']
|
||||
logvar = batch['logvar']
|
||||
z = self.reparameterize(mu, logvar)
|
||||
batch['z'] = z
|
||||
return self.decoder(batch)
|
||||
|
||||
def test(self, batch):
|
||||
'''
|
||||
class_id = batch['class']
|
||||
z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
|
||||
batch['z'] = z
|
||||
'''
|
||||
return self.decoder(batch)
|
||||
|
||||
class ENCODER(nn.Module):
|
||||
def __init__(self, layer_sizes, latent_size, num_classes,
|
||||
audio_emb_in_size, audio_emb_out_size, seq_len):
|
||||
super().__init__()
|
||||
|
||||
self.resunet = ResUnet()
|
||||
self.num_classes = num_classes
|
||||
self.seq_len = seq_len
|
||||
|
||||
self.MLP = nn.Sequential()
|
||||
layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
|
||||
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
|
||||
self.MLP.add_module(
|
||||
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
|
||||
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
|
||||
|
||||
self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
|
||||
self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
|
||||
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
|
||||
|
||||
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
|
||||
|
||||
def forward(self, batch):
|
||||
class_id = batch['class']
|
||||
pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
|
||||
ref = batch['ref'] #bs 6
|
||||
bs = pose_motion_gt.shape[0]
|
||||
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
|
||||
|
||||
#pose encode
|
||||
pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
|
||||
pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
|
||||
|
||||
#audio mapping
|
||||
print(audio_in.shape)
|
||||
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
|
||||
audio_out = audio_out.reshape(bs, -1)
|
||||
|
||||
class_bias = self.classbias[class_id] #bs latent_size
|
||||
x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
|
||||
x_out = self.MLP(x_in)
|
||||
|
||||
mu = self.linear_means(x_out)
|
||||
logvar = self.linear_means(x_out) #bs latent_size
|
||||
|
||||
batch.update({'mu':mu, 'logvar':logvar})
|
||||
return batch
|
||||
|
||||
class DECODER(nn.Module):
|
||||
def __init__(self, layer_sizes, latent_size, num_classes,
|
||||
audio_emb_in_size, audio_emb_out_size, seq_len):
|
||||
super().__init__()
|
||||
|
||||
self.resunet = ResUnet()
|
||||
self.num_classes = num_classes
|
||||
self.seq_len = seq_len
|
||||
|
||||
self.MLP = nn.Sequential()
|
||||
input_size = latent_size + seq_len*audio_emb_out_size + 6
|
||||
for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
|
||||
self.MLP.add_module(
|
||||
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
|
||||
if i+1 < len(layer_sizes):
|
||||
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
|
||||
else:
|
||||
self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
|
||||
|
||||
self.pose_linear = nn.Linear(6, 6)
|
||||
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
|
||||
|
||||
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
|
||||
|
||||
def forward(self, batch):
|
||||
|
||||
z = batch['z'] #bs latent_size
|
||||
bs = z.shape[0]
|
||||
class_id = batch['class']
|
||||
ref = batch['ref'] #bs 6
|
||||
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
|
||||
#print('audio_in: ', audio_in[:, :, :10])
|
||||
|
||||
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
|
||||
#print('audio_out: ', audio_out[:, :, :10])
|
||||
audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
|
||||
class_bias = self.classbias[class_id] #bs latent_size
|
||||
|
||||
z = z + class_bias
|
||||
x_in = torch.cat([ref, z, audio_out], dim=-1)
|
||||
x_out = self.MLP(x_in) # bs layer_sizes[-1]
|
||||
x_out = x_out.reshape((bs, self.seq_len, -1))
|
||||
|
||||
#print('x_out: ', x_out)
|
||||
|
||||
pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
|
||||
|
||||
pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
|
||||
|
||||
batch.update({'pose_motion_pred':pose_motion_pred})
|
||||
return batch
|
@ -0,0 +1,76 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
class ConvNormRelu(nn.Module):
|
||||
def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
|
||||
kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
|
||||
super().__init__()
|
||||
if kernel_size is None:
|
||||
if downsample:
|
||||
kernel_size, stride, padding = 4, 2, 1
|
||||
else:
|
||||
kernel_size, stride, padding = 3, 1, 1
|
||||
|
||||
if conv_type == '2d':
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
bias=False,
|
||||
)
|
||||
if norm == 'BN':
|
||||
self.norm = nn.BatchNorm2d(out_channels)
|
||||
elif norm == 'IN':
|
||||
self.norm = nn.InstanceNorm2d(out_channels)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
elif conv_type == '1d':
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
bias=False,
|
||||
)
|
||||
if norm == 'BN':
|
||||
self.norm = nn.BatchNorm1d(out_channels)
|
||||
elif norm == 'IN':
|
||||
self.norm = nn.InstanceNorm1d(out_channels)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
nn.init.kaiming_normal_(self.conv.weight)
|
||||
|
||||
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
if isinstance(self.norm, nn.InstanceNorm1d):
|
||||
x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
|
||||
else:
|
||||
x = self.norm(x)
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class PoseSequenceDiscriminator(nn.Module):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
|
||||
|
||||
self.seq = nn.Sequential(
|
||||
ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
|
||||
ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
|
||||
ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
|
||||
nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
|
||||
x = self.seq(x)
|
||||
x = x.squeeze(1)
|
||||
return x
|
@ -0,0 +1,140 @@
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
|
||||
|
||||
class ResidualConv(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, stride, padding):
|
||||
super(ResidualConv, self).__init__()
|
||||
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.BatchNorm2d(input_dim),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
|
||||
),
|
||||
nn.BatchNorm2d(output_dim),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
|
||||
)
|
||||
self.conv_skip = nn.Sequential(
|
||||
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
|
||||
nn.BatchNorm2d(output_dim),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return self.conv_block(x) + self.conv_skip(x)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, kernel, stride):
|
||||
super(Upsample, self).__init__()
|
||||
|
||||
self.upsample = nn.ConvTranspose2d(
|
||||
input_dim, output_dim, kernel_size=kernel, stride=stride
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.upsample(x)
|
||||
|
||||
|
||||
class Squeeze_Excite_Block(nn.Module):
|
||||
def __init__(self, channel, reduction=16):
|
||||
super(Squeeze_Excite_Block, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(channel, channel // reduction, bias=False),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Linear(channel // reduction, channel, bias=False),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, _, _ = x.size()
|
||||
y = self.avg_pool(x).view(b, c)
|
||||
y = self.fc(y).view(b, c, 1, 1)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
class ASPP(nn.Module):
|
||||
def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
|
||||
super(ASPP, self).__init__()
|
||||
|
||||
self.aspp_block1 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm2d(out_dims),
|
||||
)
|
||||
self.aspp_block2 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm2d(out_dims),
|
||||
)
|
||||
self.aspp_block3 = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
|
||||
),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.BatchNorm2d(out_dims),
|
||||
)
|
||||
|
||||
self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
|
||||
self._init_weights()
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.aspp_block1(x)
|
||||
x2 = self.aspp_block2(x)
|
||||
x3 = self.aspp_block3(x)
|
||||
out = torch.cat([x1, x2, x3], dim=1)
|
||||
return self.output(out)
|
||||
|
||||
def _init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight)
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
class Upsample_(nn.Module):
|
||||
def __init__(self, scale=2):
|
||||
super(Upsample_, self).__init__()
|
||||
|
||||
self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
|
||||
|
||||
def forward(self, x):
|
||||
return self.upsample(x)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
def __init__(self, input_encoder, input_decoder, output_dim):
|
||||
super(AttentionBlock, self).__init__()
|
||||
|
||||
self.conv_encoder = nn.Sequential(
|
||||
nn.BatchNorm2d(input_encoder),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(input_encoder, output_dim, 3, padding=1),
|
||||
nn.MaxPool2d(2, 2),
|
||||
)
|
||||
|
||||
self.conv_decoder = nn.Sequential(
|
||||
nn.BatchNorm2d(input_decoder),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(input_decoder, output_dim, 3, padding=1),
|
||||
)
|
||||
|
||||
self.conv_attn = nn.Sequential(
|
||||
nn.BatchNorm2d(output_dim),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(output_dim, 1, 1),
|
||||
)
|
||||
|
||||
def forward(self, x1, x2):
|
||||
out = self.conv_encoder(x1) + self.conv_decoder(x2)
|
||||
out = self.conv_attn(out)
|
||||
return out * x2
|
@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from src.audio2pose_models.networks import ResidualConv, Upsample
|
||||
|
||||
|
||||
class ResUnet(nn.Module):
|
||||
def __init__(self, channel=1, filters=[32, 64, 128, 256]):
|
||||
super(ResUnet, self).__init__()
|
||||
|
||||
self.input_layer = nn.Sequential(
|
||||
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(filters[0]),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
|
||||
)
|
||||
self.input_skip = nn.Sequential(
|
||||
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
|
||||
)
|
||||
|
||||
self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
|
||||
self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
|
||||
|
||||
self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
|
||||
|
||||
self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
|
||||
self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
|
||||
|
||||
self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
|
||||
self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
|
||||
|
||||
self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
|
||||
self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
|
||||
|
||||
self.output_layer = nn.Sequential(
|
||||
nn.Conv2d(filters[0], 1, 1, 1),
|
||||
nn.Sigmoid(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# Encode
|
||||
x1 = self.input_layer(x) + self.input_skip(x)
|
||||
x2 = self.residual_conv_1(x1)
|
||||
x3 = self.residual_conv_2(x2)
|
||||
# Bridge
|
||||
x4 = self.bridge(x3)
|
||||
|
||||
# Decode
|
||||
x4 = self.upsample_1(x4)
|
||||
x5 = torch.cat([x4, x3], dim=1)
|
||||
|
||||
x6 = self.up_residual_conv1(x5)
|
||||
|
||||
x6 = self.upsample_2(x6)
|
||||
x7 = torch.cat([x6, x2], dim=1)
|
||||
|
||||
x8 = self.up_residual_conv2(x7)
|
||||
|
||||
x8 = self.upsample_3(x8)
|
||||
x9 = torch.cat([x8, x1], dim=1)
|
||||
|
||||
x10 = self.up_residual_conv3(x9)
|
||||
|
||||
output = self.output_layer(x10)
|
||||
|
||||
return output
|
@ -0,0 +1,58 @@
|
||||
DATASET:
|
||||
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
|
||||
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
|
||||
TRAIN_BATCH_SIZE: 32
|
||||
EVAL_BATCH_SIZE: 32
|
||||
EXP: True
|
||||
EXP_DIM: 64
|
||||
FRAME_LEN: 32
|
||||
COEFF_LEN: 73
|
||||
NUM_CLASSES: 46
|
||||
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
|
||||
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
|
||||
LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
|
||||
DEBUG: True
|
||||
NUM_REPEATS: 2
|
||||
T: 40
|
||||
|
||||
|
||||
MODEL:
|
||||
FRAMEWORK: V2
|
||||
AUDIOENCODER:
|
||||
LEAKY_RELU: True
|
||||
NORM: 'IN'
|
||||
DISCRIMINATOR:
|
||||
LEAKY_RELU: False
|
||||
INPUT_CHANNELS: 6
|
||||
CVAE:
|
||||
AUDIO_EMB_IN_SIZE: 512
|
||||
AUDIO_EMB_OUT_SIZE: 128
|
||||
SEQ_LEN: 32
|
||||
LATENT_SIZE: 256
|
||||
ENCODER_LAYER_SIZES: [192, 1024]
|
||||
DECODER_LAYER_SIZES: [1024, 192]
|
||||
|
||||
|
||||
TRAIN:
|
||||
MAX_EPOCH: 300
|
||||
GENERATOR:
|
||||
LR: 2.0e-5
|
||||
DISCRIMINATOR:
|
||||
LR: 1.0e-5
|
||||
LOSS:
|
||||
W_FEAT: 0
|
||||
W_COEFF_EXP: 2
|
||||
W_LM: 1.0e-2
|
||||
W_LM_MOUTH: 0
|
||||
W_REG: 0
|
||||
W_SYNC: 0
|
||||
W_COLOR: 0
|
||||
W_EXPRESSION: 0
|
||||
W_LIPREADING: 0.01
|
||||
W_LIPREADING_VV: 0
|
||||
W_EYE_BLINK: 4
|
||||
|
||||
TAG:
|
||||
NAME: small_dataset
|
||||
|
||||
|
@ -0,0 +1,49 @@
|
||||
DATASET:
|
||||
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
|
||||
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
|
||||
TRAIN_BATCH_SIZE: 64
|
||||
EVAL_BATCH_SIZE: 1
|
||||
EXP: True
|
||||
EXP_DIM: 64
|
||||
FRAME_LEN: 32
|
||||
COEFF_LEN: 73
|
||||
NUM_CLASSES: 46
|
||||
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
|
||||
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
|
||||
DEBUG: True
|
||||
|
||||
|
||||
MODEL:
|
||||
AUDIOENCODER:
|
||||
LEAKY_RELU: True
|
||||
NORM: 'IN'
|
||||
DISCRIMINATOR:
|
||||
LEAKY_RELU: False
|
||||
INPUT_CHANNELS: 6
|
||||
CVAE:
|
||||
AUDIO_EMB_IN_SIZE: 512
|
||||
AUDIO_EMB_OUT_SIZE: 6
|
||||
SEQ_LEN: 32
|
||||
LATENT_SIZE: 64
|
||||
ENCODER_LAYER_SIZES: [192, 128]
|
||||
DECODER_LAYER_SIZES: [128, 192]
|
||||
|
||||
|
||||
TRAIN:
|
||||
MAX_EPOCH: 150
|
||||
GENERATOR:
|
||||
LR: 1.0e-4
|
||||
DISCRIMINATOR:
|
||||
LR: 1.0e-4
|
||||
LOSS:
|
||||
LAMBDA_REG: 1
|
||||
LAMBDA_LANDMARKS: 0
|
||||
LAMBDA_VERTICES: 0
|
||||
LAMBDA_GAN_MOTION: 0.7
|
||||
LAMBDA_GAN_COEFF: 0
|
||||
LAMBDA_KL: 1
|
||||
|
||||
TAG:
|
||||
NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
|
||||
|
||||
|
@ -0,0 +1,45 @@
|
||||
model_params:
|
||||
common_params:
|
||||
num_kp: 15
|
||||
image_channel: 3
|
||||
feature_channel: 32
|
||||
estimate_jacobian: False # True
|
||||
kp_detector_params:
|
||||
temperature: 0.1
|
||||
block_expansion: 32
|
||||
max_features: 1024
|
||||
scale_factor: 0.25 # 0.25
|
||||
num_blocks: 5
|
||||
reshape_channel: 16384 # 16384 = 1024 * 16
|
||||
reshape_depth: 16
|
||||
he_estimator_params:
|
||||
block_expansion: 64
|
||||
max_features: 2048
|
||||
num_bins: 66
|
||||
generator_params:
|
||||
block_expansion: 64
|
||||
max_features: 512
|
||||
num_down_blocks: 2
|
||||
reshape_channel: 32
|
||||
reshape_depth: 16 # 512 = 32 * 16
|
||||
num_resblocks: 6
|
||||
estimate_occlusion_map: True
|
||||
dense_motion_params:
|
||||
block_expansion: 32
|
||||
max_features: 1024
|
||||
num_blocks: 5
|
||||
reshape_depth: 16
|
||||
compress: 4
|
||||
discriminator_params:
|
||||
scales: [1]
|
||||
block_expansion: 32
|
||||
max_features: 512
|
||||
num_blocks: 4
|
||||
sn: True
|
||||
mapping_params:
|
||||
coeff_nc: 70
|
||||
descriptor_nc: 1024
|
||||
layer: 3
|
||||
num_kp: 15
|
||||
num_bins: 66
|
||||
|
@ -0,0 +1,45 @@
|
||||
model_params:
|
||||
common_params:
|
||||
num_kp: 15
|
||||
image_channel: 3
|
||||
feature_channel: 32
|
||||
estimate_jacobian: False # True
|
||||
kp_detector_params:
|
||||
temperature: 0.1
|
||||
block_expansion: 32
|
||||
max_features: 1024
|
||||
scale_factor: 0.25 # 0.25
|
||||
num_blocks: 5
|
||||
reshape_channel: 16384 # 16384 = 1024 * 16
|
||||
reshape_depth: 16
|
||||
he_estimator_params:
|
||||
block_expansion: 64
|
||||
max_features: 2048
|
||||
num_bins: 66
|
||||
generator_params:
|
||||
block_expansion: 64
|
||||
max_features: 512
|
||||
num_down_blocks: 2
|
||||
reshape_channel: 32
|
||||
reshape_depth: 16 # 512 = 32 * 16
|
||||
num_resblocks: 6
|
||||
estimate_occlusion_map: True
|
||||
dense_motion_params:
|
||||
block_expansion: 32
|
||||
max_features: 1024
|
||||
num_blocks: 5
|
||||
reshape_depth: 16
|
||||
compress: 4
|
||||
discriminator_params:
|
||||
scales: [1]
|
||||
block_expansion: 32
|
||||
max_features: 512
|
||||
num_blocks: 4
|
||||
sn: True
|
||||
mapping_params:
|
||||
coeff_nc: 73
|
||||
descriptor_nc: 1024
|
||||
layer: 3
|
||||
num_kp: 15
|
||||
num_bins: 66
|
||||
|
@ -0,0 +1,116 @@
|
||||
"""This package includes all the modules related to data loading and preprocessing
|
||||
|
||||
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
||||
You need to implement four functions:
|
||||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
||||
-- <__len__>: return the size of dataset.
|
||||
-- <__getitem__>: get a data point from data loader.
|
||||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
||||
|
||||
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
||||
See our template dataset class 'template_dataset.py' for more details.
|
||||
"""
|
||||
import numpy as np
|
||||
import importlib
|
||||
import torch.utils.data
|
||||
from face3d.data.base_dataset import BaseDataset
|
||||
|
||||
|
||||
def find_dataset_using_name(dataset_name):
|
||||
"""Import the module "data/[dataset_name]_dataset.py".
|
||||
|
||||
In the file, the class called DatasetNameDataset() will
|
||||
be instantiated. It has to be a subclass of BaseDataset,
|
||||
and it is case-insensitive.
|
||||
"""
|
||||
dataset_filename = "data." + dataset_name + "_dataset"
|
||||
datasetlib = importlib.import_module(dataset_filename)
|
||||
|
||||
dataset = None
|
||||
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
||||
for name, cls in datasetlib.__dict__.items():
|
||||
if name.lower() == target_dataset_name.lower() \
|
||||
and issubclass(cls, BaseDataset):
|
||||
dataset = cls
|
||||
|
||||
if dataset is None:
|
||||
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def get_option_setter(dataset_name):
|
||||
"""Return the static method <modify_commandline_options> of the dataset class."""
|
||||
dataset_class = find_dataset_using_name(dataset_name)
|
||||
return dataset_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_dataset(opt, rank=0):
|
||||
"""Create a dataset given the option.
|
||||
|
||||
This function wraps the class CustomDatasetDataLoader.
|
||||
This is the main interface between this package and 'train.py'/'test.py'
|
||||
|
||||
Example:
|
||||
>>> from data import create_dataset
|
||||
>>> dataset = create_dataset(opt)
|
||||
"""
|
||||
data_loader = CustomDatasetDataLoader(opt, rank=rank)
|
||||
dataset = data_loader.load_data()
|
||||
return dataset
|
||||
|
||||
class CustomDatasetDataLoader():
|
||||
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
||||
|
||||
def __init__(self, opt, rank=0):
|
||||
"""Initialize this class
|
||||
|
||||
Step 1: create a dataset instance given the name [dataset_mode]
|
||||
Step 2: create a multi-threaded data loader.
|
||||
"""
|
||||
self.opt = opt
|
||||
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
||||
self.dataset = dataset_class(opt)
|
||||
self.sampler = None
|
||||
print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
|
||||
if opt.use_ddp and opt.isTrain:
|
||||
world_size = opt.world_size
|
||||
self.sampler = torch.utils.data.distributed.DistributedSampler(
|
||||
self.dataset,
|
||||
num_replicas=world_size,
|
||||
rank=rank,
|
||||
shuffle=not opt.serial_batches
|
||||
)
|
||||
self.dataloader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
sampler=self.sampler,
|
||||
num_workers=int(opt.num_threads / world_size),
|
||||
batch_size=int(opt.batch_size / world_size),
|
||||
drop_last=True)
|
||||
else:
|
||||
self.dataloader = torch.utils.data.DataLoader(
|
||||
self.dataset,
|
||||
batch_size=opt.batch_size,
|
||||
shuffle=(not opt.serial_batches) and opt.isTrain,
|
||||
num_workers=int(opt.num_threads),
|
||||
drop_last=True
|
||||
)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.dataset.current_epoch = epoch
|
||||
if self.sampler is not None:
|
||||
self.sampler.set_epoch(epoch)
|
||||
|
||||
def load_data(self):
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of data in the dataset"""
|
||||
return min(len(self.dataset), self.opt.max_dataset_size)
|
||||
|
||||
def __iter__(self):
|
||||
"""Return a batch of data"""
|
||||
for i, data in enumerate(self.dataloader):
|
||||
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
||||
break
|
||||
yield data
|
@ -0,0 +1,125 @@
|
||||
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
||||
|
||||
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
||||
"""
|
||||
import random
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
from PIL import Image
|
||||
import torchvision.transforms as transforms
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseDataset(data.Dataset, ABC):
|
||||
"""This class is an abstract base class (ABC) for datasets.
|
||||
|
||||
To create a subclass, you need to implement the following four functions:
|
||||
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
||||
-- <__len__>: return the size of dataset.
|
||||
-- <__getitem__>: get a data point.
|
||||
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize the class; save the options in the class
|
||||
|
||||
Parameters:
|
||||
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
self.opt = opt
|
||||
# self.root = opt.dataroot
|
||||
self.current_epoch = 0
|
||||
|
||||
@staticmethod
|
||||
def modify_commandline_options(parser, is_train):
|
||||
"""Add new dataset-specific options, and rewrite default values for existing options.
|
||||
|
||||
Parameters:
|
||||
parser -- original option parser
|
||||
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
||||
|
||||
Returns:
|
||||
the modified parser.
|
||||
"""
|
||||
return parser
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset."""
|
||||
return 0
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index - - a random integer for data indexing
|
||||
|
||||
Returns:
|
||||
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def get_transform(grayscale=False):
|
||||
transform_list = []
|
||||
if grayscale:
|
||||
transform_list.append(transforms.Grayscale(1))
|
||||
transform_list += [transforms.ToTensor()]
|
||||
return transforms.Compose(transform_list)
|
||||
|
||||
def get_affine_mat(opt, size):
|
||||
shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
|
||||
w, h = size
|
||||
|
||||
if 'shift' in opt.preprocess:
|
||||
shift_pixs = int(opt.shift_pixs)
|
||||
shift_x = random.randint(-shift_pixs, shift_pixs)
|
||||
shift_y = random.randint(-shift_pixs, shift_pixs)
|
||||
if 'scale' in opt.preprocess:
|
||||
scale = 1 + opt.scale_delta * (2 * random.random() - 1)
|
||||
if 'rot' in opt.preprocess:
|
||||
rot_angle = opt.rot_angle * (2 * random.random() - 1)
|
||||
rot_rad = -rot_angle * np.pi/180
|
||||
if 'flip' in opt.preprocess:
|
||||
flip = random.random() > 0.5
|
||||
|
||||
shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
|
||||
flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
|
||||
shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
|
||||
rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
|
||||
scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
|
||||
shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
|
||||
|
||||
affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
|
||||
affine_inv = np.linalg.inv(affine)
|
||||
return affine, affine_inv, flip
|
||||
|
||||
def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
|
||||
return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
|
||||
|
||||
def apply_lm_affine(landmark, affine, flip, size):
|
||||
_, h = size
|
||||
lm = landmark.copy()
|
||||
lm[:, 1] = h - 1 - lm[:, 1]
|
||||
lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
|
||||
lm = lm @ np.transpose(affine)
|
||||
lm[:, :2] = lm[:, :2] / lm[:, 2:]
|
||||
lm = lm[:, :2]
|
||||
lm[:, 1] = h - 1 - lm[:, 1]
|
||||
if flip:
|
||||
lm_ = lm.copy()
|
||||
lm_[:17] = lm[16::-1]
|
||||
lm_[17:22] = lm[26:21:-1]
|
||||
lm_[22:27] = lm[21:16:-1]
|
||||
lm_[31:36] = lm[35:30:-1]
|
||||
lm_[36:40] = lm[45:41:-1]
|
||||
lm_[40:42] = lm[47:45:-1]
|
||||
lm_[42:46] = lm[39:35:-1]
|
||||
lm_[46:48] = lm[41:39:-1]
|
||||
lm_[48:55] = lm[54:47:-1]
|
||||
lm_[55:60] = lm[59:54:-1]
|
||||
lm_[60:65] = lm[64:59:-1]
|
||||
lm_[65:68] = lm[67:64:-1]
|
||||
lm = lm_
|
||||
return lm
|
@ -0,0 +1,125 @@
|
||||
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
|
||||
"""
|
||||
|
||||
import os.path
|
||||
from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
|
||||
from data.image_folder import make_dataset
|
||||
from PIL import Image
|
||||
import random
|
||||
import util.util as util
|
||||
import numpy as np
|
||||
import json
|
||||
import torch
|
||||
from scipy.io import loadmat, savemat
|
||||
import pickle
|
||||
from util.preprocess import align_img, estimate_norm
|
||||
from util.load_mats import load_lm3d
|
||||
|
||||
|
||||
def default_flist_reader(flist):
|
||||
"""
|
||||
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
|
||||
"""
|
||||
imlist = []
|
||||
with open(flist, 'r') as rf:
|
||||
for line in rf.readlines():
|
||||
impath = line.strip()
|
||||
imlist.append(impath)
|
||||
|
||||
return imlist
|
||||
|
||||
def jason_flist_reader(flist):
|
||||
with open(flist, 'r') as fp:
|
||||
info = json.load(fp)
|
||||
return info
|
||||
|
||||
def parse_label(label):
|
||||
return torch.tensor(np.array(label).astype(np.float32))
|
||||
|
||||
|
||||
class FlistDataset(BaseDataset):
|
||||
"""
|
||||
It requires one directories to host training images '/path/to/data/train'
|
||||
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
||||
"""
|
||||
|
||||
def __init__(self, opt):
|
||||
"""Initialize this dataset class.
|
||||
|
||||
Parameters:
|
||||
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
||||
"""
|
||||
BaseDataset.__init__(self, opt)
|
||||
|
||||
self.lm3d_std = load_lm3d(opt.bfm_folder)
|
||||
|
||||
msk_names = default_flist_reader(opt.flist)
|
||||
self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
|
||||
|
||||
self.size = len(self.msk_paths)
|
||||
self.opt = opt
|
||||
|
||||
self.name = 'train' if opt.isTrain else 'val'
|
||||
if '_' in opt.flist:
|
||||
self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""Return a data point and its metadata information.
|
||||
|
||||
Parameters:
|
||||
index (int) -- a random integer for data indexing
|
||||
|
||||
Returns a dictionary that contains A, B, A_paths and B_paths
|
||||
img (tensor) -- an image in the input domain
|
||||
msk (tensor) -- its corresponding attention mask
|
||||
lm (tensor) -- its corresponding 3d landmarks
|
||||
im_paths (str) -- image paths
|
||||
aug_flag (bool) -- a flag used to tell whether its raw or augmented
|
||||
"""
|
||||
msk_path = self.msk_paths[index % self.size] # make sure index is within then range
|
||||
img_path = msk_path.replace('mask/', '')
|
||||
lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
|
||||
|
||||
raw_img = Image.open(img_path).convert('RGB')
|
||||
raw_msk = Image.open(msk_path).convert('RGB')
|
||||
raw_lm = np.loadtxt(lm_path).astype(np.float32)
|
||||
|
||||
_, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
|
||||
|
||||
aug_flag = self.opt.use_aug and self.opt.isTrain
|
||||
if aug_flag:
|
||||
img, lm, msk = self._augmentation(img, lm, self.opt, msk)
|
||||
|
||||
_, H = img.size
|
||||
M = estimate_norm(lm, H)
|
||||
transform = get_transform()
|
||||
img_tensor = transform(img)
|
||||
msk_tensor = transform(msk)[:1, ...]
|
||||
lm_tensor = parse_label(lm)
|
||||
M_tensor = parse_label(M)
|
||||
|
||||
|
||||
return {'imgs': img_tensor,
|
||||
'lms': lm_tensor,
|
||||
'msks': msk_tensor,
|
||||
'M': M_tensor,
|
||||
'im_paths': img_path,
|
||||
'aug_flag': aug_flag,
|
||||
'dataset': self.name}
|
||||
|
||||
def _augmentation(self, img, lm, opt, msk=None):
|
||||
affine, affine_inv, flip = get_affine_mat(opt, img.size)
|
||||
img = apply_img_affine(img, affine_inv)
|
||||
lm = apply_lm_affine(lm, affine, flip, img.size)
|
||||
if msk is not None:
|
||||
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
|
||||
return img, lm, msk
|
||||
|
||||
|
||||
|
||||
|
||||
def __len__(self):
|
||||
"""Return the total number of images in the dataset.
|
||||
"""
|
||||
return self.size
|
@ -0,0 +1,66 @@
|
||||
"""A modified image folder class
|
||||
|
||||
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
||||
so that this class can load images from both current directory and its subdirectories.
|
||||
"""
|
||||
import numpy as np
|
||||
import torch.utils.data as data
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import os.path
|
||||
|
||||
IMG_EXTENSIONS = [
|
||||
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
||||
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
||||
'.tif', '.TIF', '.tiff', '.TIFF',
|
||||
]
|
||||
|
||||
|
||||
def is_image_file(filename):
|
||||
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
||||
|
||||
|
||||
def make_dataset(dir, max_dataset_size=float("inf")):
|
||||
images = []
|
||||
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
|
||||
|
||||
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
||||
for fname in fnames:
|
||||
if is_image_file(fname):
|
||||
path = os.path.join(root, fname)
|
||||
images.append(path)
|
||||
return images[:min(max_dataset_size, len(images))]
|
||||
|
||||
|
||||
def default_loader(path):
|
||||
return Image.open(path).convert('RGB')
|
||||
|
||||
|
||||
class ImageFolder(data.Dataset):
|
||||
|
||||
def __init__(self, root, transform=None, return_paths=False,
|
||||
loader=default_loader):
|
||||
imgs = make_dataset(root)
|
||||
if len(imgs) == 0:
|
||||
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
||||
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
||||
|
||||
self.root = root
|
||||
self.imgs = imgs
|
||||
self.transform = transform
|
||||
self.return_paths = return_paths
|
||||
self.loader = loader
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.imgs[index]
|
||||
img = self.loader(path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
if self.return_paths:
|
||||
return img, path
|
||||
else:
|
||||
return img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
@ -0,0 +1,107 @@
|
||||
import os
|
||||
import cv2
|
||||
import time
|
||||
import glob
|
||||
import argparse
|
||||
import face_alignment
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from itertools import cycle
|
||||
|
||||
from torch.multiprocessing import Pool, Process, set_start_method
|
||||
|
||||
class KeypointExtractor():
|
||||
def __init__(self, device):
|
||||
self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, device=device)
|
||||
|
||||
def extract_keypoint(self, images, name=None, info=True):
|
||||
if isinstance(images, list):
|
||||
keypoints = []
|
||||
if info:
|
||||
i_range = tqdm(images,desc='landmark Det:')
|
||||
else:
|
||||
i_range = images
|
||||
|
||||
for image in i_range:
|
||||
current_kp = self.extract_keypoint(image)
|
||||
if np.mean(current_kp) == -1 and keypoints:
|
||||
keypoints.append(keypoints[-1])
|
||||
else:
|
||||
keypoints.append(current_kp[None])
|
||||
|
||||
keypoints = np.concatenate(keypoints, 0)
|
||||
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
||||
return keypoints
|
||||
else:
|
||||
while True:
|
||||
try:
|
||||
keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
|
||||
break
|
||||
except RuntimeError as e:
|
||||
if str(e).startswith('CUDA'):
|
||||
print("Warning: out of memory, sleep for 1s")
|
||||
time.sleep(1)
|
||||
else:
|
||||
print(e)
|
||||
break
|
||||
except TypeError:
|
||||
print('No face detected in this image')
|
||||
shape = [68, 2]
|
||||
keypoints = -1. * np.ones(shape)
|
||||
break
|
||||
if name is not None:
|
||||
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
||||
return keypoints
|
||||
|
||||
def read_video(filename):
|
||||
frames = []
|
||||
cap = cv2.VideoCapture(filename)
|
||||
while cap.isOpened():
|
||||
ret, frame = cap.read()
|
||||
if ret:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
frame = Image.fromarray(frame)
|
||||
frames.append(frame)
|
||||
else:
|
||||
break
|
||||
cap.release()
|
||||
return frames
|
||||
|
||||
def run(data):
|
||||
filename, opt, device = data
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
||||
kp_extractor = KeypointExtractor()
|
||||
images = read_video(filename)
|
||||
name = filename.split('/')[-2:]
|
||||
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
|
||||
kp_extractor.extract_keypoint(
|
||||
images,
|
||||
name=os.path.join(opt.output_dir, name[-2], name[-1])
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
set_start_method('spawn')
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
|
||||
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
|
||||
parser.add_argument('--device_ids', type=str, default='0,1')
|
||||
parser.add_argument('--workers', type=int, default=4)
|
||||
|
||||
opt = parser.parse_args()
|
||||
filenames = list()
|
||||
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
|
||||
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
|
||||
extensions = VIDEO_EXTENSIONS
|
||||
|
||||
for ext in extensions:
|
||||
os.listdir(f'{opt.input_dir}')
|
||||
print(f'{opt.input_dir}/*.{ext}')
|
||||
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
|
||||
print('Total number of videos:', len(filenames))
|
||||
pool = Pool(opt.workers)
|
||||
args_list = cycle([opt])
|
||||
device_ids = opt.device_ids.split(",")
|
||||
device_ids = cycle(device_ids)
|
||||
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
|
||||
None
|
@ -0,0 +1,67 @@
|
||||
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
||||
|
||||
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
||||
You need to implement the following five functions:
|
||||
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
||||
-- <set_input>: unpack data from dataset and apply preprocessing.
|
||||
-- <forward>: produce intermediate results.
|
||||
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
||||
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
||||
|
||||
In the function <__init__>, you need to define four lists:
|
||||
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
||||
-- self.model_names (str list): define networks used in our training.
|
||||
-- self.visual_names (str list): specify the images that you want to display and save.
|
||||
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
||||
|
||||
Now you can use the model class by specifying flag '--model dummy'.
|
||||
See our template model class 'template_model.py' for more details.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from src.face3d.models.base_model import BaseModel
|
||||
|
||||
|
||||
def find_model_using_name(model_name):
|
||||
"""Import the module "models/[model_name]_model.py".
|
||||
|
||||
In the file, the class called DatasetNameModel() will
|
||||
be instantiated. It has to be a subclass of BaseModel,
|
||||
and it is case-insensitive.
|
||||
"""
|
||||
model_filename = "face3d.models." + model_name + "_model"
|
||||
modellib = importlib.import_module(model_filename)
|
||||
model = None
|
||||
target_model_name = model_name.replace('_', '') + 'model'
|
||||
for name, cls in modellib.__dict__.items():
|
||||
if name.lower() == target_model_name.lower() \
|
||||
and issubclass(cls, BaseModel):
|
||||
model = cls
|
||||
|
||||
if model is None:
|
||||
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
||||
exit(0)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_option_setter(model_name):
|
||||
"""Return the static method <modify_commandline_options> of the model class."""
|
||||
model_class = find_model_using_name(model_name)
|
||||
return model_class.modify_commandline_options
|
||||
|
||||
|
||||
def create_model(opt):
|
||||
"""Create a model given the option.
|
||||
|
||||
This function warps the class CustomDatasetDataLoader.
|
||||
This is the main interface between this package and 'train.py'/'test.py'
|
||||
|
||||
Example:
|
||||
>>> from models import create_model
|
||||
>>> model = create_model(opt)
|
||||
"""
|
||||
model = find_model_using_name(opt.model)
|
||||
instance = model(opt)
|
||||
print("model [%s] was created" % type(instance).__name__)
|
||||
return instance
|
@ -0,0 +1,164 @@
|
||||
# Distributed Arcface Training in Pytorch
|
||||
|
||||
This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
|
||||
identity on a single server.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
|
||||
- `pip install -r requirements.txt`.
|
||||
- Download the dataset
|
||||
from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
|
||||
.
|
||||
|
||||
## How to Training
|
||||
|
||||
To train a model, run `train.py` with the path to the configs:
|
||||
|
||||
### 1. Single node, 8 GPUs:
|
||||
|
||||
```shell
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
|
||||
```
|
||||
|
||||
### 2. Multiple nodes, each node 8 GPUs:
|
||||
|
||||
Node 0:
|
||||
|
||||
```shell
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
|
||||
```
|
||||
|
||||
Node 1:
|
||||
|
||||
```shell
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
|
||||
```
|
||||
|
||||
### 3.Training resnet2060 with 8 GPUs:
|
||||
|
||||
```shell
|
||||
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
|
||||
```
|
||||
|
||||
## Model Zoo
|
||||
|
||||
- The models are available for non-commercial research purposes only.
|
||||
- All models can be found in here.
|
||||
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
|
||||
- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
|
||||
|
||||
### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
|
||||
|
||||
ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
|
||||
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
|
||||
As the result, we can evaluate the FAIR performance for different algorithms.
|
||||
|
||||
For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
|
||||
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
|
||||
|
||||
For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
|
||||
Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
|
||||
There are totally 13,928 positive pairs and 96,983,824 negative pairs.
|
||||
|
||||
| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
|
||||
| :---: | :--- | :--- | :--- |:--- |:--- |
|
||||
| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
|
||||
| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
|
||||
| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
|
||||
| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
|
||||
| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
|
||||
| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
|
||||
| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
|
||||
| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
|
||||
| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
|
||||
| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
|
||||
|
||||
### Performance on IJB-C and Verification Datasets
|
||||
|
||||
| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
|
||||
| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
|
||||
| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
|
||||
| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
|
||||
| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
|
||||
| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
|
||||
| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
|
||||
| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
|
||||
| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
|
||||
| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
|
||||
| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
|
||||
|
||||
[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.)
|
||||
|
||||
|
||||
## [Speed Benchmark](docs/speed_benchmark.md)
|
||||
|
||||
**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
|
||||
classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
|
||||
accuracy with several times faster training performance and smaller GPU memory.
|
||||
Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
|
||||
sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
|
||||
sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
|
||||
we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
|
||||
training and mixed precision training.
|
||||
|
||||

|
||||
|
||||
More details see
|
||||
[speed_benchmark.md](docs/speed_benchmark.md) in docs.
|
||||
|
||||
### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
|
||||
|
||||
`-` means training failed because of gpu memory limitations.
|
||||
|
||||
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
||||
| :--- | :--- | :--- | :--- |
|
||||
|125000 | 4681 | 4824 | 5004 |
|
||||
|1400000 | **1672** | 3043 | 4738 |
|
||||
|5500000 | **-** | **1389** | 3975 |
|
||||
|8000000 | **-** | **-** | 3565 |
|
||||
|16000000 | **-** | **-** | 2679 |
|
||||
|29000000 | **-** | **-** | **1855** |
|
||||
|
||||
### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
|
||||
|
||||
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
||||
| :--- | :--- | :--- | :--- |
|
||||
|125000 | 7358 | 5306 | 4868 |
|
||||
|1400000 | 32252 | 11178 | 6056 |
|
||||
|5500000 | **-** | 32188 | 9854 |
|
||||
|8000000 | **-** | **-** | 12310 |
|
||||
|16000000 | **-** | **-** | 19950 |
|
||||
|29000000 | **-** | **-** | 32324 |
|
||||
|
||||
## Evaluation ICCV2021-MFR and IJB-C
|
||||
|
||||
More details see [eval.md](docs/eval.md) in docs.
|
||||
|
||||
## Test
|
||||
|
||||
We tested many versions of PyTorch. Please create an issue if you are having trouble.
|
||||
|
||||
- [x] torch 1.6.0
|
||||
- [x] torch 1.7.1
|
||||
- [x] torch 1.8.0
|
||||
- [x] torch 1.9.0
|
||||
|
||||
## Citation
|
||||
|
||||
```
|
||||
@inproceedings{deng2019arcface,
|
||||
title={Arcface: Additive angular margin loss for deep face recognition},
|
||||
author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
|
||||
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
pages={4690--4699},
|
||||
year={2019}
|
||||
}
|
||||
@inproceedings{an2020partical_fc,
|
||||
title={Partial FC: Training 10 Million Identities on a Single Machine},
|
||||
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
|
||||
Zhang, Debing and Fu Ying},
|
||||
booktitle={Arxiv 2010.05222},
|
||||
year={2020}
|
||||
}
|
||||
```
|
@ -0,0 +1,25 @@
|
||||
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
|
||||
from .mobilefacenet import get_mbf
|
||||
|
||||
|
||||
def get_model(name, **kwargs):
|
||||
# resnet
|
||||
if name == "r18":
|
||||
return iresnet18(False, **kwargs)
|
||||
elif name == "r34":
|
||||
return iresnet34(False, **kwargs)
|
||||
elif name == "r50":
|
||||
return iresnet50(False, **kwargs)
|
||||
elif name == "r100":
|
||||
return iresnet100(False, **kwargs)
|
||||
elif name == "r200":
|
||||
return iresnet200(False, **kwargs)
|
||||
elif name == "r2060":
|
||||
from .iresnet2060 import iresnet2060
|
||||
return iresnet2060(False, **kwargs)
|
||||
elif name == "mbf":
|
||||
fp16 = kwargs.get("fp16", False)
|
||||
num_features = kwargs.get("num_features", 512)
|
||||
return get_mbf(fp16=fp16, num_features=num_features)
|
||||
else:
|
||||
raise ValueError()
|
@ -0,0 +1,187 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False)
|
||||
|
||||
|
||||
class IBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
groups=1, base_width=64, dilation=1):
|
||||
super(IBasicBlock, self).__init__()
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
|
||||
self.conv1 = conv3x3(inplanes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
|
||||
self.prelu = nn.PReLU(planes)
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = self.bn1(x)
|
||||
out = self.conv1(out)
|
||||
out = self.bn2(out)
|
||||
out = self.prelu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn3(out)
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
out += identity
|
||||
return out
|
||||
|
||||
|
||||
class IResNet(nn.Module):
|
||||
fc_scale = 7 * 7
|
||||
def __init__(self,
|
||||
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
||||
super(IResNet, self).__init__()
|
||||
self.fp16 = fp16
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
||||
self.prelu = nn.PReLU(self.inplanes)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
||||
self.layer2 = self._make_layer(block,
|
||||
128,
|
||||
layers[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block,
|
||||
256,
|
||||
layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block,
|
||||
512,
|
||||
layers[3],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
|
||||
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
||||
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
||||
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
||||
nn.init.constant_(self.features.weight, 1.0)
|
||||
self.features.weight.requires_grad = False
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.normal_(m.weight, 0, 0.1)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, IBasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
||||
)
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes,
|
||||
planes,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.prelu(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.bn2(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x.float() if self.fp16 else x)
|
||||
x = self.features(x)
|
||||
return x
|
||||
|
||||
|
||||
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = IResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
raise ValueError()
|
||||
return model
|
||||
|
||||
|
||||
def iresnet18(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def iresnet34(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def iresnet50(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def iresnet100(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
|
||||
progress, **kwargs)
|
||||
|
||||
|
||||
def iresnet200(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
|
||||
progress, **kwargs)
|
||||
|
@ -0,0 +1,176 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
assert torch.__version__ >= "1.8.1"
|
||||
from torch.utils.checkpoint import checkpoint_sequential
|
||||
|
||||
__all__ = ['iresnet2060']
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=dilation,
|
||||
groups=groups,
|
||||
bias=False,
|
||||
dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False)
|
||||
|
||||
|
||||
class IBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
groups=1, base_width=64, dilation=1):
|
||||
super(IBasicBlock, self).__init__()
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
|
||||
self.conv1 = conv3x3(inplanes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
|
||||
self.prelu = nn.PReLU(planes)
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = self.bn1(x)
|
||||
out = self.conv1(out)
|
||||
out = self.bn2(out)
|
||||
out = self.prelu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn3(out)
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
out += identity
|
||||
return out
|
||||
|
||||
|
||||
class IResNet(nn.Module):
|
||||
fc_scale = 7 * 7
|
||||
|
||||
def __init__(self,
|
||||
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
||||
super(IResNet, self).__init__()
|
||||
self.fp16 = fp16
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
if replace_stride_with_dilation is None:
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
||||
self.prelu = nn.PReLU(self.inplanes)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
||||
self.layer2 = self._make_layer(block,
|
||||
128,
|
||||
layers[1],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block,
|
||||
256,
|
||||
layers[2],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block,
|
||||
512,
|
||||
layers[3],
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
|
||||
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
||||
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
||||
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
||||
nn.init.constant_(self.features.weight, 1.0)
|
||||
self.features.weight.requires_grad = False
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.normal_(m.weight, 0, 0.1)
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, IBasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
||||
)
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes,
|
||||
planes,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def checkpoint(self, func, num_seg, x):
|
||||
if self.training:
|
||||
return checkpoint_sequential(func, num_seg, x)
|
||||
else:
|
||||
return func(x)
|
||||
|
||||
def forward(self, x):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.prelu(x)
|
||||
x = self.layer1(x)
|
||||
x = self.checkpoint(self.layer2, 20, x)
|
||||
x = self.checkpoint(self.layer3, 100, x)
|
||||
x = self.layer4(x)
|
||||
x = self.bn2(x)
|
||||
x = torch.flatten(x, 1)
|
||||
x = self.dropout(x)
|
||||
x = self.fc(x.float() if self.fp16 else x)
|
||||
x = self.features(x)
|
||||
return x
|
||||
|
||||
|
||||
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
||||
model = IResNet(block, layers, **kwargs)
|
||||
if pretrained:
|
||||
raise ValueError()
|
||||
return model
|
||||
|
||||
|
||||
def iresnet2060(pretrained=False, progress=True, **kwargs):
|
||||
return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
|
@ -0,0 +1,130 @@
|
||||
'''
|
||||
Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
|
||||
Original author cavalleria
|
||||
'''
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
|
||||
import torch
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
def forward(self, x):
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
class ConvBlock(Module):
|
||||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.layers = nn.Sequential(
|
||||
Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
|
||||
BatchNorm2d(num_features=out_c),
|
||||
PReLU(num_parameters=out_c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class LinearBlock(Module):
|
||||
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
||||
super(LinearBlock, self).__init__()
|
||||
self.layers = nn.Sequential(
|
||||
Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
||||
BatchNorm2d(num_features=out_c)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class DepthWise(Module):
|
||||
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
|
||||
super(DepthWise, self).__init__()
|
||||
self.residual = residual
|
||||
self.layers = nn.Sequential(
|
||||
ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
|
||||
ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
|
||||
LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
short_cut = None
|
||||
if self.residual:
|
||||
short_cut = x
|
||||
x = self.layers(x)
|
||||
if self.residual:
|
||||
output = short_cut + x
|
||||
else:
|
||||
output = x
|
||||
return output
|
||||
|
||||
|
||||
class Residual(Module):
|
||||
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
|
||||
super(Residual, self).__init__()
|
||||
modules = []
|
||||
for _ in range(num_block):
|
||||
modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
|
||||
self.layers = Sequential(*modules)
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class GDC(Module):
|
||||
def __init__(self, embedding_size):
|
||||
super(GDC, self).__init__()
|
||||
self.layers = nn.Sequential(
|
||||
LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
|
||||
Flatten(),
|
||||
Linear(512, embedding_size, bias=False),
|
||||
BatchNorm1d(embedding_size))
|
||||
|
||||
def forward(self, x):
|
||||
return self.layers(x)
|
||||
|
||||
|
||||
class MobileFaceNet(Module):
|
||||
def __init__(self, fp16=False, num_features=512):
|
||||
super(MobileFaceNet, self).__init__()
|
||||
scale = 2
|
||||
self.fp16 = fp16
|
||||
self.layers = nn.Sequential(
|
||||
ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
|
||||
ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
|
||||
DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
|
||||
Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
||||
DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
|
||||
Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
||||
DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
|
||||
Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
||||
)
|
||||
self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
|
||||
self.features = GDC(num_features)
|
||||
self._initialize_weights()
|
||||
|
||||
def _initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
if m.bias is not None:
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward(self, x):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
x = self.layers(x)
|
||||
x = self.conv_sep(x.float() if self.fp16 else x)
|
||||
x = self.features(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_mbf(fp16, num_features):
|
||||
return MobileFaceNet(fp16, num_features)
|
@ -0,0 +1,23 @@
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# configs for test speed
|
||||
|
||||
config = edict()
|
||||
config.loss = "arcface"
|
||||
config.network = "r50"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1.0
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
|
||||
config.rec = "synthetic"
|
||||
config.num_classes = 300 * 10000
|
||||
config.num_epoch = 30
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [10, 16, 22]
|
||||
config.val_targets = []
|
@ -0,0 +1,23 @@
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# configs for test speed
|
||||
|
||||
config = edict()
|
||||
config.loss = "arcface"
|
||||
config.network = "r50"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 0.1
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
|
||||
config.rec = "synthetic"
|
||||
config.num_classes = 300 * 10000
|
||||
config.num_epoch = 30
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [10, 16, 22]
|
||||
config.val_targets = []
|
@ -0,0 +1,56 @@
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# make training faster
|
||||
# our RAM is 256G
|
||||
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
||||
|
||||
config = edict()
|
||||
config.loss = "arcface"
|
||||
config.network = "r50"
|
||||
config.resume = False
|
||||
config.output = "ms1mv3_arcface_r50"
|
||||
|
||||
config.dataset = "ms1m-retinaface-t1"
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 1
|
||||
config.fp16 = False
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 5e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
|
||||
if config.dataset == "emore":
|
||||
config.rec = "/train_tmp/faces_emore"
|
||||
config.num_classes = 85742
|
||||
config.num_image = 5822653
|
||||
config.num_epoch = 16
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [8, 14, ]
|
||||
config.val_targets = ["lfw", ]
|
||||
|
||||
elif config.dataset == "ms1m-retinaface-t1":
|
||||
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
||||
config.num_classes = 93431
|
||||
config.num_image = 5179510
|
||||
config.num_epoch = 25
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [11, 17, 22]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
|
||||
elif config.dataset == "glint360k":
|
||||
config.rec = "/train_tmp/glint360k"
|
||||
config.num_classes = 360232
|
||||
config.num_image = 17091657
|
||||
config.num_epoch = 20
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [8, 12, 15, 18]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
||||
|
||||
elif config.dataset == "webface":
|
||||
config.rec = "/train_tmp/faces_webface_112x112"
|
||||
config.num_classes = 10572
|
||||
config.num_image = "forget"
|
||||
config.num_epoch = 34
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [20, 28, 32]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
@ -0,0 +1,26 @@
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
# make training faster
|
||||
# our RAM is 256G
|
||||
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
||||
|
||||
config = edict()
|
||||
config.loss = "cosface"
|
||||
config.network = "mbf"
|
||||
config.resume = False
|
||||
config.output = None
|
||||
config.embedding_size = 512
|
||||
config.sample_rate = 0.1
|
||||
config.fp16 = True
|
||||
config.momentum = 0.9
|
||||
config.weight_decay = 2e-4
|
||||
config.batch_size = 128
|
||||
config.lr = 0.1 # batch size is 512
|
||||
|
||||
config.rec = "/train_tmp/glint360k"
|
||||
config.num_classes = 360232
|
||||
config.num_image = 17091657
|
||||
config.num_epoch = 20
|
||||
config.warmup_epoch = -1
|
||||
config.decay_epoch = [8, 12, 15, 18]
|
||||
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|