|
|
# flake8: noqa
|
|
|
import os
|
|
|
import time
|
|
|
|
|
|
import cv2
|
|
|
import torch
|
|
|
import yaml
|
|
|
from loguru import logger
|
|
|
|
|
|
os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
|
|
|
|
from magic_pdf.config.constants import *
|
|
|
from magic_pdf.model.model_list import AtomicModel
|
|
|
from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
|
|
from magic_pdf.model.sub_modules.model_utils import (
|
|
|
clean_vram, crop_img, get_res_list_from_layout_res)
|
|
|
from magic_pdf.model.sub_modules.ocr.paddleocr2pytorch.ocr_utils import (
|
|
|
get_adjusted_mfdetrec_res, get_ocr_result_list)
|
|
|
|
|
|
|
|
|
class CustomPEKModel:
|
|
|
|
|
|
def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
|
|
|
"""
|
|
|
======== model init ========
|
|
|
"""
|
|
|
# 获取当前文件(即 pdf_extract_kit.py)的绝对路径
|
|
|
current_file_path = os.path.abspath(__file__)
|
|
|
# 获取当前文件所在的目录(model)
|
|
|
current_dir = os.path.dirname(current_file_path)
|
|
|
# 上一级目录(magic_pdf)
|
|
|
root_dir = os.path.dirname(current_dir)
|
|
|
# model_config目录
|
|
|
model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
|
|
|
# 构建 model_configs.yaml 文件的完整路径
|
|
|
config_path = os.path.join(model_config_dir, 'model_configs.yaml')
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
|
self.configs = yaml.load(f, Loader=yaml.FullLoader)
|
|
|
# 初始化解析配置
|
|
|
|
|
|
# layout config
|
|
|
self.layout_config = kwargs.get('layout_config')
|
|
|
self.layout_model_name = self.layout_config.get(
|
|
|
'model', MODEL_NAME.DocLayout_YOLO
|
|
|
)
|
|
|
|
|
|
# formula config
|
|
|
self.formula_config = kwargs.get('formula_config')
|
|
|
self.mfd_model_name = self.formula_config.get(
|
|
|
'mfd_model', MODEL_NAME.YOLO_V8_MFD
|
|
|
)
|
|
|
self.mfr_model_name = self.formula_config.get(
|
|
|
'mfr_model', MODEL_NAME.UniMerNet_v2_Small
|
|
|
)
|
|
|
self.apply_formula = self.formula_config.get('enable', True)
|
|
|
|
|
|
# table config
|
|
|
self.table_config = kwargs.get('table_config')
|
|
|
self.apply_table = self.table_config.get('enable', False)
|
|
|
self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
|
|
|
self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
|
|
|
self.table_sub_model_name = self.table_config.get('sub_model', None)
|
|
|
|
|
|
# ocr config
|
|
|
self.apply_ocr = ocr
|
|
|
self.lang = kwargs.get('lang', None)
|
|
|
|
|
|
logger.info(
|
|
|
'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
|
|
|
'apply_table: {}, table_model: {}, lang: {}'.format(
|
|
|
self.layout_model_name,
|
|
|
self.apply_formula,
|
|
|
self.apply_ocr,
|
|
|
self.apply_table,
|
|
|
self.table_model_name,
|
|
|
self.lang,
|
|
|
)
|
|
|
)
|
|
|
# 初始化解析方案
|
|
|
self.device = kwargs.get('device', 'cpu')
|
|
|
|
|
|
logger.info('using device: {}'.format(self.device))
|
|
|
models_dir = kwargs.get(
|
|
|
'models_dir', os.path.join(root_dir, 'resources', 'models')
|
|
|
)
|
|
|
logger.info('using models_dir: {}'.format(models_dir))
|
|
|
|
|
|
atom_model_manager = AtomModelSingleton()
|
|
|
|
|
|
# 初始化公式识别
|
|
|
if self.apply_formula:
|
|
|
# 初始化公式检测模型
|
|
|
self.mfd_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.MFD,
|
|
|
mfd_weights=str(
|
|
|
os.path.join(
|
|
|
models_dir, self.configs['weights'][self.mfd_model_name]
|
|
|
)
|
|
|
),
|
|
|
device=self.device,
|
|
|
)
|
|
|
|
|
|
# 初始化公式解析模型
|
|
|
mfr_weight_dir = str(
|
|
|
os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
|
|
|
)
|
|
|
mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
|
|
|
|
|
|
self.mfr_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.MFR,
|
|
|
mfr_weight_dir=mfr_weight_dir,
|
|
|
mfr_cfg_path=mfr_cfg_path,
|
|
|
device=self.device,
|
|
|
)
|
|
|
|
|
|
# 初始化layout模型
|
|
|
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
|
|
self.layout_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.Layout,
|
|
|
layout_model_name=MODEL_NAME.LAYOUTLMv3,
|
|
|
layout_weights=str(
|
|
|
os.path.join(
|
|
|
models_dir, self.configs['weights'][self.layout_model_name]
|
|
|
)
|
|
|
),
|
|
|
layout_config_file=str(
|
|
|
os.path.join(
|
|
|
model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
|
|
|
)
|
|
|
),
|
|
|
device='cpu' if str(self.device).startswith("mps") else self.device,
|
|
|
)
|
|
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
|
|
self.layout_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.Layout,
|
|
|
layout_model_name=MODEL_NAME.DocLayout_YOLO,
|
|
|
doclayout_yolo_weights=str(
|
|
|
os.path.join(
|
|
|
models_dir, self.configs['weights'][self.layout_model_name]
|
|
|
)
|
|
|
),
|
|
|
device=self.device,
|
|
|
)
|
|
|
# 初始化ocr
|
|
|
self.ocr_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.OCR,
|
|
|
ocr_show_log=show_log,
|
|
|
det_db_box_thresh=0.3,
|
|
|
lang=self.lang
|
|
|
)
|
|
|
# init table model
|
|
|
if self.apply_table:
|
|
|
table_model_dir = self.configs['weights'][self.table_model_name]
|
|
|
self.table_model = atom_model_manager.get_atom_model(
|
|
|
atom_model_name=AtomicModel.Table,
|
|
|
table_model_name=self.table_model_name,
|
|
|
table_model_path=str(os.path.join(models_dir, table_model_dir)),
|
|
|
table_max_time=self.table_max_time,
|
|
|
device=self.device,
|
|
|
ocr_engine=self.ocr_model,
|
|
|
table_sub_model_name=self.table_sub_model_name
|
|
|
)
|
|
|
|
|
|
logger.info('DocAnalysis init done!')
|
|
|
|
|
|
def __call__(self, image):
|
|
|
# layout检测
|
|
|
layout_start = time.time()
|
|
|
layout_res = []
|
|
|
if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
|
|
|
# layoutlmv3
|
|
|
layout_res = self.layout_model(image, ignore_catids=[])
|
|
|
elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
|
|
|
layout_res = self.layout_model.predict(image)
|
|
|
|
|
|
layout_cost = round(time.time() - layout_start, 2)
|
|
|
logger.info(f'layout detection time: {layout_cost}')
|
|
|
|
|
|
if self.apply_formula:
|
|
|
# 公式检测
|
|
|
mfd_start = time.time()
|
|
|
mfd_res = self.mfd_model.predict(image)
|
|
|
logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
|
|
|
|
|
|
# 公式识别
|
|
|
mfr_start = time.time()
|
|
|
formula_list = self.mfr_model.predict(mfd_res, image)
|
|
|
layout_res.extend(formula_list)
|
|
|
mfr_cost = round(time.time() - mfr_start, 2)
|
|
|
logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
|
|
|
|
|
|
# 清理显存
|
|
|
clean_vram(self.device, vram_threshold=6)
|
|
|
|
|
|
# 从layout_res中获取ocr区域、表格区域、公式区域
|
|
|
ocr_res_list, table_res_list, single_page_mfdetrec_res = (
|
|
|
get_res_list_from_layout_res(layout_res)
|
|
|
)
|
|
|
|
|
|
# ocr识别
|
|
|
ocr_start = time.time()
|
|
|
# Process each area that requires OCR processing
|
|
|
for res in ocr_res_list:
|
|
|
new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
|
|
|
adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
|
|
|
|
|
|
# OCR recognition
|
|
|
new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
|
|
|
|
|
|
if self.apply_ocr:
|
|
|
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
|
|
|
else:
|
|
|
ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
|
|
|
|
|
|
# Integration results
|
|
|
if ocr_res:
|
|
|
ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
|
|
|
layout_res.extend(ocr_result_list)
|
|
|
|
|
|
ocr_cost = round(time.time() - ocr_start, 2)
|
|
|
if self.apply_ocr:
|
|
|
logger.info(f"ocr time: {ocr_cost}")
|
|
|
else:
|
|
|
logger.info(f"det time: {ocr_cost}")
|
|
|
|
|
|
# 表格识别 table recognition
|
|
|
if self.apply_table:
|
|
|
table_start = time.time()
|
|
|
for res in table_res_list:
|
|
|
new_image, _ = crop_img(res, image)
|
|
|
single_table_start_time = time.time()
|
|
|
html_code = None
|
|
|
if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
|
|
|
with torch.no_grad():
|
|
|
table_result = self.table_model.predict(new_image, 'html')
|
|
|
if len(table_result) > 0:
|
|
|
html_code = table_result[0]
|
|
|
elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
|
|
|
html_code = self.table_model.img2html(new_image)
|
|
|
elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
|
|
|
html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
|
|
|
new_image
|
|
|
)
|
|
|
run_time = time.time() - single_table_start_time
|
|
|
if run_time > self.table_max_time:
|
|
|
logger.warning(
|
|
|
f'table recognition processing exceeds max time {self.table_max_time}s'
|
|
|
)
|
|
|
# 判断是否返回正常
|
|
|
if html_code:
|
|
|
expected_ending = html_code.strip().endswith(
|
|
|
'</html>'
|
|
|
) or html_code.strip().endswith('</table>')
|
|
|
if expected_ending:
|
|
|
res['html'] = html_code
|
|
|
else:
|
|
|
logger.warning(
|
|
|
'table recognition processing fails, not found expected HTML table end'
|
|
|
)
|
|
|
else:
|
|
|
logger.warning(
|
|
|
'table recognition processing fails, not get html return'
|
|
|
)
|
|
|
logger.info(f'table time: {round(time.time() - table_start, 2)}')
|
|
|
|
|
|
return layout_res
|