You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
179 lines
7.2 KiB
Python
179 lines
7.2 KiB
Python
import json
|
|
import os
|
|
import traceback
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import click
|
|
import datasets
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from benchmarks.overall.display.dataset import build_dataset
|
|
from benchmarks.overall.registry import SCORE_REGISTRY, METHOD_REGISTRY
|
|
from benchmarks.overall.schema import FullResult
|
|
from marker.logger import configure_logging
|
|
from marker.models import create_model_dict
|
|
from marker.settings import settings
|
|
from benchmarks.overall.display.table import print_scores
|
|
|
|
configure_logging()
|
|
|
|
|
|
def get_method_scores(benchmark_dataset: datasets.Dataset, methods: List[str], score_types: List[str], artifacts: dict, max_rows=None) -> FullResult:
|
|
bench_scores = {}
|
|
averages_by_type = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
|
averages_by_block_type = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
|
|
average_times = defaultdict(list)
|
|
markdown_by_method = defaultdict(dict)
|
|
total_rows = len(benchmark_dataset)
|
|
if max_rows:
|
|
total_rows = min(max_rows, total_rows)
|
|
for idx, sample in tqdm(enumerate(benchmark_dataset), desc="Running benchmark", total=total_rows):
|
|
if max_rows is not None and idx >= max_rows:
|
|
break
|
|
|
|
doc_type = sample["classification"]
|
|
gt_cls = METHOD_REGISTRY["gt"]
|
|
gt_blocks = json.loads(sample["gt_blocks"])
|
|
gt_md = gt_cls(**artifacts)(sample)["markdown"]
|
|
markdown_by_method[idx]["gt"] = gt_md
|
|
|
|
out_data = defaultdict(dict)
|
|
|
|
try:
|
|
for method in methods:
|
|
method_cls = METHOD_REGISTRY[method](**artifacts)
|
|
method_info = method_cls(sample)
|
|
method_md = method_info["markdown"]
|
|
if method_md is None:
|
|
method_md = "" # Avoid None values
|
|
|
|
average_times[method].append(method_info["time"])
|
|
markdown_by_method[idx][method] = method_md
|
|
|
|
for score_type in score_types:
|
|
score_cls = SCORE_REGISTRY[score_type]()
|
|
try:
|
|
scores = score_cls(sample, gt_md, method_md)
|
|
except Exception as e:
|
|
# Some scorers can fail, like the LLM one
|
|
print(f"Failed to score {method} with {score_type}: {e}")
|
|
continue
|
|
|
|
out_data[method][score_type] = scores
|
|
|
|
averages_by_type[method][score_type][doc_type].append(scores["score"])
|
|
|
|
if "by_block" in scores["specific_scores"]: # Not all scorers support this
|
|
for score, gt_block in zip(scores["specific_scores"]["by_block"], gt_blocks):
|
|
averages_by_block_type[method][score_type][gt_block["block_type"]].append(score)
|
|
except Exception as e:
|
|
print(f"Failed to process {idx}: {e}")
|
|
traceback.print_exc()
|
|
if idx in markdown_by_method:
|
|
del markdown_by_method[idx]
|
|
continue
|
|
|
|
bench_scores[idx] = out_data
|
|
|
|
return {
|
|
"scores": bench_scores,
|
|
"markdown": markdown_by_method,
|
|
"averages_by_type": averages_by_type,
|
|
"averages_by_block_type": averages_by_block_type,
|
|
"average_times": average_times,
|
|
}
|
|
|
|
@click.command(help="Benchmark PDF to MD conversion.")
|
|
@click.option("--dataset", type=str, help="Path to the benchmark dataset", default="datalab-to/marker_benchmark")
|
|
@click.option("--out_dataset", type=str, help="Path to the output dataset", default=None)
|
|
@click.option("--methods", type=str, help="Comma separated list of other methods to compare against. Possible values: marker,mathpix,llamaparse,docling,mistral", default="marker")
|
|
@click.option("--scores", type=str, help="Comma separated list of scoring functions to use. Possible values: heuristic,llm", default="heuristic")
|
|
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "overall"), help="Output path for results.")
|
|
@click.option("--max_rows", type=int, default=None, help="Maximum number of rows to process.")
|
|
@click.option("--use_llm", is_flag=True, help="Use the LLM model for better marker quality.")
|
|
@click.option("--languages", type=str, help="Comma separated list of languages to use for LLM", default=None)
|
|
def main(
|
|
dataset: str,
|
|
out_dataset: str,
|
|
methods: str,
|
|
scores: str,
|
|
result_path: str,
|
|
max_rows: int,
|
|
use_llm: bool,
|
|
languages: str
|
|
):
|
|
out_path = Path(result_path)
|
|
out_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
methods = methods.split(",")
|
|
for method in methods:
|
|
if method not in METHOD_REGISTRY:
|
|
raise ValueError(f"Method {method} not allowed. Allowed methods are {METHOD_REGISTRY.keys()}")
|
|
|
|
# Ensure marker is always first
|
|
all_methods = list(set(methods))
|
|
methods = ["marker"] if "marker" in all_methods else []
|
|
methods += [m for m in all_methods if m != "marker"]
|
|
|
|
score_types = scores.split(",")
|
|
for score_type in score_types:
|
|
if score_type not in SCORE_REGISTRY:
|
|
raise ValueError(f"Score type {score_type} not allowed. Allowed types are {SCORE_REGISTRY.keys()}")
|
|
|
|
if languages:
|
|
languages = languages.split(",")
|
|
else:
|
|
languages = None
|
|
|
|
benchmark_dataset = datasets.load_dataset(dataset, split="train")
|
|
if languages:
|
|
benchmark_dataset = benchmark_dataset.filter(lambda x: x["language"] in languages)
|
|
|
|
artifacts = {
|
|
"model_dict": create_model_dict(),
|
|
"use_llm": use_llm,
|
|
"mathpix_ds": None,
|
|
"llamaparse_ds": None,
|
|
}
|
|
|
|
if "mathpix" in methods:
|
|
artifacts["mathpix_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_mathpix", split="train")
|
|
|
|
if "llamaparse" in methods:
|
|
artifacts["llamaparse_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_llamaparse", split="train")
|
|
|
|
if "mistral" in methods:
|
|
artifacts["mistral_ds"] = datasets.load_dataset("datalab-to/marker_benchmark_mistral", split="train")
|
|
|
|
if "olmocr" in methods:
|
|
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-0225-preview",
|
|
torch_dtype=torch.bfloat16).eval()
|
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
|
model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
|
|
artifacts["olmocr_model"] = {"model": model, "processor": processor}
|
|
|
|
print(f"Running benchmark with methods: {methods} and scores: {score_types}")
|
|
result = get_method_scores(benchmark_dataset, methods, score_types, artifacts, max_rows=max_rows)
|
|
|
|
# Display benchmark scoring tables
|
|
print_scores(result, out_path, methods, score_types, default_method=methods[0], default_score_type=score_types[0])
|
|
|
|
# Write to json
|
|
with open(out_path / "result.json", "w") as f:
|
|
json.dump(result, f)
|
|
|
|
if out_dataset:
|
|
if use_llm:
|
|
out_dataset += "_llm"
|
|
dataset = build_dataset(benchmark_dataset, result, score_types, max_rows=max_rows)
|
|
dataset.push_to_hub(out_dataset, private=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|