From 0fc49db2f07a758eaa9a5fbc38e192d45ba06ada Mon Sep 17 00:00:00 2001 From: zhangzhichao Date: Thu, 8 May 2025 14:05:43 +0800 Subject: [PATCH] =?UTF-8?q?=E8=A1=A8=E6=A0=BC=E8=AF=86=E5=88=AB=E5=92=8C?= =?UTF-8?q?=E6=89=AB=E6=8F=8F=E4=BB=B6=E8=AF=86=E5=88=AB=E5=86=85=E9=83=A8?= =?UTF-8?q?=E7=9A=84ocr=E6=94=B9=E4=B8=BApaddleocr?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + helper/content_recognition/main.py | 22 ++++++----- .../rapid_table_pipeline/main.py | 20 +++++----- helper/content_recognition/test.py | 18 --------- helper/content_recognition/utils.py | 39 +++++-------------- helper/db_helper.py | 2 +- helper/image_helper.py | 24 ++++++++++++ helper/page_detection/main.py | 1 - helper/page_detection/pdf_detection.py | 3 +- marker | 1 + pipeline.py | 4 +- requirements.txt | 3 +- 12 files changed, 62 insertions(+), 76 deletions(-) delete mode 100644 helper/content_recognition/test.py create mode 160000 marker diff --git a/.gitignore b/.gitignore index fe192d2..213918e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ venv *.pdf +*.PDF .vscode visual_images/*.jpg __pycache__ \ No newline at end of file diff --git a/helper/content_recognition/main.py b/helper/content_recognition/main.py index b51f2c9..a538c59 100644 --- a/helper/content_recognition/main.py +++ b/helper/content_recognition/main.py @@ -1,7 +1,8 @@ from typing import List import cv2 -from .utils import scanning_document_classify, text_rec, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark +from .utils import scanning_document_classify, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark from tqdm import tqdm +from ..image_helper import text_rec class LayoutRecognitionResult(object): @@ -60,18 +61,18 @@ def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]: # 扫描件 is_scanning_document = True content, layout_img = scanning_document_rec(layout_img) - source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg')) + source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg')) elif layout.clsid == 4: # table if scanning_document_classify(layout_img): is_scanning_document = True content, layout_img = scanning_document_rec(layout_img) - source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg')) + source_page_unwatermarked_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg')) else: content = table_rec(layout_img) elif layout.clsid == 5: # table caption - ocr_results = text_rec(layout_img) + _, ocr_results, _ = text_rec(layout_img) content = '' for o in ocr_results: content += f'{o}\n' @@ -81,25 +82,26 @@ def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]: if not content: continue + content = content.replace('\\', '') result = LayoutRecognitionResult(layout.clsid, content, layout.pos) outputs.append(result) if is_scanning_document and len(outputs) == 1: # 扫描件额外提取标题 - h, w = source_page_no_watermark_img.shape[:2] + h, w = source_page_unwatermarked_img.shape[:2] if h > w: - title_img = source_page_no_watermark_img[:360, :w, ...] + title_img = source_page_unwatermarked_img[:360, :w, ...] # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img) - # vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3) + # vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3) # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis) else: - title_img = source_page_no_watermark_img[:410, :w, ...] + title_img = source_page_unwatermarked_img[:410, :w, ...] # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img) - # vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3) + # vis = cv2.rectangle(source_page_unwatermarked_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3) # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis) - title = text_rec(title_img) + _, title, _ = text_rec(title_img) outputs[0].table_title = '\n'.join(title) else: # 自动给表格分配距离它最近的标题 diff --git a/helper/content_recognition/rapid_table_pipeline/main.py b/helper/content_recognition/rapid_table_pipeline/main.py index 3dcfef4..81202ea 100644 --- a/helper/content_recognition/rapid_table_pipeline/main.py +++ b/helper/content_recognition/rapid_table_pipeline/main.py @@ -199,21 +199,19 @@ def parse_args(arg_list: Optional[List[str]] = None): return args -try: - ocr_engine = importlib.import_module("rapidocr").RapidOCR() -except ModuleNotFoundError as exc: - raise ModuleNotFoundError( - "Please install the rapidocr by pip install rapidocr" - ) from exc +# try: +# ocr_engine = importlib.import_module("rapidocr").RapidOCR() +# except ModuleNotFoundError as exc: +# raise ModuleNotFoundError( +# "Please install the rapidocr by pip install rapidocr" +# ) from exc input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value) table_engine = RapidTable(input_args) -def table2md_pipeline(img): - rapid_ocr_output = ocr_engine(img) - ocr_result = list( - zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores) - ) +def table2md_pipeline(img, ocr_result): + # rapid_ocr_output = ocr_engine(img) + # ocr_result = list(zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)) table_results = table_engine(img, ocr_result) html_content = table_results.pred_html md_content = md(html_content) diff --git a/helper/content_recognition/test.py b/helper/content_recognition/test.py deleted file mode 100644 index 1474bb8..0000000 --- a/helper/content_recognition/test.py +++ /dev/null @@ -1,18 +0,0 @@ -from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze -from magic_pdf.data.read_api import read_local_images -from markdownify import markdownify as md -import re - - -# proc -## Create Dataset Instance -input_file = "/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/207.jpg" - -ds = read_local_images(input_file)[0] - -x = ds.apply(doc_analyze, ocr=True) -x = x.pipe_ocr_mode(None) -html = x.get_markdown(None) -content = md(html) -content = re.sub(r'\\([#*_`])', r'\1', content) -print(content) diff --git a/helper/content_recognition/utils.py b/helper/content_recognition/utils.py index 7e9cb03..d168dcc 100644 --- a/helper/content_recognition/utils.py +++ b/helper/content_recognition/utils.py @@ -2,7 +2,6 @@ import os import tempfile import cv2 import numpy as np -from paddleocr import PaddleOCR from marker.converters.table import TableConverter from marker.models import create_model_dict from marker.output import text_from_rendered @@ -11,6 +10,7 @@ from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze from magic_pdf.data.read_api import read_local_images from markdownify import markdownify as md import re +from ..image_helper import text_rec def scanning_document_classify(image): @@ -66,46 +66,27 @@ def markdown_rec(image): return html2md(html) -ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory - - -def text_rec(image): - result = ocr.ocr(image, cls=True) - output = [] - for idx in range(len(result)): - res = result[idx] - if not res: - continue - for line in res: - if not line: - continue - output.append(line[1][0]) - return output - - def table_rec(image): - return table2md_pipeline(image) + boxes, texts, conficences = text_rec(image) + ocr_result = list(zip(boxes, texts, conficences)) + return table2md_pipeline(image, ocr_result) table_converter = TableConverter(artifact_dict=create_model_dict()) def scanning_document_rec(image): - # TODO 内部的ocr可以替换为paddleocr以提升文字识别精度 - image_path = f'{tempfile.mktemp()}.jpg' - cv2.imwrite(image_path, image) + tmp_image_path = f'{tempfile.mktemp()}.jpg' try: - no_watermark_image = remove_watermark(cv2.imread(image_path)) - prefix, suffix = image_path.split('.') - new_image_path = f'{prefix}_remove_watermark.{suffix}' - cv2.imwrite(new_image_path, no_watermark_image) + unwatermarked_image = remove_watermark(image) + cv2.imwrite(tmp_image_path, unwatermarked_image) - rendered = table_converter(new_image_path) + rendered = table_converter(tmp_image_path) text, _, _ = text_from_rendered(rendered) finally: - os.remove(image_path) - return text, no_watermark_image + os.remove(tmp_image_path) + return text, unwatermarked_image def compute_box_distance(box1, box2): diff --git a/helper/db_helper.py b/helper/db_helper.py index fe0242a..445ff74 100644 --- a/helper/db_helper.py +++ b/helper/db_helper.py @@ -22,7 +22,7 @@ def create_connection(): return conn except OperationalError as e: logger.error(f"连接数据库失败: {e}") - return None + raise e # 插入数据的函数 diff --git a/helper/image_helper.py b/helper/image_helper.py index ec4c429..9a2181e 100644 --- a/helper/image_helper.py +++ b/helper/image_helper.py @@ -4,6 +4,7 @@ import os import paddleclas import cv2 from .page_detection.utils import PageDetectionResult +from paddleocr import PaddleOCR paddle_clas_model = paddleclas.PaddleClas(model_name="text_image_orientation") @@ -45,3 +46,26 @@ def page_detection_visual(page_detection_result: PageDetectionResult): img = cv2.rectangle(img, (int(pos[0]), int(pos[1])), (int(pos[2]), int(pos[3])), color, 2) cv2.putText(img, text, (int(pos[0]), int(pos[1])), cv2.FONT_HERSHEY_TRIPLEX, 1, color, 2) return img + + +ocr = PaddleOCR(use_angle_cls=False, lang='ch', use_gpu=True, show_log=False) + +def text_rec(image): + result = ocr.ocr(image, cls=False) + boxes = [] + texts = [] + conficences = [] + for idx in range(len(result)): + res = result[idx] + if not res: + continue + for line in res: + if not line: + continue + box = line[0] + text = line[1][0] + confidence = line[1][1] + boxes.append(box) + texts.append(text) + conficences.append(confidence) + return boxes, texts, conficences diff --git a/helper/page_detection/main.py b/helper/page_detection/main.py index fd23166..eadbaeb 100644 --- a/helper/page_detection/main.py +++ b/helper/page_detection/main.py @@ -4,7 +4,6 @@ from utils import non_max_suppression, merge_text_and_title_boxes, LayoutBox, Pa from tqdm import tqdm - """ 0 - Text 1 - Title diff --git a/helper/page_detection/pdf_detection.py b/helper/page_detection/pdf_detection.py index 44b484a..d4a1a16 100644 --- a/helper/page_detection/pdf_detection.py +++ b/helper/page_detection/pdf_detection.py @@ -31,7 +31,7 @@ parent_path = os.path.abspath(os.path.join(__file__, *(['..']))) sys.path.insert(0, parent_path) from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize -from picodet_postprocess import PicoDetPostProcess +from .picodet_postprocess import PicoDetPostProcess from clrnet_postprocess import CLRNetPostProcess from visualize import visualize_box_mask, imshow_lanes from utils import argsparser, Timer, multiclass_nms, coco_clsid2catid @@ -254,7 +254,6 @@ class Detector(object): self.pred_config.labels, output_dir=self.output_dir, threshold=self.threshold) - # TODO 在这里处理batch results.append(result) results = self.merge_batch_result(results) boxes = results['boxes'] diff --git a/marker b/marker new file mode 160000 index 0000000..9b31f5b --- /dev/null +++ b/marker @@ -0,0 +1 @@ +Subproject commit 9b31f5b9cb6271cf6e1f8d4cf04e8d2f29b804f4 diff --git a/pipeline.py b/pipeline.py index ccf62e4..007c019 100644 --- a/pipeline.py +++ b/pipeline.py @@ -40,7 +40,7 @@ def _pdf2markdown_pipeline(pdf_path, tmp_dir): filepaths.sort(key=lambda x: int(x.split('/')[-1].split('.')[0])) filepaths = [f'{tmp_dir}/{_}' for _ in filepaths] - # filepaths = filepaths[:75] + # filepaths = filepaths[250:251] # 3. 版面分析 t5 = time.time() @@ -91,4 +91,4 @@ def pdf2markdown_pipeline(pdf_path: str): if __name__ == '__main__': - pdf2markdown_pipeline('/mnt/pdf2markdown/demo.pdf') + pdf2markdown_pipeline('/mnt/pdf2markdown/龙源电力:2023年年度审计报告.PDF') diff --git a/requirements.txt b/requirements.txt index db03014..51f7126 100644 --- a/requirements.txt +++ b/requirements.txt @@ -73,7 +73,7 @@ magic-pdf==1.3.10 mammoth==1.9.0 markdown2==2.5.3 markdownify==0.13.1 -marker-pdf==1.6.2 +-e marker MarkupSafe==3.0.2 matplotlib==3.10.1 modelscope==1.25.0 @@ -139,7 +139,6 @@ pytz==2025.2 PyYAML==6.0.2 rapid-table==1.0.5 RapidFuzz==3.13.0 -rapidocr==2.0.7 rarfile==4.2 regex==2024.11.6 requests==2.32.3