|
|
|
|
import os
|
|
|
|
|
import tempfile
|
|
|
|
|
import cv2
|
|
|
|
|
import numpy as np
|
|
|
|
|
from marker.converters.table import TableConverter
|
|
|
|
|
from marker.models import create_model_dict
|
|
|
|
|
from marker.output import text_from_rendered
|
|
|
|
|
from .rapid_table_pipeline.main import table2md_pipeline
|
|
|
|
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
|
|
|
|
from magic_pdf.data.read_api import read_local_images
|
|
|
|
|
from markdownify import markdownify as md
|
|
|
|
|
import re
|
|
|
|
|
from ..image_helper import text_rec
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scanning_document_classify(image):
|
|
|
|
|
# 判断是否是扫描件
|
|
|
|
|
|
|
|
|
|
# 将图像从BGR颜色空间转换到HSV颜色空间
|
|
|
|
|
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
|
|
|
|
|
|
|
|
|
# 定义红色的HSV范围
|
|
|
|
|
lower_red1 = np.array([0, 70, 50])
|
|
|
|
|
upper_red1 = np.array([10, 255, 255])
|
|
|
|
|
lower_red2 = np.array([170, 70, 50])
|
|
|
|
|
upper_red2 = np.array([180, 255, 255])
|
|
|
|
|
|
|
|
|
|
# 创建两个掩码,一个用于低色调的红色,一个用于高色调的红色
|
|
|
|
|
mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
|
|
|
|
|
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
|
|
|
|
|
|
|
|
|
|
# 将两个掩码合并
|
|
|
|
|
mask = cv2.bitwise_or(mask1, mask2)
|
|
|
|
|
|
|
|
|
|
# 计算红色区域的非零像素数量
|
|
|
|
|
non_zero_pixels = cv2.countNonZero(mask)
|
|
|
|
|
return 1 < non_zero_pixels < 1000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remove_watermark(image):
|
|
|
|
|
# 去除红色印章
|
|
|
|
|
_, _, r_channel = cv2.split(image)
|
|
|
|
|
r_channel[r_channel > 210] = 255
|
|
|
|
|
r_channel = cv2.cvtColor(r_channel, cv2.COLOR_GRAY2BGR)
|
|
|
|
|
return r_channel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def html2md(html_content):
|
|
|
|
|
md_content = md(html_content)
|
|
|
|
|
md_content = re.sub(r'\\([#*_`])', r'\1', md_content)
|
|
|
|
|
return md_content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def markdown_rec(image):
|
|
|
|
|
# TODO 可以传入文件夹
|
|
|
|
|
image_path = f'{tempfile.mktemp()}.jpg'
|
|
|
|
|
cv2.imwrite(image_path, image)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
ds = read_local_images(image_path)[0]
|
|
|
|
|
x = ds.apply(doc_analyze, ocr=True)
|
|
|
|
|
x = x.pipe_ocr_mode(None)
|
|
|
|
|
html = x.get_markdown(None)
|
|
|
|
|
finally:
|
|
|
|
|
os.remove(image_path)
|
|
|
|
|
return html2md(html)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def table_rec(image):
|
|
|
|
|
boxes, texts, conficences = text_rec(image)
|
|
|
|
|
ocr_result = list(zip(boxes, texts, conficences))
|
|
|
|
|
return table2md_pipeline(image, ocr_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
table_converter = TableConverter(artifact_dict=create_model_dict())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scanning_document_rec(image):
|
|
|
|
|
tmp_image_path = f'{tempfile.mktemp()}.jpg'
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
unwatermarked_image = remove_watermark(image)
|
|
|
|
|
cv2.imwrite(tmp_image_path, unwatermarked_image)
|
|
|
|
|
|
|
|
|
|
rendered = table_converter(tmp_image_path)
|
|
|
|
|
text, _, _ = text_from_rendered(rendered)
|
|
|
|
|
finally:
|
|
|
|
|
os.remove(tmp_image_path)
|
|
|
|
|
return text, unwatermarked_image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_box_distance(box1, box2):
|
|
|
|
|
x11, y11, x12, y12 = box1
|
|
|
|
|
x21, y21, x22, y22 = box2
|
|
|
|
|
|
|
|
|
|
# 计算水平和垂直方向的重叠量
|
|
|
|
|
x_overlap = max(0, min(x12, x22) - max(x11, x21))
|
|
|
|
|
y_overlap = max(0, min(y12, y22) - max(y11, y21))
|
|
|
|
|
|
|
|
|
|
# 如果有重叠(x和y都重叠),返回负的重叠深度(取 min 表示最小穿透)
|
|
|
|
|
if x_overlap > 0 and y_overlap > 0:
|
|
|
|
|
return -min(x_overlap, y_overlap)
|
|
|
|
|
|
|
|
|
|
distances = []
|
|
|
|
|
|
|
|
|
|
# 如果 x 方向有投影重叠,计算上下边的距离
|
|
|
|
|
if x12 > x21 and x11 < x22:
|
|
|
|
|
dist_top = y21 - y12 # box1下边到box2上边
|
|
|
|
|
dist_bottom = y11 - y22 # box1上边到box2下边
|
|
|
|
|
if dist_top > 0:
|
|
|
|
|
distances.append(dist_top)
|
|
|
|
|
if dist_bottom > 0:
|
|
|
|
|
distances.append(dist_bottom)
|
|
|
|
|
|
|
|
|
|
# 如果 y 方向有投影重叠,计算左右边的距离
|
|
|
|
|
if y12 > y21 and y11 < y22:
|
|
|
|
|
dist_left = x11 - x22 # box1左边到box2右边
|
|
|
|
|
dist_right = x21 - x12 # box1右边到box2左边
|
|
|
|
|
if dist_left > 0:
|
|
|
|
|
distances.append(dist_left)
|
|
|
|
|
if dist_right > 0:
|
|
|
|
|
distances.append(dist_right)
|
|
|
|
|
|
|
|
|
|
# 如果有合法的距离,返回最小值,否则说明边无法对齐,返回 None
|
|
|
|
|
return min(distances) if distances else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assign_tables_to_titles(layout_results, max_distance=200):
|
|
|
|
|
tables = [_ for _ in layout_results if _.clsid == 4]
|
|
|
|
|
titles = [_ for _ in layout_results if _.clsid == 5]
|
|
|
|
|
|
|
|
|
|
table_to_title = {}
|
|
|
|
|
title_to_table = {}
|
|
|
|
|
|
|
|
|
|
changed = True
|
|
|
|
|
while changed:
|
|
|
|
|
changed = False
|
|
|
|
|
for title in titles:
|
|
|
|
|
title_id = id(title)
|
|
|
|
|
|
|
|
|
|
best_table = None
|
|
|
|
|
min_dist = float('inf')
|
|
|
|
|
|
|
|
|
|
for table in tables:
|
|
|
|
|
table_id = id(table)
|
|
|
|
|
|
|
|
|
|
dist = compute_box_distance(title.box, table.box)
|
|
|
|
|
if dist is None or dist > max_distance:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if dist < min_dist:
|
|
|
|
|
min_dist = dist
|
|
|
|
|
best_table = table
|
|
|
|
|
|
|
|
|
|
if best_table is None:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
table_id = id(best_table)
|
|
|
|
|
|
|
|
|
|
current_table = title_to_table.get(title_id)
|
|
|
|
|
if current_table is best_table:
|
|
|
|
|
continue # 已是最优,无需更新
|
|
|
|
|
|
|
|
|
|
prev_title = table_to_title.get(table_id)
|
|
|
|
|
if prev_title:
|
|
|
|
|
prev_title_id = id(prev_title)
|
|
|
|
|
prev_dist = compute_box_distance(prev_title.box, best_table.box)
|
|
|
|
|
if prev_dist is not None and prev_dist <= min_dist:
|
|
|
|
|
continue # 原标题绑定得更近,跳过
|
|
|
|
|
|
|
|
|
|
# 解绑旧标题
|
|
|
|
|
title_to_table.pop(prev_title_id, None)
|
|
|
|
|
|
|
|
|
|
# 更新新绑定
|
|
|
|
|
title_to_table[title_id] = best_table
|
|
|
|
|
table_to_title[table_id] = title
|
|
|
|
|
changed = True # 有更新
|
|
|
|
|
|
|
|
|
|
# 最终写回绑定结果
|
|
|
|
|
for table in tables:
|
|
|
|
|
table_id = id(table)
|
|
|
|
|
title = table_to_title.get(table_id)
|
|
|
|
|
if title:
|
|
|
|
|
table.table_title = title.content
|
|
|
|
|
else:
|
|
|
|
|
table.table_title = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
# content = text_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/5.jpg')
|
|
|
|
|
# content = markdown_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/3.jpg')
|
|
|
|
|
# content = table_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/6.jpg')
|
|
|
|
|
content = scanning_document_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/103.jpg')
|
|
|
|
|
print(content)
|