You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
92 lines
2.6 KiB
Python
92 lines
2.6 KiB
Python
import base64
|
|
import json
|
|
import tempfile
|
|
import time
|
|
from io import BytesIO
|
|
|
|
import torch
|
|
from PIL import Image
|
|
|
|
from benchmarks.overall.methods import BaseMethod, BenchmarkResult
|
|
|
|
|
|
def convert_single_page(filename: str, model, processor, device):
|
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
|
from olmocr.prompts import build_finetuning_prompt
|
|
from olmocr.prompts.anchor import get_anchor_text
|
|
|
|
image_base64 = render_pdf_to_base64png(filename, 1, target_longest_image_dim=1024)
|
|
|
|
# Build the prompt, using document metadata
|
|
anchor_text = get_anchor_text(filename, 1, pdf_engine="pdfreport", target_length=4000)
|
|
prompt = build_finetuning_prompt(anchor_text)
|
|
|
|
# Build the full prompt
|
|
messages = [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": prompt},
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
|
|
],
|
|
}
|
|
]
|
|
|
|
# Apply the chat template and processor
|
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
main_image = Image.open(BytesIO(base64.b64decode(image_base64)))
|
|
|
|
inputs = processor(
|
|
text=[text],
|
|
images=[main_image],
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
inputs = {key: value.to(device) for (key, value) in inputs.items()}
|
|
|
|
# Generate the output
|
|
output = model.generate(
|
|
**inputs,
|
|
temperature=0.8,
|
|
max_new_tokens=8192,
|
|
num_return_sequences=1,
|
|
do_sample=True,
|
|
)
|
|
|
|
# Decode the output
|
|
prompt_length = inputs["input_ids"].shape[1]
|
|
new_tokens = output[:, prompt_length:]
|
|
text_output = processor.tokenizer.batch_decode(
|
|
new_tokens, skip_special_tokens=True
|
|
)[0]
|
|
|
|
try:
|
|
text_output = json.loads(text_output)
|
|
text = text_output["natural_text"]
|
|
except Exception:
|
|
try:
|
|
text = text_output.split("natural_text")[1].strip()
|
|
except Exception:
|
|
text = ""
|
|
|
|
return text
|
|
|
|
|
|
class OlmOCRMethod(BaseMethod):
|
|
olmocr_model: dict = None
|
|
use_llm: bool = False
|
|
|
|
def __call__(self, sample) -> BenchmarkResult:
|
|
pdf_bytes = sample["pdf"] # This is a single page PDF
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as f:
|
|
f.write(pdf_bytes)
|
|
start = time.time()
|
|
result = convert_single_page(f.name, self.olmocr_model["model"], self.olmocr_model["processor"], self.olmocr_model["model"].device)
|
|
total = time.time() - start
|
|
|
|
return {
|
|
"markdown": result,
|
|
"time": total
|
|
}
|