You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
import os
|
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" # Transformers uses .isin for an op, which is not supported on MPS
|
|
|
|
from pathlib import Path
|
|
from itertools import repeat
|
|
from typing import List
|
|
|
|
import time
|
|
import datasets
|
|
from tqdm import tqdm
|
|
import click
|
|
from tabulate import tabulate
|
|
import json
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
|
|
from marker.settings import settings
|
|
from benchmarks.table.inference import inference_tables
|
|
|
|
from scoring import wrap_table_html, similarity_eval_html
|
|
|
|
def update_teds_score(result, prefix: str = "marker"):
|
|
prediction, ground_truth = result[f'{prefix}_table'], result['gt_table']
|
|
prediction, ground_truth = wrap_table_html(prediction), wrap_table_html(ground_truth)
|
|
score = similarity_eval_html(prediction, ground_truth)
|
|
result.update({f'{prefix}_score':score})
|
|
return result
|
|
|
|
|
|
@click.command(help="Benchmark Table to HTML Conversion")
|
|
@click.option("--result_path", type=str, default=os.path.join(settings.OUTPUT_DIR, "benchmark", "table"), help="Output path for results.")
|
|
@click.option("--dataset", type=str, default="datalab-to/fintabnet_bench_marker", help="Dataset to use")
|
|
@click.option("--max_rows", type=int, default=None, help="Maximum number of PDFs to process")
|
|
@click.option("--max_workers", type=int, default=16, help="Maximum number of workers to use")
|
|
@click.option("--use_llm", is_flag=True, help="Use LLM for improving table recognition.")
|
|
@click.option("--table_rec_batch_size", type=int, default=None, help="Batch size for table recognition.")
|
|
@click.option("--use_gemini", is_flag=True, help="Evaluate Gemini for table recognition.")
|
|
def main(
|
|
result_path: str,
|
|
dataset: str,
|
|
max_rows: int,
|
|
max_workers: int,
|
|
use_llm: bool,
|
|
table_rec_batch_size: int | None,
|
|
use_gemini: bool = False
|
|
):
|
|
start = time.time()
|
|
|
|
|
|
dataset = datasets.load_dataset(dataset, split='train')
|
|
dataset = dataset.shuffle(seed=0)
|
|
|
|
results, total_unaligned = inference_tables(dataset, use_llm, table_rec_batch_size, max_rows, use_gemini)
|
|
|
|
print(f"Total time: {time.time() - start}.")
|
|
print(f"Could not align {total_unaligned} tables from fintabnet.")
|
|
|
|
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
marker_results = list(
|
|
tqdm(
|
|
executor.map(update_teds_score, results), desc='Computing alignment scores', total=len(results)
|
|
)
|
|
)
|
|
|
|
avg_score = sum([r["marker_score"] for r in marker_results]) / len(marker_results)
|
|
headers = ["Avg score", "Total tables"]
|
|
data = [f"{avg_score:.3f}", len(marker_results)]
|
|
gemini_results = None
|
|
if use_gemini:
|
|
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
|
gemini_results = list(
|
|
tqdm(
|
|
executor.map(update_teds_score, results, repeat("gemini")), desc='Computing Gemini scores',
|
|
total=len(results)
|
|
)
|
|
)
|
|
avg_gemini_score = sum([r["gemini_score"] for r in gemini_results]) / len(gemini_results)
|
|
headers.append("Avg Gemini score")
|
|
data.append(f"{avg_gemini_score:.3f}")
|
|
|
|
table = tabulate([data], headers=headers, tablefmt="github")
|
|
print(table)
|
|
print("Avg score computed by comparing marker predicted HTML with original HTML")
|
|
|
|
results = {
|
|
"marker": marker_results,
|
|
"gemini": gemini_results
|
|
}
|
|
|
|
out_path = Path(result_path)
|
|
out_path.mkdir(parents=True, exist_ok=True)
|
|
with open(out_path / "table.json", "w+") as f:
|
|
json.dump(results, f, indent=2)
|
|
|
|
print(f"Results saved to {out_path}.")
|
|
|
|
if __name__ == '__main__':
|
|
main() |