diff --git a/inference.py b/inference.py index d5db6c3..8efa0e6 100644 --- a/inference.py +++ b/inference.py @@ -117,7 +117,7 @@ if __name__ == '__main__': parser = ArgumentParser() parser.add_argument("--driven_audio", default='./examples/driven_audio/20240315_154953.wav', help="path to driven audio") parser.add_argument("--source_image", default='./examples/source_image/17.png', help="path to source image") - parser.add_argument("--ref_eyeblink", default='./examples/ref_video/E05005.mp4', help="path to reference video providing eye blinking") + parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking") parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose") parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output") parser.add_argument("--result_dir", default='./results', help="path to output") diff --git a/main.py b/main.py index 34a9940..a1e6eb6 100644 --- a/main.py +++ b/main.py @@ -4,8 +4,8 @@ import os import subprocess import uuid import base64 -from fastapi import FastAPI, File, UploadFile, BackgroundTasks -from fastapi.responses import JSONResponse +from fastapi import FastAPI, File, UploadFile, BackgroundTasks, HTTPException +from fastapi.responses import JSONResponse, StreamingResponse from hashlib import md5 from typing import Dict @@ -192,6 +192,19 @@ async def get_status(uid: str): if not task: return JSONResponse(status_code=404, content={"code": 500, "status": "not found", "message": "Task not found", "video": ""}) + if uid.startswith("szr"): + if not os.path.exists("./results/" + uid): + task['status'] = 'failed' + else: + entries = os.listdir("./results/" + uid) + if entries != []: + filename = task['image_name'] + "##" + task['audio_name'] + "_enhanced.mp4" + mp4 = os.path.exists("./results/" + uid + "/" + entries[0] + "/" + filename) + if mp4: + task['status'] = 'completed' + task['video_path'] = os.path.join("./results/" + uid, entries[0], filename) + + if task['status'] == 'completed': video_base64 = encode_video_to_base64(task['video_path']) return JSONResponse(content={"code": 200, "status": task['status'], "message": "Video generation completed.", @@ -204,6 +217,76 @@ async def get_status(uid: str): return JSONResponse( content={"code": 200, "status": task['status'], "message": "Video is being generated.", "video": ""}) +# 假设你有一个函数可以从某处获取文件流(这里只是示例) +def get_file_stream(file_path): + try: + file_size = os.path.getsize(file_path) + with open(file_path, 'rb') as file: + chunk_size = 8192 # 每次读取的字节数 + while True: + chunk = file.read(chunk_size) + if not chunk: + break + yield chunk + except FileNotFoundError: + raise HTTPException(status_code=404, detail="File not found") + + +@app.post("/videostream") +def get_stream_video(uid: str): + print("uid:::::", uid) + task = tasks.get(uid) + filename = task['image_name'] + "##" + task['audio_name'] + "_enhanced.mp4" + entries = os.listdir("./results/" + uid) + file_stream = get_file_stream(os.path.join("./results/" + uid, entries[0], filename)) + # 设置适当的HTTP头 + headers = { + "Content-Disposition": f"attachment; filename={filename}", + "Content-Type": "application/octet-stream", # 或者根据文件类型设置具体的MIME类型 + # "Content-Length": file_size, # 如果事先知道文件大小,可以设置这个头 + } + return StreamingResponse(file_stream, headers=headers) + +# 生成视频 +@app.post("/get-video") +async def get_video(background_tasks: BackgroundTasks, image: UploadFile = File(...), audio: UploadFile = File(...)): + return await get_end_video(background_tasks, image, audio) + +def runcommand(command): + try: + result = subprocess.run(command, capture_output=True, text=True, check=True) + except subprocess.CalledProcessError as e: + print(e.stderr) + print("cmd::", e.cmd) + +# 处理视频 +async def get_end_video(background_tasks: BackgroundTasks, image: UploadFile, audio: UploadFile): + source_image = os.path.join(image.filename) + save_upload_file(image, source_image) + driven_audio = os.path.join(audio.filename) + save_upload_file(audio, driven_audio) + + image_md5 = get_file_md5(source_image) + get_file_md5(driven_audio) + uid = "szr" + str(uuid.uuid4()) + tasks[uid] = {'status': 'processing', + "image_name": image.filename.split('.')[0], + 'audio_name': audio.filename.split('.')[0]} + md5_to_uid[image_md5] = uid + + command = [ + "python", "inference.py", + "--driven_audio", driven_audio, + "--source_image", source_image, + "--result_dir", "./results/"+uid, + "--still", + "--preprocess", "full", + "--enhancer", "gfpgan", + ] + + background_tasks.add_task(runcommand, command) + return JSONResponse( + content={"code": 200, "message": "Video is being generated, please check back later.", "uid": uid, "video": ""}) + if __name__ == "__main__": import uvicorn