You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
1.6 KiB
Python
61 lines
1.6 KiB
Python
import numpy as np
|
|
from paddle.vision.datasets import Cifar100
|
|
from paddle.vision.transforms import Normalize
|
|
import signal
|
|
import os
|
|
from paddle.io import Dataset, DataLoader, DistributedBatchSampler
|
|
|
|
|
|
def term_mp(sig_num, frame):
|
|
"""kill all child processes"""
|
|
pid = os.getpid()
|
|
pgid = os.getpgid(os.getpid())
|
|
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
|
|
os.killpg(pgid, signal.SIGKILL)
|
|
return
|
|
|
|
|
|
def build_dataloader(mode, batch_size=4, seed=None, num_workers=0, device="gpu:0"):
|
|
normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format="HWC")
|
|
|
|
if mode.lower() == "train":
|
|
dataset = Cifar100(mode=mode, transform=normalize)
|
|
elif mode.lower() in ["test", "valid", "eval"]:
|
|
dataset = Cifar100(mode="test", transform=normalize)
|
|
else:
|
|
raise ValueError(f"{mode} should be one of ['train', 'test']")
|
|
|
|
# define batch sampler
|
|
batch_sampler = DistributedBatchSampler(
|
|
dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True
|
|
)
|
|
|
|
data_loader = DataLoader(
|
|
dataset=dataset,
|
|
batch_sampler=batch_sampler,
|
|
places=device,
|
|
num_workers=num_workers,
|
|
return_list=True,
|
|
use_shared_memory=False,
|
|
)
|
|
|
|
# support exit using ctrl+c
|
|
signal.signal(signal.SIGINT, term_mp)
|
|
signal.signal(signal.SIGTERM, term_mp)
|
|
|
|
return data_loader
|
|
|
|
|
|
# cifar100 = Cifar100(mode='train', transform=normalize)
|
|
|
|
# data = cifar100[0]
|
|
|
|
# image, label = data
|
|
|
|
# reader = build_dataloader('train')
|
|
|
|
# for idx, data in enumerate(reader):
|
|
# print(idx, data[0].shape, data[1].shape)
|
|
# if idx >= 10:
|
|
# break
|