You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
250 lines
8.6 KiB
Python
250 lines
8.6 KiB
Python
# Copyright 2020 IBM
|
|
# Author: peter.zhong@au1.ibm.com
|
|
#
|
|
# This is free software; you can redistribute it and/or modify
|
|
# it under the terms of the Apache 2.0 License.
|
|
#
|
|
# This software is distributed in the hope that it will be useful,
|
|
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
# Apache 2.0 License for more details.
|
|
|
|
from rapidfuzz.distance import Levenshtein
|
|
from apted import APTED, Config
|
|
from apted.helpers import Tree
|
|
from collections import deque
|
|
from .parallel import parallel_process
|
|
from tqdm import tqdm
|
|
from paddle.utils import try_import
|
|
|
|
|
|
class TableTree(Tree):
|
|
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
|
self.tag = tag
|
|
self.colspan = colspan
|
|
self.rowspan = rowspan
|
|
self.content = content
|
|
self.children = list(children)
|
|
|
|
def bracket(self):
|
|
"""Show tree using brackets notation"""
|
|
if self.tag == "td":
|
|
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % (
|
|
self.tag,
|
|
self.colspan,
|
|
self.rowspan,
|
|
self.content,
|
|
)
|
|
else:
|
|
result = '"tag": %s' % self.tag
|
|
for child in self.children:
|
|
result += child.bracket()
|
|
return "{{{}}}".format(result)
|
|
|
|
|
|
class CustomConfig(Config):
|
|
def rename(self, node1, node2):
|
|
"""Compares attributes of trees"""
|
|
# print(node1.tag)
|
|
if (
|
|
(node1.tag != node2.tag)
|
|
or (node1.colspan != node2.colspan)
|
|
or (node1.rowspan != node2.rowspan)
|
|
):
|
|
return 1.0
|
|
if node1.tag == "td":
|
|
if node1.content or node2.content:
|
|
# print(node1.content, )
|
|
return Levenshtein.normalized_distance(node1.content, node2.content)
|
|
return 0.0
|
|
|
|
|
|
class CustomConfig_del_short(Config):
|
|
def rename(self, node1, node2):
|
|
"""Compares attributes of trees"""
|
|
if (
|
|
(node1.tag != node2.tag)
|
|
or (node1.colspan != node2.colspan)
|
|
or (node1.rowspan != node2.rowspan)
|
|
):
|
|
return 1.0
|
|
if node1.tag == "td":
|
|
if node1.content or node2.content:
|
|
# print('before')
|
|
# print(node1.content, node2.content)
|
|
# print('after')
|
|
node1_content = node1.content
|
|
node2_content = node2.content
|
|
if len(node1_content) < 3:
|
|
node1_content = ["####"]
|
|
if len(node2_content) < 3:
|
|
node2_content = ["####"]
|
|
return Levenshtein.normalized_distance(node1_content, node2_content)
|
|
return 0.0
|
|
|
|
|
|
class CustomConfig_del_block(Config):
|
|
def rename(self, node1, node2):
|
|
"""Compares attributes of trees"""
|
|
if (
|
|
(node1.tag != node2.tag)
|
|
or (node1.colspan != node2.colspan)
|
|
or (node1.rowspan != node2.rowspan)
|
|
):
|
|
return 1.0
|
|
if node1.tag == "td":
|
|
if node1.content or node2.content:
|
|
node1_content = node1.content
|
|
node2_content = node2.content
|
|
while " " in node1_content:
|
|
print(node1_content.index(" "))
|
|
node1_content.pop(node1_content.index(" "))
|
|
while " " in node2_content:
|
|
print(node2_content.index(" "))
|
|
node2_content.pop(node2_content.index(" "))
|
|
return Levenshtein.normalized_distance(node1_content, node2_content)
|
|
return 0.0
|
|
|
|
|
|
class TEDS(object):
|
|
"""Tree Edit Distance basead Similarity"""
|
|
|
|
def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None):
|
|
assert isinstance(n_jobs, int) and (
|
|
n_jobs >= 1
|
|
), "n_jobs must be an integer greather than 1"
|
|
self.structure_only = structure_only
|
|
self.n_jobs = n_jobs
|
|
self.ignore_nodes = ignore_nodes
|
|
self.__tokens__ = []
|
|
|
|
def tokenize(self, node):
|
|
"""Tokenizes table cells"""
|
|
self.__tokens__.append("<%s>" % node.tag)
|
|
if node.text is not None:
|
|
self.__tokens__ += list(node.text)
|
|
for n in node.getchildren():
|
|
self.tokenize(n)
|
|
if node.tag != "unk":
|
|
self.__tokens__.append("</%s>" % node.tag)
|
|
if node.tag != "td" and node.tail is not None:
|
|
self.__tokens__ += list(node.tail)
|
|
|
|
def load_html_tree(self, node, parent=None):
|
|
"""Converts HTML tree to the format required by apted"""
|
|
global __tokens__
|
|
if node.tag == "td":
|
|
if self.structure_only:
|
|
cell = []
|
|
else:
|
|
self.__tokens__ = []
|
|
self.tokenize(node)
|
|
cell = self.__tokens__[1:-1].copy()
|
|
new_node = TableTree(
|
|
node.tag,
|
|
int(node.attrib.get("colspan", "1")),
|
|
int(node.attrib.get("rowspan", "1")),
|
|
cell,
|
|
*deque(),
|
|
)
|
|
else:
|
|
new_node = TableTree(node.tag, None, None, None, *deque())
|
|
if parent is not None:
|
|
parent.children.append(new_node)
|
|
if node.tag != "td":
|
|
for n in node.getchildren():
|
|
self.load_html_tree(n, new_node)
|
|
if parent is None:
|
|
return new_node
|
|
|
|
def evaluate(self, pred, true):
|
|
"""Computes TEDS score between the prediction and the ground truth of a
|
|
given sample
|
|
"""
|
|
try_import("lxml")
|
|
from lxml import etree, html
|
|
|
|
if (not pred) or (not true):
|
|
return 0.0
|
|
parser = html.HTMLParser(remove_comments=True, encoding="utf-8")
|
|
pred = html.fromstring(pred, parser=parser)
|
|
true = html.fromstring(true, parser=parser)
|
|
if pred.xpath("body/table") and true.xpath("body/table"):
|
|
pred = pred.xpath("body/table")[0]
|
|
true = true.xpath("body/table")[0]
|
|
if self.ignore_nodes:
|
|
etree.strip_tags(pred, *self.ignore_nodes)
|
|
etree.strip_tags(true, *self.ignore_nodes)
|
|
n_nodes_pred = len(pred.xpath(".//*"))
|
|
n_nodes_true = len(true.xpath(".//*"))
|
|
n_nodes = max(n_nodes_pred, n_nodes_true)
|
|
tree_pred = self.load_html_tree(pred)
|
|
tree_true = self.load_html_tree(true)
|
|
distance = APTED(
|
|
tree_pred, tree_true, CustomConfig()
|
|
).compute_edit_distance()
|
|
return 1.0 - (float(distance) / n_nodes)
|
|
else:
|
|
return 0.0
|
|
|
|
def batch_evaluate(self, pred_json, true_json):
|
|
"""Computes TEDS score between the prediction and the ground truth of
|
|
a batch of samples
|
|
@params pred_json: {'FILENAME': 'HTML CODE', ...}
|
|
@params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...}
|
|
@json: {'FILENAME': 'TEDS SCORE', ...}
|
|
"""
|
|
samples = true_json.keys()
|
|
if self.n_jobs == 1:
|
|
scores = [
|
|
self.evaluate(pred_json.get(filename, ""), true_json[filename]["html"])
|
|
for filename in tqdm(samples)
|
|
]
|
|
else:
|
|
inputs = [
|
|
{
|
|
"pred": pred_json.get(filename, ""),
|
|
"true": true_json[filename]["html"],
|
|
}
|
|
for filename in samples
|
|
]
|
|
scores = parallel_process(
|
|
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
|
|
)
|
|
scores = dict(zip(samples, scores))
|
|
return scores
|
|
|
|
def batch_evaluate_html(self, pred_htmls, true_htmls):
|
|
"""Computes TEDS score between the prediction and the ground truth of
|
|
a batch of samples
|
|
"""
|
|
if self.n_jobs == 1:
|
|
scores = [
|
|
self.evaluate(pred_html, true_html)
|
|
for (pred_html, true_html) in zip(pred_htmls, true_htmls)
|
|
]
|
|
else:
|
|
inputs = [
|
|
{"pred": pred_html, "true": true_html}
|
|
for (pred_html, true_html) in zip(pred_htmls, true_htmls)
|
|
]
|
|
|
|
scores = parallel_process(
|
|
inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1
|
|
)
|
|
return scores
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import json
|
|
import pprint
|
|
|
|
with open("sample_pred.json") as fp:
|
|
pred_json = json.load(fp)
|
|
with open("sample_gt.json") as fp:
|
|
true_json = json.load(fp)
|
|
teds = TEDS(n_jobs=4)
|
|
scores = teds.batch_evaluate(pred_json, true_json)
|
|
pp = pprint.PrettyPrinter()
|
|
pp.pprint(scores)
|