You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
515 lines
18 KiB
Python
515 lines
18 KiB
Python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import paddle
|
|
import paddle.nn as nn
|
|
import paddle.nn.functional as F
|
|
import numpy as np
|
|
from ..registry import BACKBONES
|
|
from ..weight_init import weight_init_
|
|
|
|
|
|
def conv_init(conv):
|
|
if conv.weight is not None:
|
|
weight_init_(conv.weight, 'kaiming_normal_', mode='fan_in')
|
|
if conv.bias is not None:
|
|
nn.initializer.Constant(value=0.0)(conv.bias)
|
|
|
|
|
|
def bn_init(bn, scale):
|
|
nn.initializer.Constant(value=float(scale))(bn.weight)
|
|
nn.initializer.Constant(value=0.0)(bn.bias)
|
|
|
|
|
|
def einsum(x1, x3):
|
|
"""paddle.einsum only support in dynamic graph mode.
|
|
x1 : n c u v
|
|
x2 : n c t v
|
|
"""
|
|
n, c, u, v1 = x1.shape
|
|
n, c, t, v3 = x3.shape
|
|
assert (v1 == v3), "Args of einsum not match!"
|
|
x1 = paddle.transpose(x1, perm=[0, 1, 3, 2]) # n c v u
|
|
y = paddle.matmul(x3, x1)
|
|
# out: n c t u
|
|
return y
|
|
|
|
|
|
class CTRGC(nn.Layer):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
rel_reduction=8,
|
|
mid_reduction=1):
|
|
super(CTRGC, self).__init__()
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
if in_channels == 3 or in_channels == 9:
|
|
self.rel_channels = 8
|
|
self.mid_channels = 16
|
|
else:
|
|
self.rel_channels = in_channels // rel_reduction
|
|
self.mid_channels = in_channels // mid_reduction
|
|
self.conv1 = nn.Conv2D(self.in_channels,
|
|
self.rel_channels,
|
|
kernel_size=1)
|
|
self.conv2 = nn.Conv2D(self.in_channels,
|
|
self.rel_channels,
|
|
kernel_size=1)
|
|
self.conv3 = nn.Conv2D(self.in_channels,
|
|
self.out_channels,
|
|
kernel_size=1)
|
|
self.conv4 = nn.Conv2D(self.rel_channels,
|
|
self.out_channels,
|
|
kernel_size=1)
|
|
self.tanh = nn.Tanh()
|
|
|
|
def init_weights(self):
|
|
"""Initiate the parameters.
|
|
"""
|
|
for m in self.sublayers():
|
|
if isinstance(m, nn.Conv2D):
|
|
conv_init(m)
|
|
elif isinstance(m, nn.BatchNorm2D):
|
|
bn_init(m, 1)
|
|
|
|
def forward(self, x, A=None, alpha=1):
|
|
x1, x2, x3 = self.conv1(x).mean(-2), self.conv2(x).mean(-2), self.conv3(
|
|
x)
|
|
x1 = self.tanh(x1.unsqueeze(-1) - x2.unsqueeze(-2))
|
|
x1 = self.conv4(x1) * alpha + (
|
|
A.unsqueeze(0).unsqueeze(0) if A is not None else 0) # N,C,V,V
|
|
# We only support 'paddle.einsum()' in dynamic graph mode, if use in infer model please implement self.
|
|
# x1 = paddle.einsum('ncuv,nctv->nctu', x1, x3)
|
|
x1 = einsum(x1, x3)
|
|
return x1
|
|
|
|
|
|
class TemporalConv(nn.Layer):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
dilation=1):
|
|
super(TemporalConv, self).__init__()
|
|
pad = (kernel_size + (kernel_size - 1) * (dilation - 1) - 1) // 2
|
|
self.conv = nn.Conv2D(in_channels,
|
|
out_channels,
|
|
kernel_size=(kernel_size, 1),
|
|
padding=(pad, 0),
|
|
stride=(stride, 1),
|
|
dilation=(dilation, 1))
|
|
|
|
self.bn = nn.BatchNorm2D(out_channels)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.bn(x)
|
|
return x
|
|
|
|
|
|
class MultiScale_TemporalConv(nn.Layer):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
dilations=[1, 2, 3, 4],
|
|
residual=True,
|
|
residual_kernel_size=1):
|
|
|
|
super(MultiScale_TemporalConv, self).__init__()
|
|
assert out_channels % (
|
|
len(dilations) +
|
|
2) == 0, '# out channels should be multiples of # branches'
|
|
|
|
# Multiple branches of temporal convolution
|
|
self.num_branches = len(dilations) + 2
|
|
branch_channels = out_channels // self.num_branches
|
|
if type(kernel_size) == list:
|
|
assert len(kernel_size) == len(dilations)
|
|
else:
|
|
kernel_size = [kernel_size] * len(dilations)
|
|
# Temporal Convolution branches
|
|
self.branches = nn.LayerList([
|
|
nn.Sequential(
|
|
nn.Conv2D(in_channels,
|
|
branch_channels,
|
|
kernel_size=1,
|
|
padding=0),
|
|
nn.BatchNorm2D(branch_channels),
|
|
nn.ReLU(),
|
|
TemporalConv(branch_channels,
|
|
branch_channels,
|
|
kernel_size=ks,
|
|
stride=stride,
|
|
dilation=dilation),
|
|
) for ks, dilation in zip(kernel_size, dilations)
|
|
])
|
|
|
|
# Additional Max & 1x1 branch
|
|
self.branches.append(
|
|
nn.Sequential(
|
|
nn.Conv2D(in_channels,
|
|
branch_channels,
|
|
kernel_size=1,
|
|
padding=0), nn.BatchNorm2D(branch_channels),
|
|
nn.ReLU(),
|
|
nn.MaxPool2D(kernel_size=(3, 1),
|
|
stride=(stride, 1),
|
|
padding=(1, 0)), nn.BatchNorm2D(branch_channels)))
|
|
|
|
self.branches.append(
|
|
nn.Sequential(
|
|
nn.Conv2D(in_channels,
|
|
branch_channels,
|
|
kernel_size=1,
|
|
padding=0,
|
|
stride=(stride, 1)), nn.BatchNorm2D(branch_channels)))
|
|
|
|
# Residual connection
|
|
if not residual:
|
|
self.residual = lambda x: 0
|
|
elif (in_channels == out_channels) and (stride == 1):
|
|
self.residual = lambda x: x
|
|
else:
|
|
self.residual = TemporalConv(in_channels,
|
|
out_channels,
|
|
kernel_size=residual_kernel_size,
|
|
stride=stride)
|
|
|
|
def init_weights(self):
|
|
"""Initiate the parameters.
|
|
"""
|
|
# initialize
|
|
for m in self.sublayers():
|
|
if isinstance(m, nn.Conv2D):
|
|
conv_init(m)
|
|
elif isinstance(m, nn.BatchNorm2D):
|
|
weight_init_(m.weight, 'Normal', std=0.02, mean=1.0)
|
|
nn.initializer.Constant(value=0.0)(m.bias)
|
|
|
|
def forward(self, x):
|
|
# Input dim: (N,C,T,V)
|
|
res = self.residual(x)
|
|
branch_outs = []
|
|
for tempconv in self.branches:
|
|
out = tempconv(x)
|
|
branch_outs.append(out)
|
|
|
|
out = paddle.concat(branch_outs, axis=1)
|
|
out += res
|
|
return out
|
|
|
|
|
|
class unit_tcn(nn.Layer):
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=9, stride=1):
|
|
super(unit_tcn, self).__init__()
|
|
pad = int((kernel_size - 1) / 2)
|
|
self.conv = nn.Conv2D(in_channels,
|
|
out_channels,
|
|
kernel_size=(kernel_size, 1),
|
|
padding=(pad, 0),
|
|
stride=(stride, 1))
|
|
|
|
self.bn = nn.BatchNorm2D(out_channels)
|
|
self.relu = nn.ReLU()
|
|
conv_init(self.conv)
|
|
bn_init(self.bn, 1)
|
|
|
|
def forward(self, x):
|
|
x = self.bn(self.conv(x))
|
|
return x
|
|
|
|
|
|
class unit_gcn(nn.Layer):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
A,
|
|
coff_embedding=4,
|
|
adaptive=True,
|
|
residual=True):
|
|
super(unit_gcn, self).__init__()
|
|
inter_channels = out_channels // coff_embedding
|
|
self.inter_c = inter_channels
|
|
self.out_c = out_channels
|
|
self.in_c = in_channels
|
|
self.adaptive = adaptive
|
|
self.num_subset = A.shape[0]
|
|
self.convs = nn.LayerList()
|
|
|
|
for i in range(self.num_subset):
|
|
self.convs.append(CTRGC(in_channels, out_channels))
|
|
|
|
if residual:
|
|
if in_channels != out_channels:
|
|
self.down = nn.Sequential(
|
|
nn.Conv2D(in_channels, out_channels, 1),
|
|
nn.BatchNorm2D(out_channels))
|
|
else:
|
|
self.down = lambda x: x
|
|
else:
|
|
self.down = lambda x: 0
|
|
if self.adaptive:
|
|
pa_param = paddle.ParamAttr(
|
|
initializer=paddle.nn.initializer.Assign(A.astype(np.float32)))
|
|
self.PA = paddle.create_parameter(shape=A.shape,
|
|
dtype='float32',
|
|
attr=pa_param)
|
|
else:
|
|
A_tensor = paddle.to_tensor(A, dtype="float32")
|
|
self.A = paddle.create_parameter(
|
|
shape=A_tensor.shape,
|
|
dtype='float32',
|
|
default_initializer=paddle.nn.initializer.Assign(A_tensor))
|
|
self.A.stop_gradient = True
|
|
alpha_tensor = paddle.to_tensor(np.zeros(1), dtype="float32")
|
|
self.alpha = paddle.create_parameter(
|
|
shape=alpha_tensor.shape,
|
|
dtype='float32',
|
|
default_initializer=paddle.nn.initializer.Assign(alpha_tensor))
|
|
self.bn = nn.BatchNorm2D(out_channels)
|
|
self.soft = nn.Softmax(-2)
|
|
self.relu = nn.ReLU()
|
|
|
|
def init_weights(self):
|
|
for m in self.sublayers():
|
|
if isinstance(m, nn.Conv2D):
|
|
conv_init(m)
|
|
elif isinstance(m, nn.BatchNorm2D):
|
|
bn_init(m, 1)
|
|
bn_init(self.bn, 1e-6)
|
|
|
|
def forward(self, x):
|
|
y = None
|
|
if self.adaptive:
|
|
A = self.PA
|
|
else:
|
|
A = self.A.cuda(x.get_device())
|
|
for i in range(self.num_subset):
|
|
z = self.convs[i](x, A[i], self.alpha)
|
|
y = z + y if y is not None else z
|
|
y = self.bn(y)
|
|
y += self.down(x)
|
|
y = self.relu(y)
|
|
return y
|
|
|
|
|
|
class TCN_GCN_unit(nn.Layer):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
A,
|
|
stride=1,
|
|
residual=True,
|
|
adaptive=True,
|
|
kernel_size=5,
|
|
dilations=[1, 2]):
|
|
super(TCN_GCN_unit, self).__init__()
|
|
self.gcn1 = unit_gcn(in_channels, out_channels, A, adaptive=adaptive)
|
|
self.tcn1 = MultiScale_TemporalConv(out_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
dilations=dilations,
|
|
residual=False)
|
|
self.relu = nn.ReLU()
|
|
if not residual:
|
|
self.residual = lambda x: 0
|
|
|
|
elif (in_channels == out_channels) and (stride == 1):
|
|
self.residual = lambda x: x
|
|
|
|
else:
|
|
self.residual = unit_tcn(in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=stride)
|
|
|
|
def forward(self, x):
|
|
y = self.relu(self.tcn1(self.gcn1(x)) + self.residual(x))
|
|
return y
|
|
|
|
|
|
class NTUDGraph:
|
|
|
|
def __init__(self, labeling_mode='spatial'):
|
|
num_node = 25
|
|
self_link = [(i, i) for i in range(num_node)]
|
|
inward_ori_index = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21), (6, 5),
|
|
(7, 6), (8, 7), (9, 21), (10, 9), (11, 10),
|
|
(12, 11), (13, 1), (14, 13), (15, 14), (16, 15),
|
|
(17, 1), (18, 17), (19, 18), (20, 19), (22, 23),
|
|
(23, 8), (24, 25), (25, 12)]
|
|
inward = [(i - 1, j - 1) for (i, j) in inward_ori_index]
|
|
outward = [(j, i) for (i, j) in inward]
|
|
neighbor = inward + outward
|
|
|
|
self.num_node = num_node
|
|
self.self_link = self_link
|
|
self.inward = inward
|
|
self.outward = outward
|
|
self.neighbor = neighbor
|
|
self.A = self.get_adjacency_matrix(labeling_mode)
|
|
|
|
def edge2mat(self, link, num_node):
|
|
A = np.zeros((num_node, num_node))
|
|
for i, j in link:
|
|
A[j, i] = 1
|
|
return A
|
|
|
|
def normalize_digraph(self, A):
|
|
Dl = np.sum(A, 0)
|
|
h, w = A.shape
|
|
Dn = np.zeros((w, w))
|
|
for i in range(w):
|
|
if Dl[i] > 0:
|
|
Dn[i, i] = Dl[i]**(-1)
|
|
AD = np.dot(A, Dn)
|
|
return AD
|
|
|
|
def get_spatial_graph(self, num_node, self_link, inward, outward):
|
|
I = self.edge2mat(self_link, num_node)
|
|
In = self.normalize_digraph(self.edge2mat(inward, num_node))
|
|
Out = self.normalize_digraph(self.edge2mat(outward, num_node))
|
|
A = np.stack((I, In, Out))
|
|
return A
|
|
|
|
def get_adjacency_matrix(self, labeling_mode=None):
|
|
if labeling_mode is None:
|
|
return self.A
|
|
if labeling_mode == 'spatial':
|
|
A = self.get_spatial_graph(self.num_node, self.self_link,
|
|
self.inward, self.outward)
|
|
else:
|
|
raise ValueError()
|
|
return A
|
|
|
|
|
|
@BACKBONES.register()
|
|
class CTRGCN(nn.Layer):
|
|
"""
|
|
CTR-GCN model from:
|
|
`"Channel-wise Topology Refinement Graph Convolution for Skeleton-Based Action Recognition" <https://arxiv.org/abs/2107.12213>`_
|
|
Args:
|
|
num_point: int, numbers of sketeton point.
|
|
num_person: int, numbers of person.
|
|
base_channel: int, model's hidden dim.
|
|
graph: str, sketeton adjacency matrix name.
|
|
graph_args: dict, sketeton adjacency graph class args.
|
|
in_channels: int, channels of vertex coordinate. 2 for (x,y), 3 for (x,y,z). Default 3.
|
|
adaptive: bool, if adjacency matrix can adaptive.
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_point=25,
|
|
num_person=2,
|
|
base_channel=64,
|
|
graph='ntu_rgb_d',
|
|
graph_args=dict(),
|
|
in_channels=3,
|
|
adaptive=True):
|
|
super(CTRGCN, self).__init__()
|
|
|
|
if graph == 'ntu_rgb_d':
|
|
self.graph = NTUDGraph(**graph_args)
|
|
else:
|
|
raise ValueError()
|
|
|
|
A = self.graph.A # 3,25,25
|
|
|
|
self.num_point = num_point
|
|
self.data_bn = nn.BatchNorm1D(num_person * in_channels * num_point)
|
|
self.base_channel = base_channel
|
|
|
|
self.l1 = TCN_GCN_unit(in_channels,
|
|
self.base_channel,
|
|
A,
|
|
residual=False,
|
|
adaptive=adaptive)
|
|
self.l2 = TCN_GCN_unit(self.base_channel,
|
|
self.base_channel,
|
|
A,
|
|
adaptive=adaptive)
|
|
self.l3 = TCN_GCN_unit(self.base_channel,
|
|
self.base_channel,
|
|
A,
|
|
adaptive=adaptive)
|
|
self.l4 = TCN_GCN_unit(self.base_channel,
|
|
self.base_channel,
|
|
A,
|
|
adaptive=adaptive)
|
|
self.l5 = TCN_GCN_unit(self.base_channel,
|
|
self.base_channel * 2,
|
|
A,
|
|
stride=2,
|
|
adaptive=adaptive)
|
|
self.l6 = TCN_GCN_unit(self.base_channel * 2,
|
|
self.base_channel * 2,
|
|
A,
|
|
adaptive=adaptive)
|
|
self.l7 = TCN_GCN_unit(self.base_channel * 2,
|
|
self.base_channel * 2,
|
|
A,
|
|
adaptive=adaptive)
|
|
self.l8 = TCN_GCN_unit(self.base_channel * 2,
|
|
self.base_channel * 4,
|
|
A,
|
|
stride=2,
|
|
adaptive=adaptive)
|
|
self.l9 = TCN_GCN_unit(self.base_channel * 4,
|
|
self.base_channel * 4,
|
|
A,
|
|
adaptive=adaptive)
|
|
self.l10 = TCN_GCN_unit(self.base_channel * 4,
|
|
self.base_channel * 4,
|
|
A,
|
|
adaptive=adaptive)
|
|
|
|
def init_weights(self):
|
|
bn_init(self.data_bn, 1)
|
|
|
|
def forward(self, x):
|
|
N, C, T, V, M = x.shape
|
|
x = paddle.transpose(x, perm=[0, 4, 3, 1, 2])
|
|
x = paddle.reshape(x, (N, M * V * C, T))
|
|
|
|
x = self.data_bn(x)
|
|
|
|
x = paddle.reshape(x, (N, M, V, C, T))
|
|
x = paddle.transpose(x, perm=(0, 1, 3, 4, 2))
|
|
|
|
x = paddle.reshape(x, (N * M, C, T, V))
|
|
|
|
x = self.l1(x)
|
|
x = self.l2(x)
|
|
x = self.l3(x)
|
|
x = self.l4(x)
|
|
x = self.l5(x)
|
|
x = self.l6(x)
|
|
x = self.l7(x)
|
|
x = self.l8(x)
|
|
x = self.l9(x)
|
|
x = self.l10(x)
|
|
|
|
return x, N, M
|