You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
2.3 KiB
Python
63 lines
2.3 KiB
Python
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License"
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
import numpy as np
|
|
import paddle
|
|
import paddle.nn.functional as F
|
|
|
|
from .registry import METRIC
|
|
from .base import BaseMetric
|
|
from paddlevideo.utils import get_logger
|
|
|
|
logger = get_logger("paddlevideo")
|
|
|
|
|
|
@METRIC.register
|
|
class MSRVTTMetric(BaseMetric):
|
|
def __init__(self, data_size, batch_size, log_interval=1):
|
|
"""prepare for metrics
|
|
"""
|
|
super().__init__(data_size, batch_size, log_interval)
|
|
self.score_matrix = np.zeros((data_size, data_size))
|
|
self.target_matrix = np.zeros((data_size, data_size))
|
|
self.rank_matrix = np.ones((data_size)) * data_size
|
|
|
|
def update(self, batch_id, data, outputs):
|
|
"""update metrics during each iter
|
|
"""
|
|
target = data[-1]
|
|
cm_logit = outputs[-1]
|
|
|
|
self.score_matrix[batch_id, :] = F.softmax(
|
|
cm_logit, axis=1)[:, 0].reshape([-1]).numpy()
|
|
self.target_matrix[batch_id, :] = target.reshape([-1]).numpy()
|
|
|
|
rank = np.where((np.argsort(-self.score_matrix[batch_id]) == np.where(
|
|
self.target_matrix[batch_id] == 1)[0][0]) == 1)[0][0]
|
|
self.rank_matrix[batch_id] = rank
|
|
|
|
rank_matrix_tmp = self.rank_matrix[:batch_id + 1]
|
|
r1 = 100.0 * np.sum(rank_matrix_tmp < 1) / len(rank_matrix_tmp)
|
|
r5 = 100.0 * np.sum(rank_matrix_tmp < 5) / len(rank_matrix_tmp)
|
|
r10 = 100.0 * np.sum(rank_matrix_tmp < 10) / len(rank_matrix_tmp)
|
|
|
|
medr = np.floor(np.median(rank_matrix_tmp) + 1)
|
|
meanr = np.mean(rank_matrix_tmp) + 1
|
|
logger.info(
|
|
"[{}] Final r1:{:.3f}, r5:{:.3f}, r10:{:.3f}, mder:{:.3f}, meanr:{:.3f}"
|
|
.format(batch_id, r1, r5, r10, medr, meanr))
|
|
|
|
def accumulate(self):
|
|
"""accumulate metrics when finished all iters.
|
|
"""
|
|
logger.info("Eval Finished!")
|