You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
336 lines
12 KiB
Python
336 lines
12 KiB
Python
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import time
|
|
import os.path as osp
|
|
|
|
import paddle
|
|
import paddle.distributed as dist
|
|
|
|
from ..loader.builder import build_dataloader, build_dataset
|
|
from ..modeling.builder import build_model
|
|
from ..solver import build_lr, build_optimizer
|
|
from ..utils import do_preciseBN
|
|
from paddlevideo.utils import get_logger, coloring
|
|
from paddlevideo.utils import (AverageMeter, build_record, log_batch, log_epoch,
|
|
save, load, mkdir)
|
|
from paddlevideo.utils.multigrid import MultigridSchedule, aggregate_sub_bn_stats, subn_load, subn_save, is_eval_epoch
|
|
|
|
|
|
def construct_loader(cfg, places, validate, precise_bn, num_iters_precise_bn,
|
|
world_size):
|
|
batch_size = cfg.DATASET.get('batch_size', 2)
|
|
train_dataset = build_dataset((cfg.DATASET.train, cfg.PIPELINE.train))
|
|
precise_bn_dataloader_setting = dict(
|
|
batch_size=batch_size,
|
|
num_workers=cfg.DATASET.get('num_workers', 0),
|
|
places=places,
|
|
)
|
|
if precise_bn:
|
|
cfg.DATASET.train.num_samples_precise_bn = num_iters_precise_bn * batch_size * world_size
|
|
precise_bn_dataset = build_dataset((cfg.DATASET.train,
|
|
cfg.PIPELINE.train))
|
|
precise_bn_loader = build_dataloader(precise_bn_dataset,
|
|
**precise_bn_dataloader_setting)
|
|
cfg.DATASET.train.num_samples_precise_bn = None
|
|
else:
|
|
precise_bn_loader = None
|
|
|
|
if cfg.MULTIGRID.SHORT_CYCLE:
|
|
# get batch size list in short cycle schedule
|
|
bs_factor = [
|
|
int(
|
|
round((float(cfg.PIPELINE.train.transform[1]['MultiCrop'][
|
|
'target_size']) / (s * cfg.MULTIGRID.default_crop_size))
|
|
**2)) for s in cfg.MULTIGRID.short_cycle_factors
|
|
]
|
|
batch_sizes = [
|
|
batch_size * bs_factor[0],
|
|
batch_size * bs_factor[1],
|
|
batch_size,
|
|
]
|
|
train_dataloader_setting = dict(
|
|
batch_size=batch_sizes,
|
|
multigrid=True,
|
|
num_workers=cfg.DATASET.get('num_workers', 0),
|
|
places=places,
|
|
)
|
|
else:
|
|
train_dataloader_setting = precise_bn_dataloader_setting
|
|
|
|
train_loader = build_dataloader(train_dataset, **train_dataloader_setting)
|
|
if validate:
|
|
valid_dataset = build_dataset((cfg.DATASET.valid, cfg.PIPELINE.valid))
|
|
validate_dataloader_setting = dict(
|
|
batch_size=batch_size,
|
|
num_workers=cfg.DATASET.get('num_workers', 0),
|
|
places=places,
|
|
drop_last=False,
|
|
shuffle=False)
|
|
valid_loader = build_dataloader(valid_dataset,
|
|
**validate_dataloader_setting)
|
|
else:
|
|
valid_loader = None
|
|
|
|
return train_loader, valid_loader, precise_bn_loader
|
|
|
|
|
|
def build_trainer(cfg, places, parallel, validate, precise_bn,
|
|
num_iters_precise_bn, world_size):
|
|
"""
|
|
Build training model and its associated tools, including optimizer,
|
|
dataloaders and meters.
|
|
Args:
|
|
cfg (CfgNode): configs.
|
|
Returns:
|
|
model: training model.
|
|
optimizer: optimizer.
|
|
train_loader: training data loader.
|
|
val_loader: validatoin data loader.
|
|
precise_bn_loader: training data loader for computing
|
|
precise BN.
|
|
"""
|
|
model = build_model(cfg.MODEL)
|
|
if parallel:
|
|
model = paddle.DataParallel(model)
|
|
|
|
train_loader, valid_loader, precise_bn_loader = \
|
|
construct_loader(cfg,
|
|
places,
|
|
validate,
|
|
precise_bn,
|
|
num_iters_precise_bn,
|
|
world_size,
|
|
)
|
|
|
|
lr = build_lr(cfg.OPTIMIZER.learning_rate, len(train_loader))
|
|
optimizer = build_optimizer(cfg.OPTIMIZER, lr, model=model)
|
|
|
|
return (
|
|
model,
|
|
lr,
|
|
optimizer,
|
|
train_loader,
|
|
valid_loader,
|
|
precise_bn_loader,
|
|
)
|
|
|
|
|
|
def train_model_multigrid(cfg, world_size=1, validate=True):
|
|
"""Train model entry
|
|
|
|
Args:
|
|
cfg (dict): configuration.
|
|
parallel (bool): Whether multi-card training. Default: True
|
|
validate (bool): Whether to do evaluation. Default: False.
|
|
|
|
"""
|
|
# Init multigrid.
|
|
multigrid = None
|
|
if cfg.MULTIGRID.LONG_CYCLE or cfg.MULTIGRID.SHORT_CYCLE:
|
|
multigrid = MultigridSchedule()
|
|
cfg = multigrid.init_multigrid(cfg)
|
|
if cfg.MULTIGRID.LONG_CYCLE:
|
|
cfg, _ = multigrid.update_long_cycle(cfg, cur_epoch=0)
|
|
multi_save_epoch = [i[-1] - 1 for i in multigrid.schedule]
|
|
|
|
parallel = world_size != 1
|
|
logger = get_logger("paddlevideo")
|
|
batch_size = cfg.DATASET.get('batch_size', 2)
|
|
|
|
if cfg.get('use_npu', False):
|
|
places = paddle.set_device('npu')
|
|
elif cfg.get('use_xpu', False):
|
|
places = paddle.set_device('xpu')
|
|
else:
|
|
places = paddle.set_device('gpu')
|
|
|
|
model_name = cfg.model_name
|
|
output_dir = cfg.get("output_dir", f"./output/{model_name}")
|
|
mkdir(output_dir)
|
|
local_rank = dist.ParallelEnv().local_rank
|
|
precise_bn = cfg.get("PRECISEBN")
|
|
num_iters_precise_bn = cfg.PRECISEBN.num_iters_preciseBN
|
|
|
|
# 1. Construct model
|
|
model = build_model(cfg.MODEL)
|
|
if parallel:
|
|
model = paddle.DataParallel(model)
|
|
|
|
# 2. Construct dataloader
|
|
train_loader, valid_loader, precise_bn_loader = \
|
|
construct_loader(cfg,
|
|
places,
|
|
validate,
|
|
precise_bn,
|
|
num_iters_precise_bn,
|
|
world_size,
|
|
)
|
|
|
|
# 3. Construct optimizer
|
|
lr = build_lr(cfg.OPTIMIZER.learning_rate, len(train_loader))
|
|
optimizer = build_optimizer(
|
|
cfg.OPTIMIZER, lr, parameter_list=model.parameters())
|
|
|
|
# Resume
|
|
resume_epoch = cfg.get("resume_epoch", 0)
|
|
if resume_epoch:
|
|
filename = osp.join(
|
|
output_dir,
|
|
model_name + str(local_rank) + '_' + f"{resume_epoch:05d}")
|
|
subn_load(model, filename, optimizer)
|
|
|
|
# 4. Train Model
|
|
best = 0.
|
|
total_epochs = int(cfg.epochs * cfg.MULTIGRID.epoch_factor)
|
|
for epoch in range(total_epochs):
|
|
if epoch < resume_epoch:
|
|
logger.info(
|
|
f"| epoch: [{epoch+1}] <= resume_epoch: [{ resume_epoch}], continue... "
|
|
)
|
|
continue
|
|
|
|
if cfg.MULTIGRID.LONG_CYCLE:
|
|
cfg, changed = multigrid.update_long_cycle(cfg, epoch)
|
|
if changed:
|
|
logger.info("====== Rebuild model/optimizer/loader =====")
|
|
(
|
|
model,
|
|
lr,
|
|
optimizer,
|
|
train_loader,
|
|
valid_loader,
|
|
precise_bn_loader,
|
|
) = build_trainer(cfg, places, parallel, validate, precise_bn,
|
|
num_iters_precise_bn, world_size)
|
|
|
|
#load checkpoint after re-build model
|
|
if epoch != 0:
|
|
#epoch no need to -1, haved add 1 when save
|
|
filename = osp.join(
|
|
output_dir,
|
|
model_name + str(local_rank) + '_' + f"{(epoch):05d}")
|
|
subn_load(model, filename, optimizer)
|
|
#update lr last epoch, not to use saved params
|
|
lr.last_epoch = epoch
|
|
lr.step(rebuild=True)
|
|
|
|
model.train()
|
|
record_list = build_record(cfg.MODEL)
|
|
tic = time.time()
|
|
for i, data in enumerate(train_loader):
|
|
record_list['reader_time'].update(time.time() - tic)
|
|
# 4.1 forward
|
|
outputs = model(data, mode='train')
|
|
# 4.2 backward
|
|
avg_loss = outputs['loss']
|
|
avg_loss.backward()
|
|
# 4.3 minimize
|
|
optimizer.step()
|
|
optimizer.clear_grad()
|
|
|
|
# log record
|
|
record_list['lr'].update(
|
|
float(optimizer._global_learning_rate()), batch_size)
|
|
for name, value in outputs.items():
|
|
record_list[name].update(float(value), batch_size)
|
|
record_list['batch_time'].update(time.time() - tic)
|
|
tic = time.time()
|
|
|
|
if i % cfg.get("log_interval", 10) == 0:
|
|
ips = "ips: {:.5f} instance/sec.".format(
|
|
batch_size / record_list["batch_time"].val)
|
|
log_batch(record_list, i, epoch + 1, total_epochs, "train", ips)
|
|
|
|
# learning rate iter step
|
|
if cfg.OPTIMIZER.learning_rate.get("iter_step"):
|
|
lr.step()
|
|
|
|
# learning rate epoch step
|
|
if not cfg.OPTIMIZER.learning_rate.get("iter_step"):
|
|
lr.step()
|
|
|
|
ips = "ips: {:.5f} instance/sec.".format(
|
|
batch_size * record_list["batch_time"].count /
|
|
record_list["batch_time"].sum)
|
|
log_epoch(record_list, epoch + 1, "train", ips)
|
|
|
|
def evaluate(best):
|
|
model.eval()
|
|
record_list = build_record(cfg.MODEL)
|
|
record_list.pop('lr')
|
|
tic = time.time()
|
|
for i, data in enumerate(valid_loader):
|
|
outputs = model(data, mode='valid')
|
|
|
|
# log_record
|
|
for name, value in outputs.items():
|
|
record_list[name].update(float(value), batch_size)
|
|
|
|
record_list['batch_time'].update(time.time() - tic)
|
|
tic = time.time()
|
|
|
|
if i % cfg.get("log_interval", 10) == 0:
|
|
ips = "ips: {:.5f} instance/sec.".format(
|
|
batch_size / record_list["batch_time"].val)
|
|
log_batch(record_list, i, epoch + 1, total_epochs, "val",
|
|
ips)
|
|
|
|
ips = "ips: {:.5f} instance/sec.".format(
|
|
batch_size * record_list["batch_time"].count /
|
|
record_list["batch_time"].sum)
|
|
log_epoch(record_list, epoch + 1, "val", ips)
|
|
|
|
best_flag = False
|
|
if record_list.get('top1') and record_list['top1'].avg > best:
|
|
best = record_list['top1'].avg
|
|
best_flag = True
|
|
return best, best_flag
|
|
|
|
# use precise bn to improve acc
|
|
if is_eval_epoch(cfg, epoch, total_epochs, multigrid.schedule):
|
|
logger.info(f"do precise BN in {epoch+1} ...")
|
|
do_preciseBN(model, precise_bn_loader, parallel,
|
|
min(num_iters_precise_bn, len(precise_bn_loader)))
|
|
|
|
# aggregate sub_BN stats
|
|
logger.info("Aggregate sub_BatchNorm stats...")
|
|
aggregate_sub_bn_stats(model)
|
|
|
|
# 5. Validation
|
|
if is_eval_epoch(cfg, epoch, total_epochs, multigrid.schedule):
|
|
logger.info(f"eval in {epoch+1} ...")
|
|
with paddle.no_grad():
|
|
best, save_best_flag = evaluate(best)
|
|
# save best
|
|
if save_best_flag:
|
|
save(optimizer.state_dict(),
|
|
osp.join(output_dir, model_name + "_best.pdopt"))
|
|
save(model.state_dict(),
|
|
osp.join(output_dir, model_name + "_best.pdparams"))
|
|
logger.info(
|
|
f"Already save the best model (top1 acc){int(best * 10000) / 10000}"
|
|
)
|
|
|
|
# 6. Save model and optimizer
|
|
if is_eval_epoch(
|
|
cfg, epoch,
|
|
total_epochs, multigrid.schedule) or epoch % cfg.get(
|
|
"save_interval", 10) == 0 or epoch in multi_save_epoch:
|
|
logger.info("[Save parameters] ======")
|
|
subn_save(output_dir, model_name + str(local_rank) + '_', epoch + 1,
|
|
model, optimizer)
|
|
|
|
logger.info(f'training {model_name} finished')
|