You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

97 lines
3.6 KiB
Python

import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS
from pathlib import Path
from itertools import repeat
from typing import List
import time
import datasets
from tqdm import tqdm
import click
from tabulate import tabulate
import json
from concurrent.futures import ProcessPoolExecutor
from marker.settings import settings
from benchmarks.table.inference import inference_tables
from scoring import wrap_table_html, similarity_eval_html
def update_teds_score(result, prefix: str = "marker"):
prediction, ground_truth = result[f'{prefix}_table'], result['gt_table']
prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
score = similarity_eval_html(prediction, ground_truth)
result.update({f'{prefix}_score':score})
return result
@click.command(help="Benchmark Table to HTML Conversion")
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.")
@click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use")
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
@click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.")
def main(
result_path: str,
dataset: str,
max_rows: int,
max_workers: int,
use_llm: bool,
table_rec_batch_size: int | None,
use_gemini: bool = False
):
start = time.time()
dataset = datasets.load_dataset(dataset, split='train')
dataset = dataset.shuffle(seed=0)
results, total_unaligned = inference_tables(dataset, use_llm, table_rec_batch_size, max_rows, use_gemini)
print(f"Total time: {time.time() - start}.")
print(f"Could not align {total_unaligned} tables from fintabnet.")
with ProcessPoolExecutor(max_workers=max_workers) as executor:
marker_results = list(
tqdm(
executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
)
)
avg_score = sum([r["marker_score"] for r in marker_results]) / len(marker_results)
headers = ["Avg score", "Total tables"]
data = [f"{avg_score:.3f}", len(marker_results)]
gemini_results = None
if use_gemini:
with ProcessPoolExecutor(max_workers=max_workers) as executor:
gemini_results = list(
tqdm(
executor.map(update_teds_score, results, repeat("gemini")), desc='Computing Gemini scores',
total=len(results)
)
)
avg_gemini_score = sum([r["gemini_score"] for r in gemini_results]) / len(gemini_results)
headers.append("Avg Gemini score")
data.append(f"{avg_gemini_score:.3f}")
table = tabulate([data], headers=headers, tablefmt="github")
print(table)
print("Avg score computed by comparing marker predicted HTML with original HTML")
results = {
"marker": marker_results,
"gemini": gemini_results
}
out_path = Path(result_path)
out_path.mkdir(parents=True, exist_ok=True)
with open(out_path / "table.json", "w+") as f:
json.dump(results, f, indent=2)
print(f"Results saved to {out_path}.")
if __name__ == '__main__':
main()