# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This code is refer from:
https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/resnetv2.py
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import collections.abc
from itertools import repeat
from collections import OrderedDict  # pylint: disable=g-importing-member

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingUniform
from functools import partial
from typing import Union, Callable, Type, List, Tuple

IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
normal_ = Normal(mean=0.0, std=0.01)
zeros_ = Constant(value=0.0)
ones_ = Constant(value=1.0)
kaiming_normal_ = KaimingUniform(nonlinearity="relu")


def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse


to_1tuple = _ntuple(1)
to_2tuple = _ntuple(2)
to_3tuple = _ntuple(3)
to_4tuple = _ntuple(4)
to_ntuple = _ntuple


class StdConv2dSame(nn.Conv2D):
    def __init__(
        self,
        in_channel,
        out_channels,
        kernel_size,
        stride=1,
        padding="SAME",
        dilation=1,
        groups=1,
        bias_attr=False,
        eps=1e-6,
        is_export=False,
    ):
        padding, is_dynamic = get_padding_value(
            padding, kernel_size, stride=stride, dilation=dilation
        )
        super().__init__(
            in_channel,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias_attr=bias_attr,
        )
        self.same_pad = is_dynamic
        self.export = is_export
        self.eps = eps

    def forward(self, x):
        if self.same_pad:
            if self.export:
                x = pad_same_export(x, self._kernel_size, self._stride, self._dilation)
            else:
                x = pad_same(x, self._kernel_size, self._stride, self._dilation)
        running_mean = paddle.to_tensor([0] * self._out_channels, dtype="float32")
        running_variance = paddle.to_tensor([1] * self._out_channels, dtype="float32")
        if self.export:
            weight = paddle.reshape(
                F.batch_norm(
                    self.weight.reshape([1, self._out_channels, -1]),
                    running_mean,
                    running_variance,
                    momentum=0.0,
                    epsilon=self.eps,
                    use_global_stats=False,
                ),
                self.weight.shape,
            )
        else:
            weight = paddle.reshape(
                F.batch_norm(
                    self.weight.reshape([1, self._out_channels, -1]),
                    running_mean,
                    running_variance,
                    training=True,
                    momentum=0.0,
                    epsilon=self.eps,
                ),
                self.weight.shape,
            )
        x = F.conv2d(
            x,
            weight,
            self.bias,
            self._stride,
            self._padding,
            self._dilation,
            self._groups,
        )
        return x


class StdConv2d(nn.Conv2D):
    """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.

    Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
        https://arxiv.org/abs/1903.10520v2
    """

    def __init__(
        self,
        in_channel,
        out_channels,
        kernel_size,
        stride=1,
        padding=None,
        dilation=1,
        groups=1,
        bias=False,
        eps=1e-6,
    ):
        if padding is None:
            padding = get_padding(kernel_size, stride, dilation)
        super().__init__(
            in_channel,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            groups=groups,
            bias_attr=bias,
        )
        self.eps = eps

    def forward(self, x):
        weight = F.batch_norm(
            self.weight.reshape(1, self.out_channels, -1),
            None,
            None,
            training=True,
            momentum=0.0,
            epsilon=self.eps,
        ).reshape_as(self.weight)
        x = F.conv2d(
            x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        return x


class MaxPool2dSame(nn.MaxPool2D):
    """Tensorflow like 'SAME' wrapper for 2D max pooling"""

    def __init__(
        self,
        kernel_size: int,
        stride=None,
        padding=0,
        dilation=1,
        ceil_mode=False,
        is_export=False,
    ):
        kernel_size = to_2tuple(kernel_size)
        stride = to_2tuple(stride)
        dilation = to_2tuple(dilation)
        self.export = is_export
        super(MaxPool2dSame, self).__init__(
            kernel_size, stride, (0, 0), dilation, ceil_mode
        )

    def forward(self, x):
        if self.export:
            x = pad_same_export(x, self.ksize, self.stride, value=-float("inf"))
        else:
            x = pad_same(x, self.ksize, self.stride, value=-float("inf"))
        return F.max_pool2d(x, self.ksize, self.stride, (0, 0), self.ceil_mode)


def get_padding(kernel_size: int, stride: int = 1, dilation: int = 1, **_) -> int:
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
    return padding


def is_static_pad(kernel_size: int, stride: int = 1, dilation: int = 1, **_):
    return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0


def get_padding_value(padding, kernel_size, **kwargs) -> Tuple[Tuple, bool]:
    dynamic = False
    if isinstance(padding, str):
        # for any string padding, the padding will be calculated for you, one of three ways
        padding = padding.lower()
        if padding == "same":
            # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact
            if is_static_pad(kernel_size, **kwargs):
                # static case, no extra overhead
                padding = get_padding(kernel_size, **kwargs)
            else:
                # dynamic 'SAME' padding, has runtime/GPU memory overhead
                padding = 0
                dynamic = True
        elif padding == "valid":
            # 'VALID' padding, same as padding=0
            padding = 0
        else:
            # Default to PyTorch style 'same'-ish symmetric padding
            padding = get_padding(kernel_size, **kwargs)
    return padding, dynamic


def create_pool2d(pool_type, kernel_size, stride=None, is_export=False, **kwargs):
    stride = stride or kernel_size
    padding = kwargs.pop("padding", "")
    padding, is_dynamic = get_padding_value(
        padding, kernel_size, stride=stride, **kwargs
    )
    if is_dynamic:
        if pool_type == "avg":
            return AvgPool2dSame(
                kernel_size, stride=stride, is_export=is_export, **kwargs
            )
        elif pool_type == "max":
            return MaxPool2dSame(
                kernel_size, stride=stride, is_export=is_export, **kwargs
            )
        else:
            assert False, f"Unsupported pool type {pool_type}"


def get_same_padding(x, k, s, d):
    return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)


def get_same_padding_export(x, k, s, d):
    x = paddle.to_tensor(x)
    k = paddle.to_tensor(k)
    s = paddle.to_tensor(s)
    d = paddle.to_tensor(d)
    return paddle.max((paddle.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)


def pad_same_export(x, k, s, d=(1, 1), value=0):
    ih, iw = x.shape[-2:]
    pad_h, pad_w = get_same_padding_export(
        ih, k[0], s[0], d[0]
    ), get_same_padding_export(iw, k[1], s[1], d[1])
    pad_h = pad_h.cast(paddle.int32)
    pad_w = pad_w.cast(paddle.int32)
    pad_list = paddle.to_tensor(
        [
            (pad_w // 2),
            (pad_w - pad_w // 2).cast(paddle.int32),
            (pad_h // 2).cast(paddle.int32),
            (pad_h - pad_h // 2).cast(paddle.int32),
        ]
    )

    if pad_h > 0 or pad_w > 0:
        if len(pad_list.shape) == 2:
            pad_list = pad_list.squeeze(1)
        x = F.pad(x, pad_list.cast(paddle.int32), value=value)
    return x


def pad_same(x, k, s, d=(1, 1), value=0, pad_h=None, pad_w=None):
    ih, iw = x.shape[-2:]

    pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(
        iw, k[1], s[1], d[1]
    )
    if pad_h > 0 or pad_w > 0:
        x = F.pad(
            x,
            [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2],
            value=value,
        )
    return x


class AvgPool2dSame(nn.AvgPool2D):
    """Tensorflow like 'SAME' wrapper for 2D average pooling"""

    def __init__(
        self,
        kernel_size: int,
        stride=None,
        padding=0,
        ceil_mode=False,
        count_include_pad=True,
    ):
        kernel_size = to_2tuple(kernel_size)
        stride = to_2tuple(stride)
        super(AvgPool2dSame, self).__init__(
            kernel_size, stride, (0, 0), ceil_mode, count_include_pad
        )

    def forward(self, x):
        x = pad_same(x, self.kernel_size, self.stride)
        return F.avg_pool2d(
            x,
            self.kernel_size,
            self.stride,
            self.padding,
            self.ceil_mode,
            self.count_include_pad,
        )


def drop_path(
    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (
        x.ndim - 1
    )  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor


class DropPath(nn.Layer):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""

    def __init__(self, drop_prob=None, scale_by_keep=True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)


def adaptive_pool_feat_mult(pool_type="avg"):
    if pool_type == "catavgmax":
        return 2
    else:
        return 1


class SelectAdaptivePool2d(nn.Layer):
    """Selectable global pooling layer with dynamic input kernel size"""

    def __init__(self, output_size=1, pool_type="fast", flatten=False):
        super(SelectAdaptivePool2d, self).__init__()
        self.pool_type = (
            pool_type or ""
        )  # convert other falsy values to empty string for consistent TS typing
        self.flatten = nn.Flatten(1) if flatten else nn.Identity()
        if pool_type == "":
            self.pool = nn.Identity()  # pass through

    def is_identity(self):
        return not self.pool_type

    def forward(self, x):
        x = self.pool(x)
        x = self.flatten(x)
        return x

    def feat_mult(self):
        return adaptive_pool_feat_mult(self.pool_type)

    def __repr__(self):
        return (
            self.__class__.__name__
            + " ("
            + "pool_type="
            + self.pool_type
            + ", flatten="
            + str(self.flatten)
            + ")"
        )


def _create_pool(num_features, num_classes, pool_type="avg", use_conv=False):
    flatten_in_pool = not use_conv  # flatten when we use a Linear layer after pooling
    if not pool_type:
        assert (
            num_classes == 0 or use_conv
        ), "Pooling can only be disabled if classifier is also removed or conv classifier is used"
        flatten_in_pool = (
            False  # disable flattening if pooling is pass-through (no pooling)
        )
    global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
    num_pooled_features = num_features * global_pool.feat_mult()
    return global_pool, num_pooled_features


def _create_fc(num_features, num_classes, use_conv=False):
    if num_classes <= 0:
        fc = nn.Identity()  # pass-through (no classifier)
    elif use_conv:
        fc = nn.Conv2D(num_features, num_classes, 1, bias_attr=True)
    else:
        fc = nn.Linear(num_features, num_classes, bias_attr=True)
    return fc


class ClassifierHead(nn.Layer):
    """Classifier head w/ configurable global pooling and dropout."""

    def __init__(
        self, in_chs, num_classes, pool_type="avg", drop_rate=0.0, use_conv=False
    ):
        super(ClassifierHead, self).__init__()
        self.drop_rate = drop_rate
        self.global_pool, num_pooled_features = _create_pool(
            in_chs, num_classes, pool_type, use_conv=use_conv
        )
        self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
        self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()

    def forward(self, x):
        x = self.global_pool(x)
        if self.drop_rate:
            x = F.dropout(x, p=float(self.drop_rate), training=self.training)
        x = self.fc(x)
        x = self.flatten(x)
        return x


class EvoNormBatch2d(nn.Layer):
    def __init__(
        self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None
    ):
        super(EvoNormBatch2d, self).__init__()
        self.apply_act = apply_act  # apply activation (non-linearity)
        self.momentum = momentum
        self.eps = eps
        self.weight = paddle.create_parameter(
            paddle.ones(num_features), dtype="float32"
        )
        self.bias = paddle.create_parameter(paddle.zeros(num_features), dtype="float32")
        self.v = (
            paddle.create_parameter(paddle.ones(num_features), dtype="float32")
            if apply_act
            else None
        )
        self.register_buffer("running_var", paddle.ones([num_features]))
        self.reset_parameters()

    def reset_parameters(self):
        ones_(self.weight)
        zeros_(self.bias)
        if self.apply_act:
            ones_(self.v)

    def forward(self, x):
        x_type = x.dtype
        if self.v is not None:
            running_var = self.running_var.view(1, -1, 1, 1)
            if self.training:
                var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
                n = x.numel() / x.shape[1]
                running_var = var.detach() * self.momentum * (
                    n / (n - 1)
                ) + running_var * (1 - self.momentum)
                self.running_var.copy_(running_var.view(self.running_var.shape))
            else:
                var = running_var
            v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
            d = x * v + (
                x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps
            ).sqrt().to(dtype=x_type)
            d = d.max((var + self.eps).sqrt().to(dtype=x_type))
            x = x / d
        return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)


class EvoNormSample2d(nn.Layer):
    def __init__(
        self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None
    ):
        super(EvoNormSample2d, self).__init__()
        self.apply_act = apply_act
        self.groups = groups
        self.eps = eps
        self.weight = paddle.create_parameter(
            paddle.ones(num_features), dtype="float32"
        )
        self.bias = paddle.create_parameter(paddle.zeros(num_features), dtype="float32")
        self.v = (
            paddle.create_parameter(paddle.ones(num_features), dtype="float32")
            if apply_act
            else None
        )
        self.reset_parameters()

    def reset_parameters(self):
        ones_(self.weight)
        zeros_(self.bias)
        if self.apply_act:
            ones_(self.v)

    def forward(self, x):
        B, C, H, W = x.shape
        if self.v is not None:
            n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid()
            x = x.reshape(B, self.groups, -1)
            x = (
                n.reshape(B, self.groups, -1)
                / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
            )
            x = x.reshape(B, C, H, W)
        return x * self.weight.reshape([1, -1, 1, 1]) + self.bias.reshape([1, -1, 1, 1])


class GroupNormAct(nn.GroupNorm):
    # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
    def __init__(
        self,
        num_channels,
        num_groups=32,
        eps=1e-5,
        affine=True,
        apply_act=True,
        act_layer=nn.ReLU,
        drop_block=None,
    ):
        super(GroupNormAct, self).__init__(num_groups, num_channels, epsilon=eps)
        if affine:
            self.weight = paddle.create_parameter([num_channels], dtype="float32")
            self.bias = paddle.create_parameter([num_channels], dtype="float32")
            ones_(self.weight)
            zeros_(self.bias)
        if act_layer is not None and apply_act:
            act_args = {}
            self.act = act_layer(**act_args)
        else:
            self.act = nn.Identity()

    def forward(self, x):
        x = F.group_norm(
            x,
            num_groups=self._num_groups,
            epsilon=self._epsilon,
            weight=self.weight,
            bias=self.bias,
        )
        x = self.act(x)
        return x


class BatchNormAct2d(nn.BatchNorm2D):
    def __init__(
        self,
        num_features,
        eps=1e-5,
        momentum=0.1,
        affine=True,
        track_running_stats=True,
        apply_act=True,
        act_layer=nn.ReLU,
        drop_block=None,
    ):
        super(BatchNormAct2d, self).__init__(
            num_features,
            epsilon=eps,
            momentum=momentum,
            use_global_stats=track_running_stats,
        )
        if act_layer is not None and apply_act:
            act_args = dict()
            self.act = act_layer(**act_args)
        else:
            self.act = nn.Identity()

    def _forward_python(self, x):
        return super(BatchNormAct2d, self).forward(x)

    def forward(self, x):
        x = self._forward_python(x)
        x = self.act(x)
        return x


def adapt_input_conv(in_chans, conv_weight):
    conv_type = conv_weight.dtype
    conv_weight = (
        conv_weight.float()
    )  # Some weights are in torch.half, ensure it's float for sum on CPU
    O, I, J, K = conv_weight.shape
    if in_chans == 1:
        if I > 3:
            assert conv_weight.shape[1] % 3 == 0
            # For models with space2depth stems
            conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
            conv_weight = conv_weight.sum(dim=2, keepdim=False)
        else:
            conv_weight = conv_weight.sum(dim=1, keepdim=True)
    elif in_chans != 3:
        if I != 3:
            raise NotImplementedError("Weight format not supported by conversion.")
        else:
            # NOTE this strategy should be better than random init, but there could be other combinations of
            # the original RGB input layer weights that'd work better for specific cases.
            repeat = int(math.ceil(in_chans / 3))
            conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
            conv_weight *= 3 / float(in_chans)
    conv_weight = conv_weight.to(conv_type)
    return conv_weight


def named_apply(
    fn: Callable, module: nn.Layer, name="", depth_first=True, include_root=False
) -> nn.Layer:
    if not depth_first and include_root:
        fn(module=module, name=name)
    for child_name, child_module in module.named_children():
        child_name = ".".join((name, child_name)) if name else child_name
        named_apply(
            fn=fn,
            module=child_module,
            name=child_name,
            depth_first=depth_first,
            include_root=True,
        )
    if depth_first and include_root:
        fn(module=module, name=name)
    return module


def _cfg(url="", **kwargs):
    return {
        "url": url,
        "num_classes": 1000,
        "input_size": (3, 224, 224),
        "pool_size": (7, 7),
        "crop_pct": 0.875,
        "interpolation": "bilinear",
        "mean": IMAGENET_INCEPTION_MEAN,
        "std": IMAGENET_INCEPTION_STD,
        "first_conv": "stem.conv",
        "classifier": "head.fc",
        **kwargs,
    }


def make_div(v, divisor=8):
    min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


class PreActBottleneck(nn.Layer):
    """Pre-activation (v2) bottleneck block.

    Follows the implementation of "Identity Mappings in Deep Residual Networks":
    https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua

    Except it puts the stride on 3x3 conv when available.
    """

    def __init__(
        self,
        in_chs,
        out_chs=None,
        bottle_ratio=0.25,
        stride=1,
        dilation=1,
        first_dilation=None,
        groups=1,
        act_layer=None,
        conv_layer=None,
        norm_layer=None,
        proj_layer=None,
        drop_path_rate=0.0,
        is_export=False,
    ):
        super().__init__()
        first_dilation = first_dilation or dilation
        conv_layer = conv_layer or StdConv2d
        norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
        out_chs = out_chs or in_chs
        mid_chs = make_div(out_chs * bottle_ratio)

        if proj_layer is not None:
            self.downsample = proj_layer(
                in_chs,
                out_chs,
                stride=stride,
                dilation=dilation,
                first_dilation=first_dilation,
                preact=True,
                conv_layer=conv_layer,
                norm_layer=norm_layer,
            )
        else:
            self.downsample = None

        self.norm1 = norm_layer(in_chs)
        self.conv1 = conv_layer(in_chs, mid_chs, 1, is_export=is_export)
        self.norm2 = norm_layer(mid_chs)
        self.conv2 = conv_layer(
            mid_chs,
            mid_chs,
            3,
            stride=stride,
            dilation=first_dilation,
            groups=groups,
            is_export=is_export,
        )
        self.norm3 = norm_layer(mid_chs)
        self.conv3 = conv_layer(mid_chs, out_chs, 1, is_export=is_export)
        self.drop_path = (
            DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
        )

    def zero_init_last(self):
        zeros_(self.conv3.weight)

    def forward(self, x):
        x_preact = self.norm1(x)

        # shortcut branch
        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x_preact)

        # residual branch
        x = self.conv1(x_preact)
        x = self.conv2(self.norm2(x))
        x = self.conv3(self.norm3(x))
        x = self.drop_path(x)
        return x + shortcut


class Bottleneck(nn.Layer):
    """Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT."""

    def __init__(
        self,
        in_chs,
        out_chs=None,
        bottle_ratio=0.25,
        stride=1,
        dilation=1,
        first_dilation=None,
        groups=1,
        act_layer=None,
        conv_layer=None,
        norm_layer=None,
        proj_layer=None,
        drop_path_rate=0.0,
        is_export=False,
    ):
        super().__init__()
        first_dilation = first_dilation or dilation
        act_layer = act_layer or nn.ReLU
        conv_layer = conv_layer or StdConv2d
        norm_layer = norm_layer or partial(GroupNormAct, num_groups=32)
        out_chs = out_chs or in_chs
        mid_chs = make_div(out_chs * bottle_ratio)

        if proj_layer is not None:
            self.downsample = proj_layer(
                in_chs,
                out_chs,
                stride=stride,
                dilation=dilation,
                preact=False,
                conv_layer=conv_layer,
                norm_layer=norm_layer,
                is_export=is_export,
            )
        else:
            self.downsample = None

        self.conv1 = conv_layer(in_chs, mid_chs, 1, is_export=is_export)
        self.norm1 = norm_layer(mid_chs)
        self.conv2 = conv_layer(
            mid_chs,
            mid_chs,
            3,
            stride=stride,
            dilation=first_dilation,
            groups=groups,
            is_export=is_export,
        )
        self.norm2 = norm_layer(mid_chs)
        self.conv3 = conv_layer(mid_chs, out_chs, 1, is_export=is_export)
        self.norm3 = norm_layer(out_chs, apply_act=False)
        self.drop_path = (
            DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
        )
        self.act3 = act_layer()

    def zero_init_last(self):
        zeros_(self.norm3.weight)

    def forward(self, x):
        # shortcut branch
        shortcut = x
        if self.downsample is not None:
            shortcut = self.downsample(x)

        # residual
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.drop_path(x)
        x = self.act3(x + shortcut)
        return x


class DownsampleConv(nn.Layer):
    def __init__(
        self,
        in_chs,
        out_chs,
        stride=1,
        dilation=1,
        first_dilation=None,
        preact=True,
        conv_layer=None,
        norm_layer=None,
        is_export=False,
    ):
        super(DownsampleConv, self).__init__()
        self.conv = conv_layer(in_chs, out_chs, 1, stride=stride, is_export=is_export)
        self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)

    def forward(self, x):
        return self.norm(self.conv(x))


class DownsampleAvg(nn.Layer):
    def __init__(
        self,
        in_chs,
        out_chs,
        stride=1,
        dilation=1,
        first_dilation=None,
        preact=True,
        conv_layer=None,
        norm_layer=None,
        is_export=False,
    ):
        """AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
        super(DownsampleAvg, self).__init__()
        avg_stride = stride if dilation == 1 else 1
        if stride > 1 or dilation > 1:
            avg_pool_fn = (
                AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2D
            )
            self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, exclusive=False)
        else:
            self.pool = nn.Identity()
        self.conv = conv_layer(in_chs, out_chs, 1, stride=1, is_export=is_export)
        self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)

    def forward(self, x):
        return self.norm(self.conv(self.pool(x)))


class ResNetStage(nn.Layer):
    """ResNet Stage."""

    def __init__(
        self,
        in_chs,
        out_chs,
        stride,
        dilation,
        depth,
        bottle_ratio=0.25,
        groups=1,
        avg_down=False,
        block_dpr=None,
        block_fn=PreActBottleneck,
        is_export=False,
        act_layer=None,
        conv_layer=None,
        norm_layer=None,
        **block_kwargs,
    ):
        super(ResNetStage, self).__init__()
        first_dilation = 1 if dilation in (1, 2) else 2
        layer_kwargs = dict(
            act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer
        )
        proj_layer = DownsampleAvg if avg_down else DownsampleConv
        prev_chs = in_chs
        self.blocks = nn.Sequential()
        for block_idx in range(depth):
            drop_path_rate = block_dpr[block_idx] if block_dpr else 0.0
            stride = stride if block_idx == 0 else 1
            self.blocks.add_sublayer(
                str(block_idx),
                block_fn(
                    prev_chs,
                    out_chs,
                    stride=stride,
                    dilation=dilation,
                    bottle_ratio=bottle_ratio,
                    groups=groups,
                    first_dilation=first_dilation,
                    proj_layer=proj_layer,
                    drop_path_rate=drop_path_rate,
                    is_export=is_export,
                    **layer_kwargs,
                    **block_kwargs,
                ),
            )
            prev_chs = out_chs
            first_dilation = dilation
            proj_layer = None

    def forward(self, x):
        x = self.blocks(x)
        return x


def is_stem_deep(stem_type):
    return any([s in stem_type for s in ("deep", "tiered")])


def create_resnetv2_stem(
    in_chs,
    out_chs=64,
    stem_type="",
    preact=True,
    conv_layer=StdConv2d,
    norm_layer=partial(GroupNormAct, num_groups=32),
    is_export=False,
):
    stem = OrderedDict()
    assert stem_type in (
        "",
        "fixed",
        "same",
        "deep",
        "deep_fixed",
        "deep_same",
        "tiered",
    )

    # NOTE conv padding mode can be changed by overriding the conv_layer def
    if is_stem_deep(stem_type):
        # A 3 deep 3x3  conv stack as in ResNet V1D models
        if "tiered" in stem_type:
            stem_chs = (3 * out_chs // 8, out_chs // 2)  # 'T' resnets in resnet.py
        else:
            stem_chs = (out_chs // 2, out_chs // 2)  # 'D' ResNets
        stem["conv1"] = conv_layer(
            in_chs, stem_chs[0], kernel_size=3, stride=2, is_export=is_export
        )
        stem["norm1"] = norm_layer(stem_chs[0])
        stem["conv2"] = conv_layer(
            stem_chs[0], stem_chs[1], kernel_size=3, stride=1, is_export=is_export
        )
        stem["norm2"] = norm_layer(stem_chs[1])
        stem["conv3"] = conv_layer(
            stem_chs[1], out_chs, kernel_size=3, stride=1, is_export=is_export
        )
        if not preact:
            stem["norm3"] = norm_layer(out_chs)
    else:
        # The usual 7x7 stem conv
        stem["conv"] = conv_layer(
            in_chs, out_chs, kernel_size=7, stride=2, is_export=is_export
        )
        if not preact:
            stem["norm"] = norm_layer(out_chs)

    if "fixed" in stem_type:
        # 'fixed' SAME padding approximation that is used in BiT models
        stem["pad"] = paddle.nn.Pad2D(
            1, mode="constant", value=0.0, data_format="NCHW", name=None
        )
        stem["pool"] = nn.MaxPool2D(kernel_size=3, stride=2, padding=0)
    elif "same" in stem_type:
        # full, input size based 'SAME' padding, used in ViT Hybrid model
        stem["pool"] = create_pool2d(
            "max", kernel_size=3, stride=2, padding="same", is_export=is_export
        )
    else:
        # the usual Pypaddle symmetric padding
        stem["pool"] = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
    stem_seq = nn.Sequential()
    for key, value in stem.items():
        stem_seq.add_sublayer(key, value)

    return stem_seq


class ResNetV2(nn.Layer):
    """Implementation of Pre-activation (v2) ResNet mode.

    Args:
      x: input images with shape [N, 1, H, W]

    Returns:
      The extracted features [N, 1, H//16, W//16]
    """

    def __init__(
        self,
        layers,
        channels=(256, 512, 1024, 2048),
        num_classes=1000,
        in_chans=3,
        global_pool="avg",
        output_stride=32,
        width_factor=1,
        stem_chs=64,
        stem_type="",
        avg_down=False,
        preact=True,
        act_layer=nn.ReLU,
        conv_layer=StdConv2d,
        norm_layer=partial(GroupNormAct, num_groups=32),
        drop_rate=0.0,
        drop_path_rate=0.0,
        zero_init_last=False,
        is_export=False,
    ):
        super().__init__()
        self.num_classes = num_classes
        self.drop_rate = drop_rate
        self.is_export = is_export
        wf = width_factor
        self.feature_info = []
        stem_chs = make_div(stem_chs * wf)
        self.stem = create_resnetv2_stem(
            in_chans,
            stem_chs,
            stem_type,
            preact,
            conv_layer=conv_layer,
            norm_layer=norm_layer,
            is_export=is_export,
        )
        stem_feat = (
            ("stem.conv3" if is_stem_deep(stem_type) else "stem.conv")
            if preact
            else "stem.norm"
        )
        self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))

        prev_chs = stem_chs
        curr_stride = 4
        dilation = 1
        block_dprs = [
            x.tolist()
            for x in paddle.linspace(0, drop_path_rate, sum(layers)).split(layers)
        ]
        block_fn = PreActBottleneck if preact else Bottleneck
        self.stages = nn.Sequential()
        for stage_idx, (d, c, bdpr) in enumerate(zip(layers, channels, block_dprs)):
            out_chs = make_div(c * wf)
            stride = 1 if stage_idx == 0 else 2
            if curr_stride >= output_stride:
                dilation *= stride
                stride = 1
            stage = ResNetStage(
                prev_chs,
                out_chs,
                stride=stride,
                dilation=dilation,
                depth=d,
                avg_down=avg_down,
                act_layer=act_layer,
                conv_layer=conv_layer,
                norm_layer=norm_layer,
                block_dpr=bdpr,
                block_fn=block_fn,
                is_export=is_export,
            )
            prev_chs = out_chs
            curr_stride *= stride
            self.feature_info += [
                dict(
                    num_chs=prev_chs,
                    reduction=curr_stride,
                    module=f"stages.{stage_idx}",
                )
            ]
            self.stages.add_sublayer(str(stage_idx), stage)

        self.num_features = prev_chs
        self.norm = norm_layer(self.num_features) if preact else nn.Identity()
        self.head = ClassifierHead(
            self.num_features,
            num_classes,
            pool_type=global_pool,
            drop_rate=self.drop_rate,
            use_conv=True,
        )

        self.init_weights(zero_init_last=zero_init_last)

    def init_weights(self, zero_init_last=True):
        named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

    def load_pretrained(self, checkpoint_path, prefix="resnet/"):
        _load_weights(self, checkpoint_path, prefix)

    def get_classifier(self):
        return self.head.fc

    def reset_classifier(self, num_classes, global_pool="avg"):
        self.num_classes = num_classes
        self.head = ClassifierHead(
            self.num_features,
            num_classes,
            pool_type=global_pool,
            drop_rate=self.drop_rate,
            use_conv=True,
        )

    def forward_features(self, x):
        x = self.stem(x)
        x = self.stages(x)
        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x


def _init_weights(module: nn.Layer, name: str = "", zero_init_last=True):
    if isinstance(module, nn.Linear) or (
        "head.fc" in name and isinstance(module, nn.Conv2D)
    ):
        normal_(module.weight)
        zeros_(module.bias)
    elif isinstance(module, nn.Conv2D):
        kaiming_normal_(module.weight)
        if module.bias is not None:
            zeros_(module.bias)
    elif isinstance(module, (nn.BatchNorm2D, nn.LayerNorm, nn.GroupNorm)):
        ones_(module.weight)
        zeros_(module.bias)
    elif zero_init_last and hasattr(module, "zero_init_last"):
        module.zero_init_last()


@paddle.no_grad()
def _load_weights(model: nn.Layer, checkpoint_path: str, prefix: str = "resnet/"):
    import numpy as np

    def t2p(conv_weights):
        """Possibly convert HWIO to OIHW."""
        if conv_weights.ndim == 4:
            conv_weights = conv_weights.transpose([3, 2, 0, 1])
        return paddle.to_tensor(conv_weights)

    weights = np.load(checkpoint_path)
    stem_conv_w = adapt_input_conv(
        model.stem.conv.weight.shape[1],
        t2p(weights[f"{prefix}root_block/standardized_conv2d/kernel"]),
    )
    model.stem.conv.weight.copy_(stem_conv_w)
    model.norm.weight.copy_(t2p(weights[f"{prefix}group_norm/gamma"]))
    model.norm.bias.copy_(t2p(weights[f"{prefix}group_norm/beta"]))
    if (
        isinstance(getattr(model.head, "fc", None), nn.Conv2D)
        and model.head.fc.weight.shape[0]
        == weights[f"{prefix}head/conv2d/kernel"].shape[-1]
    ):
        model.head.fc.weight.copy_(t2p(weights[f"{prefix}head/conv2d/kernel"]))
        model.head.fc.bias.copy_(t2p(weights[f"{prefix}head/conv2d/bias"]))
    for i, (sname, stage) in enumerate(model.stages.named_children()):
        for j, (bname, block) in enumerate(stage.blocks.named_children()):
            cname = "standardized_conv2d"
            block_prefix = f"{prefix}block{i + 1}/unit{j + 1:02d}/"
            block.conv1.weight.copy_(t2p(weights[f"{block_prefix}a/{cname}/kernel"]))
            block.conv2.weight.copy_(t2p(weights[f"{block_prefix}b/{cname}/kernel"]))
            block.conv3.weight.copy_(t2p(weights[f"{block_prefix}c/{cname}/kernel"]))
            block.norm1.weight.copy_(t2p(weights[f"{block_prefix}a/group_norm/gamma"]))
            block.norm2.weight.copy_(t2p(weights[f"{block_prefix}b/group_norm/gamma"]))
            block.norm3.weight.copy_(t2p(weights[f"{block_prefix}c/group_norm/gamma"]))
            block.norm1.bias.copy_(t2p(weights[f"{block_prefix}a/group_norm/beta"]))
            block.norm2.bias.copy_(t2p(weights[f"{block_prefix}b/group_norm/beta"]))
            block.norm3.bias.copy_(t2p(weights[f"{block_prefix}c/group_norm/beta"]))
            if block.downsample is not None:
                w = weights[f"{block_prefix}a/proj/{cname}/kernel"]
                block.downsample.conv.weight.copy_(t2p(w))