You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

232 lines
7.4 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
import asyncio
import subprocess
import re
import os
import logging
import boto3
from botocore.client import Config
from snowflake import SnowflakeGenerator
import requests
import time # 导入 time 模块
app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
ansi_escape = re.compile(r'\x1b[^m]*m')
minio_config = {
"endpoint_url": "http://192.168.10.137:9002", # MinIO endpoint
"aws_access_key_id": "admin", # MinIO accessKey
"aws_secret_access_key": "12345678", # MinIO secretKey
}
bucket_name = "nxfuhsi" # MinIO bucket name
s3_client = boto3.client("s3", **minio_config)
gen = SnowflakeGenerator(42)
# API URL of the large model
model_api_url = "http://112.81.86.50:11434/api/chat"
class PictureRequest(BaseModel):
file_ids: List[str]
suffix: str
class PictureResponse(BaseModel):
file_id: str
draw_img_id: str
ocr_text: str
status: int
error_msg: str = None
class TitleRequest(BaseModel):
text: str
class TitleResponse(BaseModel):
title: str
status: int
error_msg: str = None
def send_request_to_model(text: str) -> TitleResponse:
prompt = """
提取下面文本中前40个字中明显的标题
### 注意事项:
1. 将结果以JSON格式返回。不需要进行解释。
2. 如果某字段提取不到,则返回""
3. 文本原文可能为空,返回""
4. 标题只会出现在前40个字中
5. 非常明确是标题的文本,才能返回结果。
6. 有可能很大概率是没有标题的。
文本原文:
{text}
输出json格式{{"title":"*****"}},如果没有明显标题:{{"title":""}}。
回溯你输出的结果确保你的输出结果符合json格式。
"""
prompt_t = prompt.format(text=text)
payload = {
"model": "qwen2:72b",
"messages": [
{"role": "user", "content": f"{prompt_t}"}
],
"type": "json_object",
"stream": False
}
headers = {"Content-Type": "application/json"}
try:
response = requests.post(model_api_url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
json_pattern = re.compile(r'\{"title":\s*"([^"]*)"\}')
content = result.get('message', {}).get('content', '')
if not isinstance(content, str):
raise ValueError("Invalid response content")
matches = json_pattern.findall(content)
if len(matches) == 1:
title_value = matches[0]
return TitleResponse(title=title_value, status=0, error_msg="")
else:
return TitleResponse(title="", status=0, error_msg="")
except requests.exceptions.RequestException as e:
logger.error(f"Failed to request model API: {e}")
return TitleResponse(title="", status=2, error_msg=f"Request failed: {e}")
except ValueError as e:
logger.error(f"Invalid result returned by model API: {e}")
return TitleResponse(title="", status=2, error_msg=f"Invalid result: {e}")
async def process_image(file_id: str, suffix: str) -> PictureResponse:
start_time = time.time() # 记录总处理开始时间
temp_image_file = None
processed_file_name = None
error_msg = ""
try:
logger.info(f"图片后缀: “{suffix}")
if not suffix:
suffix = "jpg" # 默认后缀为 jpg
logger.info(f"无后缀,更改图片后缀为: {suffix}")
else:
suffix = suffix
# Step 1: 从 MinIO 获取图片
step_start_time = time.time()
pic_name = f"{file_id}"
response = s3_client.get_object(Bucket=bucket_name, Key=pic_name)
image_data = response['Body'].read()
logger.info(f"从 MinIO 获取图片时间: {time.time() - step_start_time:.2f}")
# Step 2: 将图片写入临时文件
step_start_time = time.time()
temp_image_file = f"{next(gen)}.{suffix}"
with open(temp_image_file, "wb") as f:
f.write(image_data)
logger.info(f"写入临时文件时间: {time.time() - step_start_time:.2f}")
# Step 3: 调用 OCR 脚本进行识别
step_start_time = time.time()
command = [
'python', 'tools/infer/predict_system_1.py',
'--use_gpu=False',
'--cls_model_dir=./models/cls',
'--rec_model_dir=./models/rec',
'--det_model_dir=./models/det',
f'--image_dir={temp_image_file}'
]
logger.info(f"正在处理file_id: {file_id}的图像 ")
process = await asyncio.create_subprocess_exec(
*command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
ocr_str = ansi_escape.sub('', stdout.decode()).strip()
status = 0
logger.info(f"对file_id:{file_id}的OCR成功")
logger.info(f"OCR 脚本识别时间: {time.time() - step_start_time:.2f}")
# Step 4: 将识别结果上传到 MinIO
step_start_time = time.time()
processed_file_id = str(next(gen))
processed_file_name = processed_file_id # Remove the suffix
result_file_path = f"inference_results/{temp_image_file}"
with open(result_file_path, "rb") as data:
s3_client.upload_fileobj(data, bucket_name, processed_file_name)
logger.info(f"文件: {processed_file_name} 存储在MinIO中。")
logger.info(f"上传至 MinIO 时间: {time.time() - step_start_time:.2f}")
else:
ocr_str = ""
status = 2
error_msg = stderr.decode().strip()
logger.error(f"对file_id:{file_id}的OCR失败, return code: {process.returncode}")
except Exception as e:
ocr_str = ""
status = 2
error_msg = str(e)
processed_file_id = file_id
logger.exception(f"处理file_id: {file_id}时发生异常")
finally:
# Step 5: 删除临时文件
if temp_image_file and os.path.exists(temp_image_file):
os.remove(temp_image_file)
if processed_file_name:
result_file_path = f"inference_results/{temp_image_file}"
if os.path.exists(result_file_path):
os.remove(result_file_path)
logger.info(f"临时文件: {result_file_path} 已被删除")
logger.info(f"删除临时文件时间: {time.time() - step_start_time:.2f}")
logger.info(f"总处理时间: {time.time() - start_time:.2f}")
return PictureResponse(
file_id=file_id,
draw_img_id=processed_file_id if status == 0 else "",
ocr_text=ocr_str,
status=status,
error_msg=error_msg
)
@app.post("/ocr", response_model=List[PictureResponse])
async def ocr_endpoint(picture: PictureRequest):
tasks = [process_image(file_id, picture.suffix) for file_id in picture.file_ids]
results = await asyncio.gather(*tasks)
return results
@app.post("/retrieve", response_model=TitleResponse)
def get_title(request: TitleRequest):
return send_request_to_model(request.text)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)