更新接口代码

main
fanpt 9 months ago
parent 9373cba5ab
commit 76cdb696ca

@ -2,16 +2,22 @@
import os
import subprocess
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse
import uuid
from fastapi import FastAPI, File, UploadFile, BackgroundTasks
from fastapi.responses import FileResponse, JSONResponse
from hashlib import md5
from typing import Dict
app = FastAPI()
tasks: Dict[str, dict] = {} # To store the status and result of each task
md5_to_uid: Dict[str, str] = {} # To map md5 to uid
def save_upload_file(upload_file: UploadFile, filename: str):
with open(filename, "wb") as buffer:
buffer.write(upload_file.file.read())
def generate_video_command(result_dir: str, img_path: str, audio_path: str, video_path: str):
return [
"python", "script.py",
@ -21,12 +27,15 @@ def generate_video_command(result_dir: str, img_path: str, audio_path: str, vide
"--ref_eyeblink", video_path,
]
def get_latest_sub_dir(result_dir: str):
sub_dirs = [os.path.join(result_dir, d) for d in os.listdir(result_dir) if os.path.isdir(os.path.join(result_dir, d))]
sub_dirs = [os.path.join(result_dir, d) for d in os.listdir(result_dir) if
os.path.isdir(os.path.join(result_dir, d))]
if not sub_dirs:
return None
return max(sub_dirs, key=os.path.getmtime)
def get_video_duration(video_path: str):
video_duration_command = [
"ffprobe",
@ -38,6 +47,7 @@ def get_video_duration(video_path: str):
result = subprocess.run(video_duration_command, capture_output=True, text=True)
return float(result.stdout.strip())
def trim_video(input_video_path: str, output_video_path: str, duration: float):
trim_command = [
"ffmpeg",
@ -48,6 +58,7 @@ def trim_video(input_video_path: str, output_video_path: str, duration: float):
]
subprocess.run(trim_command, check=True)
def remove_audio(input_video_path: str, output_video_path: str):
remove_audio_command = [
"ffmpeg",
@ -58,12 +69,14 @@ def remove_audio(input_video_path: str, output_video_path: str):
]
subprocess.run(remove_audio_command, check=True)
def save_to_final_destination(source_path: str, destination_dir: str):
os.makedirs(destination_dir, exist_ok=True)
destination_path = os.path.join(destination_dir, os.path.basename(source_path))
os.rename(source_path, destination_path)
return destination_path
def get_file_md5(file_path: str):
hash_md5 = md5()
with open(file_path, "rb") as f:
@ -71,9 +84,11 @@ def get_file_md5(file_path: str):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def record_video_mapping(image_path: str, video_path: str, record_file: str):
def record_video_mapping(image_md5: str, video_path: str, record_file: str):
with open(record_file, "a") as f:
f.write(f"{image_path} {video_path}\n")
f.write(f"{image_md5} {video_path}\n")
def check_existing_video(image_md5: str, record_file: str):
if not os.path.exists(record_file):
@ -85,17 +100,9 @@ def check_existing_video(image_md5: str, record_file: str):
return video_path
return None
@app.post("/dynamic-video")
async def generate_video(image: UploadFile = File(...)):
img_path = os.path.join("dynamic", image.filename)
save_upload_file(image, img_path)
image_md5 = get_file_md5(img_path)
def process_dynamic_video(uid: str, image_md5: str, img_path: str):
record_file = "dynamic_video_record.txt"
existing_video = check_existing_video(image_md5, record_file)
if existing_video:
return FileResponse(existing_video, media_type='video/mp4')
audio_path = "./examples/driven_audio/dynamic_audio.wav"
video_path = "./examples/ref_video/dynamic.mp4"
@ -107,30 +114,26 @@ async def generate_video(image: UploadFile = File(...)):
latest_sub_dir = get_latest_sub_dir(result_dir)
if not latest_sub_dir:
return {"error": "No subdirectory found in result directory"}
tasks[uid]['status'] = 'failed'
return
result_video_path = os.path.join(latest_sub_dir, f"{os.path.splitext(image.filename)[0]}##dynamic_audio_enhanced.mp4")
silent_video_path = os.path.join(latest_sub_dir, f"{os.path.splitext(image.filename)[0]}##dynamic_audio_enhanced_dynamic.mp4")
result_video_path = os.path.join(latest_sub_dir,
f"{os.path.splitext(os.path.basename(img_path))[0]}##dynamic_audio_enhanced.mp4")
silent_video_path = os.path.join(latest_sub_dir,
f"{os.path.splitext(os.path.basename(img_path))[0]}##dynamic_audio_enhanced_dynamic.mp4")
if os.path.exists(result_video_path):
remove_audio(result_video_path, silent_video_path)
final_destination = save_to_final_destination(silent_video_path, "results/dynamic-video")
record_video_mapping(image_md5, final_destination, record_file)
return FileResponse(final_destination, media_type='video/mp4')
tasks[uid]['status'] = 'completed'
tasks[uid]['video_path'] = final_destination
else:
return {"error": "Video file not found"}
tasks[uid]['status'] = 'failed'
@app.post("/silent-video")
async def generate_and_trim_video(image: UploadFile = File(...)):
img_path = os.path.join("silent", image.filename)
save_upload_file(image, img_path)
image_md5 = get_file_md5(img_path)
def process_silent_video(uid: str, image_md5: str, img_path: str):
record_file = "silent_video_record.txt"
existing_video = check_existing_video(image_md5, record_file)
if existing_video:
return FileResponse(existing_video, media_type='video/mp4')
audio_path = "./examples/driven_audio/silent_audio.wav"
video_path = "./examples/ref_video/silent.mp4"
@ -142,20 +145,84 @@ async def generate_and_trim_video(image: UploadFile = File(...)):
latest_sub_dir = get_latest_sub_dir(result_dir)
if not latest_sub_dir:
return {"error": "No subdirectory found in result directory"}
tasks[uid]['status'] = 'failed'
return
result_video_path = os.path.join(latest_sub_dir, f"{os.path.splitext(image.filename)[0]}##silent_audio_enhanced.mp4")
trimmed_video_path = os.path.join(latest_sub_dir, f"{os.path.splitext(image.filename)[0]}##silent_audio_enhanced_trimmed.mp4")
result_video_path = os.path.join(latest_sub_dir,
f"{os.path.splitext(os.path.basename(img_path))[0]}##silent_audio_enhanced.mp4")
trimmed_video_path = os.path.join(latest_sub_dir,
f"{os.path.splitext(os.path.basename(img_path))[0]}##silent_audio_enhanced_trimmed.mp4")
if os.path.exists(result_video_path):
video_duration = get_video_duration(result_video_path)
trim_video(result_video_path, trimmed_video_path, video_duration)
final_destination = save_to_final_destination(trimmed_video_path, "results/silent-video")
record_video_mapping(image_md5, final_destination, record_file)
return FileResponse(final_destination, media_type='video/mp4')
tasks[uid]['status'] = 'completed'
tasks[uid]['video_path'] = final_destination
else:
tasks[uid]['status'] = 'failed'
@app.post("/dynamic-video")
async def generate_video(background_tasks: BackgroundTasks, image: UploadFile = File(...)):
img_path = os.path.join("dynamic", image.filename)
save_upload_file(image, img_path)
image_md5 = get_file_md5(img_path)
record_file = "dynamic_video_record.txt"
existing_video = check_existing_video(image_md5, record_file)
if existing_video:
return FileResponse(existing_video, media_type='video/mp4')
if image_md5 in md5_to_uid:
uid = md5_to_uid[image_md5]
return {"message": "Video is being generated, please check back later.", "uid": uid}
uid = str(uuid.uuid4())
tasks[uid] = {'status': 'processing'}
md5_to_uid[image_md5] = uid
background_tasks.add_task(process_dynamic_video, uid, image_md5, img_path)
return {"message": "Video is being generated, please check back later.", "uid": uid}
@app.post("/silent-video")
async def generate_and_trim_video(background_tasks: BackgroundTasks, image: UploadFile = File(...)):
img_path = os.path.join("silent", image.filename)
save_upload_file(image, img_path)
image_md5 = get_file_md5(img_path)
record_file = "silent_video_record.txt"
existing_video = check_existing_video(image_md5, record_file)
if existing_video:
return FileResponse(existing_video, media_type='video/mp4')
if image_md5 in md5_to_uid:
uid = md5_to_uid[image_md5]
return {"message": "Video is being generated, please check back later.", "uid": uid}
uid = str(uuid.uuid4())
tasks[uid] = {'status': 'processing'}
md5_to_uid[image_md5] = uid
background_tasks.add_task(process_silent_video, uid, image_md5, img_path)
return {"message": "Video is being generated, please check back later.", "uid": uid}
@app.get("/status/{uid}")
async def get_status(uid: str):
task = tasks.get(uid)
if not task:
return JSONResponse(status_code=404, content={"message": "Task not found"})
if task['status'] == 'completed':
return FileResponse(task['video_path'], media_type='video/mp4')
elif task['status'] == 'failed':
return JSONResponse(status_code=500, content={"message": "Video generation failed"})
else:
return {"error": "Video file not found"}
return {"status": task['status']}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

Loading…
Cancel
Save