You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
10 KiB
Python

# -*- coding: utf-8 -*-
# @Time : 2019/8/23 21:58
# @Author : zhoujun
import time
import paddle
from tqdm import tqdm
from base import BaseTrainer
from utils import runningScore, cal_text_score, Polynomial, profiler
class Trainer(BaseTrainer):
def __init__(
self,
config,
model,
criterion,
train_loader,
validate_loader,
metric_cls,
post_process=None,
profiler_options=None,
):
super(Trainer, self).__init__(
config,
model,
criterion,
train_loader,
validate_loader,
metric_cls,
post_process,
)
self.profiler_options = profiler_options
self.enable_eval = config["trainer"].get("enable_eval", True)
def _train_epoch(self, epoch):
self.model.train()
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
reader_start = time.time()
epoch_start = time.time()
train_loss = 0.0
running_metric_text = runningScore(2)
for i, batch in enumerate(self.train_loader):
profiler.add_profiler_step(self.profiler_options)
if i >= self.train_loader_len:
break
self.global_step += 1
lr = self.optimizer.get_lr()
cur_batch_size = batch["Crop_img"].shape[0]
train_reader_cost += time.time() - reader_start
if self.amp:
with paddle.amp.auto_cast(
enable="gpu" in paddle.device.get_device(),
custom_white_list=self.amp.get("custom_white_list", []),
custom_black_list=self.amp.get("custom_black_list", []),
level=self.amp.get("level", "O2"),
):
preds = self.model(batch["Crop_img"])
loss_dict = self.criterion(preds.astype(paddle.float32), batch)
scaled_loss = self.amp["scaler"].scale(loss_dict["loss"])
scaled_loss.backward()
self.amp["scaler"].minimize(self.optimizer, scaled_loss)
else:
preds = self.model(batch["Crop_img"])
loss_dict = self.criterion(preds, batch)
# backward
loss_dict["loss"].backward()
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.clear_grad()
train_batch_time = time.time() - reader_start
train_batch_cost += train_batch_time
total_samples += cur_batch_size
# acc iou
score_shrink_map = cal_text_score(
preds[:, 0, :, :],
batch["shrink_map"],
batch["shrink_mask"],
running_metric_text,
thred=self.config["post_processing"]["args"]["thresh"],
)
# loss 和 acc 记录到日志
loss_str = "loss: {:.4f}, ".format(loss_dict["loss"].item())
for idx, (key, value) in enumerate(loss_dict.items()):
loss_dict[key] = value.item()
if key == "loss":
continue
loss_str += "{}: {:.4f}".format(key, loss_dict[key])
if idx < len(loss_dict) - 1:
loss_str += ", "
train_loss += loss_dict["loss"]
acc = score_shrink_map["Mean Acc"]
iou_shrink_map = score_shrink_map["Mean IoU"]
if self.global_step % self.log_iter == 0:
self.logger_info(
"[{}/{}], [{}/{}], global_step: {}, ips: {:.1f} samples/sec, avg_reader_cost: {:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, acc: {:.4f}, iou_shrink_map: {:.4f}, {}lr:{:.6}, time:{:.2f}".format(
epoch,
self.epochs,
i + 1,
self.train_loader_len,
self.global_step,
total_samples / train_batch_cost,
train_reader_cost / self.log_iter,
train_batch_cost / self.log_iter,
total_samples / self.log_iter,
acc,
iou_shrink_map,
loss_str,
lr,
train_batch_cost,
)
)
total_samples = 0
train_reader_cost = 0.0
train_batch_cost = 0.0
if self.visualdl_enable and paddle.distributed.get_rank() == 0:
# write tensorboard
for key, value in loss_dict.items():
self.writer.add_scalar(
"TRAIN/LOSS/{}".format(key), value, self.global_step
)
self.writer.add_scalar("TRAIN/ACC_IOU/acc", acc, self.global_step)
self.writer.add_scalar(
"TRAIN/ACC_IOU/iou_shrink_map", iou_shrink_map, self.global_step
)
self.writer.add_scalar("TRAIN/lr", lr, self.global_step)
reader_start = time.time()
return {
"train_loss": train_loss / self.train_loader_len,
"lr": lr,
"time": time.time() - epoch_start,
"epoch": epoch,
}
def _eval(self, epoch):
self.model.eval()
raw_metrics = []
total_frame = 0.0
total_time = 0.0
for i, batch in tqdm(
enumerate(self.validate_loader),
total=len(self.validate_loader),
desc="test model",
):
with paddle.no_grad():
start = time.time()
if self.amp:
with paddle.amp.auto_cast(
enable="gpu" in paddle.device.get_device(),
custom_white_list=self.amp.get("custom_white_list", []),
custom_black_list=self.amp.get("custom_black_list", []),
level=self.amp.get("level", "O2"),
):
preds = self.model(batch["Crop_img"])
preds = preds.astype(paddle.float32)
else:
preds = self.model(batch["Crop_img"])
boxes, scores = self.post_process(
batch, preds, is_output_polygon=self.metric_cls.is_output_polygon
)
total_frame += batch["Crop_img"].shape[0]
total_time += time.time() - start
raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores))
raw_metrics.append(raw_metric)
metrics = self.metric_cls.gather_measure(raw_metrics)
self.logger_info("FPS:{}".format(total_frame / total_time))
return metrics["recall"].avg, metrics["precision"].avg, metrics["fmeasure"].avg
def _on_epoch_finish(self):
self.logger_info(
"[{}/{}], train_loss: {:.4f}, time: {:.4f}, lr: {}".format(
self.epoch_result["epoch"],
self.epochs,
self.epoch_result["train_loss"],
self.epoch_result["time"],
self.epoch_result["lr"],
)
)
net_save_path = "{}/model_latest.pth".format(self.checkpoint_dir)
net_save_path_best = "{}/model_best.pth".format(self.checkpoint_dir)
if paddle.distributed.get_rank() == 0:
self._save_checkpoint(self.epoch_result["epoch"], net_save_path)
save_best = False
if (
self.validate_loader is not None
and self.metric_cls is not None
and self.enable_eval
): # 使用f1作为最优模型指标
recall, precision, hmean = self._eval(self.epoch_result["epoch"])
if self.visualdl_enable:
self.writer.add_scalar("EVAL/recall", recall, self.global_step)
self.writer.add_scalar(
"EVAL/precision", precision, self.global_step
)
self.writer.add_scalar("EVAL/hmean", hmean, self.global_step)
self.logger_info(
"test: recall: {:.6f}, precision: {:.6f}, hmean: {:.6f}".format(
recall, precision, hmean
)
)
if hmean >= self.metrics["hmean"]:
save_best = True
self.metrics["train_loss"] = self.epoch_result["train_loss"]
self.metrics["hmean"] = hmean
self.metrics["precision"] = precision
self.metrics["recall"] = recall
self.metrics["best_model_epoch"] = self.epoch_result["epoch"]
else:
if self.epoch_result["train_loss"] <= self.metrics["train_loss"]:
save_best = True
self.metrics["train_loss"] = self.epoch_result["train_loss"]
self.metrics["best_model_epoch"] = self.epoch_result["epoch"]
best_str = "current best, "
for k, v in self.metrics.items():
best_str += "{}: {:.6f}, ".format(k, v)
self.logger_info(best_str)
if save_best:
import shutil
shutil.copy(net_save_path, net_save_path_best)
self.logger_info("Saving current best: {}".format(net_save_path_best))
else:
self.logger_info("Saving checkpoint: {}".format(net_save_path))
def _on_train_finish(self):
if self.enable_eval:
for k, v in self.metrics.items():
self.logger_info("{}:{}".format(k, v))
self.logger_info("finish train")
def _initialize_scheduler(self):
if self.config["lr_scheduler"]["type"] == "Polynomial":
self.config["lr_scheduler"]["args"]["epochs"] = self.config["trainer"][
"epochs"
]
self.config["lr_scheduler"]["args"]["step_each_epoch"] = len(
self.train_loader
)
self.lr_scheduler = Polynomial(**self.config["lr_scheduler"]["args"])()
else:
self.lr_scheduler = self._initialize("lr_scheduler", paddle.optimizer.lr)