|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
|
|
|
|
from fastapi import FastAPI, HTTPException
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from typing import List
|
|
|
|
|
import asyncio
|
|
|
|
|
import subprocess
|
|
|
|
|
import re
|
|
|
|
|
import os
|
|
|
|
|
import logging
|
|
|
|
|
import boto3
|
|
|
|
|
from botocore.client import Config
|
|
|
|
|
from snowflake import SnowflakeGenerator
|
|
|
|
|
import requests
|
|
|
|
|
import time # 导入 time 模块
|
|
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
ansi_escape = re.compile(r'\x1b[^m]*m')
|
|
|
|
|
|
|
|
|
|
minio_config = {
|
|
|
|
|
"endpoint_url": "http://192.168.10.137:9002", # MinIO endpoint
|
|
|
|
|
"aws_access_key_id": "admin", # MinIO accessKey
|
|
|
|
|
"aws_secret_access_key": "12345678", # MinIO secretKey
|
|
|
|
|
}
|
|
|
|
|
bucket_name = "nxfuhsi" # MinIO bucket name
|
|
|
|
|
|
|
|
|
|
s3_client = boto3.client("s3", **minio_config)
|
|
|
|
|
|
|
|
|
|
gen = SnowflakeGenerator(42)
|
|
|
|
|
|
|
|
|
|
# API URL of the large model
|
|
|
|
|
model_api_url = "http://112.81.86.50:11434/api/chat"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PictureRequest(BaseModel):
|
|
|
|
|
file_ids: List[str]
|
|
|
|
|
suffix: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PictureResponse(BaseModel):
|
|
|
|
|
file_id: str
|
|
|
|
|
draw_img_id: str
|
|
|
|
|
ocr_text: str
|
|
|
|
|
status: int
|
|
|
|
|
error_msg: str = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TitleRequest(BaseModel):
|
|
|
|
|
text: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TitleResponse(BaseModel):
|
|
|
|
|
title: str
|
|
|
|
|
status: int
|
|
|
|
|
error_msg: str = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def send_request_to_model(text: str) -> TitleResponse:
|
|
|
|
|
prompt = """
|
|
|
|
|
提取下面文本中前40个字中明显的标题:
|
|
|
|
|
|
|
|
|
|
### 注意事项:
|
|
|
|
|
1. 将结果以JSON格式返回。不需要进行解释。
|
|
|
|
|
2. 如果某字段提取不到,则返回""。
|
|
|
|
|
3. 文本原文可能为空,返回""。
|
|
|
|
|
4. 标题只会出现在前40个字中
|
|
|
|
|
5. 非常明确是标题的文本,才能返回结果。
|
|
|
|
|
6. 有可能很大概率是没有标题的。
|
|
|
|
|
|
|
|
|
|
文本原文:
|
|
|
|
|
{text}
|
|
|
|
|
输出json格式:{{"title":"*****"}},如果没有明显标题:{{"title":""}}。
|
|
|
|
|
|
|
|
|
|
回溯你输出的结果,确保你的输出结果符合json格式。
|
|
|
|
|
"""
|
|
|
|
|
prompt_t = prompt.format(text=text)
|
|
|
|
|
|
|
|
|
|
payload = {
|
|
|
|
|
"model": "qwen2:72b",
|
|
|
|
|
"messages": [
|
|
|
|
|
{"role": "user", "content": f"{prompt_t}"}
|
|
|
|
|
],
|
|
|
|
|
"type": "json_object",
|
|
|
|
|
"stream": False
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
response = requests.post(model_api_url, json=payload, headers=headers)
|
|
|
|
|
response.raise_for_status()
|
|
|
|
|
result = response.json()
|
|
|
|
|
|
|
|
|
|
json_pattern = re.compile(r'\{"title":\s*"([^"]*)"\}')
|
|
|
|
|
content = result.get('message', {}).get('content', '')
|
|
|
|
|
if not isinstance(content, str):
|
|
|
|
|
raise ValueError("Invalid response content")
|
|
|
|
|
|
|
|
|
|
matches = json_pattern.findall(content)
|
|
|
|
|
|
|
|
|
|
if len(matches) == 1:
|
|
|
|
|
title_value = matches[0]
|
|
|
|
|
return TitleResponse(title=title_value, status=0, error_msg="")
|
|
|
|
|
else:
|
|
|
|
|
return TitleResponse(title="", status=0, error_msg="")
|
|
|
|
|
|
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
|
|
|
logger.error(f"Failed to request model API: {e}")
|
|
|
|
|
return TitleResponse(title="", status=2, error_msg=f"Request failed: {e}")
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
logger.error(f"Invalid result returned by model API: {e}")
|
|
|
|
|
return TitleResponse(title="", status=2, error_msg=f"Invalid result: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def process_image(file_id: str, suffix: str) -> PictureResponse:
|
|
|
|
|
start_time = time.time() # 记录总处理开始时间
|
|
|
|
|
temp_image_file = None
|
|
|
|
|
processed_file_name = None
|
|
|
|
|
error_msg = ""
|
|
|
|
|
try:
|
|
|
|
|
logger.info(f"图片后缀: “{suffix}”")
|
|
|
|
|
if not suffix:
|
|
|
|
|
suffix = "jpg" # 默认后缀为 jpg
|
|
|
|
|
logger.info(f"无后缀,更改图片后缀为: {suffix}")
|
|
|
|
|
else:
|
|
|
|
|
suffix = suffix
|
|
|
|
|
|
|
|
|
|
# Step 1: 从 MinIO 获取图片
|
|
|
|
|
step_start_time = time.time()
|
|
|
|
|
pic_name = f"{file_id}"
|
|
|
|
|
response = s3_client.get_object(Bucket=bucket_name, Key=pic_name)
|
|
|
|
|
image_data = response['Body'].read()
|
|
|
|
|
logger.info(f"从 MinIO 获取图片时间: {time.time() - step_start_time:.2f} 秒")
|
|
|
|
|
|
|
|
|
|
# Step 2: 将图片写入临时文件
|
|
|
|
|
step_start_time = time.time()
|
|
|
|
|
temp_image_file = f"{next(gen)}.{suffix}"
|
|
|
|
|
with open(temp_image_file, "wb") as f:
|
|
|
|
|
f.write(image_data)
|
|
|
|
|
logger.info(f"写入临时文件时间: {time.time() - step_start_time:.2f} 秒")
|
|
|
|
|
|
|
|
|
|
# Step 3: 调用 OCR 脚本进行识别
|
|
|
|
|
step_start_time = time.time()
|
|
|
|
|
command = [
|
|
|
|
|
'python', 'tools/infer/predict_system_1.py',
|
|
|
|
|
'--use_gpu=False',
|
|
|
|
|
'--cls_model_dir=./models/cls',
|
|
|
|
|
'--rec_model_dir=./models/rec',
|
|
|
|
|
'--det_model_dir=./models/det',
|
|
|
|
|
f'--image_dir={temp_image_file}'
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
logger.info(f"正在处理file_id: {file_id}的图像 ")
|
|
|
|
|
process = await asyncio.create_subprocess_exec(
|
|
|
|
|
*command,
|
|
|
|
|
stdout=subprocess.PIPE,
|
|
|
|
|
stderr=subprocess.PIPE
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
stdout, stderr = await process.communicate()
|
|
|
|
|
|
|
|
|
|
if process.returncode == 0:
|
|
|
|
|
ocr_str = ansi_escape.sub('', stdout.decode()).strip()
|
|
|
|
|
status = 0
|
|
|
|
|
logger.info(f"对file_id:{file_id}的OCR成功")
|
|
|
|
|
logger.info(f"OCR 脚本识别时间: {time.time() - step_start_time:.2f} 秒")
|
|
|
|
|
|
|
|
|
|
# Step 4: 将识别结果上传到 MinIO
|
|
|
|
|
step_start_time = time.time()
|
|
|
|
|
processed_file_id = str(next(gen))
|
|
|
|
|
processed_file_name = processed_file_id # Remove the suffix
|
|
|
|
|
result_file_path = f"inference_results/{temp_image_file}"
|
|
|
|
|
|
|
|
|
|
with open(result_file_path, "rb") as data:
|
|
|
|
|
s3_client.upload_fileobj(data, bucket_name, processed_file_name)
|
|
|
|
|
logger.info(f"文件: {processed_file_name} 存储在MinIO中。")
|
|
|
|
|
logger.info(f"上传至 MinIO 时间: {time.time() - step_start_time:.2f} 秒")
|
|
|
|
|
else:
|
|
|
|
|
ocr_str = ""
|
|
|
|
|
status = 2
|
|
|
|
|
error_msg = stderr.decode().strip()
|
|
|
|
|
logger.error(f"对file_id:{file_id}的OCR失败, return code: {process.returncode}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
ocr_str = ""
|
|
|
|
|
status = 2
|
|
|
|
|
error_msg = str(e)
|
|
|
|
|
processed_file_id = file_id
|
|
|
|
|
logger.exception(f"处理file_id: {file_id}时发生异常")
|
|
|
|
|
|
|
|
|
|
finally:
|
|
|
|
|
# Step 5: 删除临时文件
|
|
|
|
|
if temp_image_file and os.path.exists(temp_image_file):
|
|
|
|
|
os.remove(temp_image_file)
|
|
|
|
|
if processed_file_name:
|
|
|
|
|
result_file_path = f"inference_results/{temp_image_file}"
|
|
|
|
|
if os.path.exists(result_file_path):
|
|
|
|
|
os.remove(result_file_path)
|
|
|
|
|
logger.info(f"临时文件: {result_file_path} 已被删除")
|
|
|
|
|
logger.info(f"删除临时文件时间: {time.time() - step_start_time:.2f} 秒")
|
|
|
|
|
|
|
|
|
|
logger.info(f"总处理时间: {time.time() - start_time:.2f} 秒")
|
|
|
|
|
return PictureResponse(
|
|
|
|
|
file_id=file_id,
|
|
|
|
|
draw_img_id=processed_file_id if status == 0 else "",
|
|
|
|
|
ocr_text=ocr_str,
|
|
|
|
|
status=status,
|
|
|
|
|
error_msg=error_msg
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/ocr", response_model=List[PictureResponse])
|
|
|
|
|
async def ocr_endpoint(picture: PictureRequest):
|
|
|
|
|
tasks = [process_image(file_id, picture.suffix) for file_id in picture.file_ids]
|
|
|
|
|
results = await asyncio.gather(*tasks)
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/retrieve", response_model=TitleResponse)
|
|
|
|
|
def get_title(request: TitleRequest):
|
|
|
|
|
return send_request_to_model(request.text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import uvicorn
|
|
|
|
|
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|