from typing import List import cv2 from .utils import scanning_document_classify, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark from tqdm import tqdm from ..image_helper import text_rec class LayoutRecognitionResult(object): def __init__(self, clsid, content, box, table_title=None): self.clsid = clsid self.content = content self.box = box self.table_title = table_title def __repr__(self): return f"[{self.clsid}] {self.content}" expand_pixel = 10 def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]: page_recognition_results = [] for page_idx in tqdm(range(len(page_detection_results)), '文本识别'): results = page_detection_results[page_idx] if not results.boxes: page_recognition_results.append([]) continue img = cv2.imread(results.image_path) h, w = img.shape[:2] for layout in results.boxes: # box往外扩一点便于ocr layout.pos[0] -= expand_pixel layout.pos[1] -= expand_pixel layout.pos[2] += expand_pixel layout.pos[3] += expand_pixel layout.pos[0] = max(0, layout.pos[0]) layout.pos[1] = max(0, layout.pos[1]) layout.pos[2] = min(w, layout.pos[2]) layout.pos[3] = min(h, layout.pos[3]) outputs = [] is_scanning_document = False for layout in results.boxes: x1, y1, x2, y2 = layout.pos x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) layout_img = img[y1: y2, x1: x2] content = None if layout.clsid == 0: # text content = markdown_rec(layout_img) elif layout.clsid == 2: # figure if scanning_document_classify(layout_img): # 扫描件 is_scanning_document = True content, layout_img = scanning_document_rec(layout_img) source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg')) elif layout.clsid == 4: # table if scanning_document_classify(layout_img): is_scanning_document = True content, layout_img = scanning_document_rec(layout_img) source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg')) else: content = table_rec(layout_img) elif layout.clsid == 5: # table caption _, ocr_results, _ = text_rec(layout_img) content = '' for o in ocr_results: content += f'{o}\n' while content.endswith('\n'): content = content[:-1] if not content: continue content = content.replace('\\', '') result = LayoutRecognitionResult(layout.clsid, content, layout.pos) outputs.append(result) if is_scanning_document and len(outputs) == 1: # 扫描件额外提取标题 h, w = source_page_unwatermarked_img.shape[:2] if h > w: title_img = source_page_unwatermarked_img[:360, :w, ...] # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img) # vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3) # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis) else: title_img = source_page_unwatermarked_img[:410, :w, ...] # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img) # vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3) # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis) _, title, _ = text_rec(title_img) outputs[0].table_title = '\n'.join(title) else: # 自动给表格分配距离它最近的标题 assign_tables_to_titles(outputs) # 表格标题可以删掉了 outputs = [_ for _ in outputs if _.clsid != 5] # 将2-图片 和 4-表格转为数据库中的枚举 1-表格 for o in outputs: if o.clsid == 2 or o.clsid == 4: o.clsid = 1 page_recognition_results.append(outputs) return page_recognition_results