You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
61 lines
1.6 KiB
Python
61 lines
1.6 KiB
Python
8 months ago
|
import numpy as np
|
||
|
from paddle.vision.datasets import Cifar100
|
||
|
from paddle.vision.transforms import Normalize
|
||
|
import signal
|
||
|
import os
|
||
|
from paddle.io import Dataset, DataLoader, DistributedBatchSampler
|
||
|
|
||
|
|
||
|
def term_mp(sig_num, frame):
|
||
|
"""kill all child processes"""
|
||
|
pid = os.getpid()
|
||
|
pgid = os.getpgid(os.getpid())
|
||
|
print("main proc {} exit, kill process group " "{}".format(pid, pgid))
|
||
|
os.killpg(pgid, signal.SIGKILL)
|
||
|
return
|
||
|
|
||
|
|
||
|
def build_dataloader(mode, batch_size=4, seed=None, num_workers=0, device="gpu:0"):
|
||
|
normalize = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format="HWC")
|
||
|
|
||
|
if mode.lower() == "train":
|
||
|
dataset = Cifar100(mode=mode, transform=normalize)
|
||
|
elif mode.lower() in ["test", "valid", "eval"]:
|
||
|
dataset = Cifar100(mode="test", transform=normalize)
|
||
|
else:
|
||
|
raise ValueError(f"{mode} should be one of ['train', 'test']")
|
||
|
|
||
|
# define batch sampler
|
||
|
batch_sampler = DistributedBatchSampler(
|
||
|
dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True
|
||
|
)
|
||
|
|
||
|
data_loader = DataLoader(
|
||
|
dataset=dataset,
|
||
|
batch_sampler=batch_sampler,
|
||
|
places=device,
|
||
|
num_workers=num_workers,
|
||
|
return_list=True,
|
||
|
use_shared_memory=False,
|
||
|
)
|
||
|
|
||
|
# support exit using ctrl+c
|
||
|
signal.signal(signal.SIGINT, term_mp)
|
||
|
signal.signal(signal.SIGTERM, term_mp)
|
||
|
|
||
|
return data_loader
|
||
|
|
||
|
|
||
|
# cifar100 = Cifar100(mode='train', transform=normalize)
|
||
|
|
||
|
# data = cifar100[0]
|
||
|
|
||
|
# image, label = data
|
||
|
|
||
|
# reader = build_dataloader('train')
|
||
|
|
||
|
# for idx, data in enumerate(reader):
|
||
|
# print(idx, data[0].shape, data[1].shape)
|
||
|
# if idx >= 10:
|
||
|
# break
|