You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

575 lines
22 KiB
Python

import collections.abc
from itertools import repeat
from typing import Any, Callable, Optional, Tuple, Union
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.nn.layer import Identity
from ..registry import BACKBONES
from collections import OrderedDict
container_abcs = collections.abc
"""Model Config
"""
A0 = {'block_num': [0, 1, 3, 3, 4, 4]}
A0['conv1'] = [3, 8, (1, 3, 3), (1, 2, 2), (0, 1, 1)]
A0['b2_l0'] = [8, 8, 24, (1, 5, 5), (1, 2, 2), (0, 2, 2), (0, 1, 1)]
A0['b3_l0'] = [8, 32, 80, (3, 3, 3), (1, 2, 2), (1, 0, 0), (0, 0, 0)]
A0['b3_l1'] = [32, 32, 80, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)]
A0['b3_l2'] = [32, 32, 80, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)]
A0['b4_l0'] = [32, 56, 184, (5, 3, 3), (1, 2, 2), (2, 0, 0), (0, 0, 0)]
A0['b4_l1'] = [56, 56, 112, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)]
A0['b4_l2'] = [56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)]
A0['b5_l0'] = [56, 56, 184, (5, 3, 3), (1, 1, 1), (2, 1, 1), (0, 1, 1)]
A0['b5_l1'] = [56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)]
A0['b5_l2'] = [56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)]
A0['b5_l3'] = [56, 56, 184, (3, 3, 3), (1, 1, 1), (1, 1, 1), (0, 1, 1)]
A0['b6_l0'] = [56, 104, 384, (5, 3, 3), (1, 2, 2), (2, 1, 1), (0, 1, 1)]
A0['b6_l1'] = [104, 104, 280, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)]
A0['b6_l2'] = [104, 104, 280, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)]
A0['b6_l3'] = [104, 104, 344, (1, 5, 5), (1, 1, 1), (0, 2, 2), (0, 1, 1)]
A0['conv7'] = [104, 480, (1, 1, 1), (1, 1, 1), (0, 0, 0)]
MODEL_CONFIG = {'A0': A0}
def _ntuple(n):
def parse(x):
if isinstance(x, container_abcs.Iterable):
return x
return tuple(repeat(x, n))
return parse
def _make_divisible(v: float,
divisor: int,
min_value: Optional[int] = None) -> int:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8.
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
_single = _ntuple(1)
_pair = _ntuple(2)
_triple = _ntuple(3)
_quadruple = _ntuple(4)
class CausalModule(nn.Layer):
def __init__(self) -> None:
super().__init__()
self.activation = None
def reset_activation(self) -> None:
self.activation = None
class Conv2dBNActivation(nn.Sequential):
def __init__(
self,
in_planes: int,
out_planes: int,
kernel_size: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Layer]] = None,
activation_layer: Optional[Callable[..., nn.Layer]] = None,
**kwargs: Any,
) -> None:
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
if norm_layer is None:
norm_layer = Identity
if activation_layer is None:
activation_layer = Identity
self.kernel_size = kernel_size
self.stride = stride
dict_layers = (nn.Conv2D(in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
**kwargs), norm_layer(out_planes,
momentum=0.1),
activation_layer())
self.out_channels = out_planes
super(Conv2dBNActivation, self).__init__(dict_layers[0], dict_layers[1],
dict_layers[2])
class Conv3DBNActivation(nn.Sequential):
def __init__(
self,
in_planes: int,
out_planes: int,
kernel_size: Union[int, Tuple[int, int, int]],
padding: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Layer]] = None,
activation_layer: Optional[Callable[..., nn.Layer]] = None,
**kwargs: Any,
) -> None:
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
if norm_layer is None:
norm_layer = Identity
if activation_layer is None:
activation_layer = Identity
self.kernel_size = kernel_size
self.stride = stride
dict_layers = (nn.Conv3D(in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
**kwargs), norm_layer(out_planes,
momentum=0.1),
activation_layer())
self.out_channels = out_planes
super(Conv3DBNActivation, self).__init__(dict_layers[0], dict_layers[1],
dict_layers[2])
class ConvBlock3D(CausalModule):
def __init__(
self,
in_planes: int,
out_planes: int,
kernel_size: Union[int, Tuple[int, int, int]],
causal: bool,
conv_type: str,
padding: Union[int, Tuple[int, int, int]] = 0,
stride: Union[int, Tuple[int, int, int]] = 1,
norm_layer: Optional[Callable[..., nn.Layer]] = None,
activation_layer: Optional[Callable[..., nn.Layer]] = None,
bias_attr: bool = False,
**kwargs: Any,
) -> None:
super().__init__()
kernel_size = _triple(kernel_size)
stride = _triple(stride)
padding = _triple(padding)
self.conv_2 = None
if causal is True:
padding = (0, padding[1], padding[2])
if conv_type != "2plus1d" and conv_type != "3d":
raise ValueError("only 2plus2d or 3d are " +
"allowed as 3d convolutions")
if conv_type == "2plus1d":
self.conv_1 = Conv2dBNActivation(in_planes,
out_planes,
kernel_size=(kernel_size[1],
kernel_size[2]),
padding=(padding[1], padding[2]),
stride=(stride[1], stride[2]),
activation_layer=activation_layer,
norm_layer=norm_layer,
bias_attr=bias_attr,
**kwargs)
if kernel_size[0] > 1:
self.conv_2 = Conv2dBNActivation(
in_planes,
out_planes,
kernel_size=(kernel_size[0], 1),
padding=(padding[0], 0),
stride=(stride[0], 1),
activation_layer=activation_layer,
norm_layer=norm_layer,
bias_attr=bias_attr,
**kwargs)
elif conv_type == "3d":
self.conv_1 = Conv3DBNActivation(in_planes,
out_planes,
kernel_size=kernel_size,
padding=padding,
activation_layer=activation_layer,
norm_layer=norm_layer,
stride=stride,
bias_attr=bias_attr,
**kwargs)
self.padding = padding
self.kernel_size = kernel_size
self.dim_pad = self.kernel_size[0] - 1
self.stride = stride
self.causal = causal
self.conv_type = conv_type
def _forward(self, x: paddle.Tensor) -> paddle.Tensor:
if self.dim_pad > 0 and self.conv_2 is None and self.causal is True:
x = self._cat_stream_buffer(x)
b, c, t, h, w = x.shape
if self.conv_type == "2plus1d":
x = paddle.transpose(x, (0, 2, 1, 3, 4)) # bcthw --> btchw
x = paddle.reshape_(x, (-1, c, h, w)) # btchw --> bt,c,h,w
x = self.conv_1(x)
if self.conv_type == "2plus1d":
b, c, h, w = x.shape
x = paddle.reshape_(x, (-1, t, c, h, w)) # bt,c,h,w --> b,t,c,h,w
x = paddle.transpose(x, (0, 2, 1, 3, 4)) # b,t,c,h,w --> b,c,t,h,w
if self.conv_2 is not None:
if self.dim_pad > 0 and self.causal is True:
x = self._cat_stream_buffer(x)
b, c, t, h, w = x.shape
x = paddle.reshape_(x, (b, c, t, h * w))
x = self.conv_2(x)
b, c, t, _ = x.shape
x = paddle.reshape_(x, (b, c, t, h, w))
return x
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self._forward(x)
return x
def _cat_stream_buffer(self, x: paddle.Tensor) -> paddle.Tensor:
if self.activation is None:
self._setup_activation(x.shape)
x = paddle.concat((self.activation, x), 2)
self._save_in_activation(x)
return x
def _save_in_activation(self, x: paddle.Tensor) -> None:
assert self.dim_pad > 0
self.activation = paddle.to_tensor(x.numpy()[:, :, -self.dim_pad:,
...]).clone().detach()
def _setup_activation(self, input_shape: Tuple[float, ...]) -> None:
assert self.dim_pad > 0
self.activation = paddle.zeros(shape=[
*input_shape[:2], # type: ignore
self.dim_pad,
*input_shape[3:]
])
class TemporalCGAvgPool3D(CausalModule):
def __init__(self, ) -> None:
super().__init__()
self.n_cumulated_values = 0
self.register_forward_post_hook(self._detach_activation)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
input_shape = x.shape
cumulative_sum = paddle.cumsum(x, axis=2)
if self.activation is None:
self.activation = cumulative_sum[:, :, -1:].clone()
else:
cumulative_sum += self.activation
self.activation = cumulative_sum[:, :, -1:].clone()
noe = paddle.arange(1, input_shape[2] + 1)
axis = paddle.to_tensor([0, 1, 3, 4])
noe = paddle.unsqueeze(noe, axis=axis)
divisor = noe.expand(x.shape)
x = cumulative_sum / (self.n_cumulated_values + divisor)
self.n_cumulated_values += input_shape[2]
return x
@staticmethod
def _detach_activation(module: CausalModule, inputs: paddle.Tensor,
output: paddle.Tensor) -> None:
module.activation.detach()
def reset_activation(self) -> None:
super().reset_activation()
self.n_cumulated_values = 0
class SqueezeExcitation(nn.Layer):
def __init__(self,
input_channels: int,
activation_2: nn.Layer,
activation_1: nn.Layer,
conv_type: str,
causal: bool,
squeeze_factor: int = 4,
bias_attr: bool = True) -> None:
super().__init__()
self.causal = causal
se_multiplier = 2 if causal else 1
squeeze_channels = _make_divisible(
input_channels // squeeze_factor * se_multiplier, 8)
self.temporal_cumualtive_GAvg3D = TemporalCGAvgPool3D()
self.fc1 = ConvBlock3D(input_channels * se_multiplier,
squeeze_channels,
kernel_size=(1, 1, 1),
padding=0,
causal=causal,
conv_type=conv_type,
bias_attr=bias_attr)
self.activation_1 = activation_1()
self.activation_2 = activation_2()
self.fc2 = ConvBlock3D(squeeze_channels,
input_channels,
kernel_size=(1, 1, 1),
padding=0,
causal=causal,
conv_type=conv_type,
bias_attr=bias_attr)
def _scale(self, inputs: paddle.Tensor) -> paddle.Tensor:
if self.causal:
x_space = paddle.mean(inputs, axis=[3, 4], keepdim=True)
scale = self.temporal_cumualtive_GAvg3D(x_space)
scale = paddle.concat((scale, x_space), axis=1)
else:
scale = F.adaptive_avg_pool3d(inputs, 1)
scale = self.fc1(scale)
scale = self.activation_1(scale)
scale = self.fc2(scale)
return self.activation_2(scale)
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
scale = self._scale(inputs)
return scale * inputs
class BasicBneck(nn.Layer):
def __init__(
self,
input_channels,
out_channels,
expanded_channels,
kernel_size,
stride,
padding,
padding_avg,
causal: bool,
conv_type: str,
norm_layer: Optional[Callable[..., nn.Layer]] = None,
activation_layer: Optional[Callable[..., nn.Layer]] = None,
) -> None:
super().__init__()
assert type(stride) is tuple
if (not stride[0] == 1 or not (1 <= stride[1] <= 2)
or not (1 <= stride[2] <= 2)):
raise ValueError('illegal stride value')
self.res = None
layers = []
if expanded_channels != out_channels:
# expand
self.expand = ConvBlock3D(in_planes=input_channels,
out_planes=expanded_channels,
kernel_size=(1, 1, 1),
padding=(0, 0, 0),
causal=causal,
conv_type=conv_type,
norm_layer=norm_layer,
activation_layer=activation_layer)
# deepwise
self.deep = ConvBlock3D(in_planes=expanded_channels,
out_planes=expanded_channels,
kernel_size=kernel_size,
padding=padding,
stride=stride,
groups=expanded_channels,
causal=causal,
conv_type=conv_type,
norm_layer=norm_layer,
activation_layer=activation_layer)
# SE
self.se = SqueezeExcitation(
expanded_channels,
causal=causal,
activation_1=activation_layer,
activation_2=(nn.Sigmoid if conv_type == "3d" else nn.Hardsigmoid),
conv_type=conv_type)
# project
self.project = ConvBlock3D(expanded_channels,
out_channels,
kernel_size=(1, 1, 1),
padding=(0, 0, 0),
causal=causal,
conv_type=conv_type,
norm_layer=norm_layer,
activation_layer=Identity)
if not (stride == (1, 1, 1) and input_channels == out_channels):
if stride != (1, 1, 1):
layers.append(
nn.AvgPool3D((1, 3, 3), stride=stride, padding=padding_avg))
layers.append(
ConvBlock3D(
in_planes=input_channels,
out_planes=out_channels,
kernel_size=(1, 1, 1),
padding=(0, 0, 0),
norm_layer=norm_layer,
activation_layer=Identity,
causal=causal,
conv_type=conv_type,
))
self.res = nn.Sequential(*layers)
self.alpha = self.create_parameter(shape=[1], dtype="float32")
def forward(self, inputs: paddle.Tensor) -> paddle.Tensor:
if self.res is not None:
residual = self.res(inputs)
else:
residual = inputs
if self.expand is not None:
x = self.expand(inputs)
else:
x = inputs
x = self.deep(x)
x = self.se(x)
x = self.project(x)
result = residual + self.alpha * x
return result
@BACKBONES.register()
class MoViNet(nn.Layer):
def __init__(
self,
model_type: str = 'A0',
hidden_dim: int = 2048,
causal: bool = True,
num_classes: int = 400,
conv_type: str = "3d",
) -> None:
super().__init__()
"""
causal: causal mode
num_classes: number of classes for classifcation
conv_type: type of convolution either 3d or 2plus1d
"""
blocks_dic = OrderedDict()
cfg = MODEL_CONFIG[model_type]
norm_layer = nn.BatchNorm3D if conv_type == "3d" else nn.BatchNorm2D
activation_layer = nn.Swish if conv_type == "3d" else nn.Hardswish
# conv1
self.conv1 = ConvBlock3D(in_planes=cfg['conv1'][0],
out_planes=cfg['conv1'][1],
kernel_size=cfg['conv1'][2],
stride=cfg['conv1'][3],
padding=cfg['conv1'][4],
causal=causal,
conv_type=conv_type,
norm_layer=norm_layer,
activation_layer=activation_layer)
# blocks
for i in range(2, len(cfg['block_num']) + 1):
for j in range(cfg['block_num'][i - 1]):
blocks_dic[f'b{i}_l{j}'] = BasicBneck(
cfg[f'b{i}_l{j}'][0],
cfg[f'b{i}_l{j}'][1],
cfg[f'b{i}_l{j}'][2],
cfg[f'b{i}_l{j}'][3],
cfg[f'b{i}_l{j}'][4],
cfg[f'b{i}_l{j}'][5],
cfg[f'b{i}_l{j}'][6],
causal=causal,
conv_type=conv_type,
norm_layer=norm_layer,
activation_layer=activation_layer)
self.blocks = nn.Sequential(*(blocks_dic.values()))
# conv7
self.conv7 = ConvBlock3D(in_planes=cfg['conv7'][0],
out_planes=cfg['conv7'][1],
kernel_size=cfg['conv7'][2],
stride=cfg['conv7'][3],
padding=cfg['conv7'][4],
causal=causal,
conv_type=conv_type,
norm_layer=norm_layer,
activation_layer=activation_layer)
# pool
self.classifier = nn.Sequential(
# dense9
ConvBlock3D(in_planes=cfg['conv7'][1],
out_planes=hidden_dim,
kernel_size=(1, 1, 1),
causal=causal,
conv_type=conv_type,
bias_attr=True),
nn.Swish(),
nn.Dropout(p=0.2),
# dense10d
ConvBlock3D(in_planes=hidden_dim,
out_planes=num_classes,
kernel_size=(1, 1, 1),
causal=causal,
conv_type=conv_type,
bias_attr=True),
)
if causal:
self.cgap = TemporalCGAvgPool3D()
self.apply(self._weight_init)
self.causal = causal
def avg(self, x: paddle.Tensor) -> paddle.Tensor:
if self.causal:
avg = F.adaptive_avg_pool3d(x, (x.shape[2], 1, 1))
avg = self.cgap(avg)[:, :, -1:]
else:
avg = F.adaptive_avg_pool3d(x, 1)
return avg
@staticmethod
def _weight_init(m):
if isinstance(m, nn.Conv3D):
nn.initializer.KaimingNormal(m.weight)
if m.bias is not None:
nn.initializer.Constant(0.0)(m.bias)
elif isinstance(m, (nn.BatchNorm3D, nn.BatchNorm2D, nn.GroupNorm)):
nn.initializer.Constant(1.0)(m.weight)
nn.initializer.Constant(0.0)(m.bias)
elif isinstance(m, nn.Linear):
nn.initializer.Normal(m.weight, 0, 0.01)
nn.initializer.Constant(0.0)(m.bias)
def forward(self, x: paddle.Tensor) -> paddle.Tensor:
x = self.conv1(x)
x = self.blocks(x)
x = self.conv7(x)
x = self.avg(x)
x = self.classifier(x)
x = x.flatten(1)
return x
@staticmethod
def _clean_activation_buffers(m):
if issubclass(type(m), CausalModule):
m.reset_activation()
def clean_activation_buffers(self) -> None:
self.apply(self._clean_activation_buffers)
if __name__ == '__main__':
net = MoViNet(causal=False, conv_type='3d')
paddle.summary(net, input_size=(1, 3, 8, 224, 224))