# Copyright 2020 IBM # Author: peter.zhong@au1.ibm.com # # This is free software; you can redistribute it and/or modify # it under the terms of the Apache 2.0 License. # # This software is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # Apache 2.0 License for more details. from rapidfuzz.distance import Levenshtein from apted import APTED, Config from apted.helpers import Tree from collections import deque from .parallel import parallel_process from tqdm import tqdm from paddle.utils import try_import class TableTree(Tree): def __init__(self, tag, colspan=None, rowspan=None, content=None, *children): self.tag = tag self.colspan = colspan self.rowspan = rowspan self.content = content self.children = list(children) def bracket(self): """Show tree using brackets notation""" if self.tag == "td": result = '"tag": %s, "colspan": %d, "rowspan": %d, "text": %s' % ( self.tag, self.colspan, self.rowspan, self.content, ) else: result = '"tag": %s' % self.tag for child in self.children: result += child.bracket() return "{{{}}}".format(result) class CustomConfig(Config): def rename(self, node1, node2): """Compares attributes of trees""" # print(node1.tag) if ( (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan) ): return 1.0 if node1.tag == "td": if node1.content or node2.content: # print(node1.content, ) return Levenshtein.normalized_distance(node1.content, node2.content) return 0.0 class CustomConfig_del_short(Config): def rename(self, node1, node2): """Compares attributes of trees""" if ( (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan) ): return 1.0 if node1.tag == "td": if node1.content or node2.content: # print('before') # print(node1.content, node2.content) # print('after') node1_content = node1.content node2_content = node2.content if len(node1_content) < 3: node1_content = ["####"] if len(node2_content) < 3: node2_content = ["####"] return Levenshtein.normalized_distance(node1_content, node2_content) return 0.0 class CustomConfig_del_block(Config): def rename(self, node1, node2): """Compares attributes of trees""" if ( (node1.tag != node2.tag) or (node1.colspan != node2.colspan) or (node1.rowspan != node2.rowspan) ): return 1.0 if node1.tag == "td": if node1.content or node2.content: node1_content = node1.content node2_content = node2.content while " " in node1_content: print(node1_content.index(" ")) node1_content.pop(node1_content.index(" ")) while " " in node2_content: print(node2_content.index(" ")) node2_content.pop(node2_content.index(" ")) return Levenshtein.normalized_distance(node1_content, node2_content) return 0.0 class TEDS(object): """Tree Edit Distance basead Similarity""" def __init__(self, structure_only=False, n_jobs=1, ignore_nodes=None): assert isinstance(n_jobs, int) and ( n_jobs >= 1 ), "n_jobs must be an integer greather than 1" self.structure_only = structure_only self.n_jobs = n_jobs self.ignore_nodes = ignore_nodes self.__tokens__ = [] def tokenize(self, node): """Tokenizes table cells""" self.__tokens__.append("<%s>" % node.tag) if node.text is not None: self.__tokens__ += list(node.text) for n in node.getchildren(): self.tokenize(n) if node.tag != "unk": self.__tokens__.append("" % node.tag) if node.tag != "td" and node.tail is not None: self.__tokens__ += list(node.tail) def load_html_tree(self, node, parent=None): """Converts HTML tree to the format required by apted""" global __tokens__ if node.tag == "td": if self.structure_only: cell = [] else: self.__tokens__ = [] self.tokenize(node) cell = self.__tokens__[1:-1].copy() new_node = TableTree( node.tag, int(node.attrib.get("colspan", "1")), int(node.attrib.get("rowspan", "1")), cell, *deque(), ) else: new_node = TableTree(node.tag, None, None, None, *deque()) if parent is not None: parent.children.append(new_node) if node.tag != "td": for n in node.getchildren(): self.load_html_tree(n, new_node) if parent is None: return new_node def evaluate(self, pred, true): """Computes TEDS score between the prediction and the ground truth of a given sample """ try_import("lxml") from lxml import etree, html if (not pred) or (not true): return 0.0 parser = html.HTMLParser(remove_comments=True, encoding="utf-8") pred = html.fromstring(pred, parser=parser) true = html.fromstring(true, parser=parser) if pred.xpath("body/table") and true.xpath("body/table"): pred = pred.xpath("body/table")[0] true = true.xpath("body/table")[0] if self.ignore_nodes: etree.strip_tags(pred, *self.ignore_nodes) etree.strip_tags(true, *self.ignore_nodes) n_nodes_pred = len(pred.xpath(".//*")) n_nodes_true = len(true.xpath(".//*")) n_nodes = max(n_nodes_pred, n_nodes_true) tree_pred = self.load_html_tree(pred) tree_true = self.load_html_tree(true) distance = APTED( tree_pred, tree_true, CustomConfig() ).compute_edit_distance() return 1.0 - (float(distance) / n_nodes) else: return 0.0 def batch_evaluate(self, pred_json, true_json): """Computes TEDS score between the prediction and the ground truth of a batch of samples @params pred_json: {'FILENAME': 'HTML CODE', ...} @params true_json: {'FILENAME': {'html': 'HTML CODE'}, ...} @json: {'FILENAME': 'TEDS SCORE', ...} """ samples = true_json.keys() if self.n_jobs == 1: scores = [ self.evaluate(pred_json.get(filename, ""), true_json[filename]["html"]) for filename in tqdm(samples) ] else: inputs = [ { "pred": pred_json.get(filename, ""), "true": true_json[filename]["html"], } for filename in samples ] scores = parallel_process( inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1 ) scores = dict(zip(samples, scores)) return scores def batch_evaluate_html(self, pred_htmls, true_htmls): """Computes TEDS score between the prediction and the ground truth of a batch of samples """ if self.n_jobs == 1: scores = [ self.evaluate(pred_html, true_html) for (pred_html, true_html) in zip(pred_htmls, true_htmls) ] else: inputs = [ {"pred": pred_html, "true": true_html} for (pred_html, true_html) in zip(pred_htmls, true_htmls) ] scores = parallel_process( inputs, self.evaluate, use_kwargs=True, n_jobs=self.n_jobs, front_num=1 ) return scores if __name__ == "__main__": import json import pprint with open("sample_pred.json") as fp: pred_json = json.load(fp) with open("sample_gt.json") as fp: true_json = json.load(fp) teds = TEDS(n_jobs=4) scores = teds.batch_evaluate(pred_json, true_json) pp = pprint.PrettyPrinter() pp.pprint(scores)