115 lines
3.5 KiB
Python

# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:52
# @Author : zhoujun
import copy
import PIL
import numpy as np
import paddle
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
from paddle.vision import transforms
def get_dataset(data_path, module_name, transform, dataset_args):
"""
获取训练dataset
:param data_path: dataset文件列表每个文件内以如下格式存储 path/to/Crop_img\tlabel
:param module_name: 所使用的自定义dataset名称目前只支持data_loaders.ImageDataset
:param transform: 该数据集使用的transforms
:param dataset_args: module_name的参数
:return: 如果data_path列表不为空返回对于的ConcatDataset对象否则None
"""
from . import dataset
s_dataset = getattr(dataset, module_name)(
transform=transform, data_path=data_path, **dataset_args
)
return s_dataset
def get_transforms(transforms_config):
tr_list = []
for item in transforms_config:
if "args" not in item:
args = {}
else:
args = item["args"]
cls = getattr(transforms, item["type"])(**args)
tr_list.append(cls)
tr_list = transforms.Compose(tr_list)
return tr_list
class ICDARCollectFN:
def __init__(self, *args, **kwargs):
pass
def __call__(self, batch):
data_dict = {}
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, PIL.Image.Image)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
data_dict[k].append(v)
for k in to_tensor_keys:
data_dict[k] = paddle.stack(data_dict[k], 0)
return data_dict
def get_dataloader(module_config, distributed=False):
if module_config is None:
return None
config = copy.deepcopy(module_config)
dataset_args = config["dataset"]["args"]
if "transforms" in dataset_args:
img_transfroms = get_transforms(dataset_args.pop("transforms"))
else:
img_transfroms = None
# 创建数据集
dataset_name = config["dataset"]["type"]
data_path = dataset_args.pop("data_path")
if data_path == None:
return None
data_path = [x for x in data_path if x is not None]
if len(data_path) == 0:
return None
if (
"collate_fn" not in config["loader"]
or config["loader"]["collate_fn"] is None
or len(config["loader"]["collate_fn"]) == 0
):
config["loader"]["collate_fn"] = None
else:
config["loader"]["collate_fn"] = eval(config["loader"]["collate_fn"])()
_dataset = get_dataset(
data_path=data_path,
module_name=dataset_name,
transform=img_transfroms,
dataset_args=dataset_args,
)
sampler = None
if distributed:
# 3使用DistributedSampler
batch_sampler = DistributedBatchSampler(
dataset=_dataset,
batch_size=config["loader"].pop("batch_size"),
shuffle=config["loader"].pop("shuffle"),
)
else:
batch_sampler = BatchSampler(
dataset=_dataset,
batch_size=config["loader"].pop("batch_size"),
shuffle=config["loader"].pop("shuffle"),
)
loader = DataLoader(
dataset=_dataset, batch_sampler=batch_sampler, **config["loader"]
)
return loader