You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

97 lines
3.4 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import paddle
import csv
import paddle.nn.functional as F
from .registry import METRIC
from .base import BaseMetric
from paddlevideo.utils import get_logger
logger = get_logger("paddlevideo")
@METRIC.register
class SkeletonMetric(BaseMetric):
"""
Test for Skeleton based model.
note: only support batch size = 1, single card test.
Args:
out_file: str, file to save test results.
"""
def __init__(self,
data_size,
batch_size,
out_file='submission.csv',
log_interval=1,
top_k=5):
"""prepare for metrics
"""
super().__init__(data_size, batch_size, log_interval)
self.top1 = []
self.top5 = []
self.values = []
self.out_file = out_file
self.k = top_k
def update(self, batch_id, data, outputs):
"""update metrics during each iter
"""
if data[0].shape[0] != outputs.shape[0]:
num_segs = data[0].shape[1]
batch_size = outputs.shape[0]
outputs = outputs.reshape(
[batch_size // num_segs, num_segs, outputs.shape[-1]])
outputs = outputs.mean(axis=1)
if len(data) == 2: # data with label
labels = data[1]
top1 = paddle.metric.accuracy(input=outputs, label=labels, k=1)
top5 = paddle.metric.accuracy(input=outputs, label=labels, k=self.k)
if self.world_size > 1:
top1 = paddle.distributed.all_reduce(
top1, op=paddle.distributed.ReduceOp.SUM) / self.world_size
top5 = paddle.distributed.all_reduce(
top5, op=paddle.distributed.ReduceOp.SUM) / self.world_size
self.top1.append(top1.numpy())
self.top5.append(top5.numpy())
else: # data without label, only support batch_size=1. Used for fsd-10.
prob = F.softmax(outputs)
clas = paddle.argmax(prob, axis=1).numpy()[0]
self.values.append((batch_id, clas))
# preds ensemble
if batch_id % self.log_interval == 0:
logger.info("[TEST] Processing batch {}/{} ...".format(
batch_id,
self.data_size // (self.batch_size * self.world_size)))
def accumulate(self):
"""accumulate metrics when finished all iters.
"""
if self.top1: # data with label
logger.info('[TEST] finished, avg_acc1= {}, avg_acc5= {}'.format(
np.mean(np.array(self.top1)), np.mean(np.array(self.top5))))
else:
headers = ['sample_index', 'predict_category']
with open(
self.out_file,
'w',
) as fp:
writer = csv.writer(fp)
writer.writerow(headers)
writer.writerows(self.values)
logger.info("Results saved in {} !".format(self.out_file))