You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
133 lines
4.1 KiB
Python
133 lines
4.1 KiB
Python
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import signal
|
|
import os
|
|
import paddle
|
|
from paddle.io import DataLoader, DistributedBatchSampler
|
|
from .registry import DATASETS, PIPELINES
|
|
from ..utils.build_utils import build
|
|
from .pipelines.compose import Compose
|
|
from paddlevideo.utils import get_logger
|
|
from paddlevideo.utils.multigrid import DistributedShortSampler
|
|
import numpy as np
|
|
|
|
logger = get_logger("paddlevideo")
|
|
|
|
|
|
def build_pipeline(cfg):
|
|
"""Build pipeline.
|
|
Args:
|
|
cfg (dict): root config dict.
|
|
"""
|
|
if cfg == None:
|
|
return
|
|
return Compose(cfg)
|
|
|
|
|
|
def build_dataset(cfg):
|
|
"""Build dataset.
|
|
Args:
|
|
cfg (dict): root config dict.
|
|
|
|
Returns:
|
|
dataset: dataset.
|
|
"""
|
|
#XXX: ugly code here!
|
|
cfg_dataset, cfg_pipeline = cfg
|
|
cfg_dataset.pipeline = build_pipeline(cfg_pipeline)
|
|
dataset = build(cfg_dataset, DATASETS, key="format")
|
|
return dataset
|
|
|
|
|
|
def build_batch_pipeline(cfg):
|
|
|
|
batch_pipeline = build(cfg, PIPELINES)
|
|
return batch_pipeline
|
|
|
|
|
|
def build_dataloader(dataset,
|
|
batch_size,
|
|
num_workers,
|
|
places,
|
|
shuffle=True,
|
|
drop_last=True,
|
|
multigrid=False,
|
|
collate_fn_cfg=None,
|
|
**kwargs):
|
|
"""Build Paddle Dataloader.
|
|
|
|
XXX explain how the dataloader work!
|
|
|
|
Args:
|
|
dataset (paddle.dataset): A PaddlePaddle dataset object.
|
|
batch_size (int): batch size on single card.
|
|
num_worker (int): num_worker
|
|
shuffle(bool): whether to shuffle the data at every epoch.
|
|
"""
|
|
if multigrid:
|
|
sampler = DistributedShortSampler(dataset,
|
|
batch_sizes=batch_size,
|
|
shuffle=True,
|
|
drop_last=True)
|
|
else:
|
|
sampler = DistributedBatchSampler(dataset,
|
|
batch_size=batch_size,
|
|
shuffle=shuffle,
|
|
drop_last=drop_last)
|
|
|
|
#NOTE(shipping): when switch the mix operator on, such as: mixup, cutmix.
|
|
# batch like: [[img, label, attibute, ...], [imgs, label, attribute, ...], ...] will recollate to:
|
|
# [[img, img, ...], [label, label, ...], [attribute, attribute, ...], ...] as using numpy.transpose.
|
|
|
|
def mix_collate_fn(batch):
|
|
pipeline = build_batch_pipeline(collate_fn_cfg)
|
|
batch = pipeline(batch)
|
|
slots = []
|
|
for items in batch:
|
|
for i, item in enumerate(items):
|
|
if len(slots) < len(items):
|
|
slots.append([item])
|
|
else:
|
|
slots[i].append(item)
|
|
return [np.stack(slot, axis=0) for slot in slots]
|
|
|
|
#if collate_fn_cfg is not None:
|
|
#ugly code here. collate_fn is mix op config
|
|
# collate_fn = mix_collate_fn(collate_fn_cfg)
|
|
|
|
data_loader = DataLoader(
|
|
dataset,
|
|
batch_sampler=sampler,
|
|
places=places,
|
|
num_workers=num_workers,
|
|
collate_fn=mix_collate_fn if collate_fn_cfg is not None else None,
|
|
return_list=True,
|
|
**kwargs)
|
|
|
|
return data_loader
|
|
|
|
|
|
def term_mp(sig_num, frame):
|
|
""" kill all child processes
|
|
"""
|
|
pid = os.getpid()
|
|
pgid = os.getpgid(os.getpid())
|
|
logger.info("main proc {} exit, kill process group " "{}".format(pid, pgid))
|
|
os.killpg(pgid, signal.SIGKILL)
|
|
return
|
|
|
|
|
|
signal.signal(signal.SIGINT, term_mp)
|
|
signal.signal(signal.SIGTERM, term_mp)
|