You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

115 lines
3.5 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:52
# @Author : zhoujun
import copy
import PIL
import numpy as np
import paddle
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler
from paddle.vision import transforms
def get_dataset(data_path, module_name, transform, dataset_args):
"""
获取训练dataset
:param data_path: dataset文件列表每个文件内以如下格式存储 path/to/Crop_img\tlabel
:param module_name: 所使用的自定义dataset名称目前只支持data_loaders.ImageDataset
:param transform: 该数据集使用的transforms
:param dataset_args: module_name的参数
:return: 如果data_path列表不为空返回对于的ConcatDataset对象否则None
"""
from . import dataset
s_dataset = getattr(dataset, module_name)(
transform=transform, data_path=data_path, **dataset_args
)
return s_dataset
def get_transforms(transforms_config):
tr_list = []
for item in transforms_config:
if "args" not in item:
args = {}
else:
args = item["args"]
cls = getattr(transforms, item["type"])(**args)
tr_list.append(cls)
tr_list = transforms.Compose(tr_list)
return tr_list
class ICDARCollectFN:
def __init__(self, *args, **kwargs):
pass
def __call__(self, batch):
data_dict = {}
to_tensor_keys = []
for sample in batch:
for k, v in sample.items():
if k not in data_dict:
data_dict[k] = []
if isinstance(v, (np.ndarray, paddle.Tensor, PIL.Image.Image)):
if k not in to_tensor_keys:
to_tensor_keys.append(k)
data_dict[k].append(v)
for k in to_tensor_keys:
data_dict[k] = paddle.stack(data_dict[k], 0)
return data_dict
def get_dataloader(module_config, distributed=False):
if module_config is None:
return None
config = copy.deepcopy(module_config)
dataset_args = config["dataset"]["args"]
if "transforms" in dataset_args:
img_transfroms = get_transforms(dataset_args.pop("transforms"))
else:
img_transfroms = None
# 创建数据集
dataset_name = config["dataset"]["type"]
data_path = dataset_args.pop("data_path")
if data_path == None:
return None
data_path = [x for x in data_path if x is not None]
if len(data_path) == 0:
return None
if (
"collate_fn" not in config["loader"]
or config["loader"]["collate_fn"] is None
or len(config["loader"]["collate_fn"]) == 0
):
config["loader"]["collate_fn"] = None
else:
config["loader"]["collate_fn"] = eval(config["loader"]["collate_fn"])()
_dataset = get_dataset(
data_path=data_path,
module_name=dataset_name,
transform=img_transfroms,
dataset_args=dataset_args,
)
sampler = None
if distributed:
# 3使用DistributedSampler
batch_sampler = DistributedBatchSampler(
dataset=_dataset,
batch_size=config["loader"].pop("batch_size"),
shuffle=config["loader"].pop("shuffle"),
)
else:
batch_sampler = BatchSampler(
dataset=_dataset,
batch_size=config["loader"].pop("batch_size"),
shuffle=config["loader"].pop("shuffle"),
)
loader = DataLoader(
dataset=_dataset, batch_sampler=batch_sampler, **config["loader"]
)
return loader