You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
234 lines
9.6 KiB
Python
234 lines
9.6 KiB
Python
"""Functions for multigrid training."""
|
|
|
|
import numpy as np
|
|
|
|
|
|
class MultigridSchedule(object):
|
|
"""
|
|
This class defines multigrid training schedule and update cfg accordingly.
|
|
"""
|
|
def init_multigrid(self, cfg):
|
|
"""
|
|
Update cfg based on multigrid settings.
|
|
Args:
|
|
cfg (configs): configs that contains training and multigrid specific
|
|
hyperparameters.
|
|
Returns:
|
|
cfg (configs): the updated cfg.
|
|
"""
|
|
self.schedule = None
|
|
# We may modify cfg.DATASET.batch_size, cfg.PIPELINE.train.decode_sampler.num_frames, and
|
|
# cfg.PIPELINE.train.transform[1]['MultiCrop']['target_size'] during training, so we store their original
|
|
# value in cfg and use them as global variables.
|
|
cfg.MULTIGRID.default_batch_size = cfg.DATASET.batch_size # total bs,64
|
|
cfg.MULTIGRID.default_temporal_size = cfg.PIPELINE.train.decode_sampler.num_frames # 32
|
|
cfg.MULTIGRID.default_crop_size = cfg.PIPELINE.train.transform[1][
|
|
'MultiCrop']['target_size'] # 224
|
|
|
|
if cfg.MULTIGRID.LONG_CYCLE:
|
|
self.schedule = self.get_long_cycle_schedule(cfg)
|
|
cfg.OPTIMIZER.learning_rate.steps = [0] + [
|
|
s[-1] for s in self.schedule
|
|
]
|
|
# Fine-tuning phase.
|
|
cfg.OPTIMIZER.learning_rate.steps[-1] = (
|
|
cfg.OPTIMIZER.learning_rate.steps[-2] +
|
|
cfg.OPTIMIZER.learning_rate.steps[-1]) // 2
|
|
cfg.OPTIMIZER.learning_rate.lrs = [
|
|
cfg.OPTIMIZER.learning_rate.gamma**s[0] * s[1][0]
|
|
for s in self.schedule
|
|
]
|
|
# Fine-tuning phase.
|
|
cfg.OPTIMIZER.learning_rate.lrs = cfg.OPTIMIZER.learning_rate.lrs[:-1] + [
|
|
cfg.OPTIMIZER.learning_rate.lrs[-2],
|
|
cfg.OPTIMIZER.learning_rate.lrs[-1],
|
|
]
|
|
|
|
cfg.OPTIMIZER.learning_rate.max_epoch = self.schedule[-1][-1]
|
|
|
|
elif cfg.MULTIGRID.SHORT_CYCLE:
|
|
cfg.OPTIMIZER.learning_rate.steps = [
|
|
int(s * cfg.MULTIGRID.epoch_factor)
|
|
for s in cfg.OPTIMIZER.learning_rate.steps
|
|
]
|
|
cfg.OPTIMIZER.learning_rate.max_epoch = int(
|
|
cfg.OPTIMIZER.learning_rate.max_epoch *
|
|
cfg.OPTIMIZER.learning_rate.max_epoch)
|
|
return cfg
|
|
|
|
def update_long_cycle(self, cfg, cur_epoch):
|
|
"""
|
|
Before every epoch, check if long cycle shape should change. If it
|
|
should, update cfg accordingly.
|
|
Args:
|
|
cfg (configs): configs that contains training and multigrid specific
|
|
hyperparameters.
|
|
cur_epoch (int): current epoch index.
|
|
Returns:
|
|
cfg (configs): the updated cfg.
|
|
changed (bool): whether to change long cycle shape at this epoch
|
|
"""
|
|
base_b, base_t, base_s = get_current_long_cycle_shape(
|
|
self.schedule, cur_epoch)
|
|
if base_s != cfg.PIPELINE.train.transform[1]['MultiCrop'][
|
|
'target_size'] or base_t != cfg.PIPELINE.train.decode_sampler.num_frames:
|
|
#NOTE Modify
|
|
# no need to modify, used by pool_size in head, None when multigrid
|
|
# cfg.MODEL.head.num_frames = base_t
|
|
# cfg.MODEL.head.crop_size = base_s
|
|
cfg.PIPELINE.train.decode_sampler.num_frames = base_t
|
|
cfg.PIPELINE.train.transform[1]['MultiCrop']['target_size'] = base_s
|
|
cfg.DATASET.batch_size = base_b * cfg.MULTIGRID.default_batch_size #change bs
|
|
|
|
bs_factor = (float(cfg.DATASET.batch_size) /
|
|
cfg.MULTIGRID.bn_base_size)
|
|
|
|
if bs_factor == 1: #single bs == bn_base_size (== 8)
|
|
cfg.MODEL.backbone.bn_norm_type = "batchnorm"
|
|
else:
|
|
cfg.MODEL.backbone.bn_norm_type = "sub_batchnorm"
|
|
cfg.MODEL.backbone.bn_num_splits = int(bs_factor)
|
|
|
|
cfg.MULTIGRID.long_cycle_sampling_rate = cfg.PIPELINE.train.decode_sampler.sampling_rate * (
|
|
cfg.MULTIGRID.default_temporal_size // base_t)
|
|
print("Long cycle updates:")
|
|
print("\tbn_norm_type: {}".format(cfg.MODEL.backbone.bn_norm_type))
|
|
if cfg.MODEL.backbone.bn_norm_type == "sub_batchnorm":
|
|
print("\tbn_num_splits: {}".format(
|
|
cfg.MODEL.backbone.bn_num_splits))
|
|
print("\tTRAIN.batch_size[single card]: {}".format(
|
|
cfg.DATASET.batch_size))
|
|
print("\tDATA.NUM_FRAMES x LONG_CYCLE_SAMPLING_RATE: {}x{}".format(
|
|
cfg.PIPELINE.train.decode_sampler.num_frames,
|
|
cfg.MULTIGRID.long_cycle_sampling_rate))
|
|
print("\tDATA.train_crop_size: {}".format(
|
|
cfg.PIPELINE.train.transform[1]['MultiCrop']['target_size']))
|
|
return cfg, True
|
|
else:
|
|
return cfg, False
|
|
|
|
def get_long_cycle_schedule(self, cfg):
|
|
"""
|
|
Based on multigrid hyperparameters, define the schedule of a long cycle.
|
|
Args:
|
|
cfg (configs): configs that contains training and multigrid specific
|
|
hyperparameters.
|
|
Returns:
|
|
schedule (list): Specifies a list long cycle base shapes and their
|
|
corresponding training epochs.
|
|
"""
|
|
|
|
steps = cfg.OPTIMIZER.learning_rate.steps
|
|
|
|
default_size = float(
|
|
cfg.PIPELINE.train.decode_sampler.num_frames *
|
|
cfg.PIPELINE.train.transform[1]['MultiCrop']['target_size']**
|
|
2) # 32 * 224 * 224 C*H*W
|
|
default_iters = steps[-1] # 196
|
|
|
|
# Get shapes and average batch size for each long cycle shape.
|
|
avg_bs = []
|
|
all_shapes = []
|
|
# for t_factor, s_factor in cfg.MULTIGRID.long_cycle_factors:
|
|
for item in cfg.MULTIGRID.long_cycle_factors:
|
|
t_factor, s_factor = item["value"]
|
|
base_t = int(
|
|
round(cfg.PIPELINE.train.decode_sampler.num_frames * t_factor))
|
|
base_s = int(
|
|
round(
|
|
cfg.PIPELINE.train.transform[1]['MultiCrop']['target_size']
|
|
* s_factor))
|
|
if cfg.MULTIGRID.SHORT_CYCLE:
|
|
shapes = [
|
|
[
|
|
base_t,
|
|
cfg.MULTIGRID.default_crop_size *
|
|
cfg.MULTIGRID.short_cycle_factors[0],
|
|
],
|
|
[
|
|
base_t,
|
|
cfg.MULTIGRID.default_crop_size *
|
|
cfg.MULTIGRID.short_cycle_factors[1],
|
|
],
|
|
[base_t, base_s],
|
|
] #first two is short_cycle, last is the base long_cycle
|
|
else:
|
|
shapes = [[base_t, base_s]]
|
|
|
|
# (T, S) -> (B, T, S)
|
|
shapes = [[
|
|
int(round(default_size / (s[0] * s[1] * s[1]))), s[0], s[1]
|
|
] for s in shapes]
|
|
avg_bs.append(np.mean([s[0] for s in shapes]))
|
|
all_shapes.append(shapes)
|
|
|
|
# Get schedule regardless of cfg.MULTIGRID.epoch_factor.
|
|
total_iters = 0
|
|
schedule = []
|
|
for step_index in range(len(steps) - 1):
|
|
step_epochs = steps[step_index + 1] - steps[step_index]
|
|
|
|
for long_cycle_index, shapes in enumerate(all_shapes):
|
|
#ensure each of 4 sequences run the same num of iters
|
|
cur_epochs = (step_epochs * avg_bs[long_cycle_index] /
|
|
sum(avg_bs))
|
|
|
|
# get cur_iters from cur_epochs
|
|
cur_iters = cur_epochs / avg_bs[long_cycle_index]
|
|
total_iters += cur_iters
|
|
schedule.append((step_index, shapes[-1], cur_epochs))
|
|
|
|
iter_saving = default_iters / total_iters # ratio between default iters and real iters
|
|
|
|
final_step_epochs = cfg.OPTIMIZER.learning_rate.max_epoch - steps[-1]
|
|
|
|
# We define the fine-tuning phase to have the same amount of iteration
|
|
# saving as the rest of the training.
|
|
#final_step_epochs / iter_saving make fine-tune having the same iters as training
|
|
ft_epochs = final_step_epochs / iter_saving * avg_bs[-1]
|
|
|
|
# schedule.append((step_index + 1, all_shapes[-1][2], ft_epochs))
|
|
schedule.append((step_index + 1, all_shapes[-1][-1], ft_epochs))
|
|
|
|
# Obtrain final schedule given desired cfg.MULTIGRID.epoch_factor.
|
|
x = (cfg.OPTIMIZER.learning_rate.max_epoch *
|
|
cfg.MULTIGRID.epoch_factor / sum(s[-1] for s in schedule))
|
|
|
|
final_schedule = []
|
|
total_epochs = 0
|
|
for s in schedule:
|
|
epochs = s[2] * x
|
|
total_epochs += epochs
|
|
final_schedule.append((s[0], s[1], int(round(total_epochs))))
|
|
print_schedule(final_schedule)
|
|
return final_schedule
|
|
|
|
|
|
def print_schedule(schedule):
|
|
"""
|
|
Log schedule.
|
|
"""
|
|
print(
|
|
"Long_cycle_index\tBase_shape(bs_factor,temporal_size,crop_size)\tEpochs"
|
|
)
|
|
for s in schedule:
|
|
print("{}\t\t\t{}\t\t\t\t\t{}".format(s[0], s[1], s[2]))
|
|
|
|
|
|
def get_current_long_cycle_shape(schedule, epoch):
|
|
"""
|
|
Given a schedule and epoch index, return the long cycle base shape.
|
|
Args:
|
|
schedule (configs): configs that contains training and multigrid specific
|
|
hyperparameters.
|
|
cur_epoch (int): current epoch index.
|
|
Returns:
|
|
shapes (list): A list describing the base shape in a long cycle:
|
|
[batch size relative to default,
|
|
number of frames, spatial dimension].
|
|
"""
|
|
for s in schedule:
|
|
if epoch < s[-1]:
|
|
return s[1]
|
|
return schedule[-1][1]
|