# -*- coding: utf-8 -*-
# @Time    : 2019/8/23 21:52
# @Author  : zhoujun
import copy

import PIL
import numpy as np
import paddle
from paddle.io import DataLoader, DistributedBatchSampler, BatchSampler

from paddle.vision import transforms


def get_dataset(data_path, module_name, transform, dataset_args):
    """
    获取训练dataset
    :param data_path: dataset文件列表,每个文件内以如下格式存储 ‘path/to/Crop_img\tlabel’
    :param module_name: 所使用的自定义dataset名称,目前只支持data_loaders.ImageDataset
    :param transform: 该数据集使用的transforms
    :param dataset_args: module_name的参数
    :return: 如果data_path列表不为空,返回对于的ConcatDataset对象,否则None
    """
    from . import dataset

    s_dataset = getattr(dataset, module_name)(
        transform=transform, data_path=data_path, **dataset_args
    )
    return s_dataset


def get_transforms(transforms_config):
    tr_list = []
    for item in transforms_config:
        if "args" not in item:
            args = {}
        else:
            args = item["args"]
        cls = getattr(transforms, item["type"])(**args)
        tr_list.append(cls)
    tr_list = transforms.Compose(tr_list)
    return tr_list


class ICDARCollectFN:
    def __init__(self, *args, **kwargs):
        pass

    def __call__(self, batch):
        data_dict = {}
        to_tensor_keys = []
        for sample in batch:
            for k, v in sample.items():
                if k not in data_dict:
                    data_dict[k] = []
                if isinstance(v, (np.ndarray, paddle.Tensor, PIL.Image.Image)):
                    if k not in to_tensor_keys:
                        to_tensor_keys.append(k)
                data_dict[k].append(v)
        for k in to_tensor_keys:
            data_dict[k] = paddle.stack(data_dict[k], 0)
        return data_dict


def get_dataloader(module_config, distributed=False):
    if module_config is None:
        return None
    config = copy.deepcopy(module_config)
    dataset_args = config["dataset"]["args"]
    if "transforms" in dataset_args:
        img_transfroms = get_transforms(dataset_args.pop("transforms"))
    else:
        img_transfroms = None
    # 创建数据集
    dataset_name = config["dataset"]["type"]
    data_path = dataset_args.pop("data_path")
    if data_path == None:
        return None

    data_path = [x for x in data_path if x is not None]
    if len(data_path) == 0:
        return None
    if (
        "collate_fn" not in config["loader"]
        or config["loader"]["collate_fn"] is None
        or len(config["loader"]["collate_fn"]) == 0
    ):
        config["loader"]["collate_fn"] = None
    else:
        config["loader"]["collate_fn"] = eval(config["loader"]["collate_fn"])()

    _dataset = get_dataset(
        data_path=data_path,
        module_name=dataset_name,
        transform=img_transfroms,
        dataset_args=dataset_args,
    )
    sampler = None
    if distributed:
        # 3)使用DistributedSampler
        batch_sampler = DistributedBatchSampler(
            dataset=_dataset,
            batch_size=config["loader"].pop("batch_size"),
            shuffle=config["loader"].pop("shuffle"),
        )
    else:
        batch_sampler = BatchSampler(
            dataset=_dataset,
            batch_size=config["loader"].pop("batch_size"),
            shuffle=config["loader"].pop("shuffle"),
        )
    loader = DataLoader(
        dataset=_dataset, batch_sampler=batch_sampler, **config["loader"]
    )
    return loader