表格识别和扫描件识别内部的ocr改为paddleocr

pull/3/head
zhangzhichao 1 month ago
parent 6bf15b0c77
commit 0fc49db2f0

1
.gitignore vendored

@ -1,5 +1,6 @@
venv
*.pdf
*.PDF
.vscode
visual_images/*.jpg
__pycache__

@ -1,7 +1,8 @@
from typing import List
import cv2
from .utils import scanning_document_classify, text_rec, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark
from .utils import scanning_document_classify, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark
from tqdm import tqdm
from ..image_helper import text_rec
class LayoutRecognitionResult(object):
@ -60,18 +61,18 @@ def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]:
# 扫描件
is_scanning_document = True
content, layout_img = scanning_document_rec(layout_img)
source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
elif layout.clsid == 4:
# table
if scanning_document_classify(layout_img):
is_scanning_document = True
content, layout_img = scanning_document_rec(layout_img)
source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
else:
content = table_rec(layout_img)
elif layout.clsid == 5:
# table caption
ocr_results = text_rec(layout_img)
_, ocr_results, _ = text_rec(layout_img)
content = ''
for o in ocr_results:
content += f'{o}\n'
@ -81,25 +82,26 @@ def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]:
if not content:
continue
content = content.replace('\\', '')
result = LayoutRecognitionResult(layout.clsid, content, layout.pos)
outputs.append(result)
if is_scanning_document and len(outputs) == 1:
# 扫描件额外提取标题
h, w = source_page_no_watermark_img.shape[:2]
h, w = source_page_unwatermarked_img.shape[:2]
if h > w:
title_img = source_page_no_watermark_img[:360, :w, ...]
title_img = source_page_unwatermarked_img[:360, :w, ...]
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
# vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3)
# vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3)
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
else:
title_img = source_page_no_watermark_img[:410, :w, ...]
title_img = source_page_unwatermarked_img[:410, :w, ...]
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
# vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3)
# vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3)
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
title = text_rec(title_img)
_, title, _ = text_rec(title_img)
outputs[0].table_title = '\n'.join(title)
else:
# 自动给表格分配距离它最近的标题

@ -199,21 +199,19 @@ def parse_args(arg_list: Optional[List[str]] = None):
return args
try:
ocr_engine = importlib.import_module("rapidocr").RapidOCR()
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"Please install the rapidocr by pip install rapidocr"
) from exc
# try:
# ocr_engine = importlib.import_module("rapidocr").RapidOCR()
# except ModuleNotFoundError as exc:
# raise ModuleNotFoundError(
# "Please install the rapidocr by pip install rapidocr"
# ) from exc
input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value)
table_engine = RapidTable(input_args)
def table2md_pipeline(img):
rapid_ocr_output = ocr_engine(img)
ocr_result = list(
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
)
def table2md_pipeline(img, ocr_result):
# rapid_ocr_output = ocr_engine(img)
# ocr_result = list(zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores))
table_results = table_engine(img, ocr_result)
html_content = table_results.pred_html
md_content = md(html_content)

@ -1,18 +0,0 @@
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md
import re
# proc
## Create Dataset Instance
input_file = "/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/207.jpg"
ds = read_local_images(input_file)[0]
x = ds.apply(doc_analyze, ocr=True)
x = x.pipe_ocr_mode(None)
html = x.get_markdown(None)
content = md(html)
content = re.sub(r'\\([#*_`])', r'\1', content)
print(content)

@ -2,7 +2,6 @@ import os
import tempfile
import cv2
import numpy as np
from paddleocr import PaddleOCR
from marker.converters.table import TableConverter
from marker.models import create_model_dict
from marker.output import text_from_rendered
@ -11,6 +10,7 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md
import re
from ..image_helper import text_rec
def scanning_document_classify(image):
@ -66,46 +66,27 @@ def markdown_rec(image):
return html2md(html)
ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory
def text_rec(image):
result = ocr.ocr(image, cls=True)
output = []
for idx in range(len(result)):
res = result[idx]
if not res:
continue
for line in res:
if not line:
continue
output.append(line[1][0])
return output
def table_rec(image):
return table2md_pipeline(image)
boxes, texts, conficences = text_rec(image)
ocr_result = list(zip(boxes, texts, conficences))
return table2md_pipeline(image, ocr_result)
table_converter = TableConverter(artifact_dict=create_model_dict())
def scanning_document_rec(image):
# TODO 内部的ocr可以替换为paddleocr以提升文字识别精度
image_path = f'{tempfile.mktemp()}.jpg'
cv2.imwrite(image_path, image)
tmp_image_path = f'{tempfile.mktemp()}.jpg'
try:
no_watermark_image = remove_watermark(cv2.imread(image_path))
prefix, suffix = image_path.split('.')
new_image_path = f'{prefix}_remove_watermark.{suffix}'
cv2.imwrite(new_image_path, no_watermark_image)
unwatermarked_image = remove_watermark(image)
cv2.imwrite(tmp_image_path, unwatermarked_image)
rendered = table_converter(new_image_path)
rendered = table_converter(tmp_image_path)
text, _, _ = text_from_rendered(rendered)
finally:
os.remove(image_path)
return text, no_watermark_image
os.remove(tmp_image_path)
return text, unwatermarked_image
def compute_box_distance(box1, box2):

@ -22,7 +22,7 @@ def create_connection():
return conn
except OperationalError as e:
logger.error(f"连接数据库失败: {e}")
return None
raise e
# 插入数据的函数

@ -4,6 +4,7 @@ import os
import paddleclas
import cv2
from .page_detection.utils import PageDetectionResult
from paddleocr import PaddleOCR
paddle_clas_model = paddleclas.PaddleClas(model_name="text_image_orientation")
@ -45,3 +46,26 @@ def page_detection_visual(page_detection_result: PageDetectionResult):
img = cv2.rectangle(img, (int(pos[0]), int(pos[1])), (int(pos[2]), int(pos[3])), color, 2)
cv2.putText(img, text, (int(pos[0]), int(pos[1])), cv2.FONT_HERSHEY_TRIPLEX, 1, color, 2)
return img
ocr = PaddleOCR(use_angle_cls=False, lang='ch', use_gpu=True, show_log=False)
def text_rec(image):
result = ocr.ocr(image, cls=False)
boxes = []
texts = []
conficences = []
for idx in range(len(result)):
res = result[idx]
if not res:
continue
for line in res:
if not line:
continue
box = line[0]
text = line[1][0]
confidence = line[1][1]
boxes.append(box)
texts.append(text)
conficences.append(confidence)
return boxes, texts, conficences

@ -4,7 +4,6 @@ from utils import non_max_suppression, merge_text_and_title_boxes, LayoutBox, Pa
from tqdm import tqdm
"""
0 - Text
1 - Title

@ -31,7 +31,7 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path)
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
from picodet_postprocess import PicoDetPostProcess
from .picodet_postprocess import PicoDetPostProcess
from clrnet_postprocess import CLRNetPostProcess
from visualize import visualize_box_mask, imshow_lanes
from utils import argsparser, Timer, multiclass_nms, coco_clsid2catid
@ -254,7 +254,6 @@ class Detector(object):
self.pred_config.labels,
output_dir=self.output_dir,
threshold=self.threshold)
# TODO 在这里处理batch
results.append(result)
results = self.merge_batch_result(results)
boxes = results['boxes']

@ -0,0 +1 @@
Subproject commit 9b31f5b9cb6271cf6e1f8d4cf04e8d2f29b804f4

@ -40,7 +40,7 @@ def _pdf2markdown_pipeline(pdf_path, tmp_dir):
filepaths.sort(key=lambda x: int(x.split('/')[-1].split('.')[0]))
filepaths = [f'{tmp_dir}/{_}' for _ in filepaths]
# filepaths = filepaths[:75]
# filepaths = filepaths[250:251]
# 3. 版面分析
t5 = time.time()
@ -91,4 +91,4 @@ def pdf2markdown_pipeline(pdf_path: str):
if __name__ == '__main__':
pdf2markdown_pipeline('/mnt/pdf2markdown/demo.pdf')
pdf2markdown_pipeline('/mnt/pdf2markdown/龙源电力2023年年度审计报告.PDF')

@ -73,7 +73,7 @@ magic-pdf==1.3.10
mammoth==1.9.0
markdown2==2.5.3
markdownify==0.13.1
marker-pdf==1.6.2
-e marker
MarkupSafe==3.0.2
matplotlib==3.10.1
modelscope==1.25.0
@ -139,7 +139,6 @@ pytz==2025.2
PyYAML==6.0.2
rapid-table==1.0.5
RapidFuzz==3.13.0
rapidocr==2.0.7
rarfile==4.2
regex==2024.11.6
requests==2.32.3

Loading…
Cancel
Save