# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License" # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import os.path as osp import time import paddle import paddle.nn.functional as F from paddlevideo.utils import get_logger, main_only from tqdm import tqdm import numpy as np from scipy import ndimage def pretrain_swin_param_trans(model, state_dicts): # delete classifier's params if 'head.fc' + '.weight' in state_dicts: del state_dicts['head.fc' + '.weight'] if 'head.fc' + '.bias' in state_dicts: del state_dicts['head.fc' + '.bias'] state_dicts = { k.replace('backbone.', ''): v for k, v in state_dicts.items() } if len(state_dicts) == len(model.state_dict()): print("Load 3D weights") return state_dicts print("Load 2D weights") relative_position_index_keys = [ k for k in state_dicts.keys() if "relative_position_index" in k ] for k in relative_position_index_keys: del state_dicts[k] # delete attn_mask since we always re-init it attn_mask_keys = [k for k in state_dicts.keys() if "attn_mask" in k] for k in attn_mask_keys: del state_dicts[k] state_dicts['patch_embed.proj.weight'] = state_dicts[ 'patch_embed.proj.weight'].unsqueeze(2).tile( [1, 1, model.patch_size[0], 1, 1]) / model.patch_size[0] # bicubic interpolate relative_position_bias_table if not match relative_position_bias_table_keys = [ k for k in state_dicts.keys() if "relative_position_bias_table" in k ] total_len = len(relative_position_bias_table_keys) with tqdm(total=total_len, position=1, bar_format='{desc}', desc="Loading weights") as desc: for key in tqdm(relative_position_bias_table_keys, total=total_len, position=0): relative_position_bias_table_pretrained = state_dicts[key] relative_position_bias_table_current = model.state_dict()[key] L1, nH1 = relative_position_bias_table_pretrained.shape L2, nH2 = relative_position_bias_table_current.shape L2 = (2 * model.window_size[1] - 1) * (2 * model.window_size[2] - 1) wd = model.window_size[0] if nH1 != nH2: desc.set_description(f"Error in loading {key}, skip") else: if L1 != L2: S1 = int(L1**0.5) relative_position_bias_table_pretrained_resized = paddle.nn.functional.interpolate( relative_position_bias_table_pretrained.transpose( [1, 0]).reshape([1, nH1, S1, S1]), size=(2 * model.window_size[1] - 1, 2 * model.window_size[2] - 1), mode='bicubic') relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.reshape( [nH2, L2]).transpose([1, 0]) desc.set_description(f"Loading {key}") state_dicts[key] = relative_position_bias_table_pretrained.tile( [2 * wd - 1, 1]) time.sleep(0.01) ret_str = "loading {:<20d} weights completed.".format( len(model.state_dict())) desc.set_description(ret_str) return state_dicts def pretrain_vit_param_trans(model, state_dicts, num_patches, num_seg, attention_type): """ Convert ViT's pre-trained model parameters to a parameter dictionary that matches the existing model """ if 'head' + '.weight' in state_dicts: del state_dicts['head' + '.weight'] if 'head' + '.bias' in state_dicts: del state_dicts['head' + '.bias'] total_len = len(model.state_dict()) if num_patches + 1 != state_dicts['pos_embed'].shape[1]: # when pos_embed = state_dicts['pos_embed'] cls_pos_embed = paddle.to_tensor( pos_embed[0, 0, :]).unsqueeze(0).unsqueeze(1) other_pos_embed = paddle.to_tensor(pos_embed[0, 1:, :]) gs_new = int(np.sqrt(num_patches)) gs_old = int(np.sqrt(other_pos_embed.shape[0])) zoom = (gs_new / gs_old, gs_new / gs_old, 1) other_pos_embed = paddle.reshape(other_pos_embed, [gs_old, gs_old, -1]) other_pos_embed = ndimage.zoom(other_pos_embed, zoom, order=1) other_pos_embed = paddle.to_tensor(other_pos_embed) new_pos_embed = paddle.reshape(other_pos_embed, [1, num_patches, -1]) new_pos_embed = paddle.concat((cls_pos_embed, new_pos_embed), axis=1) state_dicts['pos_embed'] = new_pos_embed time.sleep(0.01) if 'time_embed' in state_dicts and num_seg != state_dicts[ 'time_embed'].shape[1]: time_embed = state_dicts['time_embed'].transpose((0, 2, 1)).unsqueeze(0) new_time_embed = F.interpolate(time_embed, size=(time_embed.shape[-2], num_seg), mode='nearest') state_dicts['time_embed'] = new_time_embed.squeeze(0).transpose( (0, 2, 1)) time.sleep(0.01) with tqdm(total=total_len, position=1, bar_format='{desc}', desc="Loading weights") as desc: if attention_type == 'divided_space_time': new_state_dicts = state_dicts.copy() for key in tqdm(state_dicts): if 'blocks' in key and 'attn' in key: desc.set_description("Loading %s" % key) new_key = key.replace('attn', 'temporal_attn') if not new_key in state_dicts: new_state_dicts[new_key] = state_dicts[key] else: new_state_dicts[new_key] = state_dicts[new_key] if 'blocks' in key and 'norm1' in key: desc.set_description("Loading %s" % key) new_key = key.replace('norm1', 'temporal_norm1') if not new_key in state_dicts: new_state_dicts[new_key] = state_dicts[key] else: new_state_dicts[new_key] = state_dicts[new_key] time.sleep(0.01) elif attention_type == 'space_only': # tokenshift raw vit new_state_dicts = state_dicts.copy() ret_str = "loading {:<20d} weights completed.".format( len(model.state_dict())) desc.set_description(ret_str) return new_state_dicts def pretrain_resnet18_param_trans(model, loaded_dict): encoder_dict = model.encoder.state_dict() pose_encoder_dict = model.pose_encoder.state_dict() names = ['encoder.', 'encoder_day.', 'encoder_night.'] for name in names: total_len = len(loaded_dict.items()) with tqdm(total=total_len, position=1, bar_format='{desc}', desc="Loading weights") as desc: for key, value in tqdm(loaded_dict.items(), total=total_len, position=0): key = str(name + key) if key in encoder_dict: encoder_dict[key] = value desc.set_description('Loading %s' % key) time.sleep(0.01) num_input_images = 2 loaded_dict['conv1.weight'] = paddle.concat( [loaded_dict['conv1.weight']] * num_input_images, 1) / num_input_images total_len = len(loaded_dict.items()) with tqdm(total=total_len, position=1, bar_format='{desc}', desc="Loading weights") as desc: for name, value in tqdm(loaded_dict.items(), total=total_len, position=0): name = str('encoder.' + name) if name in pose_encoder_dict: pose_encoder_dict[name] = value desc.set_description('Loading %s' % key) time.sleep(0.01) ret_str = "loading {:<20d} weights completed.".format( len(model.state_dict())) desc.set_description(ret_str) return encoder_dict, pose_encoder_dict #XXX(shipping): maybe need load N times because of different cards have different params. @main_only def load_ckpt(model, weight_path, **kargs): """ 1. Load pre-trained model parameters 2. Extract and convert from the pre-trained model to the parameters required by the existing model 3. Load the converted parameters of the existing model """ #model.set_state_dict(state_dict) if not osp.isfile(weight_path): raise IOError(f'{weight_path} is not a checkpoint file') #state_dicts = load(weight_path) logger = get_logger("paddlevideo") state_dicts = paddle.load(weight_path) if 'ResnetEncoder' in str(model): encoder_dict, pose_encoder_dict = pretrain_resnet18_param_trans( model, state_dicts) model.encoder.load_dict(encoder_dict) model.pose_encoder.load_dict(pose_encoder_dict) tmp = model.state_dict() elif "VisionTransformer" in str(model): # For TimeSformer case tmp = pretrain_vit_param_trans(model, state_dicts, kargs['num_patches'], kargs['num_seg'], kargs['attention_type']) elif 'SwinTransformer3D' in str(model): tmp = pretrain_swin_param_trans(model, state_dicts) else: tmp = {} total_len = len(model.state_dict()) with tqdm(total=total_len, position=1, bar_format='{desc}', desc="Loading weights") as desc: for item in tqdm(model.state_dict(), total=total_len, position=0): name = item desc.set_description('Loading %s' % name) if name not in state_dicts: # Convert from non-parallel model if str('backbone.' + name) in state_dicts: tmp[name] = state_dicts['backbone.' + name] else: # Convert from parallel model tmp[name] = state_dicts[name] time.sleep(0.01) ret_str = "loading {:<20d} weights completed.".format( len(model.state_dict())) desc.set_description(ret_str) model.set_state_dict(tmp) def mkdir(dir): if not os.path.exists(dir): # avoid error when train with multiple gpus try: os.makedirs(dir) except: pass def _extract_student_weights(all_params, student_prefix="Student."): s_params = { key[len(student_prefix):]: all_params[key] for key in all_params if student_prefix in key } return s_params @main_only def save(obj, path, save_student_model=False): if save_student_model: s_params = _extract_student_weights(obj) student_path = path.replace(".pdparams", "_student.pdparams") if len(s_params) > 0: paddle.save(s_params, student_path) paddle.save(obj, path) def load(file_name): if not osp.isfile(file_name): raise IOError(f'{file_name} not exist') return paddle.load(file_name)