You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
155 lines
4.7 KiB
Python
155 lines
4.7 KiB
Python
1 month ago
|
from marker.providers.pdf import PdfProvider
|
||
|
import tempfile
|
||
|
from typing import Dict, Type
|
||
|
|
||
|
from PIL import Image, ImageDraw
|
||
|
|
||
|
import datasets
|
||
|
import pytest
|
||
|
|
||
|
from marker.builders.document import DocumentBuilder
|
||
|
from marker.builders.layout import LayoutBuilder
|
||
|
from marker.builders.line import LineBuilder
|
||
|
from marker.builders.ocr import OcrBuilder
|
||
|
from marker.converters.pdf import PdfConverter
|
||
|
from marker.models import create_model_dict
|
||
|
from marker.providers.registry import provider_from_filepath
|
||
|
from marker.schema import BlockTypes
|
||
|
from marker.schema.blocks import Block
|
||
|
from marker.renderers.markdown import MarkdownRenderer
|
||
|
from marker.renderers.json import JSONRenderer
|
||
|
from marker.schema.registry import register_block_class
|
||
|
from marker.services.gemini import GoogleGeminiService
|
||
|
from marker.util import classes_to_strings, strings_to_classes
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def model_dict():
|
||
|
model_dict = create_model_dict()
|
||
|
yield model_dict
|
||
|
del model_dict
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def layout_model(model_dict):
|
||
|
yield model_dict["layout_model"]
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def detection_model(model_dict):
|
||
|
yield model_dict["detection_model"]
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def texify_model(model_dict):
|
||
|
yield model_dict["texify_model"]
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def recognition_model(model_dict):
|
||
|
yield model_dict["recognition_model"]
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def table_rec_model(model_dict):
|
||
|
yield model_dict["table_rec_model"]
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def ocr_error_model(model_dict):
|
||
|
yield model_dict["ocr_error_model"]
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def inline_detection_model(model_dict):
|
||
|
yield model_dict["inline_detection_model"]
|
||
|
|
||
|
@pytest.fixture(scope="function")
|
||
|
def config(request):
|
||
|
config_mark = request.node.get_closest_marker("config")
|
||
|
config = config_mark.args[0] if config_mark else {}
|
||
|
|
||
|
override_map: Dict[BlockTypes, Type[Block]] = config.get("override_map", {})
|
||
|
for block_type, override_block_type in override_map.items():
|
||
|
register_block_class(block_type, override_block_type)
|
||
|
|
||
|
return config
|
||
|
|
||
|
@pytest.fixture(scope="session")
|
||
|
def pdf_dataset():
|
||
|
return datasets.load_dataset("datalab-to/pdfs", split="train")
|
||
|
|
||
|
@pytest.fixture(scope="function")
|
||
|
def temp_doc(request, pdf_dataset):
|
||
|
filename_mark = request.node.get_closest_marker("filename")
|
||
|
filename = filename_mark.args[0] if filename_mark else "adversarial.pdf"
|
||
|
|
||
|
idx = pdf_dataset['filename'].index(filename)
|
||
|
suffix = filename.split(".")[-1]
|
||
|
|
||
|
temp_pdf = tempfile.NamedTemporaryFile(suffix=f".{suffix}")
|
||
|
temp_pdf.write(pdf_dataset['pdf'][idx])
|
||
|
temp_pdf.flush()
|
||
|
yield temp_pdf
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="function")
|
||
|
def doc_provider(request, config, temp_doc):
|
||
|
provider_cls = provider_from_filepath(temp_doc.name)
|
||
|
yield provider_cls(temp_doc.name, config)
|
||
|
|
||
|
@pytest.fixture(scope="function")
|
||
|
def pdf_document(request, config, doc_provider, layout_model, ocr_error_model, recognition_model, detection_model, inline_detection_model):
|
||
|
layout_builder = LayoutBuilder(layout_model, config)
|
||
|
line_builder = LineBuilder(detection_model, inline_detection_model, ocr_error_model, config)
|
||
|
ocr_builder = OcrBuilder(recognition_model, config)
|
||
|
builder = DocumentBuilder(config)
|
||
|
document = builder(doc_provider, layout_builder, line_builder, ocr_builder)
|
||
|
yield document
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="function")
|
||
|
def pdf_converter(request, config, model_dict, renderer, llm_service):
|
||
|
if llm_service:
|
||
|
llm_service = classes_to_strings([llm_service])[0]
|
||
|
yield PdfConverter(
|
||
|
artifact_dict=model_dict,
|
||
|
processor_list=None,
|
||
|
renderer=classes_to_strings([renderer])[0],
|
||
|
config=config,
|
||
|
llm_service=llm_service
|
||
|
)
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="function")
|
||
|
def renderer(request, config):
|
||
|
if request.node.get_closest_marker("output_format"):
|
||
|
output_format = request.node.get_closest_marker("output_format").args[0]
|
||
|
if output_format == "markdown":
|
||
|
return MarkdownRenderer
|
||
|
elif output_format == "json":
|
||
|
return JSONRenderer
|
||
|
else:
|
||
|
raise ValueError(f"Unknown output format: {output_format}")
|
||
|
else:
|
||
|
return MarkdownRenderer
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="function")
|
||
|
def llm_service(request, config):
|
||
|
llm_service = config.get("llm_service")
|
||
|
if not llm_service:
|
||
|
yield None
|
||
|
else:
|
||
|
yield strings_to_classes([llm_service])[0]
|
||
|
|
||
|
|
||
|
@pytest.fixture(scope="function")
|
||
|
def temp_image():
|
||
|
img = Image.new("RGB", (512, 512), color="white")
|
||
|
draw = ImageDraw.Draw(img)
|
||
|
draw.text((10, 10), "Hello, World!", fill="black", font_size=24)
|
||
|
with tempfile.NamedTemporaryFile(suffix=".png") as f:
|
||
|
img.save(f.name)
|
||
|
f.flush()
|
||
|
yield f
|