# -*- coding: utf-8 -*- # @Time : 2019/8/23 21:50 # @Author : zhoujun import os import pathlib import shutil from pprint import pformat import anyconfig import paddle import numpy as np import random from paddle.jit import to_static from paddle.static import InputSpec from utils import setup_logger class BaseTrainer: def __init__( self, config, model, criterion, train_loader, validate_loader, metric_cls, post_process=None, ): config["trainer"]["output_dir"] = os.path.join( str(pathlib.Path(os.path.abspath(__name__)).parent), config["trainer"]["output_dir"], ) config["name"] = config["name"] + "_" + model.name self.save_dir = config["trainer"]["output_dir"] self.checkpoint_dir = os.path.join(self.save_dir, "checkpoint") os.makedirs(self.checkpoint_dir, exist_ok=True) self.global_step = 0 self.start_epoch = 0 self.config = config self.criterion = criterion # logger and tensorboard self.visualdl_enable = self.config["trainer"].get("visual_dl", False) self.epochs = self.config["trainer"]["epochs"] self.log_iter = self.config["trainer"]["log_iter"] if paddle.distributed.get_rank() == 0: anyconfig.dump(config, os.path.join(self.save_dir, "config.yaml")) self.logger = setup_logger(os.path.join(self.save_dir, "train.log")) self.logger_info(pformat(self.config)) self.model = self.apply_to_static(model) # device if ( paddle.device.cuda.device_count() > 0 and paddle.device.is_compiled_with_cuda() ): self.with_cuda = True random.seed(self.config["trainer"]["seed"]) np.random.seed(self.config["trainer"]["seed"]) paddle.seed(self.config["trainer"]["seed"]) else: self.with_cuda = False self.logger_info("train with and paddle {}".format(paddle.__version__)) # metrics self.metrics = { "recall": 0, "precision": 0, "hmean": 0, "train_loss": float("inf"), "best_model_epoch": 0, } self.train_loader = train_loader if validate_loader is not None: assert post_process is not None and metric_cls is not None self.validate_loader = validate_loader self.post_process = post_process self.metric_cls = metric_cls self.train_loader_len = len(train_loader) if self.validate_loader is not None: self.logger_info( "train dataset has {} samples,{} in dataloader, validate dataset has {} samples,{} in dataloader".format( len(self.train_loader.dataset), self.train_loader_len, len(self.validate_loader.dataset), len(self.validate_loader), ) ) else: self.logger_info( "train dataset has {} samples,{} in dataloader".format( len(self.train_loader.dataset), self.train_loader_len ) ) self._initialize_scheduler() self._initialize_optimizer() # resume or finetune if self.config["trainer"]["resume_checkpoint"] != "": self._load_checkpoint( self.config["trainer"]["resume_checkpoint"], resume=True ) elif self.config["trainer"]["finetune_checkpoint"] != "": self._load_checkpoint( self.config["trainer"]["finetune_checkpoint"], resume=False ) if self.visualdl_enable and paddle.distributed.get_rank() == 0: from visualdl import LogWriter self.writer = LogWriter(self.save_dir) # 混合精度训练 self.amp = self.config.get("amp", None) if self.amp == "None": self.amp = None if self.amp: self.amp["scaler"] = paddle.amp.GradScaler( init_loss_scaling=self.amp.get("scale_loss", 1024), use_dynamic_loss_scaling=self.amp.get("use_dynamic_loss_scaling", True), ) self.model, self.optimizer = paddle.amp.decorate( models=self.model, optimizers=self.optimizer, level=self.amp.get("amp_level", "O2"), ) # 分布式训练 if paddle.device.cuda.device_count() > 1: self.model = paddle.DataParallel(self.model) # make inverse Normalize self.UN_Normalize = False for t in self.config["dataset"]["train"]["dataset"]["args"]["transforms"]: if t["type"] == "Normalize": self.normalize_mean = t["args"]["mean"] self.normalize_std = t["args"]["std"] self.UN_Normalize = True def apply_to_static(self, model): support_to_static = self.config["trainer"].get("to_static", False) if support_to_static: specs = None print("static") specs = [InputSpec([None, 3, -1, -1])] model = to_static(model, input_spec=specs) self.logger_info( "Successfully to apply @to_static with specs: {}".format(specs) ) return model def train(self): """ Full training logic """ for epoch in range(self.start_epoch + 1, self.epochs + 1): self.epoch_result = self._train_epoch(epoch) self._on_epoch_finish() if paddle.distributed.get_rank() == 0 and self.visualdl_enable: self.writer.close() self._on_train_finish() def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Current epoch number """ raise NotImplementedError def _eval(self, epoch): """ eval logic for an epoch :param epoch: Current epoch number """ raise NotImplementedError def _on_epoch_finish(self): raise NotImplementedError def _on_train_finish(self): raise NotImplementedError def _save_checkpoint(self, epoch, file_name): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth.tar' """ state_dict = self.model.state_dict() state = { "epoch": epoch, "global_step": self.global_step, "state_dict": state_dict, "optimizer": self.optimizer.state_dict(), "config": self.config, "metrics": self.metrics, } filename = os.path.join(self.checkpoint_dir, file_name) paddle.save(state, filename) def _load_checkpoint(self, checkpoint_path, resume): """ Resume from saved checkpoints :param checkpoint_path: Checkpoint path to be resumed """ self.logger_info("Loading checkpoint: {} ...".format(checkpoint_path)) checkpoint = paddle.load(checkpoint_path) self.model.set_state_dict(checkpoint["state_dict"]) if resume: self.global_step = checkpoint["global_step"] self.start_epoch = checkpoint["epoch"] self.config["lr_scheduler"]["args"]["last_epoch"] = self.start_epoch # self.scheduler.load_state_dict(checkpoint['scheduler']) self.optimizer.set_state_dict(checkpoint["optimizer"]) if "metrics" in checkpoint: self.metrics = checkpoint["metrics"] self.logger_info( "resume from checkpoint {} (epoch {})".format( checkpoint_path, self.start_epoch ) ) else: self.logger_info("finetune from checkpoint {}".format(checkpoint_path)) def _initialize(self, name, module, *args, **kwargs): module_name = self.config[name]["type"] module_args = self.config[name].get("args", {}) assert all( [k not in module_args for k in kwargs] ), "Overwriting kwargs given in config file is not allowed" module_args.update(kwargs) return getattr(module, module_name)(*args, **module_args) def _initialize_scheduler(self): self.lr_scheduler = self._initialize("lr_scheduler", paddle.optimizer.lr) def _initialize_optimizer(self): self.optimizer = self._initialize( "optimizer", paddle.optimizer, parameters=self.model.parameters(), learning_rate=self.lr_scheduler, ) def inverse_normalize(self, batch_img): if self.UN_Normalize: batch_img[:, 0, :, :] = ( batch_img[:, 0, :, :] * self.normalize_std[0] + self.normalize_mean[0] ) batch_img[:, 1, :, :] = ( batch_img[:, 1, :, :] * self.normalize_std[1] + self.normalize_mean[1] ) batch_img[:, 2, :, :] = ( batch_img[:, 2, :, :] * self.normalize_std[2] + self.normalize_mean[2] ) def logger_info(self, s): if paddle.distributed.get_rank() == 0: self.logger.info(s)