You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
505 lines
18 KiB
Python
505 lines
18 KiB
Python
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# Code was based on https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
|
# reference: https://arxiv.org/abs/2207.06966
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import math
|
|
import paddle
|
|
from paddle import nn, ParamAttr
|
|
from paddle.nn import functional as F
|
|
import numpy as np
|
|
from .self_attention import WrapEncoderForFeature
|
|
from .self_attention import WrapEncoder
|
|
from collections import OrderedDict
|
|
from typing import Optional
|
|
import copy
|
|
from itertools import permutations
|
|
|
|
|
|
class DecoderLayer(paddle.nn.Layer):
|
|
"""A Transformer decoder layer supporting two-stream attention (XLNet)
|
|
This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
|
|
|
|
def __init__(
|
|
self,
|
|
d_model,
|
|
nhead,
|
|
dim_feedforward=2048,
|
|
dropout=0.1,
|
|
activation="gelu",
|
|
layer_norm_eps=1e-05,
|
|
):
|
|
super().__init__()
|
|
self.self_attn = paddle.nn.MultiHeadAttention(
|
|
d_model, nhead, dropout=dropout, need_weights=True
|
|
) # paddle.nn.MultiHeadAttention默认为batch_first模式
|
|
self.cross_attn = paddle.nn.MultiHeadAttention(
|
|
d_model, nhead, dropout=dropout, need_weights=True
|
|
)
|
|
self.linear1 = paddle.nn.Linear(
|
|
in_features=d_model, out_features=dim_feedforward
|
|
)
|
|
self.dropout = paddle.nn.Dropout(p=dropout)
|
|
self.linear2 = paddle.nn.Linear(
|
|
in_features=dim_feedforward, out_features=d_model
|
|
)
|
|
self.norm1 = paddle.nn.LayerNorm(
|
|
normalized_shape=d_model, epsilon=layer_norm_eps
|
|
)
|
|
self.norm2 = paddle.nn.LayerNorm(
|
|
normalized_shape=d_model, epsilon=layer_norm_eps
|
|
)
|
|
self.norm_q = paddle.nn.LayerNorm(
|
|
normalized_shape=d_model, epsilon=layer_norm_eps
|
|
)
|
|
self.norm_c = paddle.nn.LayerNorm(
|
|
normalized_shape=d_model, epsilon=layer_norm_eps
|
|
)
|
|
self.dropout1 = paddle.nn.Dropout(p=dropout)
|
|
self.dropout2 = paddle.nn.Dropout(p=dropout)
|
|
self.dropout3 = paddle.nn.Dropout(p=dropout)
|
|
if activation == "gelu":
|
|
self.activation = paddle.nn.GELU()
|
|
|
|
def __setstate__(self, state):
|
|
if "activation" not in state:
|
|
state["activation"] = paddle.nn.functional.gelu
|
|
super().__setstate__(state)
|
|
|
|
def forward_stream(
|
|
self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask
|
|
):
|
|
"""Forward pass for a single stream (i.e. content or query)
|
|
tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
|
|
Both tgt_kv and memory are expected to be LayerNorm'd too.
|
|
memory is LayerNorm'd by ViT.
|
|
"""
|
|
if tgt_key_padding_mask is not None:
|
|
tgt_mask1 = (tgt_mask != float("-inf"))[None, None, :, :] & (
|
|
tgt_key_padding_mask[:, None, None, :] == False
|
|
)
|
|
tgt2, sa_weights = self.self_attn(
|
|
tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1
|
|
)
|
|
else:
|
|
tgt2, sa_weights = self.self_attn(
|
|
tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask
|
|
)
|
|
|
|
tgt = tgt + self.dropout1(tgt2)
|
|
tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
|
|
tgt = tgt + self.dropout2(tgt2)
|
|
tgt2 = self.linear2(
|
|
self.dropout(self.activation(self.linear1(self.norm2(tgt))))
|
|
)
|
|
tgt = tgt + self.dropout3(tgt2)
|
|
return tgt, sa_weights, ca_weights
|
|
|
|
def forward(
|
|
self,
|
|
query,
|
|
content,
|
|
memory,
|
|
query_mask=None,
|
|
content_mask=None,
|
|
content_key_padding_mask=None,
|
|
update_content=True,
|
|
):
|
|
query_norm = self.norm_q(query)
|
|
content_norm = self.norm_c(content)
|
|
query = self.forward_stream(
|
|
query,
|
|
query_norm,
|
|
content_norm,
|
|
memory,
|
|
query_mask,
|
|
content_key_padding_mask,
|
|
)[0]
|
|
if update_content:
|
|
content = self.forward_stream(
|
|
content,
|
|
content_norm,
|
|
content_norm,
|
|
memory,
|
|
content_mask,
|
|
content_key_padding_mask,
|
|
)[0]
|
|
return query, content
|
|
|
|
|
|
def get_clones(module, N):
|
|
return paddle.nn.LayerList([copy.deepcopy(module) for i in range(N)])
|
|
|
|
|
|
class Decoder(paddle.nn.Layer):
|
|
__constants__ = ["norm"]
|
|
|
|
def __init__(self, decoder_layer, num_layers, norm):
|
|
super().__init__()
|
|
self.layers = get_clones(decoder_layer, num_layers)
|
|
self.num_layers = num_layers
|
|
self.norm = norm
|
|
|
|
def forward(
|
|
self,
|
|
query,
|
|
content,
|
|
memory,
|
|
query_mask: Optional[paddle.Tensor] = None,
|
|
content_mask: Optional[paddle.Tensor] = None,
|
|
content_key_padding_mask: Optional[paddle.Tensor] = None,
|
|
):
|
|
for i, mod in enumerate(self.layers):
|
|
last = i == len(self.layers) - 1
|
|
query, content = mod(
|
|
query,
|
|
content,
|
|
memory,
|
|
query_mask,
|
|
content_mask,
|
|
content_key_padding_mask,
|
|
update_content=not last,
|
|
)
|
|
query = self.norm(query)
|
|
return query
|
|
|
|
|
|
class TokenEmbedding(paddle.nn.Layer):
|
|
def __init__(self, charset_size: int, embed_dim: int):
|
|
super().__init__()
|
|
self.embedding = paddle.nn.Embedding(
|
|
num_embeddings=charset_size, embedding_dim=embed_dim
|
|
)
|
|
self.embed_dim = embed_dim
|
|
|
|
def forward(self, tokens: paddle.Tensor):
|
|
return math.sqrt(self.embed_dim) * self.embedding(tokens.astype(paddle.int64))
|
|
|
|
|
|
def trunc_normal_init(param, **kwargs):
|
|
initializer = nn.initializer.TruncatedNormal(**kwargs)
|
|
initializer(param, param.block)
|
|
|
|
|
|
def constant_init(param, **kwargs):
|
|
initializer = nn.initializer.Constant(**kwargs)
|
|
initializer(param, param.block)
|
|
|
|
|
|
def kaiming_normal_init(param, **kwargs):
|
|
initializer = nn.initializer.KaimingNormal(**kwargs)
|
|
initializer(param, param.block)
|
|
|
|
|
|
class ParseQHead(nn.Layer):
|
|
def __init__(
|
|
self,
|
|
out_channels,
|
|
max_text_length,
|
|
embed_dim,
|
|
dec_num_heads,
|
|
dec_mlp_ratio,
|
|
dec_depth,
|
|
perm_num,
|
|
perm_forward,
|
|
perm_mirrored,
|
|
decode_ar,
|
|
refine_iters,
|
|
dropout,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
|
|
self.bos_id = out_channels - 2
|
|
self.eos_id = 0
|
|
self.pad_id = out_channels - 1
|
|
|
|
self.max_label_length = max_text_length
|
|
self.decode_ar = decode_ar
|
|
self.refine_iters = refine_iters
|
|
decoder_layer = DecoderLayer(
|
|
embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout
|
|
)
|
|
self.decoder = Decoder(
|
|
decoder_layer,
|
|
num_layers=dec_depth,
|
|
norm=paddle.nn.LayerNorm(normalized_shape=embed_dim),
|
|
)
|
|
self.rng = np.random.default_rng()
|
|
self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
|
|
self.perm_forward = perm_forward
|
|
self.perm_mirrored = perm_mirrored
|
|
self.head = paddle.nn.Linear(
|
|
in_features=embed_dim, out_features=out_channels - 2
|
|
)
|
|
self.text_embed = TokenEmbedding(out_channels, embed_dim)
|
|
self.pos_queries = paddle.create_parameter(
|
|
shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape,
|
|
dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype,
|
|
default_initializer=paddle.nn.initializer.Assign(
|
|
paddle.empty(shape=[1, max_text_length + 1, embed_dim])
|
|
),
|
|
)
|
|
self.pos_queries.stop_gradient = not True
|
|
self.dropout = paddle.nn.Dropout(p=dropout)
|
|
self._device = self.parameters()[0].place
|
|
trunc_normal_init(self.pos_queries, std=0.02)
|
|
self.apply(self._init_weights)
|
|
|
|
def _init_weights(self, m):
|
|
if isinstance(m, paddle.nn.Linear):
|
|
trunc_normal_init(m.weight, std=0.02)
|
|
if m.bias is not None:
|
|
constant_init(m.bias, value=0.0)
|
|
elif isinstance(m, paddle.nn.Embedding):
|
|
trunc_normal_init(m.weight, std=0.02)
|
|
if m._padding_idx is not None:
|
|
m.weight.data[m._padding_idx].zero_()
|
|
elif isinstance(m, paddle.nn.Conv2D):
|
|
kaiming_normal_init(m.weight, fan_in=None, nonlinearity="relu")
|
|
if m.bias is not None:
|
|
constant_init(m.bias, value=0.0)
|
|
elif isinstance(
|
|
m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)
|
|
):
|
|
constant_init(m.weight, value=1.0)
|
|
constant_init(m.bias, value=0.0)
|
|
|
|
def no_weight_decay(self):
|
|
param_names = {"text_embed.embedding.weight", "pos_queries"}
|
|
enc_param_names = {("encoder." + n) for n in self.encoder.no_weight_decay()}
|
|
return param_names.union(enc_param_names)
|
|
|
|
def encode(self, img):
|
|
return self.encoder(img)
|
|
|
|
def decode(
|
|
self,
|
|
tgt,
|
|
memory,
|
|
tgt_mask=None,
|
|
tgt_padding_mask=None,
|
|
tgt_query=None,
|
|
tgt_query_mask=None,
|
|
):
|
|
N, L = tgt.shape
|
|
null_ctx = self.text_embed(tgt[:, :1])
|
|
if L != 1:
|
|
tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
|
|
tgt_emb = self.dropout(paddle.concat(x=[null_ctx, tgt_emb], axis=1))
|
|
else:
|
|
tgt_emb = self.dropout(null_ctx)
|
|
if tgt_query is None:
|
|
tgt_query = self.pos_queries[:, :L].expand(shape=[N, -1, -1])
|
|
tgt_query = self.dropout(tgt_query)
|
|
return self.decoder(
|
|
tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask
|
|
)
|
|
|
|
def forward_test(self, memory, max_length=None):
|
|
testing = max_length is None
|
|
max_length = (
|
|
self.max_label_length
|
|
if max_length is None
|
|
else min(max_length, self.max_label_length)
|
|
)
|
|
bs = memory.shape[0]
|
|
num_steps = max_length + 1
|
|
|
|
pos_queries = self.pos_queries[:, :num_steps].expand(shape=[bs, -1, -1])
|
|
tgt_mask = query_mask = paddle.triu(
|
|
x=paddle.full(shape=(num_steps, num_steps), fill_value=float("-inf")),
|
|
diagonal=1,
|
|
)
|
|
if self.decode_ar:
|
|
tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype(
|
|
"int64"
|
|
)
|
|
tgt_in[:, (0)] = self.bos_id
|
|
|
|
logits = []
|
|
for i in range(paddle.to_tensor(num_steps)):
|
|
j = i + 1
|
|
tgt_out = self.decode(
|
|
tgt_in[:, :j],
|
|
memory,
|
|
tgt_mask[:j, :j],
|
|
tgt_query=pos_queries[:, i:j],
|
|
tgt_query_mask=query_mask[i:j, :j],
|
|
)
|
|
p_i = self.head(tgt_out)
|
|
logits.append(p_i)
|
|
if j < num_steps:
|
|
tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
|
|
if (
|
|
testing
|
|
and (tgt_in == self.eos_id)
|
|
.astype("bool")
|
|
.any(axis=-1)
|
|
.astype("bool")
|
|
.all()
|
|
):
|
|
break
|
|
logits = paddle.concat(x=logits, axis=1)
|
|
else:
|
|
tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype("int64")
|
|
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
|
|
logits = self.head(tgt_out)
|
|
if self.refine_iters:
|
|
temp = paddle.triu(
|
|
x=paddle.ones(shape=[num_steps, num_steps], dtype="bool"), diagonal=2
|
|
)
|
|
posi = np.where(temp.cpu().numpy() == True)
|
|
query_mask[posi] = 0
|
|
bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype("int64")
|
|
for i in range(self.refine_iters):
|
|
tgt_in = paddle.concat(x=[bos, logits[:, :-1].argmax(axis=-1)], axis=1)
|
|
tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype="int32")
|
|
tgt_padding_mask = tgt_padding_mask.cpu()
|
|
tgt_padding_mask = tgt_padding_mask.cumsum(axis=-1) > 0
|
|
tgt_padding_mask = (
|
|
tgt_padding_mask.cuda().astype(dtype="float32") == 1.0
|
|
)
|
|
tgt_out = self.decode(
|
|
tgt_in,
|
|
memory,
|
|
tgt_mask,
|
|
tgt_padding_mask,
|
|
tgt_query=pos_queries,
|
|
tgt_query_mask=query_mask[:, : tgt_in.shape[1]],
|
|
)
|
|
logits = self.head(tgt_out)
|
|
|
|
# transfer to probility
|
|
logits = F.softmax(logits, axis=-1)
|
|
|
|
final_output = {"predict": logits}
|
|
|
|
return final_output
|
|
|
|
def gen_tgt_perms(self, tgt):
|
|
"""Generate shared permutations for the whole batch.
|
|
This works because the same attention mask can be used for the shorter sequences
|
|
because of the padding mask.
|
|
"""
|
|
max_num_chars = tgt.shape[1] - 2
|
|
if max_num_chars == 1:
|
|
return paddle.arange(end=3).unsqueeze(axis=0)
|
|
perms = [paddle.arange(end=max_num_chars)] if self.perm_forward else []
|
|
max_perms = math.factorial(max_num_chars)
|
|
if self.perm_mirrored:
|
|
max_perms //= 2
|
|
num_gen_perms = min(self.max_gen_perms, max_perms)
|
|
if max_num_chars < 5:
|
|
if max_num_chars == 4 and self.perm_mirrored:
|
|
selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
|
|
else:
|
|
selector = list(range(max_perms))
|
|
perm_pool = paddle.to_tensor(
|
|
data=list(permutations(range(max_num_chars), max_num_chars)),
|
|
place=self._device,
|
|
)[selector]
|
|
if self.perm_forward:
|
|
perm_pool = perm_pool[1:]
|
|
perms = paddle.stack(x=perms)
|
|
if len(perm_pool):
|
|
i = self.rng.choice(
|
|
len(perm_pool), size=num_gen_perms - len(perms), replace=False
|
|
)
|
|
perms = paddle.concat(x=[perms, perm_pool[i]])
|
|
else:
|
|
perms.extend(
|
|
[
|
|
paddle.randperm(n=max_num_chars)
|
|
for _ in range(num_gen_perms - len(perms))
|
|
]
|
|
)
|
|
perms = paddle.stack(x=perms)
|
|
if self.perm_mirrored:
|
|
comp = perms.flip(axis=-1)
|
|
x = paddle.stack(x=[perms, comp])
|
|
perm_2 = list(range(x.ndim))
|
|
perm_2[0] = 1
|
|
perm_2[1] = 0
|
|
perms = x.transpose(perm=perm_2).reshape((-1, max_num_chars))
|
|
bos_idx = paddle.zeros(shape=(len(perms), 1), dtype=perms.dtype)
|
|
eos_idx = paddle.full(
|
|
shape=(len(perms), 1), fill_value=max_num_chars + 1, dtype=perms.dtype
|
|
)
|
|
perms = paddle.concat(x=[bos_idx, perms + 1, eos_idx], axis=1)
|
|
if len(perms) > 1:
|
|
perms[(1), 1:] = max_num_chars + 1 - paddle.arange(end=max_num_chars + 1)
|
|
return perms
|
|
|
|
def generate_attn_masks(self, perm):
|
|
"""Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
|
|
:param perm: the permutation sequence. i = 0 is always the BOS
|
|
:return: lookahead attention masks
|
|
"""
|
|
sz = perm.shape[0]
|
|
mask = paddle.zeros(shape=(sz, sz))
|
|
for i in range(sz):
|
|
query_idx = perm[i].cpu().numpy().tolist()
|
|
masked_keys = perm[i + 1 :].cpu().numpy().tolist()
|
|
if len(masked_keys) == 0:
|
|
break
|
|
mask[query_idx, masked_keys] = float("-inf")
|
|
content_mask = mask[:-1, :-1].clone()
|
|
mask[paddle.eye(num_rows=sz).astype("bool")] = float("-inf")
|
|
query_mask = mask[1:, :-1]
|
|
return content_mask, query_mask
|
|
|
|
def forward_train(self, memory, tgt):
|
|
tgt_perms = self.gen_tgt_perms(tgt)
|
|
tgt_in = tgt[:, :-1]
|
|
tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
|
|
logits_list = []
|
|
final_out = {}
|
|
for i, perm in enumerate(tgt_perms):
|
|
tgt_mask, query_mask = self.generate_attn_masks(perm)
|
|
out = self.decode(
|
|
tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask
|
|
)
|
|
logits = self.head(out)
|
|
if i == 0:
|
|
final_out["predict"] = logits
|
|
logits = logits.flatten(stop_axis=1)
|
|
logits_list.append(logits)
|
|
|
|
final_out["logits_list"] = logits_list
|
|
final_out["pad_id"] = self.pad_id
|
|
final_out["eos_id"] = self.eos_id
|
|
|
|
return final_out
|
|
|
|
def forward(self, feat, targets=None):
|
|
# feat : B, N, C
|
|
# targets : labels, labels_len
|
|
|
|
if self.training:
|
|
label = targets[0] # label
|
|
label_len = targets[1]
|
|
max_step = paddle.max(label_len).cpu().numpy()[0] + 2
|
|
crop_label = label[:, :max_step]
|
|
final_out = self.forward_train(feat, crop_label)
|
|
else:
|
|
final_out = self.forward_test(feat)
|
|
|
|
return final_out
|