You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
290 lines
12 KiB
Python
290 lines
12 KiB
Python
2 years ago
|
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
import os
|
||
|
import os.path as osp
|
||
|
import time
|
||
|
|
||
|
import paddle
|
||
|
import paddle.nn.functional as F
|
||
|
from paddlevideo.utils import get_logger, main_only
|
||
|
from tqdm import tqdm
|
||
|
import numpy as np
|
||
|
from scipy import ndimage
|
||
|
|
||
|
|
||
|
def pretrain_swin_param_trans(model, state_dicts):
|
||
|
# delete classifier's params
|
||
|
if 'head.fc' + '.weight' in state_dicts:
|
||
|
del state_dicts['head.fc' + '.weight']
|
||
|
if 'head.fc' + '.bias' in state_dicts:
|
||
|
del state_dicts['head.fc' + '.bias']
|
||
|
|
||
|
state_dicts = {
|
||
|
k.replace('backbone.', ''): v
|
||
|
for k, v in state_dicts.items()
|
||
|
}
|
||
|
|
||
|
if len(state_dicts) == len(model.state_dict()):
|
||
|
print("Load 3D weights")
|
||
|
return state_dicts
|
||
|
|
||
|
print("Load 2D weights")
|
||
|
relative_position_index_keys = [
|
||
|
k for k in state_dicts.keys() if "relative_position_index" in k
|
||
|
]
|
||
|
for k in relative_position_index_keys:
|
||
|
del state_dicts[k]
|
||
|
|
||
|
# delete attn_mask since we always re-init it
|
||
|
attn_mask_keys = [k for k in state_dicts.keys() if "attn_mask" in k]
|
||
|
for k in attn_mask_keys:
|
||
|
del state_dicts[k]
|
||
|
|
||
|
state_dicts['patch_embed.proj.weight'] = state_dicts[
|
||
|
'patch_embed.proj.weight'].unsqueeze(2).tile(
|
||
|
[1, 1, model.patch_size[0], 1, 1]) / model.patch_size[0]
|
||
|
|
||
|
# bicubic interpolate relative_position_bias_table if not match
|
||
|
relative_position_bias_table_keys = [
|
||
|
k for k in state_dicts.keys() if "relative_position_bias_table" in k
|
||
|
]
|
||
|
total_len = len(relative_position_bias_table_keys)
|
||
|
with tqdm(total=total_len,
|
||
|
position=1,
|
||
|
bar_format='{desc}',
|
||
|
desc="Loading weights") as desc:
|
||
|
for key in tqdm(relative_position_bias_table_keys,
|
||
|
total=total_len,
|
||
|
position=0):
|
||
|
relative_position_bias_table_pretrained = state_dicts[key]
|
||
|
relative_position_bias_table_current = model.state_dict()[key]
|
||
|
L1, nH1 = relative_position_bias_table_pretrained.shape
|
||
|
L2, nH2 = relative_position_bias_table_current.shape
|
||
|
L2 = (2 * model.window_size[1] - 1) * (2 * model.window_size[2] - 1)
|
||
|
wd = model.window_size[0]
|
||
|
if nH1 != nH2:
|
||
|
desc.set_description(f"Error in loading {key}, skip")
|
||
|
else:
|
||
|
if L1 != L2:
|
||
|
S1 = int(L1**0.5)
|
||
|
relative_position_bias_table_pretrained_resized = paddle.nn.functional.interpolate(
|
||
|
relative_position_bias_table_pretrained.transpose(
|
||
|
[1, 0]).reshape([1, nH1, S1, S1]),
|
||
|
size=(2 * model.window_size[1] - 1,
|
||
|
2 * model.window_size[2] - 1),
|
||
|
mode='bicubic')
|
||
|
relative_position_bias_table_pretrained = relative_position_bias_table_pretrained_resized.reshape(
|
||
|
[nH2, L2]).transpose([1, 0])
|
||
|
desc.set_description(f"Loading {key}")
|
||
|
state_dicts[key] = relative_position_bias_table_pretrained.tile(
|
||
|
[2 * wd - 1, 1])
|
||
|
time.sleep(0.01)
|
||
|
ret_str = "loading {:<20d} weights completed.".format(
|
||
|
len(model.state_dict()))
|
||
|
desc.set_description(ret_str)
|
||
|
return state_dicts
|
||
|
|
||
|
|
||
|
def pretrain_vit_param_trans(model, state_dicts, num_patches, num_seg,
|
||
|
attention_type):
|
||
|
"""
|
||
|
Convert ViT's pre-trained model parameters to a parameter dictionary that matches the existing model
|
||
|
"""
|
||
|
if 'head' + '.weight' in state_dicts:
|
||
|
del state_dicts['head' + '.weight']
|
||
|
if 'head' + '.bias' in state_dicts:
|
||
|
del state_dicts['head' + '.bias']
|
||
|
|
||
|
total_len = len(model.state_dict())
|
||
|
if num_patches + 1 != state_dicts['pos_embed'].shape[1]: # when
|
||
|
pos_embed = state_dicts['pos_embed']
|
||
|
cls_pos_embed = paddle.to_tensor(
|
||
|
pos_embed[0, 0, :]).unsqueeze(0).unsqueeze(1)
|
||
|
other_pos_embed = paddle.to_tensor(pos_embed[0, 1:, :])
|
||
|
gs_new = int(np.sqrt(num_patches))
|
||
|
gs_old = int(np.sqrt(other_pos_embed.shape[0]))
|
||
|
zoom = (gs_new / gs_old, gs_new / gs_old, 1)
|
||
|
other_pos_embed = paddle.reshape(other_pos_embed, [gs_old, gs_old, -1])
|
||
|
other_pos_embed = ndimage.zoom(other_pos_embed, zoom, order=1)
|
||
|
other_pos_embed = paddle.to_tensor(other_pos_embed)
|
||
|
new_pos_embed = paddle.reshape(other_pos_embed, [1, num_patches, -1])
|
||
|
new_pos_embed = paddle.concat((cls_pos_embed, new_pos_embed), axis=1)
|
||
|
state_dicts['pos_embed'] = new_pos_embed
|
||
|
time.sleep(0.01)
|
||
|
|
||
|
if 'time_embed' in state_dicts and num_seg != state_dicts[
|
||
|
'time_embed'].shape[1]:
|
||
|
time_embed = state_dicts['time_embed'].transpose((0, 2, 1)).unsqueeze(0)
|
||
|
new_time_embed = F.interpolate(time_embed,
|
||
|
size=(time_embed.shape[-2], num_seg),
|
||
|
mode='nearest')
|
||
|
state_dicts['time_embed'] = new_time_embed.squeeze(0).transpose(
|
||
|
(0, 2, 1))
|
||
|
time.sleep(0.01)
|
||
|
with tqdm(total=total_len,
|
||
|
position=1,
|
||
|
bar_format='{desc}',
|
||
|
desc="Loading weights") as desc:
|
||
|
if attention_type == 'divided_space_time':
|
||
|
new_state_dicts = state_dicts.copy()
|
||
|
for key in tqdm(state_dicts):
|
||
|
if 'blocks' in key and 'attn' in key:
|
||
|
desc.set_description("Loading %s" % key)
|
||
|
new_key = key.replace('attn', 'temporal_attn')
|
||
|
if not new_key in state_dicts:
|
||
|
new_state_dicts[new_key] = state_dicts[key]
|
||
|
else:
|
||
|
new_state_dicts[new_key] = state_dicts[new_key]
|
||
|
if 'blocks' in key and 'norm1' in key:
|
||
|
desc.set_description("Loading %s" % key)
|
||
|
new_key = key.replace('norm1', 'temporal_norm1')
|
||
|
if not new_key in state_dicts:
|
||
|
new_state_dicts[new_key] = state_dicts[key]
|
||
|
else:
|
||
|
new_state_dicts[new_key] = state_dicts[new_key]
|
||
|
time.sleep(0.01)
|
||
|
elif attention_type == 'space_only': # tokenshift raw vit
|
||
|
new_state_dicts = state_dicts.copy()
|
||
|
|
||
|
ret_str = "loading {:<20d} weights completed.".format(
|
||
|
len(model.state_dict()))
|
||
|
desc.set_description(ret_str)
|
||
|
return new_state_dicts
|
||
|
|
||
|
|
||
|
def pretrain_resnet18_param_trans(model, loaded_dict):
|
||
|
encoder_dict = model.encoder.state_dict()
|
||
|
pose_encoder_dict = model.pose_encoder.state_dict()
|
||
|
|
||
|
names = ['encoder.', 'encoder_day.', 'encoder_night.']
|
||
|
for name in names:
|
||
|
total_len = len(loaded_dict.items())
|
||
|
with tqdm(total=total_len,
|
||
|
position=1,
|
||
|
bar_format='{desc}',
|
||
|
desc="Loading weights") as desc:
|
||
|
for key, value in tqdm(loaded_dict.items(),
|
||
|
total=total_len,
|
||
|
position=0):
|
||
|
key = str(name + key)
|
||
|
if key in encoder_dict:
|
||
|
encoder_dict[key] = value
|
||
|
desc.set_description('Loading %s' % key)
|
||
|
time.sleep(0.01)
|
||
|
|
||
|
num_input_images = 2
|
||
|
loaded_dict['conv1.weight'] = paddle.concat(
|
||
|
[loaded_dict['conv1.weight']] * num_input_images, 1) / num_input_images
|
||
|
total_len = len(loaded_dict.items())
|
||
|
with tqdm(total=total_len,
|
||
|
position=1,
|
||
|
bar_format='{desc}',
|
||
|
desc="Loading weights") as desc:
|
||
|
for name, value in tqdm(loaded_dict.items(),
|
||
|
total=total_len,
|
||
|
position=0):
|
||
|
name = str('encoder.' + name)
|
||
|
if name in pose_encoder_dict:
|
||
|
pose_encoder_dict[name] = value
|
||
|
desc.set_description('Loading %s' % key)
|
||
|
time.sleep(0.01)
|
||
|
ret_str = "loading {:<20d} weights completed.".format(
|
||
|
len(model.state_dict()))
|
||
|
desc.set_description(ret_str)
|
||
|
return encoder_dict, pose_encoder_dict
|
||
|
|
||
|
|
||
|
#XXX(shipping): maybe need load N times because of different cards have different params.
|
||
|
@main_only
|
||
|
def load_ckpt(model, weight_path, **kargs):
|
||
|
"""
|
||
|
1. Load pre-trained model parameters
|
||
|
2. Extract and convert from the pre-trained model to the parameters
|
||
|
required by the existing model
|
||
|
3. Load the converted parameters of the existing model
|
||
|
"""
|
||
|
#model.set_state_dict(state_dict)
|
||
|
|
||
|
if not osp.isfile(weight_path):
|
||
|
raise IOError(f'{weight_path} is not a checkpoint file')
|
||
|
#state_dicts = load(weight_path)
|
||
|
|
||
|
logger = get_logger("paddlevideo")
|
||
|
state_dicts = paddle.load(weight_path)
|
||
|
if 'ResnetEncoder' in str(model):
|
||
|
encoder_dict, pose_encoder_dict = pretrain_resnet18_param_trans(
|
||
|
model, state_dicts)
|
||
|
model.encoder.load_dict(encoder_dict)
|
||
|
model.pose_encoder.load_dict(pose_encoder_dict)
|
||
|
tmp = model.state_dict()
|
||
|
elif "VisionTransformer" in str(model): # For TimeSformer case
|
||
|
tmp = pretrain_vit_param_trans(model, state_dicts, kargs['num_patches'],
|
||
|
kargs['num_seg'],
|
||
|
kargs['attention_type'])
|
||
|
elif 'SwinTransformer3D' in str(model):
|
||
|
tmp = pretrain_swin_param_trans(model, state_dicts)
|
||
|
else:
|
||
|
tmp = {}
|
||
|
total_len = len(model.state_dict())
|
||
|
with tqdm(total=total_len,
|
||
|
position=1,
|
||
|
bar_format='{desc}',
|
||
|
desc="Loading weights") as desc:
|
||
|
for item in tqdm(model.state_dict(), total=total_len, position=0):
|
||
|
name = item
|
||
|
desc.set_description('Loading %s' % name)
|
||
|
if name not in state_dicts: # Convert from non-parallel model
|
||
|
if str('backbone.' + name) in state_dicts:
|
||
|
tmp[name] = state_dicts['backbone.' + name]
|
||
|
else: # Convert from parallel model
|
||
|
tmp[name] = state_dicts[name]
|
||
|
time.sleep(0.01)
|
||
|
ret_str = "loading {:<20d} weights completed.".format(
|
||
|
len(model.state_dict()))
|
||
|
desc.set_description(ret_str)
|
||
|
model.set_state_dict(tmp)
|
||
|
|
||
|
|
||
|
def mkdir(dir):
|
||
|
if not os.path.exists(dir):
|
||
|
# avoid error when train with multiple gpus
|
||
|
try:
|
||
|
os.makedirs(dir)
|
||
|
except:
|
||
|
pass
|
||
|
|
||
|
|
||
|
def _extract_student_weights(all_params, student_prefix="Student."):
|
||
|
s_params = {
|
||
|
key[len(student_prefix):]: all_params[key]
|
||
|
for key in all_params if student_prefix in key
|
||
|
}
|
||
|
return s_params
|
||
|
|
||
|
|
||
|
@main_only
|
||
|
def save(obj, path, save_student_model=False):
|
||
|
if save_student_model:
|
||
|
s_params = _extract_student_weights(obj)
|
||
|
student_path = path.replace(".pdparams", "_student.pdparams")
|
||
|
if len(s_params) > 0:
|
||
|
paddle.save(s_params, student_path)
|
||
|
paddle.save(obj, path)
|
||
|
|
||
|
|
||
|
def load(file_name):
|
||
|
if not osp.isfile(file_name):
|
||
|
raise IOError(f'{file_name} not exist')
|
||
|
return paddle.load(file_name)
|