You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
102 lines
3.0 KiB
Python
102 lines
3.0 KiB
Python
8 months ago
|
# -*- coding: utf-8 -*-
|
||
|
# @Time : 2019/12/4 14:39
|
||
|
# @Author : zhoujun
|
||
|
import paddle
|
||
|
import paddle.nn as nn
|
||
|
|
||
|
|
||
|
class BalanceCrossEntropyLoss(nn.Layer):
|
||
|
"""
|
||
|
Balanced cross entropy loss.
|
||
|
Shape:
|
||
|
- Input: :math:`(N, 1, H, W)`
|
||
|
- GT: :math:`(N, 1, H, W)`, same shape as the input
|
||
|
- Mask: :math:`(N, H, W)`, same spatial shape as the input
|
||
|
- Output: scalar.
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, negative_ratio=3.0, eps=1e-6):
|
||
|
super(BalanceCrossEntropyLoss, self).__init__()
|
||
|
self.negative_ratio = negative_ratio
|
||
|
self.eps = eps
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
pred: paddle.Tensor,
|
||
|
gt: paddle.Tensor,
|
||
|
mask: paddle.Tensor,
|
||
|
return_origin=False,
|
||
|
):
|
||
|
"""
|
||
|
Args:
|
||
|
pred: shape :math:`(N, 1, H, W)`, the prediction of network
|
||
|
gt: shape :math:`(N, 1, H, W)`, the target
|
||
|
mask: shape :math:`(N, H, W)`, the mask indicates positive regions
|
||
|
"""
|
||
|
positive = gt * mask
|
||
|
negative = (1 - gt) * mask
|
||
|
positive_count = int(positive.sum())
|
||
|
negative_count = min(
|
||
|
int(negative.sum()), int(positive_count * self.negative_ratio)
|
||
|
)
|
||
|
loss = nn.functional.binary_cross_entropy(pred, gt, reduction="none")
|
||
|
positive_loss = loss * positive
|
||
|
negative_loss = loss * negative
|
||
|
negative_loss, _ = negative_loss.reshape([-1]).topk(negative_count)
|
||
|
|
||
|
balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
|
||
|
positive_count + negative_count + self.eps
|
||
|
)
|
||
|
|
||
|
if return_origin:
|
||
|
return balance_loss, loss
|
||
|
return balance_loss
|
||
|
|
||
|
|
||
|
class DiceLoss(nn.Layer):
|
||
|
"""
|
||
|
Loss function from https://arxiv.org/abs/1707.03237,
|
||
|
where iou computation is introduced heatmap manner to measure the
|
||
|
diversity bwtween tow heatmaps.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, eps=1e-6):
|
||
|
super(DiceLoss, self).__init__()
|
||
|
self.eps = eps
|
||
|
|
||
|
def forward(self, pred: paddle.Tensor, gt, mask, weights=None):
|
||
|
"""
|
||
|
pred: one or two heatmaps of shape (N, 1, H, W),
|
||
|
the losses of tow heatmaps are added together.
|
||
|
gt: (N, 1, H, W)
|
||
|
mask: (N, H, W)
|
||
|
"""
|
||
|
return self._compute(pred, gt, mask, weights)
|
||
|
|
||
|
def _compute(self, pred, gt, mask, weights):
|
||
|
if len(pred.shape) == 4:
|
||
|
pred = pred[:, 0, :, :]
|
||
|
gt = gt[:, 0, :, :]
|
||
|
assert pred.shape == gt.shape
|
||
|
assert pred.shape == mask.shape
|
||
|
if weights is not None:
|
||
|
assert weights.shape == mask.shape
|
||
|
mask = weights * mask
|
||
|
intersection = (pred * gt * mask).sum()
|
||
|
|
||
|
union = (pred * mask).sum() + (gt * mask).sum() + self.eps
|
||
|
loss = 1 - 2.0 * intersection / union
|
||
|
assert loss <= 1
|
||
|
return loss
|
||
|
|
||
|
|
||
|
class MaskL1Loss(nn.Layer):
|
||
|
def __init__(self, eps=1e-6):
|
||
|
super(MaskL1Loss, self).__init__()
|
||
|
self.eps = eps
|
||
|
|
||
|
def forward(self, pred: paddle.Tensor, gt, mask):
|
||
|
loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
|
||
|
return loss
|