# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import warnings import collections from itertools import repeat import paddle from paddle import nn def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable): return tuple(x) return tuple(repeat(x, n)) return parse _triple = _ntuple(3) class ConvBNLayer(nn.Layer): """A conv block that bundles conv/norm/activation layers. This block simplifies the usage of convolution layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU). It is based upon three build methods: `build_conv_layer()`, `build_norm_layer()` and `build_activation_layer()`. Besides, we add some additional features in this module. 1. Automatically set `bias` of the conv layer. 2. Spectral norm is supported. 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only supports zero and circular padding, and we add "reflect" padding mode. Args: in_channels (int): Number of channels in the input feature map. Same as that in ``nn._ConvNd``. out_channels (int): Number of channels produced by the convolution. Same as that in ``nn._ConvNd``. kernel_size (int | tuple[int]): Size of the convolving kernel. Same as that in ``nn._ConvNd``. stride (int | tuple[int]): Stride of the convolution. Same as that in ``nn._ConvNd``. padding (int | tuple[int]): Zero-padding added to both sides of the input. Same as that in ``nn._ConvNd``. dilation (int | tuple[int]): Spacing between kernel elements. Same as that in ``nn._ConvNd``. groups (int): Number of blocked connections from input channels to output channels. Same as that in ``nn._ConvNd``. """ def __init__( self, in_channels, out_channels, kernel_size, padding=0, stride=1, dilation=1, groups=1, act=None, bias=None, ): super(ConvBNLayer, self).__init__() self._conv = nn.Conv3D( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias_attr=bias) self._batch_norm = nn.BatchNorm3D(out_channels, momentum=0.1) self.act = act if act is not None: self._act_op = nn.ReLU() def forward(self, inputs): y = self._conv(inputs) y = self._batch_norm(y) if self.act is not None: y = self._act_op(y) return y class Bottleneck3d(nn.Layer): """Bottleneck 3d block for ResNet3D. Args: inplanes (int): Number of channels for the input in first conv3d layer. planes (int): Number of channels produced by some norm/conv3d layers. spatial_stride (int): Spatial stride in the conv3d layer. Default: 1. temporal_stride (int): Temporal stride in the conv3d layer. Default: 1. dilation (int): Spacing between kernel elements. Default: 1. downsample (nn.Module | None): Downsample layer. Default: None. inflate (bool): Whether to inflate kernel. Default: True. inflate_style (str): ``3x1x1`` or ``3x3x3``. which determines the kernel sizes and padding strides for conv1 and conv2 in each block. Default: '3x1x1'. non_local (bool): Determine whether to apply non-local module in this block. Default: False. non_local_cfg (dict): Config for non-local module. Default: ``dict()``. conv_cfg (dict): Config dict for convolution layer. Default: ``dict(type='Conv3d')``. norm_cfg (dict): Config for norm layers. required keys are ``type``, Default: ``dict(type='BN3d')``. act_cfg (dict): Config dict for activation layer. Default: ``dict(type='ReLU')``. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. """ expansion = 4 def __init__(self, inplanes, planes, spatial_stride=1, temporal_stride=1, dilation=1, downsample=None, inflate=True, inflate_style='3x1x1', non_local=False, non_local_cfg=dict(), conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d'), act_cfg=dict(type='ReLU'), with_cp=False): super().__init__() assert inflate_style in ['3x1x1', '3x3x3'] self.inplanes = inplanes self.planes = planes self.spatial_stride = spatial_stride self.temporal_stride = temporal_stride self.dilation = dilation self.inflate = inflate self.inflate_style = inflate_style self.norm_cfg = norm_cfg self.conv_cfg = conv_cfg self.act_cfg = act_cfg self.with_cp = with_cp self.non_local = non_local self.non_local_cfg = non_local_cfg self.conv1_stride_s = 1 self.conv2_stride_s = spatial_stride self.conv1_stride_t = 1 self.conv2_stride_t = temporal_stride if self.inflate: if inflate_style == '3x1x1': conv1_kernel_size = (3, 1, 1) conv1_padding = (1, 0, 0) conv2_kernel_size = (1, 3, 3) conv2_padding = (0, dilation, dilation) else: conv1_kernel_size = (1, 1, 1) conv1_padding = (0, 0, 0) conv2_kernel_size = (3, 3, 3) conv2_padding = (1, dilation, dilation) else: conv1_kernel_size = (1, 1, 1) conv1_padding = (0, 0, 0) conv2_kernel_size = (1, 3, 3) conv2_padding = (0, dilation, dilation) self.conv1 = ConvBNLayer( in_channels=inplanes, out_channels=planes, kernel_size=conv1_kernel_size, stride=(self.conv1_stride_t, self.conv1_stride_s, self.conv1_stride_s), padding=conv1_padding, bias=False, act='relu') self.conv2 = ConvBNLayer( in_channels=planes, out_channels=planes, kernel_size=conv2_kernel_size, stride=(self.conv2_stride_t, self.conv2_stride_s, self.conv2_stride_s), padding=conv2_padding, dilation=(1, dilation, dilation), bias=False, act='relu') self.conv3 = ConvBNLayer( in_channels=planes, out_channels=planes * self.expansion, kernel_size=1, bias=False, act=None, ) self.downsample = downsample self.relu = nn.ReLU() def forward(self, x): """Defines the computation performed at every call.""" def _inner_forward(x): """Forward wrapper for utilizing checkpoint.""" identity = x out = self.conv1(x) out = self.conv2(out) out = self.conv3(out) if self.downsample is not None: identity = self.downsample(x) out = out + identity return out out = _inner_forward(x) out = self.relu(out) if self.non_local: out = self.non_local_block(out) return out class ResNet3d(nn.Layer): """ResNet 3d backbone. Args: depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. pretrained (str | None): Name of pretrained model. stage_blocks (tuple | None): Set number of stages for each res layer. Default: None. pretrained2d (bool): Whether to load pretrained 2D model. Default: True. in_channels (int): Channel num of input features. Default: 3. base_channels (int): Channel num of stem output features. Default: 64. out_indices (Sequence[int]): Indices of output feature. Default: (3, ). num_stages (int): Resnet stages. Default: 4. spatial_strides (Sequence[int]): Spatial strides of residual blocks of each stage. Default: ``(1, 2, 2, 2)``. temporal_strides (Sequence[int]): Temporal strides of residual blocks of each stage. Default: ``(1, 1, 1, 1)``. dilations (Sequence[int]): Dilation of each stage. Default: ``(1, 1, 1, 1)``. conv1_kernel (Sequence[int]): Kernel size of the first conv layer. Default: ``(3, 7, 7)``. conv1_stride_s (int): Spatial stride of the first conv layer. Default: 2. conv1_stride_t (int): Temporal stride of the first conv layer. Default: 1. pool1_stride_s (int): Spatial stride of the first pooling layer. Default: 2. pool1_stride_t (int): Temporal stride of the first pooling layer. Default: 1. with_pool2 (bool): Whether to use pool2. Default: True. inflate (Sequence[int]): Inflate Dims of each block. Default: (1, 1, 1, 1). inflate_style (str): ``3x1x1`` or ``3x3x3``. which determines the kernel sizes and padding strides for conv1 and conv2 in each block. Default: '3x1x1'. conv_cfg (dict): Config for conv layers. required keys are ``type`` Default: ``dict(type='Conv3d')``. norm_cfg (dict): Config for norm layers. required keys are ``type`` and ``requires_grad``. Default: ``dict(type='BN3d', requires_grad=True)``. act_cfg (dict): Config dict for activation layer. Default: ``dict(type='ReLU', inplace=True)``. norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze running stats (mean and var). Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. non_local (Sequence[int]): Determine whether to apply non-local module in the corresponding block of each stages. Default: (0, 0, 0, 0). non_local_cfg (dict): Config for non-local module. Default: ``dict()``. zero_init_residual (bool): Whether to use zero initialization for residual block, Default: True. kwargs (dict, optional): Key arguments for "make_res_layer". """ arch_settings = { 50: (Bottleneck3d, (3, 4, 6, 3)), 101: (Bottleneck3d, (3, 4, 23, 3)), 152: (Bottleneck3d, (3, 8, 36, 3)) } def __init__(self, depth, stage_blocks=None, pretrained2d=True, in_channels=3, num_stages=4, base_channels=64, out_indices=(3, ), spatial_strides=(1, 2, 2, 2), temporal_strides=(1, 1, 1, 1), dilations=(1, 1, 1, 1), conv1_kernel=(3, 7, 7), conv1_stride_s=2, conv1_stride_t=1, pool1_stride_s=2, pool1_stride_t=1, with_pool1=True, with_pool2=True, inflate=(1, 1, 1, 1), inflate_style='3x1x1', conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d', requires_grad=True), act_cfg=dict(type='ReLU', inplace=True), norm_eval=False, with_cp=False, non_local=(0, 0, 0, 0), non_local_cfg=dict(), zero_init_residual=True, **kwargs): super().__init__() if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') self.depth = depth self.pretrained2d = pretrained2d self.in_channels = in_channels self.base_channels = base_channels self.num_stages = num_stages assert 1 <= num_stages <= 4 self.stage_blocks = stage_blocks self.out_indices = out_indices assert max(out_indices) < num_stages self.spatial_strides = spatial_strides self.temporal_strides = temporal_strides self.dilations = dilations assert len(spatial_strides) == len(temporal_strides) == len( dilations) == num_stages if self.stage_blocks is not None: assert len(self.stage_blocks) == num_stages self.conv1_kernel = conv1_kernel self.conv1_stride_s = conv1_stride_s self.conv1_stride_t = conv1_stride_t self.pool1_stride_s = pool1_stride_s self.pool1_stride_t = pool1_stride_t self.with_pool1 = with_pool1 self.with_pool2 = with_pool2 self.stage_inflations = _ntuple(num_stages)(inflate) self.non_local_stages = _ntuple(num_stages)(non_local) self.inflate_style = inflate_style self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual self.block, stage_blocks = self.arch_settings[depth] if self.stage_blocks is None: self.stage_blocks = stage_blocks[:num_stages] self.inplanes = self.base_channels self.non_local_cfg = non_local_cfg self._make_stem_layer() self.res_layers = [] for i, num_blocks in enumerate(self.stage_blocks): spatial_stride = spatial_strides[i] temporal_stride = temporal_strides[i] dilation = dilations[i] planes = self.base_channels * 2**i res_layer = self.make_res_layer( self.block, self.inplanes, planes, num_blocks, spatial_stride=spatial_stride, temporal_stride=temporal_stride, dilation=dilation, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg, act_cfg=self.act_cfg, non_local=self.non_local_stages[i], non_local_cfg=self.non_local_cfg, inflate=self.stage_inflations[i], inflate_style=self.inflate_style, with_cp=with_cp, **kwargs) self.inplanes = planes * self.block.expansion layer_name = f'layer{i + 1}' self.add_sublayer(layer_name, res_layer) self.res_layers.append(layer_name) self.feat_dim = self.block.expansion * self.base_channels * 2**( len(self.stage_blocks) - 1) @staticmethod def make_res_layer(block, inplanes, planes, blocks, spatial_stride=1, temporal_stride=1, dilation=1, inflate=1, inflate_style='3x1x1', non_local=0, non_local_cfg=dict(), norm_cfg=None, act_cfg=None, conv_cfg=None, with_cp=False, **kwargs): """Build residual layer for ResNet3D. Args: block (nn.Module): Residual module to be built. inplanes (int): Number of channels for the input feature in each block. planes (int): Number of channels for the output feature in each block. blocks (int): Number of residual blocks. spatial_stride (int | Sequence[int]): Spatial strides in residual and conv layers. Default: 1. temporal_stride (int | Sequence[int]): Temporal strides in residual and conv layers. Default: 1. dilation (int): Spacing between kernel elements. Default: 1. inflate (int | Sequence[int]): Determine whether to inflate for each block. Default: 1. inflate_style (str): ``3x1x1`` or ``3x3x3``. which determines the kernel sizes and padding strides for conv1 and conv2 in each block. Default: '3x1x1'. non_local (int | Sequence[int]): Determine whether to apply non-local module in the corresponding block of each stages. Default: 0. non_local_cfg (dict): Config for non-local module. Default: ``dict()``. conv_cfg (dict | None): Config for norm layers. Default: None. norm_cfg (dict | None): Config for norm layers. Default: None. act_cfg (dict | None): Config for activate layers. Default: None. with_cp (bool | None): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. Returns: nn.Module: A residual layer for the given config. """ inflate = inflate if not isinstance(inflate, int) else (inflate, ) * blocks non_local = non_local if not isinstance(non_local, int) else (non_local, ) * blocks assert len(inflate) == blocks and len(non_local) == blocks downsample = None if spatial_stride != 1 or inplanes != planes * block.expansion: downsample = ConvBNLayer( in_channels=inplanes, out_channels=planes * block.expansion, kernel_size=1, stride=(temporal_stride, spatial_stride, spatial_stride), bias=False, act=None) layers = [] layers.append( block( inplanes, planes, spatial_stride=spatial_stride, temporal_stride=temporal_stride, dilation=dilation, downsample=downsample, inflate=(inflate[0] == 1), inflate_style=inflate_style, non_local=(non_local[0] == 1), non_local_cfg=non_local_cfg, norm_cfg=norm_cfg, conv_cfg=conv_cfg, act_cfg=act_cfg, with_cp=with_cp, **kwargs)) inplanes = planes * block.expansion for i in range(1, blocks): layers.append( block( inplanes, planes, spatial_stride=1, temporal_stride=1, dilation=dilation, inflate=(inflate[i] == 1), inflate_style=inflate_style, non_local=(non_local[i] == 1), non_local_cfg=non_local_cfg, norm_cfg=norm_cfg, conv_cfg=conv_cfg, act_cfg=act_cfg, with_cp=with_cp, **kwargs)) return nn.Sequential(*layers) @staticmethod def _inflate_conv_params(conv3d, state_dict_2d, module_name_2d, inflated_param_names): """Inflate a conv module from 2d to 3d. Args: conv3d (nn.Module): The destination conv3d module. state_dict_2d (OrderedDict): The state dict of pretrained 2d model. module_name_2d (str): The name of corresponding conv module in the 2d model. inflated_param_names (list[str]): List of parameters that have been inflated. """ weight_2d_name = module_name_2d + '.weight' conv2d_weight = state_dict_2d[weight_2d_name] kernel_t = conv3d.weight.data.shape[2] new_weight = conv2d_weight.data.unsqueeze(2).expand_as( conv3d.weight) / kernel_t conv3d.weight.data.copy_(new_weight) inflated_param_names.append(weight_2d_name) if getattr(conv3d, 'bias') is not None: bias_2d_name = module_name_2d + '.bias' conv3d.bias.data.copy_(state_dict_2d[bias_2d_name]) inflated_param_names.append(bias_2d_name) @staticmethod def _inflate_bn_params(bn3d, state_dict_2d, module_name_2d, inflated_param_names): """Inflate a norm module from 2d to 3d. Args: bn3d (nn.Module): The destination bn3d module. state_dict_2d (OrderedDict): The state dict of pretrained 2d model. module_name_2d (str): The name of corresponding bn module in the 2d model. inflated_param_names (list[str]): List of parameters that have been inflated. """ for param_name, param in bn3d.named_parameters(): param_2d_name = f'{module_name_2d}.{param_name}' param_2d = state_dict_2d[param_2d_name] if param.data.shape != param_2d.shape: warnings.warn(f'The parameter of {module_name_2d} is not' 'loaded due to incompatible shapes. ') return param.data.copy_(param_2d) inflated_param_names.append(param_2d_name) for param_name, param in bn3d.named_buffers(): param_2d_name = f'{module_name_2d}.{param_name}' # some buffers like num_batches_tracked may not exist in old # checkpoints if param_2d_name in state_dict_2d: param_2d = state_dict_2d[param_2d_name] param.data.copy_(param_2d) inflated_param_names.append(param_2d_name) def _make_stem_layer(self): """Construct the stem layers consists of a conv+norm+act module and a pooling layer.""" self.conv1 = ConvBNLayer( in_channels=self.in_channels, out_channels=self.base_channels, kernel_size=self.conv1_kernel, stride=(self.conv1_stride_t, self.conv1_stride_s, self.conv1_stride_s), padding=tuple([(k - 1) // 2 for k in _triple(self.conv1_kernel)]), bias=False, act="relu") self.maxpool = nn.MaxPool3D( kernel_size=(1, 3, 3), stride=(self.pool1_stride_t, self.pool1_stride_s, self.pool1_stride_s), padding=(0, 1, 1)) self.pool2 = nn.MaxPool3D(kernel_size=(2, 1, 1), stride=(2, 1, 1)) @staticmethod def _init_weights(self, pretrained=None): pass def init_weights(self, pretrained=None): self._init_weights(self, pretrained) def forward(self, x): """Defines the computation performed at every call. Args: x (torch.Tensor): The input data. Returns: torch.Tensor: The feature of the input samples extracted by the backbone. """ x = self.conv1(x) if self.with_pool1: x = self.maxpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) x = res_layer(x) if i == 0 and self.with_pool2: x = self.pool2(x) if i in self.out_indices: outs.append(x) if len(outs) == 1: return outs[0] return tuple(outs) def train(self, mode=True): """Set the optimization status when training.""" super().train() if mode and self.norm_eval: for m in self.modules(): if isinstance(m, paddle.nn._BatchNormBase): m.eval()