You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
6.6 KiB
Python
189 lines
6.6 KiB
Python
8 months ago
|
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
import numpy as np
|
||
|
import paddle
|
||
|
|
||
|
from .rec_postprocess import AttnLabelDecode
|
||
|
|
||
|
|
||
|
class TableLabelDecode(AttnLabelDecode):
|
||
|
""" """
|
||
|
|
||
|
def __init__(self, character_dict_path, merge_no_span_structure=False, **kwargs):
|
||
|
dict_character = []
|
||
|
with open(character_dict_path, "rb") as fin:
|
||
|
lines = fin.readlines()
|
||
|
for line in lines:
|
||
|
line = line.decode("utf-8").strip("\n").strip("\r\n")
|
||
|
dict_character.append(line)
|
||
|
|
||
|
if merge_no_span_structure:
|
||
|
if "<td></td>" not in dict_character:
|
||
|
dict_character.append("<td></td>")
|
||
|
if "<td>" in dict_character:
|
||
|
dict_character.remove("<td>")
|
||
|
|
||
|
dict_character = self.add_special_char(dict_character)
|
||
|
self.dict = {}
|
||
|
for i, char in enumerate(dict_character):
|
||
|
self.dict[char] = i
|
||
|
self.character = dict_character
|
||
|
self.td_token = ["<td>", "<td", "<td></td>"]
|
||
|
|
||
|
def __call__(self, preds, batch=None):
|
||
|
structure_probs = preds["structure_probs"]
|
||
|
bbox_preds = preds["loc_preds"]
|
||
|
if isinstance(structure_probs, paddle.Tensor):
|
||
|
structure_probs = structure_probs.numpy()
|
||
|
if isinstance(bbox_preds, paddle.Tensor):
|
||
|
bbox_preds = bbox_preds.numpy()
|
||
|
shape_list = batch[-1]
|
||
|
result = self.decode(structure_probs, bbox_preds, shape_list)
|
||
|
if len(batch) == 1: # only contains shape
|
||
|
return result
|
||
|
|
||
|
label_decode_result = self.decode_label(batch)
|
||
|
return result, label_decode_result
|
||
|
|
||
|
def decode(self, structure_probs, bbox_preds, shape_list):
|
||
|
"""convert text-label into text-index."""
|
||
|
ignored_tokens = self.get_ignored_tokens()
|
||
|
end_idx = self.dict[self.end_str]
|
||
|
|
||
|
structure_idx = structure_probs.argmax(axis=2)
|
||
|
structure_probs = structure_probs.max(axis=2)
|
||
|
|
||
|
structure_batch_list = []
|
||
|
bbox_batch_list = []
|
||
|
batch_size = len(structure_idx)
|
||
|
for batch_idx in range(batch_size):
|
||
|
structure_list = []
|
||
|
bbox_list = []
|
||
|
score_list = []
|
||
|
for idx in range(len(structure_idx[batch_idx])):
|
||
|
char_idx = int(structure_idx[batch_idx][idx])
|
||
|
if idx > 0 and char_idx == end_idx:
|
||
|
break
|
||
|
if char_idx in ignored_tokens:
|
||
|
continue
|
||
|
text = self.character[char_idx]
|
||
|
if text in self.td_token:
|
||
|
bbox = bbox_preds[batch_idx, idx]
|
||
|
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
|
||
|
bbox_list.append(bbox)
|
||
|
structure_list.append(text)
|
||
|
score_list.append(structure_probs[batch_idx, idx])
|
||
|
structure_batch_list.append([structure_list, np.mean(score_list)])
|
||
|
bbox_batch_list.append(np.array(bbox_list))
|
||
|
result = {
|
||
|
"bbox_batch_list": bbox_batch_list,
|
||
|
"structure_batch_list": structure_batch_list,
|
||
|
}
|
||
|
return result
|
||
|
|
||
|
def decode_label(self, batch):
|
||
|
"""convert text-label into text-index."""
|
||
|
structure_idx = batch[1]
|
||
|
gt_bbox_list = batch[2]
|
||
|
shape_list = batch[-1]
|
||
|
ignored_tokens = self.get_ignored_tokens()
|
||
|
end_idx = self.dict[self.end_str]
|
||
|
|
||
|
structure_batch_list = []
|
||
|
bbox_batch_list = []
|
||
|
batch_size = len(structure_idx)
|
||
|
for batch_idx in range(batch_size):
|
||
|
structure_list = []
|
||
|
bbox_list = []
|
||
|
for idx in range(len(structure_idx[batch_idx])):
|
||
|
char_idx = int(structure_idx[batch_idx][idx])
|
||
|
if idx > 0 and char_idx == end_idx:
|
||
|
break
|
||
|
if char_idx in ignored_tokens:
|
||
|
continue
|
||
|
structure_list.append(self.character[char_idx])
|
||
|
|
||
|
bbox = gt_bbox_list[batch_idx][idx]
|
||
|
if bbox.sum() != 0:
|
||
|
bbox = self._bbox_decode(bbox, shape_list[batch_idx])
|
||
|
bbox_list.append(bbox)
|
||
|
structure_batch_list.append(structure_list)
|
||
|
bbox_batch_list.append(bbox_list)
|
||
|
result = {
|
||
|
"bbox_batch_list": bbox_batch_list,
|
||
|
"structure_batch_list": structure_batch_list,
|
||
|
}
|
||
|
return result
|
||
|
|
||
|
def _bbox_decode(self, bbox, shape):
|
||
|
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
|
||
|
bbox[0::2] *= w
|
||
|
bbox[1::2] *= h
|
||
|
return bbox
|
||
|
|
||
|
|
||
|
class TableMasterLabelDecode(TableLabelDecode):
|
||
|
""" """
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
character_dict_path,
|
||
|
box_shape="ori",
|
||
|
merge_no_span_structure=True,
|
||
|
**kwargs,
|
||
|
):
|
||
|
super(TableMasterLabelDecode, self).__init__(
|
||
|
character_dict_path, merge_no_span_structure
|
||
|
)
|
||
|
self.box_shape = box_shape
|
||
|
assert box_shape in [
|
||
|
"ori",
|
||
|
"pad",
|
||
|
], "The shape used for box normalization must be ori or pad"
|
||
|
|
||
|
def add_special_char(self, dict_character):
|
||
|
self.beg_str = "<SOS>"
|
||
|
self.end_str = "<EOS>"
|
||
|
self.unknown_str = "<UKN>"
|
||
|
self.pad_str = "<PAD>"
|
||
|
dict_character = dict_character
|
||
|
dict_character = dict_character + [
|
||
|
self.unknown_str,
|
||
|
self.beg_str,
|
||
|
self.end_str,
|
||
|
self.pad_str,
|
||
|
]
|
||
|
return dict_character
|
||
|
|
||
|
def get_ignored_tokens(self):
|
||
|
pad_idx = self.dict[self.pad_str]
|
||
|
start_idx = self.dict[self.beg_str]
|
||
|
end_idx = self.dict[self.end_str]
|
||
|
unknown_idx = self.dict[self.unknown_str]
|
||
|
return [start_idx, end_idx, pad_idx, unknown_idx]
|
||
|
|
||
|
def _bbox_decode(self, bbox, shape):
|
||
|
h, w, ratio_h, ratio_w, pad_h, pad_w = shape
|
||
|
if self.box_shape == "pad":
|
||
|
h, w = pad_h, pad_w
|
||
|
bbox[0::2] *= w
|
||
|
bbox[1::2] *= h
|
||
|
bbox[0::2] /= ratio_w
|
||
|
bbox[1::2] /= ratio_h
|
||
|
x, y, w, h = bbox
|
||
|
x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
|
||
|
bbox = np.array([x1, y1, x2, y2])
|
||
|
return bbox
|