You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

195 lines
6.1 KiB
Python

1 month ago
import os
import tempfile
import cv2
import numpy as np
from marker.converters.table import TableConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
from .rapid_table_pipeline.main import table2md_pipeline
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md
import re
from ..image_helper import text_rec
1 month ago
def scanning_document_classify(image):
# 判断是否是扫描件
# 将图像从BGR颜色空间转换到HSV颜色空间
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# 定义红色的HSV范围
lower_red1 = np.array([0, 70, 50])
upper_red1 = np.array([10, 255, 255])
lower_red2 = np.array([170, 70, 50])
upper_red2 = np.array([180, 255, 255])
# 创建两个掩码,一个用于低色调的红色,一个用于高色调的红色
mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
# 将两个掩码合并
mask = cv2.bitwise_or(mask1, mask2)
# 计算红色区域的非零像素数量
non_zero_pixels = cv2.countNonZero(mask)
return 1 < non_zero_pixels < 1000
def remove_watermark(image):
# 去除红色印章
_, _, r_channel = cv2.split(image)
r_channel[r_channel > 210] = 255
r_channel = cv2.cvtColor(r_channel, cv2.COLOR_GRAY2BGR)
return r_channel
def html2md(html_content):
md_content = md(html_content)
md_content = re.sub(r'\\([#*_`])', r'\1', md_content)
return md_content
def markdown_rec(image):
# TODO 可以传入文件夹
image_path = f'{tempfile.mktemp()}.jpg'
cv2.imwrite(image_path, image)
try:
ds = read_local_images(image_path)[0]
x = ds.apply(doc_analyze, ocr=True)
x = x.pipe_ocr_mode(None)
html = x.get_markdown(None)
finally:
os.remove(image_path)
return html2md(html)
def table_rec(image):
boxes, texts, conficences = text_rec(image)
ocr_result = list(zip(boxes, texts, conficences))
return table2md_pipeline(image, ocr_result)
1 month ago
table_converter = TableConverter(artifact_dict=create_model_dict())
def scanning_document_rec(image):
tmp_image_path = f'{tempfile.mktemp()}.jpg'
1 month ago
try:
unwatermarked_image = remove_watermark(image)
cv2.imwrite(tmp_image_path, unwatermarked_image)
1 month ago
rendered = table_converter(tmp_image_path)
1 month ago
text, _, _ = text_from_rendered(rendered)
finally:
os.remove(tmp_image_path)
return text, unwatermarked_image
1 month ago
def compute_box_distance(box1, box2):
x11, y11, x12, y12 = box1
x21, y21, x22, y22 = box2
# 计算水平和垂直方向的重叠量
x_overlap = max(0, min(x12, x22) - max(x11, x21))
y_overlap = max(0, min(y12, y22) - max(y11, y21))
# 如果有重叠x和y都重叠返回负的重叠深度取 min 表示最小穿透)
if x_overlap > 0 and y_overlap > 0:
return -min(x_overlap, y_overlap)
distances = []
# 如果 x 方向有投影重叠,计算上下边的距离
if x12 > x21 and x11 < x22:
dist_top = y21 - y12 # box1下边到box2上边
dist_bottom = y11 - y22 # box1上边到box2下边
if dist_top > 0:
distances.append(dist_top)
if dist_bottom > 0:
distances.append(dist_bottom)
# 如果 y 方向有投影重叠,计算左右边的距离
if y12 > y21 and y11 < y22:
dist_left = x11 - x22 # box1左边到box2右边
dist_right = x21 - x12 # box1右边到box2左边
if dist_left > 0:
distances.append(dist_left)
if dist_right > 0:
distances.append(dist_right)
# 如果有合法的距离,返回最小值,否则说明边无法对齐,返回 None
return min(distances) if distances else None
def assign_tables_to_titles(layout_results, max_distance=200):
tables = [_ for _ in layout_results if _.clsid == 4]
titles = [_ for _ in layout_results if _.clsid == 5]
table_to_title = {}
title_to_table = {}
changed = True
while changed:
changed = False
for title in titles:
title_id = id(title)
best_table = None
min_dist = float('inf')
for table in tables:
table_id = id(table)
dist = compute_box_distance(title.box, table.box)
if dist is None or dist > max_distance:
continue
if dist < min_dist:
min_dist = dist
best_table = table
if best_table is None:
continue
table_id = id(best_table)
current_table = title_to_table.get(title_id)
if current_table is best_table:
continue # 已是最优,无需更新
prev_title = table_to_title.get(table_id)
if prev_title:
prev_title_id = id(prev_title)
prev_dist = compute_box_distance(prev_title.box, best_table.box)
if prev_dist is not None and prev_dist <= min_dist:
continue # 原标题绑定得更近,跳过
# 解绑旧标题
title_to_table.pop(prev_title_id, None)
# 更新新绑定
title_to_table[title_id] = best_table
table_to_title[table_id] = title
changed = True # 有更新
# 最终写回绑定结果
for table in tables:
table_id = id(table)
title = table_to_title.get(table_id)
if title:
table.table_title = title.content
else:
table.table_title = None
if __name__ == '__main__':
# content = text_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/5.jpg')
# content = markdown_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/3.jpg')
# content = table_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/6.jpg')
content = scanning_document_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/103.jpg')
print(content)