You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
4.7 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import tempfile
import cv2
from marker.converters.table import TableConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
from .rapid_table_pipeline.main import table2md_pipeline
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md
import re
from ..image_helper import text_rec
from ..constants import PageDetectionEnum as E
def html2md(html_content):
md_content = md(html_content)
md_content = re.sub(r'\\([#*_`])', r'\1', md_content)
return md_content
def markdown_rec(image):
# TODO 可以传入文件夹
image_path = f'{tempfile.mktemp()}.jpg'
cv2.imwrite(image_path, image)
try:
ds = read_local_images(image_path)[0]
x = ds.apply(doc_analyze, ocr=True)
x = x.pipe_ocr_mode(None)
html = x.get_markdown(None)
finally:
os.remove(image_path)
return html2md(html)
def table_rec(image):
boxes, texts, conficences = text_rec(image)
if not texts:
return None
ocr_result = list(zip(boxes, texts, conficences))
return table2md_pipeline(image, ocr_result)
table_converter = TableConverter(artifact_dict=create_model_dict())
def scanning_document_rec(scanning_document_image):
tmp_image_path = f'{tempfile.mktemp()}.jpg'
try:
cv2.imwrite(tmp_image_path, scanning_document_image)
rendered = table_converter(tmp_image_path)
text, _, _ = text_from_rendered(rendered)
finally:
os.remove(tmp_image_path)
return text
def compute_box_distance(box1, box2):
x11, y11, x12, y12 = box1
x21, y21, x22, y22 = box2
# 计算水平和垂直方向的重叠量
x_overlap = max(0, min(x12, x22) - max(x11, x21))
y_overlap = max(0, min(y12, y22) - max(y11, y21))
# 如果有重叠x和y都重叠返回负的重叠深度取 min 表示最小穿透)
if x_overlap > 0 and y_overlap > 0:
return -min(x_overlap, y_overlap)
distances = []
# 如果 x 方向有投影重叠,计算上下边的距离
if x12 > x21 and x11 < x22:
dist_top = y21 - y12 # box1下边到box2上边
dist_bottom = y11 - y22 # box1上边到box2下边
if dist_top > 0:
distances.append(dist_top)
if dist_bottom > 0:
distances.append(dist_bottom)
# 如果 y 方向有投影重叠,计算左右边的距离
if y12 > y21 and y11 < y22:
dist_left = x11 - x22 # box1左边到box2右边
dist_right = x21 - x12 # box1右边到box2左边
if dist_left > 0:
distances.append(dist_left)
if dist_right > 0:
distances.append(dist_right)
# 如果有合法的距离,返回最小值,否则说明边无法对齐,返回 None
return min(distances) if distances else None
def assign_tables_to_titles(layout_results, max_distance=200):
tables = [_ for _ in layout_results if _.clsid in (E.TABLE.value, E.SCANNED_DOCUMENT.value)]
titles = [_ for _ in layout_results if _.clsid == E.TABLE_CAPTION.value]
table_to_title = {}
title_to_table = {}
changed = True
while changed:
changed = False
for title in titles:
title_id = id(title)
best_table = None
min_dist = float('inf')
for table in tables:
table_id = id(table)
dist = compute_box_distance(title.box, table.box)
if dist is None or dist > max_distance:
continue
if dist < min_dist:
min_dist = dist
best_table = table
if best_table is None:
continue
table_id = id(best_table)
current_table = title_to_table.get(title_id)
if current_table is best_table:
continue # 已是最优,无需更新
prev_title = table_to_title.get(table_id)
if prev_title:
prev_title_id = id(prev_title)
prev_dist = compute_box_distance(prev_title.box, best_table.box)
if prev_dist is not None and prev_dist <= min_dist:
continue # 原标题绑定得更近,跳过
# 解绑旧标题
title_to_table.pop(prev_title_id, None)
# 更新新绑定
title_to_table[title_id] = best_table
table_to_title[table_id] = title
changed = True # 有更新
# 最终写回绑定结果
for table in tables:
table_id = id(table)
title = table_to_title.get(table_id)
if title:
table.table_title = title.content
else:
table.table_title = None