You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
from typing import List
|
|
from .utils import table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles
|
|
from tqdm import tqdm
|
|
from ..image_helper import text_rec
|
|
from ..page_detection.utils import PageDetectionResult
|
|
from ..constants import PageDetectionEnum as E
|
|
|
|
|
|
class LayoutRecognitionResult(object):
|
|
|
|
def __init__(self, clsid, content, box, table_title=None):
|
|
self.clsid = clsid
|
|
self.content = content
|
|
self.box = box
|
|
self.table_title = table_title
|
|
|
|
def __repr__(self):
|
|
return f"[{self.clsid}] {self.content}"
|
|
|
|
|
|
def rec(page_detection_results: List[PageDetectionResult]) -> List[List[LayoutRecognitionResult]]:
|
|
page_recognition_results = []
|
|
|
|
for page_idx in tqdm(range(len(page_detection_results)), '文本识别'):
|
|
results = page_detection_results[page_idx]
|
|
if not results.boxes:
|
|
page_recognition_results.append([])
|
|
continue
|
|
|
|
img = results.image
|
|
|
|
outputs = []
|
|
|
|
for layout in results.boxes:
|
|
x1, y1, x2, y2 = layout.pos
|
|
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
|
layout_img = img[y1: y2, x1: x2]
|
|
content = None
|
|
if layout.clsid == E.TEXT.value:
|
|
# text
|
|
content = markdown_rec(layout_img)
|
|
elif layout.clsid == E.TABLE.value:
|
|
# table
|
|
content = table_rec(layout_img)
|
|
elif layout.clsid == E.SCANNED_DOCUMENT.value:
|
|
# scanned document
|
|
content = scanning_document_rec(layout_img)
|
|
elif layout.clsid == E.TABLE_CAPTION.value:
|
|
# table caption
|
|
_, ocr_results, _ = text_rec(layout_img)
|
|
content = ''
|
|
for o in ocr_results:
|
|
content += f'{o}\n'
|
|
while content.endswith('\n'):
|
|
content = content[:-1]
|
|
|
|
if not content:
|
|
continue
|
|
|
|
content = content.replace('\\', '')
|
|
result = LayoutRecognitionResult(layout.clsid, content, layout.pos)
|
|
outputs.append(result)
|
|
|
|
# if is_scanning_document and len(outputs) == 1:
|
|
# # 扫描件额外提取标题
|
|
# h, w = layout_img.shape[:2]
|
|
# if h > w:
|
|
# title_img = layout_img[:360, :w, ...]
|
|
|
|
# # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
|
|
# # vis = cv2.rectangle(layout_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3)
|
|
# # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
|
|
# else:
|
|
# title_img = layout_img[:410, :w, ...]
|
|
|
|
# # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
|
|
# # vis = cv2.rectangle(layout_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3)
|
|
# # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
|
|
# _, title, _ = text_rec(title_img)
|
|
# outputs[0].table_title = '\n'.join(title)
|
|
# else:
|
|
|
|
# 自动给表格分配距离它最近的标题
|
|
assign_tables_to_titles(outputs)
|
|
|
|
# 表格标题可以删掉了
|
|
outputs = [_ for _ in outputs if _.clsid != E.TABLE_CAPTION.value]
|
|
# 将表格转为数据库中的枚举 1-表格
|
|
for o in outputs:
|
|
if o.clsid == E.TABLE.value or o.clsid == E.SCANNED_DOCUMENT.value:
|
|
o.clsid = 1
|
|
else:
|
|
o.clsid = 0
|
|
page_recognition_results.append(outputs)
|
|
|
|
return page_recognition_results
|