@ -4,8 +4,8 @@ import os
import subprocess
import uuid
import base64
from fastapi import FastAPI, File, UploadFile, BackgroundTasks
from fastapi.responses import JSONResponse
from fastapi import FastAPI, File, UploadFile, BackgroundTasks, HTTPException
from fastapi.responses import JSONResponse, StreamingResponse
from hashlib import md5
from typing import Dict
@ -15,15 +15,18 @@ md5_to_uid: Dict[str, str] = {} # 将md5映射到uid
ALLOWED_IMAGE_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} # 允许的图片扩展名
# 保存上传的文件
def save_upload_file(upload_file: UploadFile, filename: str):
with open(filename, "wb") as buffer:
# 检查文件是否是允许的类型
def is_allowed_file(filename: str):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_IMAGE_EXTENSIONS
# 生成视频处理命令
def generate_video_command(result_dir: str, img_path: str, audio_path: str, video_path: str):
return [
@ -34,23 +37,29 @@ def generate_video_command(result_dir: str, img_path: str, audio_path: str, vide
"--ref_eyeblink", video_path,
# 获取最新的子目录
def get_latest_sub_dir(result_dir: str):
sub_dirs = [os.path.join(result_dir, d) for d in os.listdir(result_dir) if os.path.isdir(os.path.join(result_dir, d))]
sub_dirs = [os.path.join(result_dir, d) for d in os.listdir(result_dir) if
os.path.isdir(os.path.join(result_dir, d))]
return max(sub_dirs, key=os.path.getmtime, default=None)
# 获取视频时长
def get_video_duration(video_path: str):
result = subprocess.run(
["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", video_path],
["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1",
capture_output=True, text=True
return float(result.stdout.strip())
# 运行FFmpeg命令
def run_ffmpeg_command(command: list):
subprocess.run(command, check=True)
# 将文件保存到最终目的地
def save_to_final_destination(source_path: str, destination_dir: str):
os.makedirs(destination_dir, exist_ok=True)
@ -58,6 +67,7 @@ def save_to_final_destination(source_path: str, destination_dir: str):
os.rename(source_path, destination_path)
return destination_path
# 计算文件的MD5
def get_file_md5(file_path: str):
hash_md5 = md5()
@ -66,11 +76,13 @@ def get_file_md5(file_path: str):
return hash_md5.hexdigest()
# 记录视频映射关系
def record_video_mapping(image_md5: str, video_path: str, record_file: str):
with open(record_file, "a") as f:
f.write(f"{image_md5} {video_path}\n")
# 检查是否存在已处理的视频
def check_existing_video(image_md5: str, record_file: str):
if not os.path.exists(record_file):
@ -82,6 +94,7 @@ def check_existing_video(image_md5: str, record_file: str):
return video_path
return None
# 处理视频生成
def process_video(uid: str, image_md5: str, img_path: str, audio_type: str):
record_file = f"{audio_type}_video_record.txt"
@ -108,7 +121,8 @@ def process_video(uid: str, image_md5: str, img_path: str, audio_type: str):
run_ffmpeg_command(["ffmpeg", "-i", result_video_path, "-an", "-vcodec", "copy", processed_video_path])
video_duration = get_video_duration(result_video_path)
run_ffmpeg_command(["ffmpeg", "-i", result_video_path, "-t", str(video_duration - 2), "-c", "copy", processed_video_path])
["ffmpeg", "-i", result_video_path, "-t", str(video_duration - 2), "-c", "copy", processed_video_path])
final_destination = save_to_final_destination(processed_video_path, f"results/{audio_type}-video")
record_video_mapping(image_md5, final_destination, record_file)
@ -117,60 +131,164 @@ def process_video(uid: str, image_md5: str, img_path: str, audio_type: str):
tasks[uid]['status'] = 'failed'
# 将视频编码为base64
def encode_video_to_base64(video_path: str):
with open(video_path, "rb") as video_file:
return base64.b64encode(video_file.read()).decode('utf-8')
# 处理视频请求
async def handle_video_request(background_tasks: BackgroundTasks, image: UploadFile, audio_type: str):
if not is_allowed_file(image.filename):
return JSONResponse(status_code=400, content={"code": 500, "message": "Invalid file type. Only images (png, jpg, jpeg, gif) are allowed.", "uid": "", "video": ""})
return JSONResponse(status_code=400, content={"code": 500,
"message": "Invalid file type. Only images (png, jpg, jpeg, gif) are allowed.",
"uid": "", "video": ""})
img_path = os.path.join(audio_type, image.filename)
save_upload_file(image, img_path)
# 对同一个图片内容会生成同一个md5,所以会导致同时调用静、动两个接口时只会生成一个的问题
image_md5 = get_file_md5(img_path)
# 所以对同一个图片内容生成的md5加类型做区分
image_md5 = audio_type + image_md5
record_file = f"{audio_type}_video_record.txt"
existing_video = check_existing_video(image_md5, record_file)
if existing_video:
video_base64 = encode_video_to_base64(existing_video)
return JSONResponse(content={"code": 500, "message": "Video retrieved successfully.", "uid": "", "video": video_base64})
return JSONResponse(
content={"code": 500, "message": "Video retrieved successfully.", "uid": "", "video": video_base64})
if image_md5 in md5_to_uid:
uid = md5_to_uid[image_md5]
return JSONResponse(content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""})
return JSONResponse(
content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid,
"video": ""})
uid = str(uuid.uuid4())
tasks[uid] = {'status': 'processing'}
md5_to_uid[image_md5] = uid
background_tasks.add_task(process_video, uid, image_md5, img_path, audio_type)
return JSONResponse(content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""})
return JSONResponse(
content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""})
# 动态视频生成端点
async def generate_dynamic_video(background_tasks: BackgroundTasks, image: UploadFile = File(...)):
return await handle_video_request(background_tasks, image, "dynamic")
# 静态视频生成端点
async def generate_silent_video(background_tasks: BackgroundTasks, image: UploadFile = File(...)):
return await handle_video_request(background_tasks, image, "silent")
# 获取任务状态端点
async def get_status(uid: str):
task = tasks.get(uid)
if not task:
return JSONResponse(status_code=404, content={"code": 500, "status": "not found", "message": "Task not found", "video": ""})
return JSONResponse(status_code=404,
content={"code": 500, "status": "not found", "message": "Task not found", "video": ""})
if uid.startswith("szr"):
if not os.path.exists("./results/" + uid):
task['status'] = 'failed'
entries = os.listdir("./results/" + uid)
if entries != []:
filename = task['image_name'] + "##" + task['audio_name'] + "_enhanced.mp4"
mp4 = os.path.exists("./results/" + uid + "/" + entries[0] + "/" + filename)
if mp4:
task['status'] = 'completed'
task['video_path'] = os.path.join("./results/" + uid, entries[0], filename)
if task['status'] == 'completed':
video_base64 = encode_video_to_base64(task['video_path'])
return JSONResponse(content={"code": 200, "status": task['status'], "message": "Video generation completed.", "video": video_base64})
return JSONResponse(content={"code": 200, "status": task['status'], "message": "Video generation completed.",
"video": video_base64})
elif task['status'] == 'failed':
return JSONResponse(status_code=500, content={"code": 500, "status": task['status'], "message": "Video generation failed", "video": ""})
return JSONResponse(status_code=500,
content={"code": 500, "status": task['status'], "message": "Video generation failed",
"video": ""})
return JSONResponse(content={"code": 200, "status": task['status'], "message": "Video is being generated.", "video": ""})
return JSONResponse(
content={"code": 200, "status": task['status'], "message": "Video is being generated.", "video": ""})
# 假设你有一个函数可以从某处获取文件流(这里只是示例)
def get_file_stream(file_path):
file_size = os.path.getsize(file_path)
with open(file_path, 'rb') as file:
chunk_size = 8192 # 每次读取的字节数
while True:
chunk = file.read(chunk_size)
if not chunk:
yield chunk
except FileNotFoundError:
raise HTTPException(status_code=404, detail="File not found")
def get_stream_video(uid: str):
print("uid:::::", uid)
task = tasks.get(uid)
filename = task['image_name'] + "##" + task['audio_name'] + "_enhanced.mp4"
entries = os.listdir("./results/" + uid)
file_stream = get_file_stream(os.path.join("./results/" + uid, entries[0], filename))
# 设置适当的HTTP头
headers = {
"Content-Disposition": f"attachment; filename={filename}",
"Content-Type": "application/octet-stream", # 或者根据文件类型设置具体的MIME类型
# "Content-Length": file_size, # 如果事先知道文件大小,可以设置这个头
return StreamingResponse(file_stream, headers=headers)
# 生成视频
async def get_video(background_tasks: BackgroundTasks, image: UploadFile = File(...), audio: UploadFile = File(...)):
return await get_end_video(background_tasks, image, audio)
def runcommand(command):
result = subprocess.run(command, capture_output=True, text=True, check=True)
except subprocess.CalledProcessError as e:
print("cmd::", e.cmd)
# 处理视频
async def get_end_video(background_tasks: BackgroundTasks, image: UploadFile, audio: UploadFile):
source_image = os.path.join(image.filename)
save_upload_file(image, source_image)
driven_audio = os.path.join(audio.filename)
save_upload_file(audio, driven_audio)
image_md5 = get_file_md5(source_image) + get_file_md5(driven_audio)
uid = "szr" + str(uuid.uuid4())
tasks[uid] = {'status': 'processing',
"image_name": image.filename.split('.')[0],
'audio_name': audio.filename.split('.')[0]}
md5_to_uid[image_md5] = uid
command = [
"python", "inference.py",
"--driven_audio", driven_audio,
"--source_image", source_image,
"--result_dir", "./results/"+uid,
"--preprocess", "full",
"--enhancer", "gfpgan",
background_tasks.add_task(runcommand, command)
return JSONResponse(
content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""})
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="", port=8000)