You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

182 lines
7.0 KiB
Python

from typing import List
import numpy as np
from bs4 import BeautifulSoup
import pypdfium2 as pdfium
from tqdm import tqdm
import base64
import tempfile
from benchmarks.table.gemini import gemini_table_rec
from marker.config.parser import ConfigParser
from marker.converters.table import TableConverter
from marker.models import create_model_dict
from marker.processors.llm.llm_table import LLMTableProcessor
from marker.processors.table import TableProcessor
from marker.renderers.json import JSONBlockOutput
from marker.schema.polygon import PolygonBox
from marker.util import matrix_intersection_area
def extract_tables(children: List[JSONBlockOutput]):
tables = []
for child in children:
if child.block_type == 'Table':
tables.append(child)
elif child.children:
tables.extend(extract_tables(child.children))
return tables
def fix_table_html(table_html: str) -> str:
marker_table_soup = BeautifulSoup(table_html, 'html.parser')
tbody = marker_table_soup.find('tbody')
if tbody:
tbody.unwrap()
for th_tag in marker_table_soup.find_all('th'):
th_tag.name = 'td'
for br_tag in marker_table_soup.find_all('br'):
br_tag.replace_with(marker_table_soup.new_string(''))
marker_table_html = str(marker_table_soup)
marker_table_html = marker_table_html.replace("\n", " ") # Fintabnet uses spaces instead of newlines
return marker_table_html
def inference_tables(dataset, use_llm: bool, table_rec_batch_size: int | None, max_rows: int, use_gemini: bool):
models = create_model_dict()
config_parser = ConfigParser({'output_format': 'json', "use_llm": use_llm, "table_rec_batch_size": table_rec_batch_size, "disable_tqdm": True})
total_unaligned = 0
results = []
iterations = len(dataset)
if max_rows is not None:
iterations = min(max_rows, len(dataset))
for i in tqdm(range(iterations), desc='Converting Tables'):
try:
row = dataset[i]
pdf_binary = base64.b64decode(row['pdf'])
gt_tables = row['tables'] # Already sorted by reading order, which is what marker returns
# Only use the basic table processors
converter = TableConverter(
config=config_parser.generate_config_dict(),
artifact_dict=models,
processor_list=[
"marker.processors.table.TableProcessor",
"marker.processors.llm.llm_table.LLMTableProcessor",
],
renderer=config_parser.get_renderer()
)
with tempfile.NamedTemporaryFile(suffix=".pdf", mode="wb") as temp_pdf_file:
temp_pdf_file.write(pdf_binary)
temp_pdf_file.seek(0)
marker_json = converter(temp_pdf_file.name).children
doc = pdfium.PdfDocument(temp_pdf_file.name)
page_image = doc[0].render(scale=96/72).to_pil()
doc.close()
if len(marker_json) == 0 or len(gt_tables) == 0:
print(f'No tables detected, skipping...')
total_unaligned += len(gt_tables)
continue
marker_tables = extract_tables(marker_json)
marker_table_boxes = [table.bbox for table in marker_tables]
page_bbox = marker_json[0].bbox
if len(marker_tables) != len(gt_tables):
print(f'Number of tables do not match, skipping...')
total_unaligned += len(gt_tables)
continue
table_images = [
page_image.crop(
PolygonBox.from_bbox(bbox)
.rescale(
(page_bbox[2], page_bbox[3]), (page_image.width, page_image.height)
).bbox
)
for bbox
in marker_table_boxes
]
# Normalize the bboxes
for bbox in marker_table_boxes:
bbox[0] = bbox[0] / page_bbox[2]
bbox[1] = bbox[1] / page_bbox[3]
bbox[2] = bbox[2] / page_bbox[2]
bbox[3] = bbox[3] / page_bbox[3]
gt_boxes = [table['normalized_bbox'] for table in gt_tables]
gt_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in gt_boxes]
marker_areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in marker_table_boxes]
table_alignments = matrix_intersection_area(gt_boxes, marker_table_boxes)
aligned_tables = []
used_tables = set()
unaligned_tables = set()
for table_idx, alignment in enumerate(table_alignments):
try:
max_area = np.max(alignment)
aligned_idx = np.argmax(alignment)
except ValueError:
# No alignment found
unaligned_tables.add(table_idx)
continue
if max_area <= .01:
# No alignment found
unaligned_tables.add(table_idx)
continue
if aligned_idx in used_tables:
# Marker table already aligned with another gt table
unaligned_tables.add(table_idx)
continue
# Gt table doesn't align well with any marker table
gt_table_pct = gt_areas[table_idx] / max_area
if not .85 < gt_table_pct < 1.15:
unaligned_tables.add(table_idx)
continue
# Marker table doesn't align with gt table
marker_table_pct = marker_areas[aligned_idx] / max_area
if not .85 < marker_table_pct < 1.15:
unaligned_tables.add(table_idx)
continue
gemini_html = ""
if use_gemini:
try:
gemini_html = gemini_table_rec(table_images[aligned_idx])
except Exception as e:
print(f'Gemini failed: {e}')
aligned_tables.append(
(marker_tables[aligned_idx], gt_tables[table_idx], gemini_html)
)
used_tables.add(aligned_idx)
total_unaligned += len(unaligned_tables)
for marker_table, gt_table, gemini_table in aligned_tables:
gt_table_html = gt_table['html']
# marker wraps the table in <tbody> which fintabnet data doesn't
# Fintabnet doesn't use th tags, need to be replaced for fair comparison
marker_table_html = fix_table_html(marker_table.html)
gemini_table_html = fix_table_html(gemini_table)
results.append({
"marker_table": marker_table_html,
"gt_table": gt_table_html,
"gemini_table": gemini_table_html
})
except pdfium.PdfiumError:
print('Broken PDF, Skipping...')
continue
return results, total_unaligned