You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
2.6 KiB
Python
77 lines
2.6 KiB
Python
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
class VQAReTokenRelation(object):
|
|
def __init__(self, **kwargs):
|
|
pass
|
|
|
|
def __call__(self, data):
|
|
"""
|
|
build relations
|
|
"""
|
|
entities = data["entities"]
|
|
relations = data["relations"]
|
|
id2label = data.pop("id2label")
|
|
empty_entity = data.pop("empty_entity")
|
|
entity_id_to_index_map = data.pop("entity_id_to_index_map")
|
|
|
|
relations = list(set(relations))
|
|
relations = [
|
|
rel
|
|
for rel in relations
|
|
if rel[0] not in empty_entity and rel[1] not in empty_entity
|
|
]
|
|
kv_relations = []
|
|
for rel in relations:
|
|
pair = [id2label[rel[0]], id2label[rel[1]]]
|
|
if pair == ["question", "answer"]:
|
|
kv_relations.append(
|
|
{
|
|
"head": entity_id_to_index_map[rel[0]],
|
|
"tail": entity_id_to_index_map[rel[1]],
|
|
}
|
|
)
|
|
elif pair == ["answer", "question"]:
|
|
kv_relations.append(
|
|
{
|
|
"head": entity_id_to_index_map[rel[1]],
|
|
"tail": entity_id_to_index_map[rel[0]],
|
|
}
|
|
)
|
|
else:
|
|
continue
|
|
relations = sorted(
|
|
[
|
|
{
|
|
"head": rel["head"],
|
|
"tail": rel["tail"],
|
|
"start_index": self.get_relation_span(rel, entities)[0],
|
|
"end_index": self.get_relation_span(rel, entities)[1],
|
|
}
|
|
for rel in kv_relations
|
|
],
|
|
key=lambda x: x["head"],
|
|
)
|
|
|
|
data["relations"] = relations
|
|
return data
|
|
|
|
def get_relation_span(self, rel, entities):
|
|
bound = []
|
|
for entity_index in [rel["head"], rel["tail"]]:
|
|
bound.append(entities[entity_index]["start"])
|
|
bound.append(entities[entity_index]["end"])
|
|
return min(bound), max(bound)
|