You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
69 lines
2.3 KiB
Python
69 lines
2.3 KiB
Python
4 weeks ago
|
# Copyright (c) Opendatalab. All rights reserved.
|
||
|
import os
|
||
|
|
||
|
from magic_pdf.data.data_reader_writer import FileBasedDataWriter, FileBasedDataReader
|
||
|
from magic_pdf.data.dataset import PymuDocDataset
|
||
|
from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
|
||
|
from magic_pdf.config.enums import SupportedPdfParseMethod
|
||
|
|
||
|
# args
|
||
|
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
||
|
pdf_file_name = os.path.join(__dir__, "pdfs", "demo1.pdf") # replace with the real pdf path
|
||
|
name_without_extension = os.path.basename(pdf_file_name).split('.')[0]
|
||
|
|
||
|
# prepare env
|
||
|
local_image_dir = os.path.join(__dir__, "output", name_without_extension, "images")
|
||
|
local_md_dir = os.path.join(__dir__, "output", name_without_extension)
|
||
|
image_dir = str(os.path.basename(local_image_dir))
|
||
|
os.makedirs(local_image_dir, exist_ok=True)
|
||
|
|
||
|
image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
|
||
|
|
||
|
# read bytes
|
||
|
reader1 = FileBasedDataReader("")
|
||
|
pdf_bytes = reader1.read(pdf_file_name) # read the pdf content
|
||
|
|
||
|
# proc
|
||
|
## Create Dataset Instance
|
||
|
ds = PymuDocDataset(pdf_bytes)
|
||
|
|
||
|
## inference
|
||
|
if ds.classify() == SupportedPdfParseMethod.OCR:
|
||
|
infer_result = ds.apply(doc_analyze, ocr=True)
|
||
|
|
||
|
## pipeline
|
||
|
pipe_result = infer_result.pipe_ocr_mode(image_writer)
|
||
|
|
||
|
else:
|
||
|
infer_result = ds.apply(doc_analyze, ocr=False)
|
||
|
|
||
|
## pipeline
|
||
|
pipe_result = infer_result.pipe_txt_mode(image_writer)
|
||
|
|
||
|
### get model inference result
|
||
|
model_inference_result = infer_result.get_infer_res()
|
||
|
|
||
|
### draw layout result on each page
|
||
|
pipe_result.draw_layout(os.path.join(local_md_dir, f"{name_without_extension}_layout.pdf"))
|
||
|
|
||
|
### draw spans result on each page
|
||
|
pipe_result.draw_span(os.path.join(local_md_dir, f"{name_without_extension}_spans.pdf"))
|
||
|
|
||
|
### get markdown content
|
||
|
md_content = pipe_result.get_markdown(image_dir)
|
||
|
|
||
|
### dump markdown
|
||
|
pipe_result.dump_md(md_writer, f"{name_without_extension}.md", image_dir)
|
||
|
|
||
|
### get content list content
|
||
|
content_list_content = pipe_result.get_content_list(image_dir)
|
||
|
|
||
|
### dump content list
|
||
|
pipe_result.dump_content_list(md_writer, f"{name_without_extension}_content_list.json", image_dir)
|
||
|
|
||
|
### get middle json
|
||
|
middle_json_content = pipe_result.get_middle_json()
|
||
|
|
||
|
### dump middle json
|
||
|
pipe_result.dump_middle_json(md_writer, f'{name_without_extension}_middle.json')
|