You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

119 lines
4.6 KiB
Python

from typing import List
import cv2
from .utils import scanning_document_classify, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark
from tqdm import tqdm
from ..image_helper import text_rec
class LayoutRecognitionResult(object):
def __init__(self, clsid, content, box, table_title=None):
self.clsid = clsid
self.content = content
self.box = box
self.table_title = table_title
def __repr__(self):
return f"[{self.clsid}] {self.content}"
expand_pixel = 10
def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]:
page_recognition_results = []
for page_idx in tqdm(range(len(page_detection_results)), '文本识别'):
results = page_detection_results[page_idx]
if not results.boxes:
page_recognition_results.append([])
continue
img = cv2.imread(results.image_path)
h, w = img.shape[:2]
for layout in results.boxes:
# box往外扩一点便于ocr
layout.pos[0] -= expand_pixel
layout.pos[1] -= expand_pixel
layout.pos[2] += expand_pixel
layout.pos[3] += expand_pixel
layout.pos[0] = max(0, layout.pos[0])
layout.pos[1] = max(0, layout.pos[1])
layout.pos[2] = min(w, layout.pos[2])
layout.pos[3] = min(h, layout.pos[3])
outputs = []
is_scanning_document = False
for layout in results.boxes:
x1, y1, x2, y2 = layout.pos
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
layout_img = img[y1: y2, x1: x2]
content = None
if layout.clsid == 0:
# text
content = markdown_rec(layout_img)
elif layout.clsid == 2:
# figure
if scanning_document_classify(layout_img):
# 扫描件
is_scanning_document = True
content, layout_img = scanning_document_rec(layout_img)
source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
elif layout.clsid == 4:
# table
if scanning_document_classify(layout_img):
is_scanning_document = True
content, layout_img = scanning_document_rec(layout_img)
source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
else:
content = table_rec(layout_img)
elif layout.clsid == 5:
# table caption
_, ocr_results, _ = text_rec(layout_img)
content = ''
for o in ocr_results:
content += f'{o}\n'
while content.endswith('\n'):
content = content[:-1]
if not content:
continue
content = content.replace('\\', '')
result = LayoutRecognitionResult(layout.clsid, content, layout.pos)
outputs.append(result)
if is_scanning_document and len(outputs) == 1:
# 扫描件额外提取标题
h, w = source_page_unwatermarked_img.shape[:2]
if h > w:
title_img = source_page_unwatermarked_img[:360, :w, ...]
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
# vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3)
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
else:
title_img = source_page_unwatermarked_img[:410, :w, ...]
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
# vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3)
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
_, title, _ = text_rec(title_img)
outputs[0].table_title = '\n'.join(title)
else:
# 自动给表格分配距离它最近的标题
assign_tables_to_titles(outputs)
# 表格标题可以删掉了
outputs = [_ for _ in outputs if _.clsid != 5]
# 将2-图片 和 4-表格转为数据库中的枚举 1-表格
for o in outputs:
if o.clsid == 2 or o.clsid == 4:
o.clsid = 1
page_recognition_results.append(outputs)
return page_recognition_results