Compare commits

...

3 Commits

Author SHA1 Message Date
zhangzhichao b5321f477e Merge pull request '增加扫描件分类' (#8) from zzc into main
Reviewed-on: #8
3 weeks ago
zhangzhichao d7e095d0e2 删除一些代码 3 weeks ago
zhangzhichao a8c3125446 服务 4 weeks ago

@ -87,8 +87,10 @@ def rec(page_detection_results: List[PageDetectionResult]) -> List[List[LayoutRe
outputs = [_ for _ in outputs if _.clsid != E.TABLE_CAPTION.value]
# 将表格转为数据库中的枚举 1-表格
for o in outputs:
if o.clsid == E.TABLE.value:
if o.clsid == E.TABLE.value or o.clsid == E.SCANNED_DOCUMENT.value:
o.clsid = 1
else:
o.clsid = 0
page_recognition_results.append(outputs)
return page_recognition_results

@ -2,12 +2,12 @@ from typing import List
from typing_extensions import deprecated
import numpy as np
from pdf2image import convert_from_path
import os
import cv2
from .page_detection.utils import PageDetectionResult
from paddleocr import PaddleOCR
from .constants import PageDetectionEnum as E
from paddlex import create_model
from tqdm import tqdm
ocr = PaddleOCR(use_angle_cls=False, lang='ch', use_gpu=True, show_log=False)
@ -82,7 +82,7 @@ def image_orient_cls(images):
images = [images]
angles = []
for img in images:
for img in tqdm(images, '文本方向分类'):
h, w = img.shape[:2]
det_res = ocr.ocr(img, det=True, rec=False, cls=False)[0]
boxes = []

@ -89,6 +89,8 @@ def layout_analysis(images) -> List[PageDetectionResult]:
# box外扩便于后续的ocr
h, w = image.shape[:2]
for layout in page_detecion_outputs.boxes:
if layout.clsid != E.TEXT.value:
continue
layout.pos[0] -= expand_pixel
layout.pos[1] -= expand_pixel
layout.pos[2] += expand_pixel

@ -14,9 +14,6 @@
import os
import yaml
import glob
import json
from pathlib import Path
import cv2
import numpy as np
@ -34,7 +31,7 @@ from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, L
from .picodet_postprocess import PicoDetPostProcess
from clrnet_postprocess import CLRNetPostProcess
from visualize import visualize_box_mask, imshow_lanes
from utils import argsparser, Timer, multiclass_nms, coco_clsid2catid
from utils import Timer
# Global dictionary
SUPPORT_MODELS = {
@ -45,7 +42,6 @@ SUPPORT_MODELS = {
}
class Detector(object):
"""
Args:
@ -770,13 +766,10 @@ def load_predictor(model_dir,
precision_mode=precision_map[run_mode],
use_static=False,
use_calib_mode=trt_calib_mode)
if FLAGS.collect_trt_shape_info:
config.collect_shape_range_info(FLAGS.tuned_trt_shape_file)
elif os.path.exists(FLAGS.tuned_trt_shape_file):
print(f'Use dynamic shape file: '
f'{FLAGS.tuned_trt_shape_file} for TRT...')
if os.path.exists('shape_range_info.pbtxt'):
print('Use dynamic shape file: shape_range_info.pbtxt for TRT...')
config.enable_tuned_tensorrt_dynamic_shape(
FLAGS.tuned_trt_shape_file, True)
'shape_range_info.pbtxt', True)
if use_dynamic_shape:
min_input_shape = {
@ -849,19 +842,9 @@ def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
print("save result to: " + out_path)
def print_arguments(args):
print('----------- Running Arguments -----------')
for arg, value in sorted(vars(args).items()):
print('%s: %s' % (arg, value))
print('------------------------------------------')
class Pipeline(object):
def __init__(self, model_dir):
if FLAGS.use_fd_format:
deploy_file = os.path.join(model_dir, 'inference.yml')
else:
deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
with open(deploy_file) as f:
yml_conf = yaml.safe_load(f)
@ -876,40 +859,27 @@ class Pipeline(object):
self.detector = eval(detector_func)(
model_dir,
device=FLAGS.device,
run_mode=FLAGS.run_mode,
batch_size=FLAGS.batch_size,
trt_min_shape=FLAGS.trt_min_shape,
trt_max_shape=FLAGS.trt_max_shape,
trt_opt_shape=FLAGS.trt_opt_shape,
trt_calib_mode=FLAGS.trt_calib_mode,
cpu_threads=FLAGS.cpu_threads,
enable_mkldnn=FLAGS.enable_mkldnn,
enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
threshold=FLAGS.threshold,
output_dir=FLAGS.output_dir,
use_fd_format=FLAGS.use_fd_format)
device='GPU',
run_mode='paddle',
batch_size=1,
trt_min_shape=1,
trt_max_shape=1280,
trt_opt_shape=640,
trt_calib_mode=False,
cpu_threads=1,
enable_mkldnn=False,
enable_mkldnn_bfloat16=False,
threshold=0.5,
output_dir='output',
use_fd_format=False)
def __call__(self, image):
if isinstance(image, np.ndarray):
image = [image]
results = self.detector.predict_image(
image,
visual=FLAGS.save_images)
visual=False)
return results
paddle.enable_static()
parser = argsparser()
FLAGS = parser.parse_args()
# print_arguments(FLAGS)
FLAGS.device = 'GPU'
FLAGS.save_images = False
FLAGS.device = FLAGS.device.upper()
assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU', 'MLU', 'GCU'
], "device should be CPU, GPU, XPU, MLU, NPU or GCU"
assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
assert not (
FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True
), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'

@ -33,199 +33,6 @@ class PageDetectionResult(object):
self.image = image
def argsparser():
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--image_dir",
type=str,
default=None,
help="Dir of image file, `image_file` has a higher priority.")
parser.add_argument(
"--batch_size", type=int, default=1, help="batch_size for inference.")
parser.add_argument(
"--video_file",
type=str,
default=None,
help="Path of video file, `video_file` or `camera_id` has a highest priority."
)
parser.add_argument(
"--camera_id",
type=int,
default=-1,
help="device id of camera to predict.")
parser.add_argument(
"--threshold", type=float, default=0.5, help="Threshold of score.")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory of output visualization files.")
parser.add_argument(
"--run_mode",
type=str,
default='paddle',
help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
parser.add_argument(
"--device",
type=str,
default='cpu',
help="Choose the device you want to run, it can be: CPU/GPU/XPU/NPU, default is CPU."
)
parser.add_argument(
"--use_gpu",
type=ast.literal_eval,
default=False,
help="Deprecated, please use `--device`.")
parser.add_argument(
"--run_benchmark",
type=ast.literal_eval,
default=False,
help="Whether to predict a image_file repeatedly for benchmark")
parser.add_argument(
"--enable_mkldnn",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn with CPU.")
parser.add_argument(
"--enable_mkldnn_bfloat16",
type=ast.literal_eval,
default=False,
help="Whether use mkldnn bfloat16 inference with CPU.")
parser.add_argument(
"--cpu_threads", type=int, default=1, help="Num of threads with CPU.")
parser.add_argument(
"--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.")
parser.add_argument(
"--trt_max_shape",
type=int,
default=1280,
help="max_shape for TensorRT.")
parser.add_argument(
"--trt_opt_shape",
type=int,
default=640,
help="opt_shape for TensorRT.")
parser.add_argument(
"--trt_calib_mode",
type=bool,
default=False,
help="If the model is produced by TRT offline quantitative "
"calibration, trt_calib_mode need to set True.")
parser.add_argument(
'--save_images',
type=ast.literal_eval,
default=True,
help='Save visualization image results.')
parser.add_argument(
'--save_mot_txts',
action='store_true',
help='Save tracking results (txt).')
parser.add_argument(
'--save_mot_txt_per_img',
action='store_true',
help='Save tracking results (txt) for each image.')
parser.add_argument(
'--scaled',
type=bool,
default=False,
help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
"True in general detector.")
parser.add_argument(
"--tracker_config", type=str, default=None, help=("tracker donfig"))
parser.add_argument(
"--reid_model_dir",
type=str,
default=None,
help=("Directory include:'model.pdiparams', 'model.pdmodel', "
"'infer_cfg.yml', created by tools/export_model.py."))
parser.add_argument(
"--reid_batch_size",
type=int,
default=50,
help="max batch_size for reid model inference.")
parser.add_argument(
'--use_dark',
type=ast.literal_eval,
default=True,
help='whether to use darkpose to get better keypoint position predict ')
parser.add_argument(
"--action_file",
type=str,
default=None,
help="Path of input file for action recognition.")
parser.add_argument(
"--window_size",
type=int,
default=50,
help="Temporal size of skeleton feature for action recognition.")
parser.add_argument(
"--random_pad",
type=ast.literal_eval,
default=False,
help="Whether do random padding for action recognition.")
parser.add_argument(
"--save_results",
action='store_true',
default=False,
help="Whether save detection result to file using coco format")
parser.add_argument(
'--use_coco_category',
action='store_true',
default=False,
help='Whether to use the coco format dictionary `clsid2catid`')
parser.add_argument(
"--slice_infer",
action='store_true',
help="Whether to slice the image and merge the inference results for small object detection."
)
parser.add_argument(
'--slice_size',
nargs='+',
type=int,
default=[640, 640],
help="Height of the sliced image.")
parser.add_argument(
"--overlap_ratio",
nargs='+',
type=float,
default=[0.25, 0.25],
help="Overlap height ratio of the sliced image.")
parser.add_argument(
"--combine_method",
type=str,
default='nms',
help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
)
parser.add_argument(
"--match_threshold",
type=float,
default=0.6,
help="Combine method matching threshold.")
parser.add_argument(
"--match_metric",
type=str,
default='ios',
help="Combine method matching metric, choose in ['iou', 'ios'].")
parser.add_argument(
"--collect_trt_shape_info",
action='store_true',
default=False,
help="Whether to collect dynamic shape before using tensorrt.")
parser.add_argument(
"--tuned_trt_shape_file",
type=str,
default="shape_range_info.pbtxt",
help="Path of a dynamic shape file for tensorrt.")
parser.add_argument("--use_fd_format", action="store_true")
parser.add_argument(
"--task_type",
type=str,
default='Detection',
help="How to save the coco result, it only work with save_results==True. Optional inputs are Rotate or Detection, default is Detection."
)
return parser
class Times(object):
def __init__(self):
self.time = 0.
@ -325,212 +132,6 @@ class Timer(Times):
return dic
def multiclass_nms(bboxs, num_classes, match_threshold=0.6, match_metric='iou'):
final_boxes = []
for c in range(num_classes):
idxs = bboxs[:, 0] == c
if np.count_nonzero(idxs) == 0: continue
r = nms(bboxs[idxs, 1:], match_threshold, match_metric)
final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1))
return final_boxes
def nms(dets, match_threshold=0.6, match_metric='iou'):
""" Apply NMS to avoid detecting too many overlapping bounding boxes.
Args:
dets: shape [N, 5], [score, x1, y1, x2, y2]
match_metric: 'iou' or 'ios'
match_threshold: overlap thresh for match metric.
"""
if dets.shape[0] == 0:
return dets[[], :]
scores = dets[:, 0]
x1 = dets[:, 1]
y1 = dets[:, 2]
x2 = dets[:, 3]
y2 = dets[:, 4]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = scores.argsort()[::-1]
ndets = dets.shape[0]
suppressed = np.zeros((ndets), dtype=np.int32)
for _i in range(ndets):
i = order[_i]
if suppressed[i] == 1:
continue
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, ndets):
j = order[_j]
if suppressed[j] == 1:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0.0, xx2 - xx1 + 1)
h = max(0.0, yy2 - yy1 + 1)
inter = w * h
if match_metric == 'iou':
union = iarea + areas[j] - inter
match_value = inter / union
elif match_metric == 'ios':
smaller = min(iarea, areas[j])
match_value = inter / smaller
else:
raise ValueError()
if match_value >= match_threshold:
suppressed[j] = 1
keep = np.where(suppressed == 0)[0]
dets = dets[keep, :]
return dets
coco_clsid2catid = {
0: 1,
1: 2,
2: 3,
3: 4,
4: 5,
5: 6,
6: 7,
7: 8,
8: 9,
9: 10,
10: 11,
11: 13,
12: 14,
13: 15,
14: 16,
15: 17,
16: 18,
17: 19,
18: 20,
19: 21,
20: 22,
21: 23,
22: 24,
23: 25,
24: 27,
25: 28,
26: 31,
27: 32,
28: 33,
29: 34,
30: 35,
31: 36,
32: 37,
33: 38,
34: 39,
35: 40,
36: 41,
37: 42,
38: 43,
39: 44,
40: 46,
41: 47,
42: 48,
43: 49,
44: 50,
45: 51,
46: 52,
47: 53,
48: 54,
49: 55,
50: 56,
51: 57,
52: 58,
53: 59,
54: 60,
55: 61,
56: 62,
57: 63,
58: 64,
59: 65,
60: 67,
61: 70,
62: 72,
63: 73,
64: 74,
65: 75,
66: 76,
67: 77,
68: 78,
69: 79,
70: 80,
71: 81,
72: 82,
73: 84,
74: 85,
75: 86,
76: 87,
77: 88,
78: 89,
79: 90
}
def gaussian_radius(bbox_size, min_overlap):
height, width = bbox_size
a1 = 1
b1 = (height + width)
c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
radius1 = (b1 + sq1) / (2 * a1)
a2 = 4
b2 = 2 * (height + width)
c2 = (1 - min_overlap) * width * height
sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
radius2 = (b2 + sq2) / 2
a3 = 4 * min_overlap
b3 = -2 * min_overlap * (height + width)
c3 = (min_overlap - 1) * width * height
sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
radius3 = (b3 + sq3) / 2
return min(radius1, radius2, radius3)
def gaussian2D(shape, sigma_x=1, sigma_y=1):
m, n = [(ss - 1.) / 2. for ss in shape]
y, x = np.ogrid[-m:m + 1, -n:n + 1]
h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y *
sigma_y)))
h[h < np.finfo(h.dtype).eps * h.max()] = 0
return h
def draw_umich_gaussian(heatmap, center, radius, k=1):
"""
draw_umich_gaussian, refer to https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py#L126
"""
diameter = 2 * radius + 1
gaussian = gaussian2D(
(diameter, diameter), sigma_x=diameter / 6, sigma_y=diameter / 6)
x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]
left, right = min(x, radius), min(width - x, radius + 1)
top, bottom = min(y, radius), min(height - y, radius + 1)
masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
radius + right]
if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
return heatmap
def iou(box1, box2):
"""计算两个框的 IoU交并比"""
x1 = max(box1[0], box2[0])

@ -3,8 +3,9 @@ import os
from loguru import logger
env = os.environ.get('env', 'dev')
logger.info(f'Configure using this environment: {env}')
load_dotenv(dotenv_path='.env.dev' if env == 'dev' else '.env', override=True)
dotenv_path = '.env.dev' if env == 'dev' else '.env'
logger.info(f'Configure using this dotenv path: {dotenv_path}')
load_dotenv(dotenv_path=dotenv_path, override=True)
import time
import traceback
@ -42,7 +43,7 @@ def _pdf2markdown_pipeline(pdf_path, visual):
img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
images[i] = img
# images = images[90: 123]
# images = images[:2]
# 3. 版面分析
t5 = time.time()
@ -76,6 +77,7 @@ def pdf2markdown_pipeline(pdf_path: str, visual=False, insert_db=True):
pdf_name = pdf_path.split('/')[-1]
start_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
process_status = 0
pdf_id = None
try:
results = _pdf2markdown_pipeline(pdf_path, visual)
except Exception:
@ -84,7 +86,6 @@ def pdf2markdown_pipeline(pdf_path: str, visual=False, insert_db=True):
end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
if insert_db:
insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, None)
pdf_id = None
else:
process_status = PDFAnalysisStatus.SUCCESS.value
end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
@ -94,6 +95,7 @@ def pdf2markdown_pipeline(pdf_path: str, visual=False, insert_db=True):
if __name__ == '__main__':
pdf2markdown_pipeline('/mnt/pdf2markdown/龙源电力2021年年度报告.PDF', visual=True, insert_db=True)
# pdf2markdown_pipeline('/mnt/pdf2markdown/龙源电力2022年年度报告.PDF', visual=True, insert_db=True)
# pdf2markdown_pipeline('/mnt/pdf2markdown/龙源电力2023年年度报告.PDF', visual=True, insert_db=True)
insert_db = False
pdf2markdown_pipeline('/mnt/pdf2markdown/龙源电力2021年年度报告.PDF', visual=True, insert_db=insert_db)
# pdf2markdown_pipeline('/mnt/pdf2markdown/龙源电力2022年年度报告.PDF', visual=True, insert_db=insert_db)
# pdf2markdown_pipeline('/mnt/pdf2markdown/龙源电力2023年年度报告.PDF', visual=True, insert_db=insert_db)

@ -1,19 +1,27 @@
from flask import Flask, request, jsonify
import requests
from pipeline import pdf2markdown_pipeline
import concurrent.futures
from loguru import logger
app = Flask(__name__)
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
def pdf2markdown_task(pdf_paths, callback_url):
for pdf_path in pdf_paths:
process_status, pdf_id = pdf2markdown_pipeline(pdf_path)
requests.post(callback_url, json={'pdfId': pdf_id, 'processStatus': process_status})
@app.route('/pdf-qa-server/pdf-to-md', methods=['POST'])
def pdf2markdown():
data = request.json
logger.info(f'request params: {data}')
pdf_paths = data['pathList']
callback_url = data['webhookUrl']
for pdf_path in pdf_paths:
process_status, pdf_id = pdf2markdown_pipeline(pdf_path)
requests.post(callback_url, json={'pdfId': pdf_id, 'processStatus': process_status})
executor.submit(pdf2markdown_task, pdf_paths, callback_url)
return jsonify({})

@ -0,0 +1,3 @@
#! /bin/bash
gunicorn -w 1 -b 0.0.0.0:8000 server:app
Loading…
Cancel
Save