You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
2.9 KiB
Python
96 lines
2.9 KiB
Python
import argparse
|
|
import os.path
|
|
import random
|
|
import re
|
|
import time
|
|
from functools import partial
|
|
from pathlib import Path
|
|
from typing import List
|
|
|
|
import click
|
|
import datasets
|
|
from tabulate import tabulate
|
|
from bs4 import BeautifulSoup
|
|
|
|
from surya.settings import settings
|
|
from surya.texify import TexifyPredictor, TexifyResult
|
|
import json
|
|
import io
|
|
from rapidfuzz.distance import Levenshtein
|
|
|
|
def normalize_text(text):
|
|
soup = BeautifulSoup(text, "html.parser")
|
|
text = soup.get_text()
|
|
text = re.sub(r"\n", " ", text)
|
|
text = re.sub(r"\s+", " ", text)
|
|
return text.strip()
|
|
|
|
|
|
def score_text(predictions, references):
|
|
lev_dist = []
|
|
for p, r in zip(predictions, references):
|
|
p = normalize_text(p)
|
|
r = normalize_text(r)
|
|
lev_dist.append(Levenshtein.normalized_distance(p, r))
|
|
|
|
return sum(lev_dist) / len(lev_dist)
|
|
|
|
|
|
def inference_texify(source_data, predictor):
|
|
texify_predictions: List[TexifyResult] = predictor([sd["image"] for sd in source_data])
|
|
out_data = [
|
|
{"text": texify_predictions[i].text, "equation": source_data[i]["equation"]}
|
|
for i in range(len(texify_predictions))
|
|
]
|
|
|
|
return out_data
|
|
|
|
|
|
def image_to_bmp(image):
|
|
img_out = io.BytesIO()
|
|
image.save(img_out, format="BMP")
|
|
return img_out
|
|
|
|
@click.command(help="Benchmark the performance of texify.")
|
|
@click.option("--ds_name", type=str, help="Path to dataset file with source images/equations.", default=settings.TEXIFY_BENCHMARK_DATASET)
|
|
@click.option("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
|
|
@click.option("--max_rows", type=int, help="Maximum number of images to benchmark.", default=None)
|
|
def main(ds_name: str, results_dir: str, max_rows: int):
|
|
predictor = TexifyPredictor()
|
|
ds = datasets.load_dataset(ds_name, split="train")
|
|
|
|
if max_rows:
|
|
ds = ds.filter(lambda x, idx: idx < max_rows, with_indices=True)
|
|
|
|
start = time.time()
|
|
predictions = inference_texify(ds, predictor)
|
|
time_taken = time.time() - start
|
|
|
|
text = [p["text"] for p in predictions]
|
|
references = [p["equation"] for p in predictions]
|
|
scores = score_text(text, references)
|
|
|
|
write_data = {
|
|
"scores": scores,
|
|
"text": [{"prediction": p, "reference": r} for p, r in zip(text, references)]
|
|
}
|
|
|
|
score_table = [
|
|
["texify", write_data["scores"], time_taken]
|
|
]
|
|
score_headers = ["edit", "time taken (s)"]
|
|
score_dirs = ["⬇", "⬇"]
|
|
|
|
score_headers = [f"{h} {d}" for h, d in zip(score_headers, score_dirs)]
|
|
table = tabulate(score_table, headers=["Method", *score_headers])
|
|
print()
|
|
print(table)
|
|
|
|
result_path = Path(results_dir) / "texify_bench"
|
|
result_path.mkdir(parents=True, exist_ok=True)
|
|
with open(result_path / "results.json", "w", encoding="utf-8") as f:
|
|
json.dump(write_data, f, indent=4)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |