You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

133 lines
4.6 KiB
Python

# -*- coding: utf-8 -*-
# @Time : 2019/12/4 14:54
# @Author : zhoujun
import paddle
from paddle import nn, ParamAttr
class DBHead(nn.Layer):
def __init__(self, in_channels, out_channels, k=50):
super().__init__()
self.k = k
self.binarize = nn.Sequential(
nn.Conv2D(
in_channels,
in_channels // 4,
3,
padding=1,
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
),
nn.BatchNorm2D(
in_channels // 4,
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
),
nn.ReLU(),
nn.Conv2DTranspose(
in_channels // 4,
in_channels // 4,
2,
2,
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
),
nn.BatchNorm2D(
in_channels // 4,
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
),
nn.ReLU(),
nn.Conv2DTranspose(
in_channels // 4, 1, 2, 2, weight_attr=nn.initializer.KaimingNormal()
),
nn.Sigmoid(),
)
self.thresh = self._init_thresh(in_channels)
def forward(self, x):
shrink_maps = self.binarize(x)
threshold_maps = self.thresh(x)
if self.training:
binary_maps = self.step_function(shrink_maps, threshold_maps)
y = paddle.concat((shrink_maps, threshold_maps, binary_maps), axis=1)
else:
y = paddle.concat((shrink_maps, threshold_maps), axis=1)
return y
def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
in_channels = inner_channels
if serial:
in_channels += 1
self.thresh = nn.Sequential(
nn.Conv2D(
in_channels,
inner_channels // 4,
3,
padding=1,
bias_attr=bias,
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
),
nn.BatchNorm2D(
inner_channels // 4,
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
),
nn.ReLU(),
self._init_upsample(
inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias
),
nn.BatchNorm2D(
inner_channels // 4,
weight_attr=ParamAttr(initializer=nn.initializer.Constant(1)),
bias_attr=ParamAttr(initializer=nn.initializer.Constant(1e-4)),
),
nn.ReLU(),
self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
nn.Sigmoid(),
)
return self.thresh
def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
if smooth:
inter_out_channels = out_channels
if out_channels == 1:
inter_out_channels = in_channels
module_list = [
nn.Upsample(scale_factor=2, mode="nearest"),
nn.Conv2D(
in_channels,
inter_out_channels,
3,
1,
1,
bias_attr=bias,
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
),
]
if out_channels == 1:
module_list.append(
nn.Conv2D(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=1,
bias_attr=True,
weight_attr=ParamAttr(
initializer=nn.initializer.KaimingNormal()
),
)
)
return nn.Sequential(module_list)
else:
return nn.Conv2DTranspose(
in_channels,
out_channels,
2,
2,
weight_attr=ParamAttr(initializer=nn.initializer.KaimingNormal()),
)
def step_function(self, x, y):
return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))