# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys __dir__ = os.path.dirname(__file__) sys.path.append(__dir__) sys.path.append(os.path.join(__dir__, "..", "..", "..")) sys.path.append(os.path.join(__dir__, "..", "..", "..", "tools")) import paddle import paddle.distributed as dist from ppocr.data import build_dataloader, set_signal_handlers from ppocr.modeling.architectures import build_model from ppocr.losses import build_loss from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import load_model import tools.program as program dist.get_world_size() def get_pruned_params(parameters): params = [] for param in parameters: if ( len(param.shape) == 4 and "depthwise" not in param.name and "transpose" not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name ): params.append(param.name) return params def main(config, device, logger, vdl_writer): # init dist environment if config["Global"]["distributed"]: dist.init_parallel_env() global_config = config["Global"] # build dataloader set_signal_handlers() train_dataloader = build_dataloader(config, "Train", device, logger) if config["Eval"]: valid_dataloader = build_dataloader(config, "Eval", device, logger) else: valid_dataloader = None # build post process post_process_class = build_post_process(config["PostProcess"], global_config) # build model # for rec algorithm if hasattr(post_process_class, "character"): char_num = len(getattr(post_process_class, "character")) config["Architecture"]["Head"]["out_channels"] = char_num model = build_model(config["Architecture"]) if config["Architecture"]["model_type"] == "det": input_shape = [1, 3, 640, 640] elif config["Architecture"]["model_type"] == "rec": input_shape = [1, 3, 32, 320] flops = paddle.flops(model, input_shape) logger.info("FLOPs before pruning: {}".format(flops)) from paddleslim.dygraph import FPGMFilterPruner model.train() pruner = FPGMFilterPruner(model, input_shape) # build loss loss_class = build_loss(config["Loss"]) # build optim optimizer, lr_scheduler = build_optimizer( config["Optimizer"], epochs=config["Global"]["epoch_num"], step_each_epoch=len(train_dataloader), model=model, ) # build metric eval_class = build_metric(config["Metric"]) # load pretrain model pre_best_model_dict = load_model(config, model, optimizer) logger.info( "train dataloader has {} iters, valid dataloader has {} iters".format( len(train_dataloader), len(valid_dataloader) ) ) # build metric eval_class = build_metric(config["Metric"]) logger.info( "train dataloader has {} iters, valid dataloader has {} iters".format( len(train_dataloader), len(valid_dataloader) ) ) def eval_fn(): metric = program.eval( model, valid_dataloader, post_process_class, eval_class, False ) if config["Architecture"]["model_type"] == "det": main_indicator = "hmean" else: main_indicator = "acc" logger.info("metric[{}]: {}".format(main_indicator, metric[main_indicator])) return metric[main_indicator] run_sensitive_analysis = False """ run_sensitive_analysis=True: Automatically compute the sensitivities of convolutions in a model. The sensitivity of a convolution is the losses of accuracy on test dataset in different pruned ratios. The sensitivities can be used to get a group of best ratios with some condition. run_sensitive_analysis=False: Set prune trim ratio to a fixed value, such as 10%. The larger the value, the more convolution weights will be cropped. """ if run_sensitive_analysis: params_sensitive = pruner.sensitive( eval_func=eval_fn, sen_file="./deploy/slim/prune/sen.pickle", skip_vars=[ "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0", ], ) logger.info( "The sensitivity analysis results of model parameters saved in sen.pickle" ) # calculate pruned params's ratio params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02) for key in params_sensitive.keys(): logger.info("{}, {}".format(key, params_sensitive[key])) else: params_sensitive = {} for param in model.parameters(): if "transpose" not in param.name and "linear" not in param.name: # set prune ratio as 10%. The larger the value, the more convolution weights will be cropped params_sensitive[param.name] = 0.1 plan = pruner.prune_vars(params_sensitive, [0]) flops = paddle.flops(model, input_shape) logger.info("FLOPs after pruning: {}".format(flops)) # start train program.train( config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer, ) if __name__ == "__main__": config, device, logger, vdl_writer = program.preprocess(is_train=True) main(config, device, logger, vdl_writer)