You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
133 lines
4.6 KiB
Python
133 lines
4.6 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Time : 2019/12/4 14:54
|
|
# @Author : zhoujun
|
|
import paddle
|
|
from paddle import nn, ParamAttr
|
|
|
|
|
|
class DBHead(nn.Layer):
|
|
def __init__(self, in_channels, out_channels, k=50):
|
|
super().__init__()
|
|
self.k = k
|
|
self.binarize = nn.Sequential(
|
|
nn.Conv2D(
|
|
in_channels,
|
|
in_channels // 4,
|
|
3,
|
|
padding=1,
|
|
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
|
|
),
|
|
nn.BatchNorm2D(
|
|
in_channels // 4,
|
|
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
|
|
bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
|
|
),
|
|
nn.ReLU(),
|
|
nn.Conv2DTranspose(
|
|
in_channels // 4,
|
|
in_channels // 4,
|
|
2,
|
|
2,
|
|
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
|
|
),
|
|
nn.BatchNorm2D(
|
|
in_channels // 4,
|
|
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
|
|
bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
|
|
),
|
|
nn.ReLU(),
|
|
nn.Conv2DTranspose(
|
|
in_channels // 4, 1, 2, 2, weight_attr=nn.initializer.KaimingNormal()
|
|
),
|
|
nn.Sigmoid(),
|
|
)
|
|
|
|
self.thresh = self._init_thresh(in_channels)
|
|
|
|
def forward(self, x):
|
|
shrink_maps = self.binarize(x)
|
|
threshold_maps = self.thresh(x)
|
|
if self.training:
|
|
binary_maps = self.step_function(shrink_maps, threshold_maps)
|
|
y = paddle.concat((shrink_maps, threshold_maps, binary_maps), axis=1)
|
|
else:
|
|
y = paddle.concat((shrink_maps, threshold_maps), axis=1)
|
|
return y
|
|
|
|
def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
|
|
in_channels = inner_channels
|
|
if serial:
|
|
in_channels += 1
|
|
self.thresh = nn.Sequential(
|
|
nn.Conv2D(
|
|
in_channels,
|
|
inner_channels // 4,
|
|
3,
|
|
padding=1,
|
|
bias_attr=bias,
|
|
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
|
|
),
|
|
nn.BatchNorm2D(
|
|
inner_channels // 4,
|
|
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
|
|
bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
|
|
),
|
|
nn.ReLU(),
|
|
self._init_upsample(
|
|
inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias
|
|
),
|
|
nn.BatchNorm2D(
|
|
inner_channels // 4,
|
|
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
|
|
bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
|
|
),
|
|
nn.ReLU(),
|
|
self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
|
|
nn.Sigmoid(),
|
|
)
|
|
return self.thresh
|
|
|
|
def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
|
|
if smooth:
|
|
inter_out_channels = out_channels
|
|
if out_channels == 1:
|
|
inter_out_channels = in_channels
|
|
module_list = [
|
|
nn.Upsample(scale_factor=2, mode="nearest"),
|
|
nn.Conv2D(
|
|
in_channels,
|
|
inter_out_channels,
|
|
3,
|
|
1,
|
|
1,
|
|
bias_attr=bias,
|
|
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
|
|
),
|
|
]
|
|
if out_channels == 1:
|
|
module_list.append(
|
|
nn.Conv2D(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=1,
|
|
bias_attr=True,
|
|
weight_attr=ParamAttr(
|
|
initializer=nn.initializer.KaimingNormal()
|
|
),
|
|
)
|
|
)
|
|
return nn.Sequential(module_list)
|
|
else:
|
|
return nn.Conv2DTranspose(
|
|
in_channels,
|
|
out_channels,
|
|
2,
|
|
2,
|
|
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
|
|
)
|
|
|
|
def step_function(self, x, y):
|
|
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
|