表格识别和扫描件识别内部的ocr改为paddleocr

pull/3/head
zhangzhichao 1 month ago
parent 6bf15b0c77
commit 0fc49db2f0

1
.gitignore vendored

@ -1,5 +1,6 @@
venv venv
*.pdf *.pdf
*.PDF
.vscode .vscode
visual_images/*.jpg visual_images/*.jpg
__pycache__ __pycache__

@ -1,7 +1,8 @@
from typing import List from typing import List
import cv2 import cv2
from .utils import scanning_document_classify, text_rec, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark from .utils import scanning_document_classify, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark
from tqdm import tqdm from tqdm import tqdm
from ..image_helper import text_rec
class LayoutRecognitionResult(object): class LayoutRecognitionResult(object):
@ -60,18 +61,18 @@ def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]:
# 扫描件 # 扫描件
is_scanning_document = True is_scanning_document = True
content, layout_img = scanning_document_rec(layout_img) content, layout_img = scanning_document_rec(layout_img)
source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg')) source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
elif layout.clsid == 4: elif layout.clsid == 4:
# table # table
if scanning_document_classify(layout_img): if scanning_document_classify(layout_img):
is_scanning_document = True is_scanning_document = True
content, layout_img = scanning_document_rec(layout_img) content, layout_img = scanning_document_rec(layout_img)
source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg')) source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
else: else:
content = table_rec(layout_img) content = table_rec(layout_img)
elif layout.clsid == 5: elif layout.clsid == 5:
# table caption # table caption
ocr_results = text_rec(layout_img) _, ocr_results, _ = text_rec(layout_img)
content = '' content = ''
for o in ocr_results: for o in ocr_results:
content += f'{o}\n' content += f'{o}\n'
@ -81,25 +82,26 @@ def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]:
if not content: if not content:
continue continue
content = content.replace('\\', '')
result = LayoutRecognitionResult(layout.clsid, content, layout.pos) result = LayoutRecognitionResult(layout.clsid, content, layout.pos)
outputs.append(result) outputs.append(result)
if is_scanning_document and len(outputs) == 1: if is_scanning_document and len(outputs) == 1:
# 扫描件额外提取标题 # 扫描件额外提取标题
h, w = source_page_no_watermark_img.shape[:2] h, w = source_page_unwatermarked_img.shape[:2]
if h > w: if h > w:
title_img = source_page_no_watermark_img[:360, :w, ...] title_img = source_page_unwatermarked_img[:360, :w, ...]
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img) # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
# vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3) # vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3)
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis) # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
else: else:
title_img = source_page_no_watermark_img[:410, :w, ...] title_img = source_page_unwatermarked_img[:410, :w, ...]
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img) # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
# vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3) # vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3)
# cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis) # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
title = text_rec(title_img) _, title, _ = text_rec(title_img)
outputs[0].table_title = '\n'.join(title) outputs[0].table_title = '\n'.join(title)
else: else:
# 自动给表格分配距离它最近的标题 # 自动给表格分配距离它最近的标题

@ -199,21 +199,19 @@ def parse_args(arg_list: Optional[List[str]] = None):
return args return args
try: # try:
ocr_engine = importlib.import_module("rapidocr").RapidOCR() # ocr_engine = importlib.import_module("rapidocr").RapidOCR()
except ModuleNotFoundError as exc: # except ModuleNotFoundError as exc:
raise ModuleNotFoundError( # raise ModuleNotFoundError(
"Please install the rapidocr by pip install rapidocr" # "Please install the rapidocr by pip install rapidocr"
) from exc # ) from exc
input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value) input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value)
table_engine = RapidTable(input_args) table_engine = RapidTable(input_args)
def table2md_pipeline(img): def table2md_pipeline(img, ocr_result):
rapid_ocr_output = ocr_engine(img) # rapid_ocr_output = ocr_engine(img)
ocr_result = list( # ocr_result = list(zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores))
zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
)
table_results = table_engine(img, ocr_result) table_results = table_engine(img, ocr_result)
html_content = table_results.pred_html html_content = table_results.pred_html
md_content = md(html_content) md_content = md(html_content)

@ -1,18 +0,0 @@
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md
import re
# proc
## Create Dataset Instance
input_file = "/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/207.jpg"
ds = read_local_images(input_file)[0]
x = ds.apply(doc_analyze, ocr=True)
x = x.pipe_ocr_mode(None)
html = x.get_markdown(None)
content = md(html)
content = re.sub(r'\\([#*_`])', r'\1', content)
print(content)

@ -2,7 +2,6 @@ import os
import tempfile import tempfile
import cv2 import cv2
import numpy as np import numpy as np
from paddleocr import PaddleOCR
from marker.converters.table import TableConverter from marker.converters.table import TableConverter
from marker.models import create_model_dict from marker.models import create_model_dict
from marker.output import text_from_rendered from marker.output import text_from_rendered
@ -11,6 +10,7 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
from magic_pdf.data.read_api import read_local_images from magic_pdf.data.read_api import read_local_images
from markdownify import markdownify as md from markdownify import markdownify as md
import re import re
from ..image_helper import text_rec
def scanning_document_classify(image): def scanning_document_classify(image):
@ -66,46 +66,27 @@ def markdown_rec(image):
return html2md(html) return html2md(html)
ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory
def text_rec(image):
result = ocr.ocr(image, cls=True)
output = []
for idx in range(len(result)):
res = result[idx]
if not res:
continue
for line in res:
if not line:
continue
output.append(line[1][0])
return output
def table_rec(image): def table_rec(image):
return table2md_pipeline(image) boxes, texts, conficences = text_rec(image)
ocr_result = list(zip(boxes, texts, conficences))
return table2md_pipeline(image, ocr_result)
table_converter = TableConverter(artifact_dict=create_model_dict()) table_converter = TableConverter(artifact_dict=create_model_dict())
def scanning_document_rec(image): def scanning_document_rec(image):
# TODO 内部的ocr可以替换为paddleocr以提升文字识别精度 tmp_image_path = f'{tempfile.mktemp()}.jpg'
image_path = f'{tempfile.mktemp()}.jpg'
cv2.imwrite(image_path, image)
try: try:
no_watermark_image = remove_watermark(cv2.imread(image_path)) unwatermarked_image = remove_watermark(image)
prefix, suffix = image_path.split('.') cv2.imwrite(tmp_image_path, unwatermarked_image)
new_image_path = f'{prefix}_remove_watermark.{suffix}'
cv2.imwrite(new_image_path, no_watermark_image)
rendered = table_converter(new_image_path) rendered = table_converter(tmp_image_path)
text, _, _ = text_from_rendered(rendered) text, _, _ = text_from_rendered(rendered)
finally: finally:
os.remove(image_path) os.remove(tmp_image_path)
return text, no_watermark_image return text, unwatermarked_image
def compute_box_distance(box1, box2): def compute_box_distance(box1, box2):

@ -22,7 +22,7 @@ def create_connection():
return conn return conn
except OperationalError as e: except OperationalError as e:
logger.error(f"连接数据库失败: {e}") logger.error(f"连接数据库失败: {e}")
return None raise e
# 插入数据的函数 # 插入数据的函数

@ -4,6 +4,7 @@ import os
import paddleclas import paddleclas
import cv2 import cv2
from .page_detection.utils import PageDetectionResult from .page_detection.utils import PageDetectionResult
from paddleocr import PaddleOCR
paddle_clas_model = paddleclas.PaddleClas(model_name="text_image_orientation") paddle_clas_model = paddleclas.PaddleClas(model_name="text_image_orientation")
@ -45,3 +46,26 @@ def page_detection_visual(page_detection_result: PageDetectionResult):
img = cv2.rectangle(img, (int(pos[0]), int(pos[1])), (int(pos[2]), int(pos[3])), color, 2) img = cv2.rectangle(img, (int(pos[0]), int(pos[1])), (int(pos[2]), int(pos[3])), color, 2)
cv2.putText(img, text, (int(pos[0]), int(pos[1])), cv2.FONT_HERSHEY_TRIPLEX, 1, color, 2) cv2.putText(img, text, (int(pos[0]), int(pos[1])), cv2.FONT_HERSHEY_TRIPLEX, 1, color, 2)
return img return img
ocr = PaddleOCR(use_angle_cls=False, lang='ch', use_gpu=True, show_log=False)
def text_rec(image):
result = ocr.ocr(image, cls=False)
boxes = []
texts = []
conficences = []
for idx in range(len(result)):
res = result[idx]
if not res:
continue
for line in res:
if not line:
continue
box = line[0]
text = line[1][0]
confidence = line[1][1]
boxes.append(box)
texts.append(text)
conficences.append(confidence)
return boxes, texts, conficences

@ -4,7 +4,6 @@ from utils import non_max_suppression, merge_text_and_title_boxes, LayoutBox, Pa
from tqdm import tqdm from tqdm import tqdm
""" """
0 - Text 0 - Text
1 - Title 1 - Title

@ -31,7 +31,7 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
sys.path.insert(0, parent_path) sys.path.insert(0, parent_path)
from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
from picodet_postprocess import PicoDetPostProcess from .picodet_postprocess import PicoDetPostProcess
from clrnet_postprocess import CLRNetPostProcess from clrnet_postprocess import CLRNetPostProcess
from visualize import visualize_box_mask, imshow_lanes from visualize import visualize_box_mask, imshow_lanes
from utils import argsparser, Timer, multiclass_nms, coco_clsid2catid from utils import argsparser, Timer, multiclass_nms, coco_clsid2catid
@ -254,7 +254,6 @@ class Detector(object):
self.pred_config.labels, self.pred_config.labels,
output_dir=self.output_dir, output_dir=self.output_dir,
threshold=self.threshold) threshold=self.threshold)
# TODO 在这里处理batch
results.append(result) results.append(result)
results = self.merge_batch_result(results) results = self.merge_batch_result(results)
boxes = results['boxes'] boxes = results['boxes']

@ -0,0 +1 @@
Subproject commit 9b31f5b9cb6271cf6e1f8d4cf04e8d2f29b804f4

@ -40,7 +40,7 @@ def _pdf2markdown_pipeline(pdf_path, tmp_dir):
filepaths.sort(key=lambda x: int(x.split('/')[-1].split('.')[0])) filepaths.sort(key=lambda x: int(x.split('/')[-1].split('.')[0]))
filepaths = [f'{tmp_dir}/{_}' for _ in filepaths] filepaths = [f'{tmp_dir}/{_}' for _ in filepaths]
# filepaths = filepaths[:75] # filepaths = filepaths[250:251]
# 3. 版面分析 # 3. 版面分析
t5 = time.time() t5 = time.time()
@ -91,4 +91,4 @@ def pdf2markdown_pipeline(pdf_path: str):
if __name__ == '__main__': if __name__ == '__main__':
pdf2markdown_pipeline('/mnt/pdf2markdown/demo.pdf') pdf2markdown_pipeline('/mnt/pdf2markdown/龙源电力2023年年度审计报告.PDF')

@ -73,7 +73,7 @@ magic-pdf==1.3.10
mammoth==1.9.0 mammoth==1.9.0
markdown2==2.5.3 markdown2==2.5.3
markdownify==0.13.1 markdownify==0.13.1
marker-pdf==1.6.2 -e marker
MarkupSafe==3.0.2 MarkupSafe==3.0.2
matplotlib==3.10.1 matplotlib==3.10.1
modelscope==1.25.0 modelscope==1.25.0
@ -139,7 +139,6 @@ pytz==2025.2
PyYAML==6.0.2 PyYAML==6.0.2
rapid-table==1.0.5 rapid-table==1.0.5
RapidFuzz==3.13.0 RapidFuzz==3.13.0
rapidocr==2.0.7
rarfile==4.2 rarfile==4.2
regex==2024.11.6 regex==2024.11.6
requests==2.32.3 requests==2.32.3

Loading…
Cancel
Save