|
|
|
@ -4,8 +4,8 @@ import os
|
|
|
|
|
import subprocess
|
|
|
|
|
import uuid
|
|
|
|
|
import base64
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, BackgroundTasks
|
|
|
|
|
from fastapi.responses import JSONResponse
|
|
|
|
|
from fastapi import FastAPI, File, UploadFile, BackgroundTasks, HTTPException
|
|
|
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
|
from hashlib import md5
|
|
|
|
|
from typing import Dict
|
|
|
|
|
|
|
|
|
@ -15,15 +15,18 @@ md5_to_uid: Dict[str, str] = {} # 将md5映射到uid
|
|
|
|
|
|
|
|
|
|
ALLOWED_IMAGE_EXTENSIONS = {'png', 'jpg', 'jpeg', 'gif'} # 允许的图片扩展名
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 保存上传的文件
|
|
|
|
|
def save_upload_file(upload_file: UploadFile, filename: str):
|
|
|
|
|
with open(filename, "wb") as buffer:
|
|
|
|
|
buffer.write(upload_file.file.read())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 检查文件是否是允许的类型
|
|
|
|
|
def is_allowed_file(filename: str):
|
|
|
|
|
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_IMAGE_EXTENSIONS
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 生成视频处理命令
|
|
|
|
|
def generate_video_command(result_dir: str, img_path: str, audio_path: str, video_path: str):
|
|
|
|
|
return [
|
|
|
|
@ -34,23 +37,29 @@ def generate_video_command(result_dir: str, img_path: str, audio_path: str, vide
|
|
|
|
|
"--ref_eyeblink", video_path,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取最新的子目录
|
|
|
|
|
def get_latest_sub_dir(result_dir: str):
|
|
|
|
|
sub_dirs = [os.path.join(result_dir, d) for d in os.listdir(result_dir) if os.path.isdir(os.path.join(result_dir, d))]
|
|
|
|
|
sub_dirs = [os.path.join(result_dir, d) for d in os.listdir(result_dir) if
|
|
|
|
|
os.path.isdir(os.path.join(result_dir, d))]
|
|
|
|
|
return max(sub_dirs, key=os.path.getmtime, default=None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取视频时长
|
|
|
|
|
def get_video_duration(video_path: str):
|
|
|
|
|
result = subprocess.run(
|
|
|
|
|
["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", video_path],
|
|
|
|
|
["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1",
|
|
|
|
|
video_path],
|
|
|
|
|
capture_output=True, text=True
|
|
|
|
|
)
|
|
|
|
|
return float(result.stdout.strip())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 运行FFmpeg命令
|
|
|
|
|
def run_ffmpeg_command(command: list):
|
|
|
|
|
subprocess.run(command, check=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 将文件保存到最终目的地
|
|
|
|
|
def save_to_final_destination(source_path: str, destination_dir: str):
|
|
|
|
|
os.makedirs(destination_dir, exist_ok=True)
|
|
|
|
@ -58,6 +67,7 @@ def save_to_final_destination(source_path: str, destination_dir: str):
|
|
|
|
|
os.rename(source_path, destination_path)
|
|
|
|
|
return destination_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 计算文件的MD5
|
|
|
|
|
def get_file_md5(file_path: str):
|
|
|
|
|
hash_md5 = md5()
|
|
|
|
@ -66,11 +76,13 @@ def get_file_md5(file_path: str):
|
|
|
|
|
hash_md5.update(chunk)
|
|
|
|
|
return hash_md5.hexdigest()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 记录视频映射关系
|
|
|
|
|
def record_video_mapping(image_md5: str, video_path: str, record_file: str):
|
|
|
|
|
with open(record_file, "a") as f:
|
|
|
|
|
f.write(f"{image_md5} {video_path}\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 检查是否存在已处理的视频
|
|
|
|
|
def check_existing_video(image_md5: str, record_file: str):
|
|
|
|
|
if not os.path.exists(record_file):
|
|
|
|
@ -82,6 +94,7 @@ def check_existing_video(image_md5: str, record_file: str):
|
|
|
|
|
return video_path
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 处理视频生成
|
|
|
|
|
def process_video(uid: str, image_md5: str, img_path: str, audio_type: str):
|
|
|
|
|
record_file = f"{audio_type}_video_record.txt"
|
|
|
|
@ -108,7 +121,8 @@ def process_video(uid: str, image_md5: str, img_path: str, audio_type: str):
|
|
|
|
|
run_ffmpeg_command(["ffmpeg", "-i", result_video_path, "-an", "-vcodec", "copy", processed_video_path])
|
|
|
|
|
else:
|
|
|
|
|
video_duration = get_video_duration(result_video_path)
|
|
|
|
|
run_ffmpeg_command(["ffmpeg", "-i", result_video_path, "-t", str(video_duration - 2), "-c", "copy", processed_video_path])
|
|
|
|
|
run_ffmpeg_command(
|
|
|
|
|
["ffmpeg", "-i", result_video_path, "-t", str(video_duration - 2), "-c", "copy", processed_video_path])
|
|
|
|
|
|
|
|
|
|
final_destination = save_to_final_destination(processed_video_path, f"results/{audio_type}-video")
|
|
|
|
|
record_video_mapping(image_md5, final_destination, record_file)
|
|
|
|
@ -117,60 +131,164 @@ def process_video(uid: str, image_md5: str, img_path: str, audio_type: str):
|
|
|
|
|
else:
|
|
|
|
|
tasks[uid]['status'] = 'failed'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 将视频编码为base64
|
|
|
|
|
def encode_video_to_base64(video_path: str):
|
|
|
|
|
with open(video_path, "rb") as video_file:
|
|
|
|
|
return base64.b64encode(video_file.read()).decode('utf-8')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 处理视频请求
|
|
|
|
|
async def handle_video_request(background_tasks: BackgroundTasks, image: UploadFile, audio_type: str):
|
|
|
|
|
if not is_allowed_file(image.filename):
|
|
|
|
|
return JSONResponse(status_code=400, content={"code": 500, "message": "Invalid file type. Only images (png, jpg, jpeg, gif) are allowed.", "uid": "", "video": ""})
|
|
|
|
|
return JSONResponse(status_code=400, content={"code": 500,
|
|
|
|
|
"message": "Invalid file type. Only images (png, jpg, jpeg, gif) are allowed.",
|
|
|
|
|
"uid": "", "video": ""})
|
|
|
|
|
|
|
|
|
|
img_path = os.path.join(audio_type, image.filename)
|
|
|
|
|
save_upload_file(image, img_path)
|
|
|
|
|
|
|
|
|
|
# 对同一个图片内容会生成同一个md5,所以会导致同时调用静、动两个接口时只会生成一个的问题
|
|
|
|
|
image_md5 = get_file_md5(img_path)
|
|
|
|
|
# 所以对同一个图片内容生成的md5加类型做区分
|
|
|
|
|
image_md5 = audio_type + image_md5
|
|
|
|
|
record_file = f"{audio_type}_video_record.txt"
|
|
|
|
|
existing_video = check_existing_video(image_md5, record_file)
|
|
|
|
|
if existing_video:
|
|
|
|
|
video_base64 = encode_video_to_base64(existing_video)
|
|
|
|
|
return JSONResponse(content={"code": 500, "message": "Video retrieved successfully.", "uid": "", "video": video_base64})
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"code": 500, "message": "Video retrieved successfully.", "uid": "", "video": video_base64})
|
|
|
|
|
|
|
|
|
|
if image_md5 in md5_to_uid:
|
|
|
|
|
uid = md5_to_uid[image_md5]
|
|
|
|
|
return JSONResponse(content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""})
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid,
|
|
|
|
|
"video": ""})
|
|
|
|
|
|
|
|
|
|
uid = str(uuid.uuid4())
|
|
|
|
|
tasks[uid] = {'status': 'processing'}
|
|
|
|
|
md5_to_uid[image_md5] = uid
|
|
|
|
|
background_tasks.add_task(process_video, uid, image_md5, img_path, audio_type)
|
|
|
|
|
return JSONResponse(content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""})
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 动态视频生成端点
|
|
|
|
|
@app.post("/dynamic-video")
|
|
|
|
|
async def generate_dynamic_video(background_tasks: BackgroundTasks, image: UploadFile = File(...)):
|
|
|
|
|
return await handle_video_request(background_tasks, image, "dynamic")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 静态视频生成端点
|
|
|
|
|
@app.post("/silent-video")
|
|
|
|
|
async def generate_silent_video(background_tasks: BackgroundTasks, image: UploadFile = File(...)):
|
|
|
|
|
return await handle_video_request(background_tasks, image, "silent")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取任务状态端点
|
|
|
|
|
@app.get("/status/{uid}")
|
|
|
|
|
async def get_status(uid: str):
|
|
|
|
|
task = tasks.get(uid)
|
|
|
|
|
if not task:
|
|
|
|
|
return JSONResponse(status_code=404, content={"code": 500, "status": "not found", "message": "Task not found", "video": ""})
|
|
|
|
|
return JSONResponse(status_code=404,
|
|
|
|
|
content={"code": 500, "status": "not found", "message": "Task not found", "video": ""})
|
|
|
|
|
if uid.startswith("szr"):
|
|
|
|
|
if not os.path.exists("./results/" + uid):
|
|
|
|
|
task['status'] = 'failed'
|
|
|
|
|
else:
|
|
|
|
|
entries = os.listdir("./results/" + uid)
|
|
|
|
|
if entries != []:
|
|
|
|
|
filename = task['image_name'] + "##" + task['audio_name'] + "_enhanced.mp4"
|
|
|
|
|
mp4 = os.path.exists("./results/" + uid + "/" + entries[0] + "/" + filename)
|
|
|
|
|
if mp4:
|
|
|
|
|
task['status'] = 'completed'
|
|
|
|
|
task['video_path'] = os.path.join("./results/" + uid, entries[0], filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if task['status'] == 'completed':
|
|
|
|
|
video_base64 = encode_video_to_base64(task['video_path'])
|
|
|
|
|
return JSONResponse(content={"code": 200, "status": task['status'], "message": "Video generation completed.", "video": video_base64})
|
|
|
|
|
return JSONResponse(content={"code": 200, "status": task['status'], "message": "Video generation completed.",
|
|
|
|
|
"video": video_base64})
|
|
|
|
|
elif task['status'] == 'failed':
|
|
|
|
|
return JSONResponse(status_code=500, content={"code": 500, "status": task['status'], "message": "Video generation failed", "video": ""})
|
|
|
|
|
return JSONResponse(status_code=500,
|
|
|
|
|
content={"code": 500, "status": task['status'], "message": "Video generation failed",
|
|
|
|
|
"video": ""})
|
|
|
|
|
else:
|
|
|
|
|
return JSONResponse(content={"code": 200, "status": task['status'], "message": "Video is being generated.", "video": ""})
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"code": 200, "status": task['status'], "message": "Video is being generated.", "video": ""})
|
|
|
|
|
|
|
|
|
|
# 假设你有一个函数可以从某处获取文件流(这里只是示例)
|
|
|
|
|
def get_file_stream(file_path):
|
|
|
|
|
try:
|
|
|
|
|
file_size = os.path.getsize(file_path)
|
|
|
|
|
with open(file_path, 'rb') as file:
|
|
|
|
|
chunk_size = 8192 # 每次读取的字节数
|
|
|
|
|
while True:
|
|
|
|
|
chunk = file.read(chunk_size)
|
|
|
|
|
if not chunk:
|
|
|
|
|
break
|
|
|
|
|
yield chunk
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
raise HTTPException(status_code=404, detail="File not found")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/videostream")
|
|
|
|
|
def get_stream_video(uid: str):
|
|
|
|
|
print("uid:::::", uid)
|
|
|
|
|
task = tasks.get(uid)
|
|
|
|
|
filename = task['image_name'] + "##" + task['audio_name'] + "_enhanced.mp4"
|
|
|
|
|
entries = os.listdir("./results/" + uid)
|
|
|
|
|
file_stream = get_file_stream(os.path.join("./results/" + uid, entries[0], filename))
|
|
|
|
|
# 设置适当的HTTP头
|
|
|
|
|
headers = {
|
|
|
|
|
"Content-Disposition": f"attachment; filename={filename}",
|
|
|
|
|
"Content-Type": "application/octet-stream", # 或者根据文件类型设置具体的MIME类型
|
|
|
|
|
# "Content-Length": file_size, # 如果事先知道文件大小,可以设置这个头
|
|
|
|
|
}
|
|
|
|
|
return StreamingResponse(file_stream, headers=headers)
|
|
|
|
|
|
|
|
|
|
# 生成视频
|
|
|
|
|
@app.post("/get-video")
|
|
|
|
|
async def get_video(background_tasks: BackgroundTasks, image: UploadFile = File(...), audio: UploadFile = File(...)):
|
|
|
|
|
return await get_end_video(background_tasks, image, audio)
|
|
|
|
|
|
|
|
|
|
def runcommand(command):
|
|
|
|
|
try:
|
|
|
|
|
result = subprocess.run(command, capture_output=True, text=True, check=True)
|
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
|
|
|
print(e.stderr)
|
|
|
|
|
print("cmd::", e.cmd)
|
|
|
|
|
|
|
|
|
|
# 处理视频
|
|
|
|
|
async def get_end_video(background_tasks: BackgroundTasks, image: UploadFile, audio: UploadFile):
|
|
|
|
|
source_image = os.path.join(image.filename)
|
|
|
|
|
save_upload_file(image, source_image)
|
|
|
|
|
driven_audio = os.path.join(audio.filename)
|
|
|
|
|
save_upload_file(audio, driven_audio)
|
|
|
|
|
|
|
|
|
|
image_md5 = get_file_md5(source_image) + get_file_md5(driven_audio)
|
|
|
|
|
uid = "szr" + str(uuid.uuid4())
|
|
|
|
|
tasks[uid] = {'status': 'processing',
|
|
|
|
|
"image_name": image.filename.split('.')[0],
|
|
|
|
|
'audio_name': audio.filename.split('.')[0]}
|
|
|
|
|
md5_to_uid[image_md5] = uid
|
|
|
|
|
|
|
|
|
|
command = [
|
|
|
|
|
"python", "inference.py",
|
|
|
|
|
"--driven_audio", driven_audio,
|
|
|
|
|
"--source_image", source_image,
|
|
|
|
|
"--result_dir", "./results/"+uid,
|
|
|
|
|
"--still",
|
|
|
|
|
"--preprocess", "full",
|
|
|
|
|
"--enhancer", "gfpgan",
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
background_tasks.add_task(runcommand, command)
|
|
|
|
|
return JSONResponse(
|
|
|
|
|
content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import uvicorn
|
|
|
|
|
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
|
|
|
|