|
|
# -*- coding: utf-8 -*-
|
|
|
# @Time : 2019/8/23 21:52
|
|
|
# @Author : zhoujun
|
|
|
import copy
|
|
|
|
|
|
import PIL
|
|
|
import numpy as np
|
|
|
import paddle
|
|
|
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
|
|
|
|
|
|
from paddle.vision import transforms
|
|
|
|
|
|
|
|
|
def get_dataset(data_path, module_name, transform, dataset_args):
|
|
|
"""
|
|
|
获取训练dataset
|
|
|
:param data_path: dataset文件列表,每个文件内以如下格式存储 ‘path/to/Crop_img\tlabel’
|
|
|
:param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset
|
|
|
:param transform: 该数据集使用的transforms
|
|
|
:param dataset_args: module_name的参数
|
|
|
:return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None
|
|
|
"""
|
|
|
from . import dataset
|
|
|
|
|
|
s_dataset = getattr(dataset, module_name)(
|
|
|
transform=transform, data_path=data_path, **dataset_args
|
|
|
)
|
|
|
return s_dataset
|
|
|
|
|
|
|
|
|
def get_transforms(transforms_config):
|
|
|
tr_list = []
|
|
|
for item in transforms_config:
|
|
|
if "args" not in item:
|
|
|
args = {}
|
|
|
else:
|
|
|
args = item["args"]
|
|
|
cls = getattr(transforms, item["type"])(**args)
|
|
|
tr_list.append(cls)
|
|
|
tr_list = transforms.Compose(tr_list)
|
|
|
return tr_list
|
|
|
|
|
|
|
|
|
class ICDARCollectFN:
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
pass
|
|
|
|
|
|
def __call__(self, batch):
|
|
|
data_dict = {}
|
|
|
to_tensor_keys = []
|
|
|
for sample in batch:
|
|
|
for k, v in sample.items():
|
|
|
if k not in data_dict:
|
|
|
data_dict[k] = []
|
|
|
if isinstance(v, (np.ndarray, paddle.Tensor, PIL.Image.Image)):
|
|
|
if k not in to_tensor_keys:
|
|
|
to_tensor_keys.append(k)
|
|
|
data_dict[k].append(v)
|
|
|
for k in to_tensor_keys:
|
|
|
data_dict[k] = paddle.stack(data_dict[k], 0)
|
|
|
return data_dict
|
|
|
|
|
|
|
|
|
def get_dataloader(module_config, distributed=False):
|
|
|
if module_config is None:
|
|
|
return None
|
|
|
config = copy.deepcopy(module_config)
|
|
|
dataset_args = config["dataset"]["args"]
|
|
|
if "transforms" in dataset_args:
|
|
|
img_transfroms = get_transforms(dataset_args.pop("transforms"))
|
|
|
else:
|
|
|
img_transfroms = None
|
|
|
# 创建数据集
|
|
|
dataset_name = config["dataset"]["type"]
|
|
|
data_path = dataset_args.pop("data_path")
|
|
|
if data_path == None:
|
|
|
return None
|
|
|
|
|
|
data_path = [x for x in data_path if x is not None]
|
|
|
if len(data_path) == 0:
|
|
|
return None
|
|
|
if (
|
|
|
"collate_fn" not in config["loader"]
|
|
|
or config["loader"]["collate_fn"] is None
|
|
|
or len(config["loader"]["collate_fn"]) == 0
|
|
|
):
|
|
|
config["loader"]["collate_fn"] = None
|
|
|
else:
|
|
|
config["loader"]["collate_fn"] = eval(config["loader"]["collate_fn"])()
|
|
|
|
|
|
_dataset = get_dataset(
|
|
|
data_path=data_path,
|
|
|
module_name=dataset_name,
|
|
|
transform=img_transfroms,
|
|
|
dataset_args=dataset_args,
|
|
|
)
|
|
|
sampler = None
|
|
|
if distributed:
|
|
|
# 3)使用DistributedSampler
|
|
|
batch_sampler = DistributedBatchSampler(
|
|
|
dataset=_dataset,
|
|
|
batch_size=config["loader"].pop("batch_size"),
|
|
|
shuffle=config["loader"].pop("shuffle"),
|
|
|
)
|
|
|
else:
|
|
|
batch_sampler = BatchSampler(
|
|
|
dataset=_dataset,
|
|
|
batch_size=config["loader"].pop("batch_size"),
|
|
|
shuffle=config["loader"].pop("shuffle"),
|
|
|
)
|
|
|
loader = DataLoader(
|
|
|
dataset=_dataset, batch_sampler=batch_sampler, **config["loader"]
|
|
|
)
|
|
|
return loader
|