You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
4.8 KiB
YAML

8 months ago
Global:
debug: false
use_gpu: true
epoch_num: 100
log_smooth_window: 20
print_batch_step: 10
save_model_dir: ./output/rec_svtrv2_gtc_distill_lr00002/
save_epoch_step: 5
eval_batch_step:
- 0
- 1000
cal_metric_during_train: False
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
max_text_length: &max_text_length 25
infer_mode: false
use_space_char: true
distributed: true
save_res_path: ./output/rec/predicts_svtrv2_gtc_distill.txt
Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.99
epsilon: 1.e-8
weight_decay: 0.05
no_weight_decay_name: norm pos_embed patch_embed downsample
one_dim_param_no_weight_decay: True
lr:
name: Cosine
learning_rate: 0.0002 # 8gpus 192bs
warmup_epoch: 5
Architecture:
model_type: rec
name: DistillationModel
algorithm: Distillation
Models:
Teacher:
pretrained: ./output/rec_svtrv2_gtc/best_accuracy
freeze_params: true
return_all_feats: true
model_type: rec
algorithm: SVTR_LCNet
Transform: null
Backbone:
name: SVTRv2
use_pos_embed: False
dims: [128, 256, 384]
depths: [6, 6, 6]
num_heads: [4, 8, 12]
mixer: [['Conv','Conv','Conv','Conv','Conv','Conv'],['Conv','Conv','Global','Global','Global','Global'],['Global','Global','Global','Global','Global','Global']]
local_k: [[5, 5], [5, 5], [-1, -1]]
sub_k: [[2, 1], [2, 1], [-1, -1]]
last_stage: False
use_pool: True
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 256
depth: 2
hidden_dims: 256
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
num_decoder_layers: 2
max_text_length: *max_text_length
Student:
pretrained: ./output/rec_repsvtr_gtc/best_accuracy
freeze_params: false
return_all_feats: true
model_type: rec
algorithm: SVTR_LCNet
Transform: null
Backbone:
name: repvit_svtr
Head:
name: MultiHead
head_list:
- CTCHead:
Neck:
name: svtr
dims: 256
depth: 2
hidden_dims: 256
kernel_size: [1, 3]
use_guide: True
Head:
fc_decay: 0.00001
- NRTRHead:
nrtr_dim: 384
num_decoder_layers: 2
max_text_length: *max_text_length
Loss:
name: CombinedLoss
loss_config_list:
- DistillationDKDLoss:
weight: 0.1
model_name_pairs:
- - Student
- Teacher
key: head_out
multi_head: true
alpha: 1.0
beta: 2.0
dis_head: gtc
name: dkd
- DistillationCTCLoss:
weight: 1.0
model_name_list:
- Student
key: head_out
multi_head: true
- DistillationNRTRLoss:
weight: 1.0
smoothing: false
model_name_list:
- Student
key: head_out
multi_head: true
- DistillCTCLogits:
weight: 1.0
reduction: mean
model_name_pairs:
- - Student
- Teacher
key: head_out
PostProcess:
name: DistillationCTCLabelDecode
model_name:
- Student
key: head_out
multi_head: true
Metric:
name: DistillationMetric
base_metric_name: RecMetric
main_indicator: acc
key: Student
Train:
dataset:
name: MultiScaleDataSet
ds_width: false
data_dir: ./train_data/
ext_op_transform_idx: 1
label_file_list:
- ./train_data/train_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug:
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 192
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
loader:
shuffle: true
batch_size_per_card: *bs
drop_last: true
num_workers: 8
Eval:
dataset:
name: SimpleDataSet
data_dir: ./train_data
label_file_list:
- ./train_data/val_list.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- MultiLabelEncode:
gtc_encode: NRTRLabelEncode
- RecResizeImg:
image_shape: [3, 48, 320]
- KeepKeys:
keep_keys:
- image
- label_ctc
- label_gtc
- length
- valid_ratio
loader:
shuffle: false
drop_last: false
batch_size_per_card: 128
num_workers: 4