You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

64 lines
1.9 KiB
Python

import json
from json import JSONDecodeError
from pathlib import Path
import datasets
from tqdm import tqdm
class Downloader:
cache_path: Path = Path("cache")
service: str
def __init__(self, api_key, app_id, max_rows: int = 2200):
self.cache_path.mkdir(exist_ok=True)
self.max_rows = max_rows
self.api_key = api_key
self.app_id = app_id
self.ds = datasets.load_dataset("datalab-to/marker_benchmark", split="train")
def get_html(self, pdf_bytes):
raise NotImplementedError
def upload_ds(self):
rows = []
for file in self.cache_path.glob("*.json"):
with open(file, "r") as f:
data = json.load(f)
rows.append(data)
out_ds = datasets.Dataset.from_list(rows, features=datasets.Features({
"md": datasets.Value("string"),
"uuid": datasets.Value("string"),
"time": datasets.Value("float"),
}))
out_ds.push_to_hub(f"datalab-to/marker_benchmark_{self.service}", private=True)
def generate_data(self):
max_rows = self.max_rows
for idx, sample in tqdm(enumerate(self.ds), desc=f"Saving {self.service} results"):
cache_file = self.cache_path / f"{idx}.json"
if cache_file.exists():
continue
pdf_bytes = sample["pdf"] # This is a single page PDF
try:
out_data = self.get_html(pdf_bytes)
except JSONDecodeError as e:
print(f"Error with sample {idx}: {e}")
continue
except Exception as e:
print(f"Error with sample {idx}: {e}")
continue
out_data["uuid"] = sample["uuid"]
with cache_file.open("w") as f:
json.dump(out_data, f)
if idx >= max_rows:
break
def __call__(self):
self.generate_data()
self.upload_ds()