You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

505 lines
18 KiB
Python

8 months ago
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Code was based on https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
# reference: https://arxiv.org/abs/2207.06966
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
from paddle import nn, ParamAttr
from paddle.nn import functional as F
import numpy as np
from .self_attention import WrapEncoderForFeature
from .self_attention import WrapEncoder
from collections import OrderedDict
from typing import Optional
import copy
from itertools import permutations
class DecoderLayer(paddle.nn.Layer):
"""A Transformer decoder layer supporting two-stream attention (XLNet)
This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch."""
def __init__(
self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation="gelu",
layer_norm_eps=1e-05,
):
super().__init__()
self.self_attn = paddle.nn.MultiHeadAttention(
d_model, nhead, dropout=dropout, need_weights=True
) # paddle.nn.MultiHeadAttention默认为batch_first模式
self.cross_attn = paddle.nn.MultiHeadAttention(
d_model, nhead, dropout=dropout, need_weights=True
)
self.linear1 = paddle.nn.Linear(
in_features=d_model, out_features=dim_feedforward
)
self.dropout = paddle.nn.Dropout(p=dropout)
self.linear2 = paddle.nn.Linear(
in_features=dim_feedforward, out_features=d_model
)
self.norm1 = paddle.nn.LayerNorm(
normalized_shape=d_model, epsilon=layer_norm_eps
)
self.norm2 = paddle.nn.LayerNorm(
normalized_shape=d_model, epsilon=layer_norm_eps
)
self.norm_q = paddle.nn.LayerNorm(
normalized_shape=d_model, epsilon=layer_norm_eps
)
self.norm_c = paddle.nn.LayerNorm(
normalized_shape=d_model, epsilon=layer_norm_eps
)
self.dropout1 = paddle.nn.Dropout(p=dropout)
self.dropout2 = paddle.nn.Dropout(p=dropout)
self.dropout3 = paddle.nn.Dropout(p=dropout)
if activation == "gelu":
self.activation = paddle.nn.GELU()
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = paddle.nn.functional.gelu
super().__setstate__(state)
def forward_stream(
self, tgt, tgt_norm, tgt_kv, memory, tgt_mask, tgt_key_padding_mask
):
"""Forward pass for a single stream (i.e. content or query)
tgt_norm is just a LayerNorm'd tgt. Added as a separate parameter for efficiency.
Both tgt_kv and memory are expected to be LayerNorm'd too.
memory is LayerNorm'd by ViT.
"""
if tgt_key_padding_mask is not None:
tgt_mask1 = (tgt_mask != float("-inf"))[None, None, :, :] & (
tgt_key_padding_mask[:, None, None, :] == False
)
tgt2, sa_weights = self.self_attn(
tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask1
)
else:
tgt2, sa_weights = self.self_attn(
tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask
)
tgt = tgt + self.dropout1(tgt2)
tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory)
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.linear2(
self.dropout(self.activation(self.linear1(self.norm2(tgt))))
)
tgt = tgt + self.dropout3(tgt2)
return tgt, sa_weights, ca_weights
def forward(
self,
query,
content,
memory,
query_mask=None,
content_mask=None,
content_key_padding_mask=None,
update_content=True,
):
query_norm = self.norm_q(query)
content_norm = self.norm_c(content)
query = self.forward_stream(
query,
query_norm,
content_norm,
memory,
query_mask,
content_key_padding_mask,
)[0]
if update_content:
content = self.forward_stream(
content,
content_norm,
content_norm,
memory,
content_mask,
content_key_padding_mask,
)[0]
return query, content
def get_clones(module, N):
return paddle.nn.LayerList([copy.deepcopy(module) for i in range(N)])
class Decoder(paddle.nn.Layer):
__constants__ = ["norm"]
def __init__(self, decoder_layer, num_layers, norm):
super().__init__()
self.layers = get_clones(decoder_layer, num_layers)
self.num_layers = num_layers
self.norm = norm
def forward(
self,
query,
content,
memory,
query_mask: Optional[paddle.Tensor] = None,
content_mask: Optional[paddle.Tensor] = None,
content_key_padding_mask: Optional[paddle.Tensor] = None,
):
for i, mod in enumerate(self.layers):
last = i == len(self.layers) - 1
query, content = mod(
query,
content,
memory,
query_mask,
content_mask,
content_key_padding_mask,
update_content=not last,
)
query = self.norm(query)
return query
class TokenEmbedding(paddle.nn.Layer):
def __init__(self, charset_size: int, embed_dim: int):
super().__init__()
self.embedding = paddle.nn.Embedding(
num_embeddings=charset_size, embedding_dim=embed_dim
)
self.embed_dim = embed_dim
def forward(self, tokens: paddle.Tensor):
return math.sqrt(self.embed_dim) * self.embedding(tokens.astype(paddle.int64))
def trunc_normal_init(param, **kwargs):
initializer = nn.initializer.TruncatedNormal(**kwargs)
initializer(param, param.block)
def constant_init(param, **kwargs):
initializer = nn.initializer.Constant(**kwargs)
initializer(param, param.block)
def kaiming_normal_init(param, **kwargs):
initializer = nn.initializer.KaimingNormal(**kwargs)
initializer(param, param.block)
class ParseQHead(nn.Layer):
def __init__(
self,
out_channels,
max_text_length,
embed_dim,
dec_num_heads,
dec_mlp_ratio,
dec_depth,
perm_num,
perm_forward,
perm_mirrored,
decode_ar,
refine_iters,
dropout,
**kwargs,
):
super().__init__()
self.bos_id = out_channels - 2
self.eos_id = 0
self.pad_id = out_channels - 1
self.max_label_length = max_text_length
self.decode_ar = decode_ar
self.refine_iters = refine_iters
decoder_layer = DecoderLayer(
embed_dim, dec_num_heads, embed_dim * dec_mlp_ratio, dropout
)
self.decoder = Decoder(
decoder_layer,
num_layers=dec_depth,
norm=paddle.nn.LayerNorm(normalized_shape=embed_dim),
)
self.rng = np.random.default_rng()
self.max_gen_perms = perm_num // 2 if perm_mirrored else perm_num
self.perm_forward = perm_forward
self.perm_mirrored = perm_mirrored
self.head = paddle.nn.Linear(
in_features=embed_dim, out_features=out_channels - 2
)
self.text_embed = TokenEmbedding(out_channels, embed_dim)
self.pos_queries = paddle.create_parameter(
shape=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).shape,
dtype=paddle.empty(shape=[1, max_text_length + 1, embed_dim]).numpy().dtype,
default_initializer=paddle.nn.initializer.Assign(
paddle.empty(shape=[1, max_text_length + 1, embed_dim])
),
)
self.pos_queries.stop_gradient = not True
self.dropout = paddle.nn.Dropout(p=dropout)
self._device = self.parameters()[0].place
trunc_normal_init(self.pos_queries, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, paddle.nn.Linear):
trunc_normal_init(m.weight, std=0.02)
if m.bias is not None:
constant_init(m.bias, value=0.0)
elif isinstance(m, paddle.nn.Embedding):
trunc_normal_init(m.weight, std=0.02)
if m._padding_idx is not None:
m.weight.data[m._padding_idx].zero_()
elif isinstance(m, paddle.nn.Conv2D):
kaiming_normal_init(m.weight, fan_in=None, nonlinearity="relu")
if m.bias is not None:
constant_init(m.bias, value=0.0)
elif isinstance(
m, (paddle.nn.LayerNorm, paddle.nn.BatchNorm2D, paddle.nn.GroupNorm)
):
constant_init(m.weight, value=1.0)
constant_init(m.bias, value=0.0)
def no_weight_decay(self):
param_names = {"text_embed.embedding.weight", "pos_queries"}
enc_param_names = {("encoder." + n) for n in self.encoder.no_weight_decay()}
return param_names.union(enc_param_names)
def encode(self, img):
return self.encoder(img)
def decode(
self,
tgt,
memory,
tgt_mask=None,
tgt_padding_mask=None,
tgt_query=None,
tgt_query_mask=None,
):
N, L = tgt.shape
null_ctx = self.text_embed(tgt[:, :1])
if L != 1:
tgt_emb = self.pos_queries[:, : L - 1] + self.text_embed(tgt[:, 1:])
tgt_emb = self.dropout(paddle.concat(x=[null_ctx, tgt_emb], axis=1))
else:
tgt_emb = self.dropout(null_ctx)
if tgt_query is None:
tgt_query = self.pos_queries[:, :L].expand(shape=[N, -1, -1])
tgt_query = self.dropout(tgt_query)
return self.decoder(
tgt_query, tgt_emb, memory, tgt_query_mask, tgt_mask, tgt_padding_mask
)
def forward_test(self, memory, max_length=None):
testing = max_length is None
max_length = (
self.max_label_length
if max_length is None
else min(max_length, self.max_label_length)
)
bs = memory.shape[0]
num_steps = max_length + 1
pos_queries = self.pos_queries[:, :num_steps].expand(shape=[bs, -1, -1])
tgt_mask = query_mask = paddle.triu(
x=paddle.full(shape=(num_steps, num_steps), fill_value=float("-inf")),
diagonal=1,
)
if self.decode_ar:
tgt_in = paddle.full(shape=(bs, num_steps), fill_value=self.pad_id).astype(
"int64"
)
tgt_in[:, (0)] = self.bos_id
logits = []
for i in range(paddle.to_tensor(num_steps)):
j = i + 1
tgt_out = self.decode(
tgt_in[:, :j],
memory,
tgt_mask[:j, :j],
tgt_query=pos_queries[:, i:j],
tgt_query_mask=query_mask[i:j, :j],
)
p_i = self.head(tgt_out)
logits.append(p_i)
if j < num_steps:
tgt_in[:, (j)] = p_i.squeeze().argmax(axis=-1)
if (
testing
and (tgt_in == self.eos_id)
.astype("bool")
.any(axis=-1)
.astype("bool")
.all()
):
break
logits = paddle.concat(x=logits, axis=1)
else:
tgt_in = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype("int64")
tgt_out = self.decode(tgt_in, memory, tgt_query=pos_queries)
logits = self.head(tgt_out)
if self.refine_iters:
temp = paddle.triu(
x=paddle.ones(shape=[num_steps, num_steps], dtype="bool"), diagonal=2
)
posi = np.where(temp.cpu().numpy() == True)
query_mask[posi] = 0
bos = paddle.full(shape=(bs, 1), fill_value=self.bos_id).astype("int64")
for i in range(self.refine_iters):
tgt_in = paddle.concat(x=[bos, logits[:, :-1].argmax(axis=-1)], axis=1)
tgt_padding_mask = (tgt_in == self.eos_id).astype(dtype="int32")
tgt_padding_mask = tgt_padding_mask.cpu()
tgt_padding_mask = tgt_padding_mask.cumsum(axis=-1) > 0
tgt_padding_mask = (
tgt_padding_mask.cuda().astype(dtype="float32") == 1.0
)
tgt_out = self.decode(
tgt_in,
memory,
tgt_mask,
tgt_padding_mask,
tgt_query=pos_queries,
tgt_query_mask=query_mask[:, : tgt_in.shape[1]],
)
logits = self.head(tgt_out)
# transfer to probility
logits = F.softmax(logits, axis=-1)
final_output = {"predict": logits}
return final_output
def gen_tgt_perms(self, tgt):
"""Generate shared permutations for the whole batch.
This works because the same attention mask can be used for the shorter sequences
because of the padding mask.
"""
max_num_chars = tgt.shape[1] - 2
if max_num_chars == 1:
return paddle.arange(end=3).unsqueeze(axis=0)
perms = [paddle.arange(end=max_num_chars)] if self.perm_forward else []
max_perms = math.factorial(max_num_chars)
if self.perm_mirrored:
max_perms //= 2
num_gen_perms = min(self.max_gen_perms, max_perms)
if max_num_chars < 5:
if max_num_chars == 4 and self.perm_mirrored:
selector = [0, 3, 4, 6, 9, 10, 12, 16, 17, 18, 19, 21]
else:
selector = list(range(max_perms))
perm_pool = paddle.to_tensor(
data=list(permutations(range(max_num_chars), max_num_chars)),
place=self._device,
)[selector]
if self.perm_forward:
perm_pool = perm_pool[1:]
perms = paddle.stack(x=perms)
if len(perm_pool):
i = self.rng.choice(
len(perm_pool), size=num_gen_perms - len(perms), replace=False
)
perms = paddle.concat(x=[perms, perm_pool[i]])
else:
perms.extend(
[
paddle.randperm(n=max_num_chars)
for _ in range(num_gen_perms - len(perms))
]
)
perms = paddle.stack(x=perms)
if self.perm_mirrored:
comp = perms.flip(axis=-1)
x = paddle.stack(x=[perms, comp])
perm_2 = list(range(x.ndim))
perm_2[0] = 1
perm_2[1] = 0
perms = x.transpose(perm=perm_2).reshape((-1, max_num_chars))
bos_idx = paddle.zeros(shape=(len(perms), 1), dtype=perms.dtype)
eos_idx = paddle.full(
shape=(len(perms), 1), fill_value=max_num_chars + 1, dtype=perms.dtype
)
perms = paddle.concat(x=[bos_idx, perms + 1, eos_idx], axis=1)
if len(perms) > 1:
perms[(1), 1:] = max_num_chars + 1 - paddle.arange(end=max_num_chars + 1)
return perms
def generate_attn_masks(self, perm):
"""Generate attention masks given a sequence permutation (includes pos. for bos and eos tokens)
:param perm: the permutation sequence. i = 0 is always the BOS
:return: lookahead attention masks
"""
sz = perm.shape[0]
mask = paddle.zeros(shape=(sz, sz))
for i in range(sz):
query_idx = perm[i].cpu().numpy().tolist()
masked_keys = perm[i + 1 :].cpu().numpy().tolist()
if len(masked_keys) == 0:
break
mask[query_idx, masked_keys] = float("-inf")
content_mask = mask[:-1, :-1].clone()
mask[paddle.eye(num_rows=sz).astype("bool")] = float("-inf")
query_mask = mask[1:, :-1]
return content_mask, query_mask
def forward_train(self, memory, tgt):
tgt_perms = self.gen_tgt_perms(tgt)
tgt_in = tgt[:, :-1]
tgt_padding_mask = (tgt_in == self.pad_id) | (tgt_in == self.eos_id)
logits_list = []
final_out = {}
for i, perm in enumerate(tgt_perms):
tgt_mask, query_mask = self.generate_attn_masks(perm)
out = self.decode(
tgt_in, memory, tgt_mask, tgt_padding_mask, tgt_query_mask=query_mask
)
logits = self.head(out)
if i == 0:
final_out["predict"] = logits
logits = logits.flatten(stop_axis=1)
logits_list.append(logits)
final_out["logits_list"] = logits_list
final_out["pad_id"] = self.pad_id
final_out["eos_id"] = self.eos_id
return final_out
def forward(self, feat, targets=None):
# feat : B, N, C
# targets : labels, labels_len
if self.training:
label = targets[0] # label
label_len = targets[1]
max_step = paddle.max(label_len).cpu().numpy()[0] + 2
crop_label = label[:, :max_step]
final_out = self.forward_train(feat, crop_label)
else:
final_out = self.forward_test(feat)
return final_out