You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

61 lines
1.6 KiB
Python

import numpy as np
from paddle.vision.datasets import Cifar100
from paddle.vision.transforms import Normalize
import signal
import os
from paddle.io import Dataset, DataLoader, DistributedBatchSampler
def term_mp(sig_num, frame):
"""kill all child processes"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL)
return
def build_dataloader(mode, batch_size=4, seed=None, num_workers=0, device="gpu:0"):
normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format="HWC")
if mode.lower() == "train":
dataset = Cifar100(mode=mode, transform=normalize)
elif mode.lower() in ["test", "valid", "eval"]:
dataset = Cifar100(mode="test", transform=normalize)
else:
raise ValueError(f"{mode} should be one of ['train', 'test']")
# define batch sampler
batch_sampler = DistributedBatchSampler(
dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True
)
data_loader = DataLoader(
dataset=dataset,
batch_sampler=batch_sampler,
places=device,
num_workers=num_workers,
return_list=True,
use_shared_memory=False,
)
# support exit using ctrl+c
signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
return data_loader
# cifar100 = Cifar100(mode='train', transform=normalize)
# data = cifar100[0]
# image, label = data
# reader = build_dataloader('train')
# for idx, data in enumerate(reader):
# print(idx, data[0].shape, data[1].shape)
# if idx >= 10:
# break