# -*- coding: utf-8 -*- # @Time : 2018/6/11 15:54 # @Author : zhoujun import os import sys import pathlib __dir__ = pathlib.Path(os.path.abspath(__file__)) sys.path.append(str(__dir__)) sys.path.append(str(__dir__.parent.parent)) import argparse import time import paddle from tqdm.auto import tqdm class EVAL: def __init__(self, model_path, gpu_id=0): from models import build_model from data_loader import get_dataloader from post_processing import get_post_processing from utils import get_metric self.gpu_id = gpu_id if ( self.gpu_id is not None and isinstance(self.gpu_id, int) and paddle.device.is_compiled_with_cuda() ): paddle.device.set_device("gpu:{}".format(self.gpu_id)) else: paddle.device.set_device("cpu") checkpoint = paddle.load(model_path) config = checkpoint["config"] config["arch"]["backbone"]["pretrained"] = False self.validate_loader = get_dataloader( config["dataset"]["validate"], config["distributed"] ) self.model = build_model(config["arch"]) self.model.set_state_dict(checkpoint["state_dict"]) self.post_process = get_post_processing(config["post_processing"]) self.metric_cls = get_metric(config["metric"]) def eval(self): self.model.eval() raw_metrics = [] total_frame = 0.0 total_time = 0.0 for i, batch in tqdm( enumerate(self.validate_loader), total=len(self.validate_loader), desc="test model", ): with paddle.no_grad(): start = time.time() preds = self.model(batch["Crop_img"]) boxes, scores = self.post_process( batch, preds, is_output_polygon=self.metric_cls.is_output_polygon ) total_frame += batch["Crop_img"].shape[0] total_time += time.time() - start raw_metric = self.metric_cls.validate_measure(batch, (boxes, scores)) raw_metrics.append(raw_metric) metrics = self.metric_cls.gather_measure(raw_metrics) print("FPS:{}".format(total_frame / total_time)) return { "recall": metrics["recall"].avg, "precision": metrics["precision"].avg, "fmeasure": metrics["fmeasure"].avg, } def init_args(): parser = argparse.ArgumentParser(description="DBNet.paddle") parser.add_argument( "--model_path", required=False, default="output/DBNet_resnet18_FPN_DBHead/checkpoint/1.pth", type=str, ) args = parser.parse_args() return args if __name__ == "__main__": args = init_args() eval = EVAL(args.model_path) result = eval.eval() print(result)