|
| 1 | +import torch |
| 2 | +from torchmetrics import Metric |
| 3 | +import numpy as np |
| 4 | + |
| 5 | + |
| 6 | +class MAPK(Metric): |
| 7 | + """Computes the mean average precision at k. |
| 8 | + Args: |
| 9 | + k: Number of predictions to consider |
| 10 | + dist_sync_on_step: Whether to sync the output across all GPUs |
| 11 | + Example: |
| 12 | + >>> from torchmetrics import MAPK |
| 13 | + >>> target = torch.tensor([0, 1, 2, 3]) |
| 14 | + >>> preds = torch.tensor([[0, 1, 2, 3], |
| 15 | + ... [0, 2, 1, 3], |
| 16 | + ... [0, 1, 3, 2], |
| 17 | + ... [0, 3, 1, 2]]) |
| 18 | + >>> mapk = MAPK(k=3) |
| 19 | + >>> mapk(preds, target) |
| 20 | + tensor(0.3333) |
| 21 | +
|
| 22 | + >>> y_pred = torch.tensor([[0.5, 0.2, 0.2], # 0 is in top 2 |
| 23 | + [0.3, 0.4, 0.2], # 1 is in top 2 |
| 24 | + [0.2, 0.4, 0.3], # 2 is in top 2 |
| 25 | + [0.7, 0.2, 0.1]]) # 2 isn't in top 2 |
| 26 | + >>> y_true = torch.tensor([0, 1, 2, 2]) |
| 27 | + >>> mapk_metric = MAPK(k=2) |
| 28 | + >>> mapk_metric.update(y_pred, y_true) |
| 29 | + >>> result = mapk_metric.compute() |
| 30 | + >>> print(result) # tensor(0.37500) |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__(self, k=3, dist_sync_on_step=False): |
| 34 | + super().__init__(dist_sync_on_step=dist_sync_on_step) |
| 35 | + self.k = k |
| 36 | + self.add_state("actual", default=[], dist_reduce_fx="cat") |
| 37 | + self.add_state("predicted", default=[], dist_reduce_fx="cat") |
| 38 | + |
| 39 | + def update(self, y_pred: torch.Tensor, y: torch.Tensor): |
| 40 | + sorted_prediction_ids = np.argsort(-y_pred.cpu().numpy(), axis=1) |
| 41 | + top_k_prediction_ids = sorted_prediction_ids[:, : self.k] |
| 42 | + self.actual.append(y.cpu().numpy().reshape(-1, 1)) |
| 43 | + self.predicted.append(top_k_prediction_ids) |
| 44 | + |
| 45 | + def compute(self): |
| 46 | + actual = np.concatenate(self.actual) |
| 47 | + predicted = np.concatenate(self.predicted) |
| 48 | + return self.mapk(actual, predicted) |
| 49 | + |
| 50 | + @staticmethod |
| 51 | + def apk(actual, predicted, k=10): |
| 52 | + if len(predicted) > k: |
| 53 | + predicted = predicted[:k] |
| 54 | + score = 0.0 |
| 55 | + num_hits = 0.0 |
| 56 | + for i, p in enumerate(predicted): |
| 57 | + if p in actual and p not in predicted[:i]: |
| 58 | + num_hits += 1.0 |
| 59 | + score += num_hits / (i + 1.0) |
| 60 | + if not actual: |
| 61 | + return 0.0 |
| 62 | + return score / min(len(actual), k) |
| 63 | + |
| 64 | + def mapk(self, actual, predicted): |
| 65 | + return np.mean([self.apk(a, p, self.k) for a, p in zip(actual, predicted)]) |
0 commit comments