You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

195 lines
6.1 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import tempfile
import cv2
import numpy as np
from marker.converters.table import TableConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
from .rapid_table_pipeline.main import table2md_pipeline
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md
import re
from ..image_helper import text_rec
def scanning_document_classify(image):
# 判断是否是扫描件
# 将图像从BGR颜色空间转换到HSV颜色空间
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# 定义红色的HSV范围
lower_red1 = np.array([0, 70, 50])
upper_red1 = np.array([10, 255, 255])
lower_red2 = np.array([170, 70, 50])
upper_red2 = np.array([180, 255, 255])
# 创建两个掩码,一个用于低色调的红色,一个用于高色调的红色
mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
# 将两个掩码合并
mask = cv2.bitwise_or(mask1, mask2)
# 计算红色区域的非零像素数量
non_zero_pixels = cv2.countNonZero(mask)
return 1 < non_zero_pixels < 1000
def remove_watermark(image):
# 去除红色印章
_, _, r_channel = cv2.split(image)
r_channel[r_channel > 210] = 255
r_channel = cv2.cvtColor(r_channel, cv2.COLOR_GRAY2BGR)
return r_channel
def html2md(html_content):
md_content = md(html_content)
md_content = re.sub(r'\\([#*_`])', r'\1', md_content)
return md_content
def markdown_rec(image):
# TODO 可以传入文件夹
image_path = f'{tempfile.mktemp()}.jpg'
cv2.imwrite(image_path, image)
try:
ds = read_local_images(image_path)[0]
x = ds.apply(doc_analyze, ocr=True)
x = x.pipe_ocr_mode(None)
html = x.get_markdown(None)
finally:
os.remove(image_path)
return html2md(html)
def table_rec(image):
boxes, texts, conficences = text_rec(image)
ocr_result = list(zip(boxes, texts, conficences))
return table2md_pipeline(image, ocr_result)
table_converter = TableConverter(artifact_dict=create_model_dict())
def scanning_document_rec(image):
tmp_image_path = f'{tempfile.mktemp()}.jpg'
try:
unwatermarked_image = remove_watermark(image)
cv2.imwrite(tmp_image_path, unwatermarked_image)
rendered = table_converter(tmp_image_path)
text, _, _ = text_from_rendered(rendered)
finally:
os.remove(tmp_image_path)
return text, unwatermarked_image
def compute_box_distance(box1, box2):
x11, y11, x12, y12 = box1
x21, y21, x22, y22 = box2
# 计算水平和垂直方向的重叠量
x_overlap = max(0, min(x12, x22) - max(x11, x21))
y_overlap = max(0, min(y12, y22) - max(y11, y21))
# 如果有重叠x和y都重叠返回负的重叠深度取 min 表示最小穿透)
if x_overlap > 0 and y_overlap > 0:
return -min(x_overlap, y_overlap)
distances = []
# 如果 x 方向有投影重叠,计算上下边的距离
if x12 > x21 and x11 < x22:
dist_top = y21 - y12 # box1下边到box2上边
dist_bottom = y11 - y22 # box1上边到box2下边
if dist_top > 0:
distances.append(dist_top)
if dist_bottom > 0:
distances.append(dist_bottom)
# 如果 y 方向有投影重叠,计算左右边的距离
if y12 > y21 and y11 < y22:
dist_left = x11 - x22 # box1左边到box2右边
dist_right = x21 - x12 # box1右边到box2左边
if dist_left > 0:
distances.append(dist_left)
if dist_right > 0:
distances.append(dist_right)
# 如果有合法的距离,返回最小值,否则说明边无法对齐,返回 None
return min(distances) if distances else None
def assign_tables_to_titles(layout_results, max_distance=200):
tables = [_ for _ in layout_results if _.clsid == 4]
titles = [_ for _ in layout_results if _.clsid == 5]
table_to_title = {}
title_to_table = {}
changed = True
while changed:
changed = False
for title in titles:
title_id = id(title)
best_table = None
min_dist = float('inf')
for table in tables:
table_id = id(table)
dist = compute_box_distance(title.box, table.box)
if dist is None or dist > max_distance:
continue
if dist < min_dist:
min_dist = dist
best_table = table
if best_table is None:
continue
table_id = id(best_table)
current_table = title_to_table.get(title_id)
if current_table is best_table:
continue # 已是最优,无需更新
prev_title = table_to_title.get(table_id)
if prev_title:
prev_title_id = id(prev_title)
prev_dist = compute_box_distance(prev_title.box, best_table.box)
if prev_dist is not None and prev_dist <= min_dist:
continue # 原标题绑定得更近,跳过
# 解绑旧标题
title_to_table.pop(prev_title_id, None)
# 更新新绑定
title_to_table[title_id] = best_table
table_to_title[table_id] = title
changed = True # 有更新
# 最终写回绑定结果
for table in tables:
table_id = id(table)
title = table_to_title.get(table_id)
if title:
table.table_title = title.content
else:
table.table_title = None
if __name__ == '__main__':
# content = text_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/5.jpg')
# content = markdown_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/3.jpg')
# content = table_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/6.jpg')
content = scanning_document_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/103.jpg')
print(content)