import argparse from collections import defaultdict import click from benchmark.utils.scoring import overlap_score from surya.input.processing import convert_if_not_rgb from surya.debug.text import draw_text_on_image from surya.recognition import RecognitionPredictor from surya.settings import settings from surya.recognition.languages import CODE_TO_LANGUAGE from benchmark.utils.tesseract import tesseract_ocr_parallel, surya_lang_to_tesseract, TESS_CODE_TO_LANGUAGE from benchmark.utils.textract import textract_ocr_parallel import os import datasets import json import time from tabulate import tabulate KEY_LANGUAGES = ["Chinese", "Spanish", "English", "Arabic", "Hindi", "Bengali", "Russian", "Japanese"] @click.command(help="Benchmark recognition model.") @click.option("--results_dir", type=str, help="Path to JSON file with OCR results.", default=os.path.join(settings.RESULT_DIR, "benchmark")) @click.option("--max_rows", type=int, help="Maximum number of pdf pages to OCR.", default=None) @click.option("--debug", is_flag=True, help="Enable debug mode.", default=False) @click.option("--tesseract", is_flag=True, help="Run benchmarks on tesseract.", default=False) @click.option("--textract", is_flag=True, help="Run benchmarks on textract.", default=False) @click.option("--langs", type=str, help="Specify certain languages to benchmark.", default=None) @click.option("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28) @click.option("--textract_cpus", type=int, help="Number of CPUs to use for textract.", default=28) @click.option("--specify_language", is_flag=True, help="Pass language codes into the model.", default=False) def main(results_dir: str, max_rows: int, debug: bool, tesseract: bool, textract: bool, langs: str, tess_cpus: int, textract_cpus:int, specify_language: bool): rec_predictor = RecognitionPredictor() split = "train" dataset = datasets.load_dataset(settings.RECOGNITION_BENCH_DATASET_NAME, split=split) if langs: langs = langs.split(",") dataset = dataset.filter(lambda x: x["language"] in langs, num_proc=4) if max_rows and max_rows= 1: bad_detections = [] for idx, (score, lang) in enumerate(zip(flat_surya_scores, lang_list)): if score < .8: bad_detections.append((idx, lang, score)) print(f"Found {len(bad_detections)} bad detections. Writing to file...") with open(os.path.join(result_path, "bad_detections.json"), "w+") as f: json.dump(bad_detections, f) if debug == 2: for idx, (image, pred, ref_text, bbox, lang) in enumerate(zip(images, predictions_by_image, line_text, bboxes, lang_list)): pred_image_name = f"{'_'.join(lang)}_{idx}_pred.png" ref_image_name = f"{'_'.join(lang)}_{idx}_ref.png" pred_text = [l.text for l in pred.text_lines] pred_image = draw_text_on_image(bbox, pred_text, image.size, lang) pred_image.save(os.path.join(result_path, pred_image_name)) ref_image = draw_text_on_image(bbox, ref_text, image.size, lang) ref_image.save(os.path.join(result_path, ref_image_name)) image.save(os.path.join(result_path, f"{'_'.join(lang)}_{idx}_image.png")) print(f"Wrote results to {result_path}") if __name__ == "__main__": main()