support Ultralight-Digital-Human
parent
07ed664c78
commit
30c812ef73
@ -0,0 +1,116 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import optim
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from unet import Model
|
||||||
|
import pickle
|
||||||
|
# from unet2 import Model
|
||||||
|
# from unet_att import Model
|
||||||
|
|
||||||
|
import time
|
||||||
|
def osmakedirs(path_list):
|
||||||
|
for path in path_list:
|
||||||
|
os.makedirs(path) if not os.path.exists(path) else None
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Train',
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||||
|
|
||||||
|
parser.add_argument('--dataset', type=str, default="")
|
||||||
|
#parser.add_argument('--save_path', type=str, default="") # end with .mp4 please
|
||||||
|
parser.add_argument('--checkpoint', type=str, default="")
|
||||||
|
parser.add_argument('--avatar_id', default='ultralight_avatar1', type=str)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
checkpoint = args.checkpoint
|
||||||
|
dataset_dir = args.dataset
|
||||||
|
|
||||||
|
img_dir = os.path.join(dataset_dir, "full_body_img/")
|
||||||
|
lms_dir = os.path.join(dataset_dir, "landmarks/")
|
||||||
|
|
||||||
|
avatar_path = f"./results/avatars/{args.avatar_id}"
|
||||||
|
full_imgs_path = f"{avatar_path}/full_imgs"
|
||||||
|
face_imgs_path = f"{avatar_path}/face_imgs"
|
||||||
|
coords_path = f"{avatar_path}/coords.pkl"
|
||||||
|
pth_path = f"{avatar_path}/ultralight.pth"
|
||||||
|
osmakedirs([avatar_path,full_imgs_path,face_imgs_path])
|
||||||
|
|
||||||
|
len_img = len(os.listdir(img_dir)) - 1
|
||||||
|
exm_img = cv2.imread(img_dir+"0.jpg")
|
||||||
|
h, w = exm_img.shape[:2]
|
||||||
|
|
||||||
|
step_stride = 0
|
||||||
|
img_idx = 0
|
||||||
|
coord_list = []
|
||||||
|
|
||||||
|
net = Model(6, 'hubert').cuda()
|
||||||
|
net.load_state_dict(torch.load(checkpoint))
|
||||||
|
net.eval()
|
||||||
|
for i in range(len_img):
|
||||||
|
if img_idx>len_img - 1:
|
||||||
|
step_stride = -1
|
||||||
|
if img_idx<1:
|
||||||
|
step_stride = 1
|
||||||
|
img_idx += step_stride
|
||||||
|
img_path = img_dir + str(img_idx)+'.jpg'
|
||||||
|
lms_path = lms_dir + str(img_idx)+'.lms'
|
||||||
|
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
lms_list = []
|
||||||
|
with open(lms_path, "r") as f:
|
||||||
|
lines = f.read().splitlines()
|
||||||
|
for line in lines:
|
||||||
|
arr = line.split(" ")
|
||||||
|
arr = np.array(arr, dtype=np.float32)
|
||||||
|
lms_list.append(arr)
|
||||||
|
lms = np.array(lms_list, dtype=np.int32)
|
||||||
|
xmin = lms[1][0]
|
||||||
|
ymin = lms[52][1]
|
||||||
|
|
||||||
|
xmax = lms[31][0]
|
||||||
|
width = xmax - xmin
|
||||||
|
ymax = ymin + width
|
||||||
|
crop_img = img[ymin:ymax, xmin:xmax]
|
||||||
|
h, w = crop_img.shape[:2]
|
||||||
|
crop_img = cv2.resize(crop_img, (168, 168), cv2.INTER_AREA)
|
||||||
|
crop_img_ori = crop_img.copy()
|
||||||
|
img_real_ex = crop_img[4:164, 4:164].copy()
|
||||||
|
img_real_ex_ori = img_real_ex.copy()
|
||||||
|
img_masked = cv2.rectangle(img_real_ex_ori,(5,5,150,145),(0,0,0),-1)
|
||||||
|
|
||||||
|
img_masked = img_masked.transpose(2,0,1).astype(np.float32)
|
||||||
|
img_real_ex = img_real_ex.transpose(2,0,1).astype(np.float32)
|
||||||
|
|
||||||
|
img_real_ex_T = torch.from_numpy(img_real_ex / 255.0)
|
||||||
|
img_masked_T = torch.from_numpy(img_masked / 255.0)
|
||||||
|
img_concat_T = torch.cat([img_real_ex_T, img_masked_T], axis=0)[None]
|
||||||
|
|
||||||
|
audio_feat = torch.zeros(1, 32, 32, 32)
|
||||||
|
#print('audio_feat:',audio_feat.shape)
|
||||||
|
audio_feat = audio_feat.cuda()
|
||||||
|
img_concat_T = img_concat_T.cuda()
|
||||||
|
#print('img_concat_T:',img_concat_T.shape)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
pred = net(img_concat_T, audio_feat)[0]
|
||||||
|
|
||||||
|
pred = pred.cpu().numpy().transpose(1,2,0)*255
|
||||||
|
pred = np.array(pred, dtype=np.uint8)
|
||||||
|
crop_img_ori[4:164, 4:164] = pred
|
||||||
|
crop_img_ori = cv2.resize(crop_img_ori, (w, h))
|
||||||
|
img[ymin:ymax, xmin:xmax] = crop_img_ori
|
||||||
|
|
||||||
|
cv2.putText(img, "LiveTalking", (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.3, (128,128,128), 1)
|
||||||
|
cv2.imwrite(f"{full_imgs_path}/{img_idx:08d}.png", img)
|
||||||
|
cv2.imwrite(f"{face_imgs_path}/{img_idx:08d}.png", crop_img)
|
||||||
|
coord_list.append((xmin, ymin, xmin+w, ymin+h))
|
||||||
|
|
||||||
|
with open(coords_path, 'wb') as f:
|
||||||
|
pickle.dump(coord_list, f)
|
||||||
|
os.system(f"cp {checkpoint} {pth_path}")
|
||||||
|
|
||||||
|
# ffmpeg -i test_video.mp4 -i test_audio.pcm -c:v libx264 -c:a aac result_test.mp4
|
Loading…
Reference in New Issue