You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
110 lines
3.6 KiB
Python
110 lines
3.6 KiB
Python
""""
|
|
TEDS Code Adapted from https://github.com/ibm-aur-nlp/EDD
|
|
"""
|
|
|
|
import distance
|
|
from apted import APTED, Config
|
|
from apted.helpers import Tree
|
|
from lxml import html
|
|
from collections import deque
|
|
|
|
def wrap_table_html(table_html:str)->str:
|
|
return f'<html><body>{table_html}</body></html>'
|
|
|
|
class TableTree(Tree):
|
|
def __init__(self, tag, colspan=None, rowspan=None, content=None, *children):
|
|
self.tag = tag
|
|
self.colspan = colspan
|
|
self.rowspan = rowspan
|
|
self.content = content
|
|
|
|
# Sets self.name and self.children
|
|
super().__init__(tag, *children)
|
|
|
|
def bracket(self):
|
|
"""Show tree using brackets notation"""
|
|
if self.tag == 'td':
|
|
result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % \
|
|
(self.tag, self.colspan, self.rowspan, self.content)
|
|
else:
|
|
result = '"tag": %s' % self.tag
|
|
for child in self.children:
|
|
result += child.bracket()
|
|
return "{{{}}}".format(result)
|
|
|
|
class CustomConfig(Config):
|
|
@staticmethod
|
|
def maximum(*sequences):
|
|
return max(map(len, sequences))
|
|
|
|
def normalized_distance(self, *sequences):
|
|
return float(distance.levenshtein(*sequences)) / self.maximum(*sequences)
|
|
|
|
def rename(self, node1, node2):
|
|
if (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan):
|
|
return 1.
|
|
if node1.tag == 'td':
|
|
if node1.content or node2.content:
|
|
return self.normalized_distance(node1.content, node2.content)
|
|
return 0.
|
|
|
|
def tokenize(node):
|
|
"""
|
|
Tokenizes table cells
|
|
"""
|
|
global __tokens__
|
|
__tokens__.append('<%s>' % node.tag)
|
|
if node.text is not None:
|
|
__tokens__ += list(node.text)
|
|
for n in node.getchildren():
|
|
tokenize(n)
|
|
if node.tag != 'unk':
|
|
__tokens__.append('</%s>' % node.tag)
|
|
if node.tag != 'td' and node.tail is not None:
|
|
__tokens__ += list(node.tail)
|
|
|
|
def tree_convert_html(node, convert_cell=False, parent=None):
|
|
"""
|
|
Converts HTML tree to the format required by apted
|
|
"""
|
|
global __tokens__
|
|
if node.tag == 'td':
|
|
if convert_cell:
|
|
__tokens__ = []
|
|
tokenize(node)
|
|
cell = __tokens__[1:-1].copy()
|
|
else:
|
|
cell = []
|
|
new_node = TableTree(node.tag,
|
|
int(node.attrib.get('colspan', '1')),
|
|
int(node.attrib.get('rowspan', '1')),
|
|
cell, *deque())
|
|
else:
|
|
new_node = TableTree(node.tag, None, None, None, *deque())
|
|
if parent is not None:
|
|
parent.children.append(new_node)
|
|
if node.tag != 'td':
|
|
for n in node.getchildren():
|
|
tree_convert_html(n, convert_cell, new_node)
|
|
if parent is None:
|
|
return new_node
|
|
|
|
def similarity_eval_html(pred, true, structure_only=False):
|
|
"""
|
|
Computes TEDS score between the prediction and the ground truth of a given samples
|
|
"""
|
|
pred, true = html.fromstring(pred), html.fromstring(true)
|
|
if pred.xpath('body/table') and true.xpath('body/table'):
|
|
pred = pred.xpath('body/table')[0]
|
|
true = true.xpath('body/table')[0]
|
|
n_nodes_pred = len(pred.xpath(".//*"))
|
|
n_nodes_true = len(true.xpath(".//*"))
|
|
tree_pred = tree_convert_html(pred, convert_cell=not structure_only)
|
|
tree_true = tree_convert_html(true, convert_cell=not structure_only)
|
|
n_nodes = max(n_nodes_pred, n_nodes_true)
|
|
distance = APTED(tree_pred, tree_true, CustomConfig()).compute_edit_distance()
|
|
return 1.0 - (float(distance) / n_nodes)
|
|
else:
|
|
return 0.0
|
|
|