You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

69 lines
2.7 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from typing import List
from .pdf_detection import Pipeline
from utils import non_max_suppression, merge_text_and_title_boxes, LayoutBox, PageDetectionResult
from tqdm import tqdm
"""
0 - Text
1 - Title
2 - Figure
3 - Figure caption
4 - Table
5 - Table caption
6 - Header
7 - Footer
8 - Reference
9 - Equation
"""
pipeline = Pipeline('./models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer')
effective_labels = [0, 1, 2, 4, 5]
# nms优先级索引越低优先级越低box重叠时优先保留表格
label_scores = [1, 5, 0, 2, 4]
expand_pixel = 10
def layout_analysis(image_paths) -> List[PageDetectionResult]:
layout_analysis_results = []
for image_path in tqdm(image_paths, '版面分析'):
page_detecion_outputs = pipeline(image_path)
layout_boxes = []
for i in range(len(page_detecion_outputs)):
clsid, box, confidence = page_detecion_outputs[i]
if clsid in effective_labels:
layout_boxes.append(LayoutBox(clsid, box, confidence))
page_detecion_outputs = PageDetectionResult(layout_boxes, image_path)
scores = []
poses = []
for box in page_detecion_outputs.boxes:
# 相同的label重叠时保留面积更大的
area = (box.pos[3] - box.pos[1]) * (box.pos[2] - box.pos[0])
area_score = area / 5000000
scores.append(label_scores.index(box.clsid) + area_score)
poses.append(box.pos)
indices = non_max_suppression(poses, scores, 0.2)
_boxes = []
for i in indices:
_boxes.append(page_detecion_outputs.boxes[i])
page_detecion_outputs.boxes = _boxes
for i in range(len(page_detecion_outputs.boxes) - 1, -1, -1):
box = page_detecion_outputs.boxes[i]
if box.clsid in (0, 5):
# 移除Table box和Figure box中的Table caption box和Text box (有些扫描件会被识别为Figure)
for _box in page_detecion_outputs.boxes:
if _box.clsid != 2 and _box.clsid != 4:
continue
if box.pos[0] > _box.pos[0] and box.pos[1] > _box.pos[1] and box.pos[2] < _box.pos[2] and box.pos[3] < _box.pos[3]:
page_detecion_outputs.boxes.remove(box)
# 将text和title合并起来便于转成markdown格式
page_detecion_outputs.boxes = merge_text_and_title_boxes(page_detecion_outputs.boxes, (0, 1))
# 对box进行排序
page_detecion_outputs.boxes.sort(key=lambda x: (x.pos[1], x.pos[0]))
layout_analysis_results.append(page_detecion_outputs)
return layout_analysis_results