You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Time : 2019/8/23 21:57
|
|
# @Author : zhoujun
|
|
from addict import Dict
|
|
from paddle import nn
|
|
import paddle.nn.functional as F
|
|
|
|
from models.backbone import build_backbone
|
|
from models.neck import build_neck
|
|
from models.head import build_head
|
|
|
|
|
|
class Model(nn.Layer):
|
|
def __init__(self, model_config: dict):
|
|
"""
|
|
PANnet
|
|
:param model_config: 模型配置
|
|
"""
|
|
super().__init__()
|
|
model_config = Dict(model_config)
|
|
backbone_type = model_config.backbone.pop("type")
|
|
neck_type = model_config.neck.pop("type")
|
|
head_type = model_config.head.pop("type")
|
|
self.backbone = build_backbone(backbone_type, **model_config.backbone)
|
|
self.neck = build_neck(
|
|
neck_type, in_channels=self.backbone.out_channels, **model_config.neck
|
|
)
|
|
self.head = build_head(
|
|
head_type, in_channels=self.neck.out_channels, **model_config.head
|
|
)
|
|
self.name = f"{backbone_type}_{neck_type}_{head_type}"
|
|
|
|
def forward(self, x):
|
|
_, _, H, W = x.shape
|
|
backbone_out = self.backbone(x)
|
|
neck_out = self.neck(backbone_out)
|
|
y = self.head(neck_out)
|
|
y = F.interpolate(y, size=(H, W), mode="bilinear", align_corners=True)
|
|
return y
|