You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

250 lines
8.6 KiB
Python

# Copyright 2020 IBM
# Author: peter.zhong@au1.ibm.com
#
# This is free software; you can redistribute it and/or modify
# it under the terms of the Apache 2.0 License.
#
# This software is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Apache 2.0 License for more details.
from rapidfuzz.distance import Levenshtein
from apted import APTED, Config
from apted.helpers import Tree
from collections import deque
from .parallel import parallel_process
from tqdm import tqdm
from paddle.utils import try_import
class TableTree(Tree):
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
self.tag = tag
self.colspan = colspan
self.rowspan = rowspan
self.content = content
self.children = list(children)
def bracket(self):
"""Show tree using brackets notation"""
if self.tag == "td":
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
self.tag,
self.colspan,
self.rowspan,
self.content,
)
else:
result = '"tag": %s' % self.tag
for child in self.children:
result += child.bracket()
return "{{{}}}".format(result)
class CustomConfig(Config):
def rename(self, node1, node2):
"""Compares attributes of trees"""
# print(node1.tag)
if (
(node1.tag != node2.tag)
or (node1.colspan != node2.colspan)
or (node1.rowspan != node2.rowspan)
):
return 1.0
if node1.tag == "td":
if node1.content or node2.content:
# print(node1.content, )
return Levenshtein.normalized_distance(node1.content, node2.content)
return 0.0
class CustomConfig_del_short(Config):
def rename(self, node1, node2):
"""Compares attributes of trees"""
if (
(node1.tag != node2.tag)
or (node1.colspan != node2.colspan)
or (node1.rowspan != node2.rowspan)
):
return 1.0
if node1.tag == "td":
if node1.content or node2.content:
# print('before')
# print(node1.content, node2.content)
# print('after')
node1_content = node1.content
node2_content = node2.content
if len(node1_content) < 3:
node1_content = ["####"]
if len(node2_content) < 3:
node2_content = ["####"]
return Levenshtein.normalized_distance(node1_content, node2_content)
return 0.0
class CustomConfig_del_block(Config):
def rename(self, node1, node2):
"""Compares attributes of trees"""
if (
(node1.tag != node2.tag)
or (node1.colspan != node2.colspan)
or (node1.rowspan != node2.rowspan)
):
return 1.0
if node1.tag == "td":
if node1.content or node2.content:
node1_content = node1.content
node2_content = node2.content
while " " in node1_content:
print(node1_content.index(" "))
node1_content.pop(node1_content.index(" "))
while " " in node2_content:
print(node2_content.index(" "))
node2_content.pop(node2_content.index(" "))
return Levenshtein.normalized_distance(node1_content, node2_content)
return 0.0
class TEDS(object):
"""Tree Edit Distance basead Similarity"""
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
assert isinstance(n_jobs, int) and (
n_jobs >= 1
), "n_jobs must be an integer greather than 1"
self.structure_only = structure_only
self.n_jobs = n_jobs
self.ignore_nodes = ignore_nodes
self.__tokens__ = []
def tokenize(self, node):
"""Tokenizes table cells"""
self.__tokens__.append("<%s>" % node.tag)
if node.text is not None:
self.__tokens__ += list(node.text)
for n in node.getchildren():
self.tokenize(n)
if node.tag != "unk":
self.__tokens__.append("</%s>" % node.tag)
if node.tag != "td" and node.tail is not None:
self.__tokens__ += list(node.tail)
def load_html_tree(self, node, parent=None):
"""Converts HTML tree to the format required by apted"""
global __tokens__
if node.tag == "td":
if self.structure_only:
cell = []
else:
self.__tokens__ = []
self.tokenize(node)
cell = self.__tokens__[1:-1].copy()
new_node = TableTree(
node.tag,
int(node.attrib.get("colspan", "1")),
int(node.attrib.get("rowspan", "1")),
cell,
*deque(),
)
else:
new_node = TableTree(node.tag, None, None, None, *deque())
if parent is not None:
parent.children.append(new_node)
if node.tag != "td":
for n in node.getchildren():
self.load_html_tree(n, new_node)
if parent is None:
return new_node
def evaluate(self, pred, true):
"""Computes TEDS score between the prediction and the ground truth of a
given sample
"""
try_import("lxml")
from lxml import etree, html
if (not pred) or (not true):
return 0.0
parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
pred = html.fromstring(pred, parser=parser)
true = html.fromstring(true, parser=parser)
if pred.xpath("body/table") and true.xpath("body/table"):
pred = pred.xpath("body/table")[0]
true = true.xpath("body/table")[0]
if self.ignore_nodes:
etree.strip_tags(pred, *self.ignore_nodes)
etree.strip_tags(true, *self.ignore_nodes)
n_nodes_pred = len(pred.xpath(".//*"))
n_nodes_true = len(true.xpath(".//*"))
n_nodes = max(n_nodes_pred, n_nodes_true)
tree_pred = self.load_html_tree(pred)
tree_true = self.load_html_tree(true)
distance = APTED(
tree_pred, tree_true, CustomConfig()
).compute_edit_distance()
return 1.0 - (float(distance) / n_nodes)
else:
return 0.0
def batch_evaluate(self, pred_json, true_json):
"""Computes TEDS score between the prediction and the ground truth of
a batch of samples
@params pred_json: {'FILENAME': 'HTML CODE', ...}
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
@json: {'FILENAME': 'TEDS SCORE', ...}
"""
samples = true_json.keys()
if self.n_jobs == 1:
scores = [
self.evaluate(pred_json.get(filename, ""), true_json[filename]["html"])
for filename in tqdm(samples)
]
else:
inputs = [
{
"pred": pred_json.get(filename, ""),
"true": true_json[filename]["html"],
}
for filename in samples
]
scores = parallel_process(
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
)
scores = dict(zip(samples, scores))
return scores
def batch_evaluate_html(self, pred_htmls, true_htmls):
"""Computes TEDS score between the prediction and the ground truth of
a batch of samples
"""
if self.n_jobs == 1:
scores = [
self.evaluate(pred_html, true_html)
for (pred_html, true_html) in zip(pred_htmls, true_htmls)
]
else:
inputs = [
{"pred": pred_html, "true": true_html}
for (pred_html, true_html) in zip(pred_htmls, true_htmls)
]
scores = parallel_process(
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
)
return scores
if __name__ == "__main__":
import json
import pprint
with open("sample_pred.json") as fp:
pred_json = json.load(fp)
with open("sample_gt.json") as fp:
true_json = json.load(fp)
teds = TEDS(n_jobs=4)
scores = teds.batch_evaluate(pred_json, true_json)
pp = pprint.PrettyPrinter()
pp.pprint(scores)