You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
257 lines
10 KiB
Python
257 lines
10 KiB
Python
# -*- coding: utf-8 -*-
|
|
# @Time : 2019/8/23 21:58
|
|
# @Author : zhoujun
|
|
import time
|
|
|
|
import paddle
|
|
from tqdm import tqdm
|
|
|
|
from base import BaseTrainer
|
|
from utils import runningScore, cal_text_score, Polynomial, profiler
|
|
|
|
|
|
class Trainer(BaseTrainer):
|
|
def __init__(
|
|
self,
|
|
config,
|
|
model,
|
|
criterion,
|
|
train_loader,
|
|
validate_loader,
|
|
metric_cls,
|
|
post_process=None,
|
|
profiler_options=None,
|
|
):
|
|
super(Trainer, self).__init__(
|
|
config,
|
|
model,
|
|
criterion,
|
|
train_loader,
|
|
validate_loader,
|
|
metric_cls,
|
|
post_process,
|
|
)
|
|
self.profiler_options = profiler_options
|
|
self.enable_eval = config["trainer"].get("enable_eval", True)
|
|
|
|
def _train_epoch(self, epoch):
|
|
self.model.train()
|
|
total_samples = 0
|
|
train_reader_cost = 0.0
|
|
train_batch_cost = 0.0
|
|
reader_start = time.time()
|
|
epoch_start = time.time()
|
|
train_loss = 0.0
|
|
running_metric_text = runningScore(2)
|
|
|
|
for i, batch in enumerate(self.train_loader):
|
|
profiler.add_profiler_step(self.profiler_options)
|
|
if i >= self.train_loader_len:
|
|
break
|
|
self.global_step += 1
|
|
lr = self.optimizer.get_lr()
|
|
|
|
cur_batch_size = batch["Crop_img"].shape[0]
|
|
|
|
train_reader_cost += time.time() - reader_start
|
|
if self.amp:
|
|
with paddle.amp.auto_cast(
|
|
enable="gpu" in paddle.device.get_device(),
|
|
custom_white_list=self.amp.get("custom_white_list", []),
|
|
custom_black_list=self.amp.get("custom_black_list", []),
|
|
level=self.amp.get("level", "O2"),
|
|
):
|
|
preds = self.model(batch["Crop_img"])
|
|
loss_dict = self.criterion(preds.astype(paddle.float32), batch)
|
|
scaled_loss = self.amp["scaler"].scale(loss_dict["loss"])
|
|
scaled_loss.backward()
|
|
self.amp["scaler"].minimize(self.optimizer, scaled_loss)
|
|
else:
|
|
preds = self.model(batch["Crop_img"])
|
|
loss_dict = self.criterion(preds, batch)
|
|
# backward
|
|
loss_dict["loss"].backward()
|
|
self.optimizer.step()
|
|
self.lr_scheduler.step()
|
|
self.optimizer.clear_grad()
|
|
|
|
train_batch_time = time.time() - reader_start
|
|
train_batch_cost += train_batch_time
|
|
total_samples += cur_batch_size
|
|
|
|
# acc iou
|
|
score_shrink_map = cal_text_score(
|
|
preds[:, 0, :, :],
|
|
batch["shrink_map"],
|
|
batch["shrink_mask"],
|
|
running_metric_text,
|
|
thred=self.config["post_processing"]["args"]["thresh"],
|
|
)
|
|
|
|
# loss 和 acc 记录到日志
|
|
loss_str = "loss: {:.4f}, ".format(loss_dict["loss"].item())
|
|
for idx, (key, value) in enumerate(loss_dict.items()):
|
|
loss_dict[key] = value.item()
|
|
if key == "loss":
|
|
continue
|
|
loss_str += "{}: {:.4f}".format(key, loss_dict[key])
|
|
if idx < len(loss_dict) - 1:
|
|
loss_str += ", "
|
|
|
|
train_loss += loss_dict["loss"]
|
|
acc = score_shrink_map["Mean Acc"]
|
|
iou_shrink_map = score_shrink_map["Mean IoU"]
|
|
|
|
if self.global_step % self.log_iter == 0:
|
|
self.logger_info(
|
|
"[{}/{}], [{}/{}], global_step: {}, ips: {:.1f} samples/sec, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, acc: {:.4f}, iou_shrink_map: {:.4f}, {}lr:{:.6}, time:{:.2f}".format(
|
|
epoch,
|
|
self.epochs,
|
|
i + 1,
|
|
self.train_loader_len,
|
|
self.global_step,
|
|
total_samples / train_batch_cost,
|
|
train_reader_cost / self.log_iter,
|
|
train_batch_cost / self.log_iter,
|
|
total_samples / self.log_iter,
|
|
acc,
|
|
iou_shrink_map,
|
|
loss_str,
|
|
lr,
|
|
train_batch_cost,
|
|
)
|
|
)
|
|
total_samples = 0
|
|
train_reader_cost = 0.0
|
|
train_batch_cost = 0.0
|
|
|
|
if self.visualdl_enable and paddle.distributed.get_rank() == 0:
|
|
# write tensorboard
|
|
for key, value in loss_dict.items():
|
|
self.writer.add_scalar(
|
|
"TRAIN/LOSS/{}".format(key), value, self.global_step
|
|
)
|
|
self.writer.add_scalar("TRAIN/ACC_IOU/acc", acc, self.global_step)
|
|
self.writer.add_scalar(
|
|
"TRAIN/ACC_IOU/iou_shrink_map", iou_shrink_map, self.global_step
|
|
)
|
|
self.writer.add_scalar("TRAIN/lr", lr, self.global_step)
|
|
reader_start = time.time()
|
|
return {
|
|
"train_loss": train_loss / self.train_loader_len,
|
|
"lr": lr,
|
|
"time": time.time() - epoch_start,
|
|
"epoch": epoch,
|
|
}
|
|
|
|
def _eval(self, epoch):
|
|
self.model.eval()
|
|
raw_metrics = []
|
|
total_frame = 0.0
|
|
total_time = 0.0
|
|
for i, batch in tqdm(
|
|
enumerate(self.validate_loader),
|
|
total=len(self.validate_loader),
|
|
desc="test model",
|
|
):
|
|
with paddle.no_grad():
|
|
start = time.time()
|
|
if self.amp:
|
|
with paddle.amp.auto_cast(
|
|
enable="gpu" in paddle.device.get_device(),
|
|
custom_white_list=self.amp.get("custom_white_list", []),
|
|
custom_black_list=self.amp.get("custom_black_list", []),
|
|
level=self.amp.get("level", "O2"),
|
|
):
|
|
preds = self.model(batch["Crop_img"])
|
|
preds = preds.astype(paddle.float32)
|
|
else:
|
|
preds = self.model(batch["Crop_img"])
|
|
boxes, scores = self.post_process(
|
|
batch, preds, is_output_polygon=self.metric_cls.is_output_polygon
|
|
)
|
|
total_frame += batch["Crop_img"].shape[0]
|
|
total_time += time.time() - start
|
|
raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores))
|
|
raw_metrics.append(raw_metric)
|
|
metrics = self.metric_cls.gather_measure(raw_metrics)
|
|
self.logger_info("FPS:{}".format(total_frame / total_time))
|
|
return metrics["recall"].avg, metrics["precision"].avg, metrics["fmeasure"].avg
|
|
|
|
def _on_epoch_finish(self):
|
|
self.logger_info(
|
|
"[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}".format(
|
|
self.epoch_result["epoch"],
|
|
self.epochs,
|
|
self.epoch_result["train_loss"],
|
|
self.epoch_result["time"],
|
|
self.epoch_result["lr"],
|
|
)
|
|
)
|
|
net_save_path = "{}/model_latest.pth".format(self.checkpoint_dir)
|
|
net_save_path_best = "{}/model_best.pth".format(self.checkpoint_dir)
|
|
|
|
if paddle.distributed.get_rank() == 0:
|
|
self._save_checkpoint(self.epoch_result["epoch"], net_save_path)
|
|
save_best = False
|
|
if (
|
|
self.validate_loader is not None
|
|
and self.metric_cls is not None
|
|
and self.enable_eval
|
|
): # 使用f1作为最优模型指标
|
|
recall, precision, hmean = self._eval(self.epoch_result["epoch"])
|
|
|
|
if self.visualdl_enable:
|
|
self.writer.add_scalar("EVAL/recall", recall, self.global_step)
|
|
self.writer.add_scalar(
|
|
"EVAL/precision", precision, self.global_step
|
|
)
|
|
self.writer.add_scalar("EVAL/hmean", hmean, self.global_step)
|
|
self.logger_info(
|
|
"test: recall: {:.6f}, precision: {:.6f}, hmean: {:.6f}".format(
|
|
recall, precision, hmean
|
|
)
|
|
)
|
|
|
|
if hmean >= self.metrics["hmean"]:
|
|
save_best = True
|
|
self.metrics["train_loss"] = self.epoch_result["train_loss"]
|
|
self.metrics["hmean"] = hmean
|
|
self.metrics["precision"] = precision
|
|
self.metrics["recall"] = recall
|
|
self.metrics["best_model_epoch"] = self.epoch_result["epoch"]
|
|
else:
|
|
if self.epoch_result["train_loss"] <= self.metrics["train_loss"]:
|
|
save_best = True
|
|
self.metrics["train_loss"] = self.epoch_result["train_loss"]
|
|
self.metrics["best_model_epoch"] = self.epoch_result["epoch"]
|
|
best_str = "current best, "
|
|
for k, v in self.metrics.items():
|
|
best_str += "{}: {:.6f}, ".format(k, v)
|
|
self.logger_info(best_str)
|
|
if save_best:
|
|
import shutil
|
|
|
|
shutil.copy(net_save_path, net_save_path_best)
|
|
self.logger_info("Saving current best: {}".format(net_save_path_best))
|
|
else:
|
|
self.logger_info("Saving checkpoint: {}".format(net_save_path))
|
|
|
|
def _on_train_finish(self):
|
|
if self.enable_eval:
|
|
for k, v in self.metrics.items():
|
|
self.logger_info("{}:{}".format(k, v))
|
|
self.logger_info("finish train")
|
|
|
|
def _initialize_scheduler(self):
|
|
if self.config["lr_scheduler"]["type"] == "Polynomial":
|
|
self.config["lr_scheduler"]["args"]["epochs"] = self.config["trainer"][
|
|
"epochs"
|
|
]
|
|
self.config["lr_scheduler"]["args"]["step_each_epoch"] = len(
|
|
self.train_loader
|
|
)
|
|
self.lr_scheduler = Polynomial(**self.config["lr_scheduler"]["args"])()
|
|
else:
|
|
self.lr_scheduler = self._initialize("lr_scheduler", paddle.optimizer.lr)
|