You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
SadTalker/main.py

161 lines
6.8 KiB
Python

# -*- coding: utf-8 -*-
import os
import subprocess
import uuid
import base64
from fastapi import FastAPI, File, UploadFile, BackgroundTasks
from fastapi.responses import JSONResponse
from hashlib import md5
from typing import Dict
app = FastAPI()
tasks: Dict[str, dict] = {} # To store the status and result of each task
md5_to_uid: Dict[str, str] = {} # To map md5 to uid
ALLOWED_IMAGE_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'}
def save_upload_file(upload_file: UploadFile, filename: str):
with open(filename, "wb") as buffer:
buffer.write(upload_file.file.read())
def is_allowed_file(filename: str):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_IMAGE_EXTENSIONS
def generate_video_command(result_dir: str, img_path: str, audio_path: str, video_path: str):
return [
"python", "script.py",
"--source_image", img_path,
"--result_dir", result_dir,
"--driven_audio", audio_path,
"--ref_eyeblink", video_path,
]
def get_latest_sub_dir(result_dir: str):
sub_dirs = [os.path.join(result_dir, d) for d in os.listdir(result_dir) if os.path.isdir(os.path.join(result_dir, d))]
return max(sub_dirs, key=os.path.getmtime, default=None)
def get_video_duration(video_path: str):
result = subprocess.run(
["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", video_path],
capture_output=True, text=True
)
return float(result.stdout.strip())
def run_ffmpeg_command(command: list):
subprocess.run(command, check=True)
def save_to_final_destination(source_path: str, destination_dir: str):
os.makedirs(destination_dir, exist_ok=True)
destination_path = os.path.join(destination_dir, os.path.basename(source_path))
os.rename(source_path, destination_path)
return destination_path
def get_file_md5(file_path: str):
hash_md5 = md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def record_video_mapping(image_md5: str, video_path: str, record_file: str):
with open(record_file, "a") as f:
f.write(f"{image_md5} {video_path}\n")
def check_existing_video(image_md5: str, record_file: str):
if not os.path.exists(record_file):
return None
with open(record_file, "r") as f:
for line in f:
recorded_image_md5, video_path = line.strip().split()
if recorded_image_md5 == image_md5:
return video_path
return None
def process_video(uid: str, image_md5: str, img_path: str, audio_type: str):
record_file = f"{audio_type}_video_record.txt"
audio_path = f"./examples/driven_audio/{audio_type}_audio.wav"
video_path = f"./examples/ref_video/{audio_type}.mp4"
result_dir = os.path.join("results")
os.makedirs(result_dir, exist_ok=True)
command = generate_video_command(result_dir, img_path, audio_path, video_path)
subprocess.run(command, check=True)
latest_sub_dir = get_latest_sub_dir(result_dir)
if not latest_sub_dir:
tasks[uid]['status'] = 'failed'
return
base_filename = os.path.splitext(os.path.basename(img_path))[0]
result_video_path = os.path.join(latest_sub_dir, f"{base_filename}##{audio_type}_audio_enhanced.mp4")
processed_video_path = os.path.join(latest_sub_dir, f"{base_filename}##{audio_type}_audio_enhanced_processed.mp4")
if os.path.exists(result_video_path):
if audio_type == "dynamic":
run_ffmpeg_command(["ffmpeg", "-i", result_video_path, "-an", "-vcodec", "copy", processed_video_path])
else:
video_duration = get_video_duration(result_video_path)
run_ffmpeg_command(["ffmpeg", "-i", result_video_path, "-t", str(video_duration - 2), "-c", "copy", processed_video_path])
final_destination = save_to_final_destination(processed_video_path, f"results/{audio_type}-video")
record_video_mapping(image_md5, final_destination, record_file)
tasks[uid]['status'] = 'completed'
tasks[uid]['video_path'] = final_destination
else:
tasks[uid]['status'] = 'failed'
def encode_video_to_base64(video_path: str):
with open(video_path, "rb") as video_file:
return base64.b64encode(video_file.read()).decode('utf-8')
async def handle_video_request(background_tasks: BackgroundTasks, image: UploadFile, audio_type: str):
if not is_allowed_file(image.filename):
return JSONResponse(status_code=400, content={"code": 400, "message": "Invalid file type. Only images (png, jpg, jpeg, gif) are allowed.", "uid": "", "video": ""})
img_path = os.path.join(audio_type, image.filename)
save_upload_file(image, img_path)
image_md5 = get_file_md5(img_path)
record_file = f"{audio_type}_video_record.txt"
existing_video = check_existing_video(image_md5, record_file)
if existing_video:
video_base64 = encode_video_to_base64(existing_video)
return JSONResponse(content={"code": 200, "message": "Video retrieved successfully.", "uid": "", "video": video_base64})
if image_md5 in md5_to_uid:
uid = md5_to_uid[image_md5]
return JSONResponse(content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""})
uid = str(uuid.uuid4())
tasks[uid] = {'status': 'processing'}
md5_to_uid[image_md5] = uid
background_tasks.add_task(process_video, uid, image_md5, img_path, audio_type)
return JSONResponse(content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""})
@app.post("/dynamic-video")
async def generate_dynamic_video(background_tasks: BackgroundTasks, image: UploadFile = File(...)):
return await handle_video_request(background_tasks, image, "dynamic")
@app.post("/silent-video")
async def generate_silent_video(background_tasks: BackgroundTasks, image: UploadFile = File(...)):
return await handle_video_request(background_tasks, image, "silent")
@app.get("/status/{uid}")
async def get_status(uid: str):
task = tasks.get(uid)
if not task:
return JSONResponse(status_code=404, content={"code": 500, "status": "not found", "message": "Task not found", "video": ""})
if task['status'] == 'completed':
video_base64 = encode_video_to_base64(task['video_path'])
return JSONResponse(content={"code": 200, "status": task['status'], "message": "Video generation completed.", "video": video_base64})
elif task['status'] == 'failed':
return JSONResponse(status_code=500, content={"code": 500, "status": task['status'], "message": "Video generation failed", "video": ""})
else:
return JSONResponse(content={"code": 200, "status": task['status'], "message": "Video is being generated.", "video": ""})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)