You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

232 lines
7.4 KiB
Python

# -*- coding: utf-8 -*-
from fastapi import FastAPI, HTTPException
8 months ago
from pydantic import BaseModel
from typing import List
import asyncio
8 months ago
import subprocess
import re
import os
import logging
import boto3
from botocore.client import Config
from snowflake import SnowflakeGenerator
import requests
import time # 导入 time 模块
8 months ago
app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
ansi_escape = re.compile(r'\x1b[^m]*m')
minio_config = {
"endpoint_url": "http://192.168.10.137:9002", # MinIO endpoint
"aws_access_key_id": "admin", # MinIO accessKey
"aws_secret_access_key": "12345678", # MinIO secretKey
}
bucket_name = "nxfuhsi" # MinIO bucket name
s3_client = boto3.client("s3", **minio_config)
gen = SnowflakeGenerator(42)
# API URL of the large model
model_api_url = "http://112.81.86.50:11434/api/chat"
8 months ago
class PictureRequest(BaseModel):
file_ids: List[str]
suffix: str
8 months ago
class PictureResponse(BaseModel):
file_id: str
draw_img_id: str
ocr_text: str
status: int
error_msg: str = None
class TitleRequest(BaseModel):
text: str
class TitleResponse(BaseModel):
title: str
status: int
error_msg: str = None
def send_request_to_model(text: str) -> TitleResponse:
prompt = """
提取下面文本中前40个字中明显的标题
### 注意事项:
1. 将结果以JSON格式返回不需要进行解释
2. 如果某字段提取不到则返回""
3. 文本原文可能为空返回""
4. 标题只会出现在前40个字中
5. 非常明确是标题的文本才能返回结果
6. 有可能很大概率是没有标题的
文本原文
{text}
输出json格式{{"title":"*****"}},如果没有明显标题{{"title":""}}
回溯你输出的结果确保你的输出结果符合json格式
"""
prompt_t = prompt.format(text=text)
payload = {
"model": "qwen2:72b",
"messages": [
{"role": "user", "content": f"{prompt_t}"}
],
"type": "json_object",
"stream": False
}
headers = {"Content-Type": "application/json"}
try:
response = requests.post(model_api_url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
json_pattern = re.compile(r'\{"title":\s*"([^"]*)"\}')
content = result.get('message', {}).get('content', '')
if not isinstance(content, str):
raise ValueError("Invalid response content")
matches = json_pattern.findall(content)
if len(matches) == 1:
title_value = matches[0]
return TitleResponse(title=title_value, status=0, error_msg="")
else:
return TitleResponse(title="", status=0, error_msg="")
except requests.exceptions.RequestException as e:
logger.error(f"Failed to request model API: {e}")
return TitleResponse(title="", status=2, error_msg=f"Request failed: {e}")
except ValueError as e:
logger.error(f"Invalid result returned by model API: {e}")
return TitleResponse(title="", status=2, error_msg=f"Invalid result: {e}")
async def process_image(file_id: str, suffix: str) -> PictureResponse:
start_time = time.time() # 记录总处理开始时间
temp_image_file = None
processed_file_name = None
error_msg = ""
try:
logger.info(f"图片后缀: “{suffix}")
if not suffix:
suffix = "jpg" # 默认后缀为 jpg
logger.info(f"无后缀,更改图片后缀为: {suffix}")
else:
suffix = suffix
# Step 1: 从 MinIO 获取图片
step_start_time = time.time()
pic_name = f"{file_id}"
response = s3_client.get_object(Bucket=bucket_name, Key=pic_name)
image_data = response['Body'].read()
logger.info(f"从 MinIO 获取图片时间: {time.time() - step_start_time:.2f}")
# Step 2: 将图片写入临时文件
step_start_time = time.time()
temp_image_file = f"{next(gen)}.{suffix}"
with open(temp_image_file, "wb") as f:
f.write(image_data)
logger.info(f"写入临时文件时间: {time.time() - step_start_time:.2f}")
# Step 3: 调用 OCR 脚本进行识别
step_start_time = time.time()
command = [
'python', 'tools/infer/predict_system_1.py',
'--use_gpu=False',
'--cls_model_dir=./models/cls',
'--rec_model_dir=./models/rec',
'--det_model_dir=./models/det',
f'--image_dir={temp_image_file}'
]
logger.info(f"正在处理file_id: {file_id}的图像 ")
process = await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
ocr_str = ansi_escape.sub('', stdout.decode()).strip()
status = 0
logger.info(f"对file_id:{file_id}的OCR成功")
logger.info(f"OCR 脚本识别时间: {time.time() - step_start_time:.2f}")
# Step 4: 将识别结果上传到 MinIO
step_start_time = time.time()
processed_file_id = str(next(gen))
processed_file_name = processed_file_id # Remove the suffix
result_file_path = f"inference_results/{temp_image_file}"
with open(result_file_path, "rb") as data:
s3_client.upload_fileobj(data, bucket_name, processed_file_name)
logger.info(f"文件: {processed_file_name} 存储在MinIO中。")
logger.info(f"上传至 MinIO 时间: {time.time() - step_start_time:.2f}")
else:
ocr_str = ""
status = 2
error_msg = stderr.decode().strip()
logger.error(f"对file_id:{file_id}的OCR失败, return code: {process.returncode}")
except Exception as e:
ocr_str = ""
status = 2
error_msg = str(e)
processed_file_id = file_id
logger.exception(f"处理file_id: {file_id}时发生异常")
finally:
# Step 5: 删除临时文件
if temp_image_file and os.path.exists(temp_image_file):
os.remove(temp_image_file)
if processed_file_name:
result_file_path = f"inference_results/{temp_image_file}"
if os.path.exists(result_file_path):
os.remove(result_file_path)
logger.info(f"临时文件: {result_file_path} 已被删除")
logger.info(f"删除临时文件时间: {time.time() - step_start_time:.2f}")
logger.info(f"总处理时间: {time.time() - start_time:.2f}")
return PictureResponse(
file_id=file_id,
draw_img_id=processed_file_id if status == 0 else "",
ocr_text=ocr_str,
status=status,
error_msg=error_msg
)
8 months ago
@app.post("/ocr", response_model=List[PictureResponse])
async def ocr_endpoint(picture: PictureRequest):
tasks = [process_image(file_id, picture.suffix) for file_id in picture.file_ids]
results = await asyncio.gather(*tasks)
8 months ago
return results
@app.post("/retrieve", response_model=TitleResponse)
def get_title(request: TitleRequest):
return send_request_to_model(request.text)
8 months ago
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)