You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.7 KiB
Python

1 month ago
from typing import List
from .pdf_detection import Pipeline
from utils import non_max_suppression, merge_text_and_title_boxes, LayoutBox, PageDetectionResult
from tqdm import tqdm
"""
0 - Text
1 - Title
2 - Figure
3 - Figure caption
4 - Table
5 - Table caption
6 - Header
7 - Footer
8 - Reference
9 - Equation
"""
pipeline = Pipeline('./models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer')
effective_labels = [0, 1, 2, 4, 5]
# nms优先级索引越低优先级越低box重叠时优先保留表格
label_scores = [1, 5, 0, 2, 4]
expand_pixel = 10
def layout_analysis(image_paths) -> List[PageDetectionResult]:
layout_analysis_results = []
for image_path in tqdm(image_paths, '版面分析'):
page_detecion_outputs = pipeline(image_path)
layout_boxes = []
for i in range(len(page_detecion_outputs)):
clsid, box, confidence = page_detecion_outputs[i]
if clsid in effective_labels:
layout_boxes.append(LayoutBox(clsid, box, confidence))
page_detecion_outputs = PageDetectionResult(layout_boxes, image_path)
scores = []
poses = []
for box in page_detecion_outputs.boxes:
# 相同的label重叠时保留面积更大的
area = (box.pos[3] - box.pos[1]) * (box.pos[2] - box.pos[0])
area_score = area / 5000000
scores.append(label_scores.index(box.clsid) + area_score)
poses.append(box.pos)
indices = non_max_suppression(poses, scores, 0.2)
_boxes = []
for i in indices:
_boxes.append(page_detecion_outputs.boxes[i])
page_detecion_outputs.boxes = _boxes
for i in range(len(page_detecion_outputs.boxes) - 1, -1, -1):
box = page_detecion_outputs.boxes[i]
if box.clsid in (0, 5):
# 移除Table box和Figure box中的Table caption box和Text box (有些扫描件会被识别为Figure)
for _box in page_detecion_outputs.boxes:
if _box.clsid != 2 and _box.clsid != 4:
continue
if box.pos[0] > _box.pos[0] and box.pos[1] > _box.pos[1] and box.pos[2] < _box.pos[2] and box.pos[3] < _box.pos[3]:
page_detecion_outputs.boxes.remove(box)
# 将text和title合并起来便于转成markdown格式
page_detecion_outputs.boxes = merge_text_and_title_boxes(page_detecion_outputs.boxes, (0, 1))
# 对box进行排序
page_detecion_outputs.boxes.sort(key=lambda x: (x.pos[1], x.pos[0]))
layout_analysis_results.append(page_detecion_outputs)
return layout_analysis_results