You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

92 lines
2.6 KiB
Python

import base64
import json
import tempfile
import time
from io import BytesIO
import torch
from PIL import Image
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
def convert_single_page(filename: str, model, processor, device):
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.prompts import build_finetuning_prompt
from olmocr.prompts.anchor import get_anchor_text
image_base64 = render_pdf_to_base64png(filename, 1, target_longest_image_dim=1024)
# Build the prompt, using document metadata
anchor_text = get_anchor_text(filename, 1, pdf_engine="pdfreport", target_length=4000)
prompt = build_finetuning_prompt(anchor_text)
# Build the full prompt
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
],
}
]
# Apply the chat template and processor
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
inputs = processor(
text=[text],
images=[main_image],
padding=True,
return_tensors="pt",
)
inputs = {key: value.to(device) for (key, value) in inputs.items()}
# Generate the output
output = model.generate(
**inputs,
temperature=0.8,
max_new_tokens=8192,
num_return_sequences=1,
do_sample=True,
)
# Decode the output
prompt_length = inputs["input_ids"].shape[1]
new_tokens = output[:, prompt_length:]
text_output = processor.tokenizer.batch_decode(
new_tokens, skip_special_tokens=True
)[0]
try:
text_output = json.loads(text_output)
text = text_output["natural_text"]
except Exception:
try:
text = text_output.split("natural_text")[1].strip()
except Exception:
text = ""
return text
class OlmOCRMethod(BaseMethod):
olmocr_model: dict = None
use_llm: bool = False
def __call__(self, sample) -> BenchmarkResult:
pdf_bytes = sample["pdf"] # This is a single page PDF
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as f:
f.write(pdf_bytes)
start = time.time()
result = convert_single_page(f.name, self.olmocr_model["model"], self.olmocr_model["processor"], self.olmocr_model["model"].device)
total = time.time() - start
return {
"markdown": result,
"time": total
}