You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
205 lines
6.3 KiB
Python
205 lines
6.3 KiB
Python
from functools import partial
|
|
from itertools import repeat
|
|
|
|
import numpy as np
|
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
|
|
|
|
|
def box_area(box):
|
|
return (box[2] - box[0]) * (box[3] - box[1])
|
|
|
|
|
|
def calculate_iou(box1, box2, box1_only=False):
|
|
intersection = intersection_area(box1, box2)
|
|
union = box_area(box1)
|
|
if not box1_only:
|
|
union += box_area(box2) - intersection
|
|
|
|
if union == 0:
|
|
return 0
|
|
return intersection / union
|
|
|
|
|
|
def match_boxes(preds, references):
|
|
num_actual = len(references)
|
|
num_predicted = len(preds)
|
|
|
|
iou_matrix = np.zeros((num_actual, num_predicted))
|
|
for i, actual in enumerate(references):
|
|
for j, pred in enumerate(preds):
|
|
iou_matrix[i, j] = calculate_iou(actual, pred, box1_only=True)
|
|
|
|
sorted_indices = np.argsort(iou_matrix, axis=None)[::-1]
|
|
sorted_ious = iou_matrix.flatten()[sorted_indices]
|
|
actual_indices, predicted_indices = np.unravel_index(sorted_indices, iou_matrix.shape)
|
|
|
|
assigned_actual = set()
|
|
assigned_pred = set()
|
|
|
|
matches = []
|
|
for idx, iou in zip(zip(actual_indices, predicted_indices), sorted_ious):
|
|
i, j = idx
|
|
if i not in assigned_actual and j not in assigned_pred:
|
|
iou_val = iou_matrix[i, j]
|
|
if iou_val > .95: # Account for rounding on box edges
|
|
iou_val = 1.0
|
|
matches.append((i, j, iou_val))
|
|
assigned_actual.add(i)
|
|
assigned_pred.add(j)
|
|
|
|
unassigned_actual = set(range(num_actual)) - assigned_actual
|
|
unassigned_pred = set(range(num_predicted)) - assigned_pred
|
|
matches.extend([(i, None, -1.0) for i in unassigned_actual])
|
|
matches.extend([(None, j, 0.0) for j in unassigned_pred])
|
|
|
|
return matches
|
|
|
|
def penalized_iou_score(preds, references):
|
|
matches = match_boxes(preds, references)
|
|
iou = sum([match[2] for match in matches]) / len(matches)
|
|
return iou
|
|
|
|
def intersection_pixels(box1, box2):
|
|
x_left = max(box1[0], box2[0])
|
|
y_top = max(box1[1], box2[1])
|
|
x_right = min(box1[2], box2[2])
|
|
y_bottom = min(box1[3], box2[3])
|
|
|
|
if x_right < x_left or y_bottom < y_top:
|
|
return set()
|
|
|
|
x_left, x_right = int(x_left), int(x_right)
|
|
y_top, y_bottom = int(y_top), int(y_bottom)
|
|
|
|
coords = np.meshgrid(np.arange(x_left, x_right), np.arange(y_top, y_bottom))
|
|
pixels = set(zip(coords[0].flat, coords[1].flat))
|
|
|
|
return pixels
|
|
|
|
|
|
def calculate_coverage(box, other_boxes, penalize_double=False):
|
|
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
|
if box_area == 0:
|
|
return 0
|
|
|
|
# find total coverage of the box
|
|
covered_pixels = set()
|
|
double_coverage = list()
|
|
for other_box in other_boxes:
|
|
ia = intersection_pixels(box, other_box)
|
|
double_coverage.append(list(covered_pixels.intersection(ia)))
|
|
covered_pixels = covered_pixels.union(ia)
|
|
|
|
# Penalize double coverage - having multiple bboxes overlapping the same pixels
|
|
double_coverage_penalty = len(double_coverage)
|
|
if not penalize_double:
|
|
double_coverage_penalty = 0
|
|
covered_pixels_count = max(0, len(covered_pixels) - double_coverage_penalty)
|
|
return covered_pixels_count / box_area
|
|
|
|
|
|
def intersection_area(box1, box2):
|
|
x_left = max(box1[0], box2[0])
|
|
y_top = max(box1[1], box2[1])
|
|
x_right = min(box1[2], box2[2])
|
|
y_bottom = min(box1[3], box2[3])
|
|
|
|
if x_right < x_left or y_bottom < y_top:
|
|
return 0.0
|
|
|
|
return (x_right - x_left) * (y_bottom - y_top)
|
|
|
|
|
|
def calculate_coverage_fast(box, other_boxes, penalize_double=False):
|
|
box = np.array(box)
|
|
other_boxes = np.array(other_boxes)
|
|
|
|
# Calculate box area
|
|
box_area = (box[2] - box[0]) * (box[3] - box[1])
|
|
if box_area == 0:
|
|
return 0
|
|
|
|
x_left = np.maximum(box[0], other_boxes[:, 0])
|
|
y_top = np.maximum(box[1], other_boxes[:, 1])
|
|
x_right = np.minimum(box[2], other_boxes[:, 2])
|
|
y_bottom = np.minimum(box[3], other_boxes[:, 3])
|
|
|
|
widths = np.maximum(0, x_right - x_left)
|
|
heights = np.maximum(0, y_bottom - y_top)
|
|
intersect_areas = widths * heights
|
|
|
|
total_intersect = np.sum(intersect_areas)
|
|
|
|
return min(1.0, total_intersect / box_area)
|
|
|
|
|
|
def precision_recall(preds, references, threshold=.5, workers=8, penalize_double=True):
|
|
if len(references) == 0:
|
|
return {
|
|
"precision": 1,
|
|
"recall": 1,
|
|
}
|
|
|
|
if len(preds) == 0:
|
|
return {
|
|
"precision": 0,
|
|
"recall": 0,
|
|
}
|
|
|
|
# If we're not penalizing double coverage, we can use a faster calculation
|
|
coverage_func = calculate_coverage_fast
|
|
if penalize_double:
|
|
coverage_func = calculate_coverage
|
|
|
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
precision_func = partial(coverage_func, penalize_double=penalize_double)
|
|
precision_iou = executor.map(precision_func, preds, repeat(references))
|
|
reference_iou = executor.map(coverage_func, references, repeat(preds))
|
|
|
|
precision_classes = [1 if i > threshold else 0 for i in precision_iou]
|
|
precision = sum(precision_classes) / len(precision_classes)
|
|
|
|
recall_classes = [1 if i > threshold else 0 for i in reference_iou]
|
|
recall = sum(recall_classes) / len(recall_classes)
|
|
|
|
return {
|
|
"precision": precision,
|
|
"recall": recall,
|
|
}
|
|
|
|
|
|
def mean_coverage(preds, references):
|
|
coverages = []
|
|
|
|
for box1 in references:
|
|
coverage = calculate_coverage(box1, preds)
|
|
coverages.append(coverage)
|
|
|
|
for box2 in preds:
|
|
coverage = calculate_coverage(box2, references)
|
|
coverages.append(coverage)
|
|
|
|
# Calculate the average coverage over all comparisons
|
|
if len(coverages) == 0:
|
|
return 0
|
|
coverage = sum(coverages) / len(coverages)
|
|
return {"coverage": coverage}
|
|
|
|
|
|
def rank_accuracy(preds, references):
|
|
# Preds and references need to be aligned so each position refers to the same bbox
|
|
pairs = []
|
|
for i, pred in enumerate(preds):
|
|
for j, pred2 in enumerate(preds):
|
|
if i == j:
|
|
continue
|
|
pairs.append((i, j, pred > pred2))
|
|
|
|
# Find how many of the prediction rankings are correct
|
|
correct = 0
|
|
for i, ref in enumerate(references):
|
|
for j, ref2 in enumerate(references):
|
|
if (i, j, ref > ref2) in pairs:
|
|
correct += 1
|
|
|
|
return correct / len(pairs) |