You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

107 lines
4.3 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from typing import List
from .pdf_detection import Pipeline
from utils import non_max_suppression, merge_text_and_title_boxes, LayoutBox, PageDetectionResult
from tqdm import tqdm
from ..constants import PageDetectionEnum as E
from ..image_helper import remove_watermark
"""
0 - Text
1 - Title
2 - Figure
3 - Figure caption
4 - Table
5 - Table caption
6 - Header
7 - Footer
8 - Reference
9 - Equation
使用训练后的权重时id需要+1即TEXT从1开始
"""
pipeline = Pipeline('./models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer_v2')
effective_labels = [E.TEXT.value, E.TITLE.value, E.TABLE.value, E.TABLE_CAPTION.value, E.SCANNED_DOCUMENT.value]
# nms优先级索引越低优先级越低
label_scores = [E.TITLE.value, E.TABLE_CAPTION.value, E.TEXT.value, E.TABLE.value, E.SCANNED_DOCUMENT.value]
expand_pixel = 10
def layout_analysis(images) -> List[PageDetectionResult]:
layout_analysis_results = []
for image in tqdm(images, '版面分析'):
page_detecion_outputs = pipeline(image)
layout_boxes = []
is_scanned_document = False
for o in page_detecion_outputs:
clsid, box, confidence = o
if clsid in effective_labels:
layout_boxes.append(LayoutBox(clsid, box, confidence))
if clsid == E.SCANNED_DOCUMENT.value:
is_scanned_document = True
image = remove_watermark(image)
if is_scanned_document:
# 扫描件需要去水印后重新进行版面分析来识别出标题,因为训练的图片是去水印之后的
_page_detecion_outputs = pipeline(image)
for o in _page_detecion_outputs:
clsid, box, confidence = o
if clsid == E.TABLE_CAPTION.value:
layout_boxes.append(LayoutBox(clsid, box, confidence))
page_detecion_outputs = PageDetectionResult(layout_boxes, image)
scores = []
poses = []
for box in page_detecion_outputs.boxes:
# 相同的label重叠时保留面积更大的
area = (box.pos[3] - box.pos[1]) * (box.pos[2] - box.pos[0])
area_score = area / 5000000
scores.append(label_scores.index(box.clsid) + area_score)
poses.append(box.pos)
indices = non_max_suppression(poses, scores, 0.2)
_boxes = []
for i in indices:
_boxes.append(page_detecion_outputs.boxes[i])
page_detecion_outputs.boxes = _boxes
if not is_scanned_document:
for i in range(len(page_detecion_outputs.boxes) - 1, -1, -1):
# 移除Table box和Figure box中的Table caption box和Text box (有些扫描件会被识别为Figure)
box = page_detecion_outputs.boxes[i]
if box.clsid in (E.TEXT.value, E.TABLE_CAPTION.value):
for _box in page_detecion_outputs.boxes:
if _box.clsid != E.FIGURE.value and _box.clsid != E.TABLE.value:
continue
if box.pos[0] > _box.pos[0] and box.pos[1] > _box.pos[1] and box.pos[2] < _box.pos[2] and box.pos[3] < _box.pos[3]:
page_detecion_outputs.boxes.remove(box)
# 将text和title合并起来便于转成markdown格式
merged_labels = [E.TEXT.value, E.TITLE.value]
other_labels = list(set(effective_labels) - set(merged_labels))
page_detecion_outputs.boxes = merge_text_and_title_boxes(page_detecion_outputs.boxes, merged_labels, other_labels, E.TEXT.value)
# 对box进行排序
page_detecion_outputs.boxes.sort(key=lambda x: (x.pos[1], x.pos[0]))
# box外扩便于后续的ocr
h, w = image.shape[:2]
for layout in page_detecion_outputs.boxes:
if layout.clsid != E.TEXT.value:
continue
layout.pos[0] -= expand_pixel
layout.pos[1] -= expand_pixel
layout.pos[2] += expand_pixel
layout.pos[3] += expand_pixel
layout.pos[0] = max(0, layout.pos[0])
layout.pos[1] = max(0, layout.pos[1])
layout.pos[2] = min(w, layout.pos[2])
layout.pos[3] = min(h, layout.pos[3])
layout_analysis_results.append(page_detecion_outputs)
return layout_analysis_results