You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
155 lines
5.2 KiB
Python
155 lines
5.2 KiB
Python
import copy
|
|
import json
|
|
import os
|
|
from typing import Callable
|
|
|
|
from magic_pdf.config.constants import PARSE_TYPE_OCR, PARSE_TYPE_TXT
|
|
from magic_pdf.config.enums import SupportedPdfParseMethod
|
|
from magic_pdf.data.data_reader_writer import DataWriter
|
|
from magic_pdf.data.dataset import Dataset
|
|
from magic_pdf.libs.draw_bbox import draw_model_bbox
|
|
from magic_pdf.libs.version import __version__
|
|
from magic_pdf.operators.pipes import PipeResult
|
|
from magic_pdf.pdf_parse_union_core_v2 import pdf_parse_union
|
|
from magic_pdf.operators import InferenceResultBase
|
|
|
|
class InferenceResult(InferenceResultBase):
|
|
def __init__(self, inference_results: list, dataset: Dataset):
|
|
"""Initialized method.
|
|
|
|
Args:
|
|
inference_results (list): the inference result generated by model
|
|
dataset (Dataset): the dataset related with model inference result
|
|
"""
|
|
self._infer_res = inference_results
|
|
self._dataset = dataset
|
|
|
|
def draw_model(self, file_path: str) -> None:
|
|
"""Draw model inference result.
|
|
|
|
Args:
|
|
file_path (str): the output file path
|
|
"""
|
|
dir_name = os.path.dirname(file_path)
|
|
base_name = os.path.basename(file_path)
|
|
if not os.path.exists(dir_name):
|
|
os.makedirs(dir_name, exist_ok=True)
|
|
draw_model_bbox(
|
|
copy.deepcopy(self._infer_res), self._dataset, dir_name, base_name
|
|
)
|
|
|
|
def dump_model(self, writer: DataWriter, file_path: str):
|
|
"""Dump model inference result to file.
|
|
|
|
Args:
|
|
writer (DataWriter): writer handle
|
|
file_path (str): the location of target file
|
|
"""
|
|
writer.write_string(
|
|
file_path, json.dumps(self._infer_res, ensure_ascii=False, indent=4)
|
|
)
|
|
|
|
def get_infer_res(self):
|
|
"""Get the inference result.
|
|
|
|
Returns:
|
|
list: the inference result generated by model
|
|
"""
|
|
return self._infer_res
|
|
|
|
def apply(self, proc: Callable, *args, **kwargs):
|
|
"""Apply callable method which.
|
|
|
|
Args:
|
|
proc (Callable): invoke proc as follows:
|
|
proc(inference_result, *args, **kwargs)
|
|
|
|
Returns:
|
|
Any: return the result generated by proc
|
|
"""
|
|
return proc(copy.deepcopy(self._infer_res), *args, **kwargs)
|
|
|
|
def pipe_txt_mode(
|
|
self,
|
|
imageWriter: DataWriter,
|
|
start_page_id=0,
|
|
end_page_id=None,
|
|
debug_mode=False,
|
|
lang=None,
|
|
) -> PipeResult:
|
|
"""Post-proc the model inference result, Extract the text using the
|
|
third library, such as `pymupdf`
|
|
|
|
Args:
|
|
imageWriter (DataWriter): the image writer handle
|
|
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
|
|
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
|
|
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
|
|
lang (str, optional): Defaults to None.
|
|
|
|
Returns:
|
|
PipeResult: the result
|
|
"""
|
|
|
|
def proc(*args, **kwargs) -> PipeResult:
|
|
res = pdf_parse_union(*args, **kwargs)
|
|
res['_parse_type'] = PARSE_TYPE_TXT
|
|
res['_version_name'] = __version__
|
|
if 'lang' in kwargs and kwargs['lang'] is not None:
|
|
res['lang'] = kwargs['lang']
|
|
return PipeResult(res, self._dataset)
|
|
|
|
res = self.apply(
|
|
proc,
|
|
self._dataset,
|
|
imageWriter,
|
|
SupportedPdfParseMethod.TXT,
|
|
start_page_id=start_page_id,
|
|
end_page_id=end_page_id,
|
|
debug_mode=debug_mode,
|
|
lang=lang,
|
|
)
|
|
return res
|
|
|
|
def pipe_ocr_mode(
|
|
self,
|
|
imageWriter: DataWriter,
|
|
start_page_id=0,
|
|
end_page_id=None,
|
|
debug_mode=False,
|
|
lang=None,
|
|
) -> PipeResult:
|
|
"""Post-proc the model inference result, Extract the text using `OCR`
|
|
technical.
|
|
|
|
Args:
|
|
imageWriter (DataWriter): the image writer handle
|
|
start_page_id (int, optional): Defaults to 0. Let user select some pages He/She want to process
|
|
end_page_id (int, optional): Defaults to the last page index of dataset. Let user select some pages He/She want to process
|
|
debug_mode (bool, optional): Defaults to False. will dump more log if enabled
|
|
lang (str, optional): Defaults to None.
|
|
|
|
Returns:
|
|
PipeResult: the result
|
|
"""
|
|
|
|
def proc(*args, **kwargs) -> PipeResult:
|
|
res = pdf_parse_union(*args, **kwargs)
|
|
res['_parse_type'] = PARSE_TYPE_OCR
|
|
res['_version_name'] = __version__
|
|
if 'lang' in kwargs and kwargs['lang'] is not None:
|
|
res['lang'] = kwargs['lang']
|
|
return PipeResult(res, self._dataset)
|
|
|
|
res = self.apply(
|
|
proc,
|
|
self._dataset,
|
|
imageWriter,
|
|
SupportedPdfParseMethod.OCR,
|
|
start_page_id=start_page_id,
|
|
end_page_id=end_page_id,
|
|
debug_mode=debug_mode,
|
|
lang=lang,
|
|
)
|
|
return res
|