You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

157 lines
6.0 KiB
Python

import argparse
import click
from PIL import ImageDraw
import collections
import json
from surya.debug.draw import draw_bboxes_on_image
from tabulate import tabulate
from surya.input.processing import convert_if_not_rgb
from surya.table_rec import TableRecPredictor
from surya.settings import settings
from benchmark.utils.metrics import penalized_iou_score
from benchmark.utils.tatr import load_tatr, batch_inference_tatr
import os
import time
import datasets
@click.command(help="Benchmark table rec dataset")
@click.option("--results_dir", type=str, help="Path to JSON file with benchmark results.", default=os.path.join(settings.RESULT_DIR, "benchmark"))
@click.option("--max_rows", type=int, help="Maximum number of images to run benchmark on.", default=512)
@click.option("--tatr", is_flag=True, help="Run table transformer.", default=False)
@click.option("--debug", is_flag=True, help="Enable debug mode.", default=False)
def main(results_dir: str, max_rows: int, tatr: bool, debug: bool):
table_rec_predictor = TableRecPredictor()
pathname = "table_rec_bench"
# These have already been shuffled randomly, so sampling from the start is fine
split = "train"
if max_rows is not None:
split = f"train[:{max_rows}]"
dataset = datasets.load_dataset(settings.TABLE_REC_BENCH_DATASET_NAME, split=split)
images = list(dataset["image"])
images = convert_if_not_rgb(images)
if settings.TABLE_REC_STATIC_CACHE:
# Run through one batch to compile the model
table_rec_predictor(images[:1])
start = time.time()
table_rec_predictions = table_rec_predictor(images)
surya_time = time.time() - start
folder_name = os.path.basename(pathname).split(".")[0]
result_path = os.path.join(results_dir, folder_name)
os.makedirs(result_path, exist_ok=True)
page_metrics = collections.OrderedDict()
mean_col_iou = 0
mean_row_iou = 0
for idx, (pred, image) in enumerate(zip(table_rec_predictions, images)):
row = dataset[idx]
pred_row_boxes = [p.bbox for p in pred.rows]
pred_col_bboxes = [p.bbox for p in pred.cols]
actual_row_bboxes = [r["bbox"] for r in row["rows"]]
actual_col_bboxes = [c["bbox"] for c in row["columns"]]
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
page_results = {
"row_score": row_score,
"col_score": col_score,
"row_count": len(actual_row_bboxes),
"col_count": len(actual_col_bboxes)
}
mean_col_iou += col_score
mean_row_iou += row_score
page_metrics[idx] = page_results
if debug:
# Save debug images
draw_img = image.copy()
draw_bboxes_on_image(pred_row_boxes, draw_img, [f"Row {i}" for i in range(len(pred_row_boxes))])
draw_bboxes_on_image(pred_col_bboxes, draw_img, [f"Col {i}" for i in range(len(pred_col_bboxes))], color="blue")
draw_img.save(os.path.join(result_path, f"{idx}_bbox.png"))
actual_draw_image = image.copy()
draw_bboxes_on_image(actual_row_bboxes, actual_draw_image, [f"Row {i}" for i in range(len(actual_row_bboxes))])
draw_bboxes_on_image(actual_col_bboxes, actual_draw_image, [f"Col {i}" for i in range(len(actual_col_bboxes))], color="blue")
actual_draw_image.save(os.path.join(result_path, f"{idx}_actual.png"))
mean_col_iou /= len(table_rec_predictions)
mean_row_iou /= len(table_rec_predictions)
out_data = {"surya": {
"time": surya_time,
"mean_row_iou": mean_row_iou,
"mean_col_iou": mean_col_iou,
"page_metrics": page_metrics
}}
if tatr:
tatr_model = load_tatr()
start = time.time()
tatr_predictions = batch_inference_tatr(tatr_model, images, 1)
tatr_time = time.time() - start
page_metrics = collections.OrderedDict()
mean_col_iou = 0
mean_row_iou = 0
for idx, pred in enumerate(tatr_predictions):
row = dataset[idx]
pred_row_boxes = [p["bbox"] for p in pred["rows"]]
pred_col_bboxes = [p["bbox"] for p in pred["cols"]]
actual_row_bboxes = [r["bbox"] for r in row["rows"]]
actual_col_bboxes = [c["bbox"] for c in row["columns"]]
row_score = penalized_iou_score(pred_row_boxes, actual_row_bboxes)
col_score = penalized_iou_score(pred_col_bboxes, actual_col_bboxes)
page_results = {
"row_score": row_score,
"col_score": col_score,
"row_count": len(actual_row_bboxes),
"col_count": len(actual_col_bboxes)
}
mean_col_iou += col_score
mean_row_iou += row_score
page_metrics[idx] = page_results
mean_col_iou /= len(tatr_predictions)
mean_row_iou /= len(tatr_predictions)
out_data["tatr"] = {
"time": tatr_time,
"mean_row_iou": mean_row_iou,
"mean_col_iou": mean_col_iou,
"page_metrics": page_metrics
}
with open(os.path.join(result_path, "results.json"), "w+", encoding="utf-8") as f:
json.dump(out_data, f, indent=4)
table = [
["Model", "Row Intersection", "Col Intersection", "Time Per Image"],
["Surya", f"{out_data['surya']['mean_row_iou']:.2f}", f"{out_data['surya']['mean_col_iou']:.5f}",
f"{surya_time / len(images):.5f}"],
]
if tatr:
table.append(["Table transformer", f"{out_data['tatr']['mean_row_iou']:.2f}", f"{out_data['tatr']['mean_col_iou']:.5f}",
f"{tatr_time / len(images):.5f}"])
print(tabulate(table, headers="firstrow", tablefmt="github"))
print("Intersection is the average of the intersection % between each actual row/column, and the predictions. With penalties for too many/few predictions.")
print("Note that table transformers is unbatched, since the example code in the repo is unbatched.")
print(f"Wrote results to {result_path}")
if __name__ == "__main__":
main()