# copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is refer from:
https://github.com/LBH1024/CAN/models/can.py
https://github.com/LBH1024/CAN/models/counting.py
https://github.com/LBH1024/CAN/models/decoder.py
https://github.com/LBH1024/CAN/models/attention.py

"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import paddle.nn as nn
import paddle
import math

"""
Counting Module
"""


class ChannelAtt(nn.Layer):
    def __init__(self, channel, reduction):
        super(ChannelAtt, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2D(1)

        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction),
            nn.ReLU(),
            nn.Linear(channel // reduction, channel),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.shape
        y = paddle.reshape(self.avg_pool(x), [b, c])
        y = paddle.reshape(self.fc(y), [b, c, 1, 1])
        return x * y


class CountingDecoder(nn.Layer):
    def __init__(self, in_channel, out_channel, kernel_size):
        super(CountingDecoder, self).__init__()
        self.in_channel = in_channel
        self.out_channel = out_channel

        self.trans_layer = nn.Sequential(
            nn.Conv2D(
                self.in_channel,
                512,
                kernel_size=kernel_size,
                padding=kernel_size // 2,
                bias_attr=False,
            ),
            nn.BatchNorm2D(512),
        )

        self.channel_att = ChannelAtt(512, 16)

        self.pred_layer = nn.Sequential(
            nn.Conv2D(512, self.out_channel, kernel_size=1, bias_attr=False),
            nn.Sigmoid(),
        )

    def forward(self, x, mask):
        b, _, h, w = x.shape
        x = self.trans_layer(x)
        x = self.channel_att(x)
        x = self.pred_layer(x)

        if mask is not None:
            x = x * mask
        x = paddle.reshape(x, [b, self.out_channel, -1])
        x1 = paddle.sum(x, axis=-1)

        return x1, paddle.reshape(x, [b, self.out_channel, h, w])


"""
Attention Decoder
"""


class PositionEmbeddingSine(nn.Layer):
    def __init__(
        self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
    ):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        if scale is not None and normalize is False:
            raise ValueError("normalize should be True if scale is passed")
        if scale is None:
            scale = 2 * math.pi
        self.scale = scale

    def forward(self, x, mask):
        y_embed = paddle.cumsum(mask, 1, dtype="float32")
        x_embed = paddle.cumsum(mask, 2, dtype="float32")

        if self.normalize:
            eps = 1e-6
            y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
        dim_t = paddle.arange(self.num_pos_feats, dtype="float32")
        dim_d = paddle.expand(paddle.to_tensor(2), dim_t.shape)
        dim_t = self.temperature ** (
            2 * (dim_t / dim_d).astype("int64") / self.num_pos_feats
        )

        pos_x = paddle.unsqueeze(x_embed, [3]) / dim_t
        pos_y = paddle.unsqueeze(y_embed, [3]) / dim_t

        pos_x = paddle.flatten(
            paddle.stack(
                [paddle.sin(pos_x[:, :, :, 0::2]), paddle.cos(pos_x[:, :, :, 1::2])],
                axis=4,
            ),
            3,
        )
        pos_y = paddle.flatten(
            paddle.stack(
                [paddle.sin(pos_y[:, :, :, 0::2]), paddle.cos(pos_y[:, :, :, 1::2])],
                axis=4,
            ),
            3,
        )

        pos = paddle.transpose(paddle.concat([pos_y, pos_x], axis=3), [0, 3, 1, 2])

        return pos


class AttDecoder(nn.Layer):
    def __init__(
        self,
        ratio,
        is_train,
        input_size,
        hidden_size,
        encoder_out_channel,
        dropout,
        dropout_ratio,
        word_num,
        counting_decoder_out_channel,
        attention,
    ):
        super(AttDecoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.out_channel = encoder_out_channel
        self.attention_dim = attention["attention_dim"]
        self.dropout_prob = dropout
        self.ratio = ratio
        self.word_num = word_num

        self.counting_num = counting_decoder_out_channel
        self.is_train = is_train

        self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
        self.embedding = nn.Embedding(self.word_num, self.input_size)
        self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
        self.word_attention = Attention(hidden_size, attention["attention_dim"])

        self.encoder_feature_conv = nn.Conv2D(
            self.out_channel,
            self.attention_dim,
            kernel_size=attention["word_conv_kernel"],
            padding=attention["word_conv_kernel"] // 2,
        )

        self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
        self.word_embedding_weight = nn.Linear(self.input_size, self.hidden_size)
        self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
        self.counting_context_weight = nn.Linear(self.counting_num, self.hidden_size)
        self.word_convert = nn.Linear(self.hidden_size, self.word_num)

        if dropout:
            self.dropout = nn.Dropout(dropout_ratio)

    def forward(self, cnn_features, labels, counting_preds, images_mask):
        if self.is_train:
            _, num_steps = labels.shape
        else:
            num_steps = 36

        batch_size, _, height, width = cnn_features.shape
        images_mask = images_mask[:, :, :: self.ratio, :: self.ratio]

        word_probs = paddle.zeros((batch_size, num_steps, self.word_num))
        word_alpha_sum = paddle.zeros((batch_size, 1, height, width))

        hidden = self.init_hidden(cnn_features, images_mask)
        counting_context_weighted = self.counting_context_weight(counting_preds)
        cnn_features_trans = self.encoder_feature_conv(cnn_features)

        position_embedding = PositionEmbeddingSine(256, normalize=True)
        pos = position_embedding(cnn_features_trans, images_mask[:, 0, :, :])

        cnn_features_trans = cnn_features_trans + pos

        word = paddle.ones([batch_size, 1], dtype="int64")  # init word as sos
        word = word.squeeze(axis=1)
        for i in range(num_steps):
            word_embedding = self.embedding(word)
            _, hidden = self.word_input_gru(word_embedding, hidden)
            word_context_vec, _, word_alpha_sum = self.word_attention(
                cnn_features, cnn_features_trans, hidden, word_alpha_sum, images_mask
            )

            current_state = self.word_state_weight(hidden)
            word_weighted_embedding = self.word_embedding_weight(word_embedding)
            word_context_weighted = self.word_context_weight(word_context_vec)

            if self.dropout_prob:
                word_out_state = self.dropout(
                    current_state
                    + word_weighted_embedding
                    + word_context_weighted
                    + counting_context_weighted
                )
            else:
                word_out_state = (
                    current_state
                    + word_weighted_embedding
                    + word_context_weighted
                    + counting_context_weighted
                )

            word_prob = self.word_convert(word_out_state)
            word_probs[:, i] = word_prob

            if self.is_train:
                word = labels[:, i]
            else:
                word = word_prob.argmax(1)
                word = paddle.multiply(
                    word, labels[:, i]
                )  # labels are oneslike tensor in infer/predict mode

        return word_probs

    def init_hidden(self, features, feature_mask):
        average = paddle.sum(
            paddle.sum(features * feature_mask, axis=-1), axis=-1
        ) / paddle.sum((paddle.sum(feature_mask, axis=-1)), axis=-1)
        average = self.init_weight(average)
        return paddle.tanh(average)


"""
Attention Module
"""


class Attention(nn.Layer):
    def __init__(self, hidden_size, attention_dim):
        super(Attention, self).__init__()
        self.hidden = hidden_size
        self.attention_dim = attention_dim
        self.hidden_weight = nn.Linear(self.hidden, self.attention_dim)
        self.attention_conv = nn.Conv2D(
            1, 512, kernel_size=11, padding=5, bias_attr=False
        )
        self.attention_weight = nn.Linear(512, self.attention_dim, bias_attr=False)
        self.alpha_convert = nn.Linear(self.attention_dim, 1)

    def forward(
        self, cnn_features, cnn_features_trans, hidden, alpha_sum, image_mask=None
    ):
        query = self.hidden_weight(hidden)
        alpha_sum_trans = self.attention_conv(alpha_sum)
        coverage_alpha = self.attention_weight(
            paddle.transpose(alpha_sum_trans, [0, 2, 3, 1])
        )
        alpha_score = paddle.tanh(
            paddle.unsqueeze(query, [1, 2])
            + coverage_alpha
            + paddle.transpose(cnn_features_trans, [0, 2, 3, 1])
        )
        energy = self.alpha_convert(alpha_score)
        energy = energy - energy.max()
        energy_exp = paddle.exp(paddle.squeeze(energy, -1))

        if image_mask is not None:
            energy_exp = energy_exp * paddle.squeeze(image_mask, 1)
        alpha = energy_exp / (
            paddle.unsqueeze(paddle.sum(paddle.sum(energy_exp, -1), -1), [1, 2]) + 1e-10
        )
        alpha_sum = paddle.unsqueeze(alpha, 1) + alpha_sum
        context_vector = paddle.sum(
            paddle.sum((paddle.unsqueeze(alpha, 1) * cnn_features), -1), -1
        )

        return context_vector, alpha, alpha_sum


class CANHead(nn.Layer):
    def __init__(self, in_channel, out_channel, ratio, attdecoder, **kwargs):
        super(CANHead, self).__init__()

        self.in_channel = in_channel
        self.out_channel = out_channel

        self.counting_decoder1 = CountingDecoder(
            self.in_channel, self.out_channel, 3
        )  # mscm
        self.counting_decoder2 = CountingDecoder(self.in_channel, self.out_channel, 5)

        self.decoder = AttDecoder(ratio, **attdecoder)

        self.ratio = ratio

    def forward(self, inputs, targets=None):
        cnn_features, images_mask, labels = inputs

        counting_mask = images_mask[:, :, :: self.ratio, :: self.ratio]
        counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask)
        counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask)
        counting_preds = (counting_preds1 + counting_preds2) / 2

        word_probs = self.decoder(cnn_features, labels, counting_preds, images_mask)
        return word_probs, counting_preds, counting_preds1, counting_preds2