Skip to content

Commit e6d4544

Browse files
v0.2.29
MAPK updated
1 parent bbf557f commit e6d4544

3 files changed

Lines changed: 87 additions & 45 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.2.28"
10+
version = "0.2.29"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/torch/mapk.py

Lines changed: 85 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,108 @@
1+
import torchmetrics
12
import torch
2-
from torchmetrics import Metric
33
import numpy as np
44

55

6-
class MAPK(Metric):
7-
"""Computes the mean average precision at k.
6+
class MAPK(torchmetrics.Metric):
7+
"""
8+
Mean Average Precision at K (MAPK) metric.
9+
10+
This class inherits from the `Metric` class of the `torchmetrics` library.
11+
812
Args:
9-
k: Number of predictions to consider
10-
dist_sync_on_step: Whether to sync the output across all GPUs
11-
device: Device to use for the computation
13+
k (int): The number of top predictions to consider when calculating the metric.
14+
dist_sync_on_step (bool): Whether to synchronize the metric states across processes during the forward pass.
15+
16+
Attributes:
17+
total (torch.Tensor): The cumulative sum of the metric scores across all batches.
18+
count (torch.Tensor): The number of batches processed.
19+
1220
Example:
13-
>>> from torchmetrics import MAPK
14-
>>> target = torch.tensor([0, 1, 2, 3])
15-
>>> preds = torch.tensor([[0, 1, 2, 3],
16-
... [0, 2, 1, 3],
17-
... [0, 1, 3, 2],
18-
... [0, 3, 1, 2]])
19-
>>> mapk = MAPK(k=3)
20-
>>> mapk(preds, target)
21-
tensor(0.3333)
22-
23-
>>> y_pred = torch.tensor([[0.5, 0.2, 0.2], # 0 is in top 2
24-
[0.3, 0.4, 0.2], # 1 is in top 2
25-
[0.2, 0.4, 0.3], # 2 is in top 2
26-
[0.7, 0.2, 0.1]]) # 2 isn't in top 2
27-
>>> y_true = torch.tensor([0, 1, 2, 2])
28-
>>> mapk_metric = MAPK(k=2)
29-
>>> mapk_metric.update(y_pred, y_true)
30-
>>> result = mapk_metric.compute()
31-
>>> print(result) # tensor(0.37500)
21+
from spotPython.torch.mapk import MAPK
22+
import torch
23+
mapk = MAPK(k=2)
24+
target = torch.tensor([0, 1, 2, 2])
25+
preds = torch.tensor(
26+
[
27+
[0.5, 0.2, 0.2], # 0 is in top 2
28+
[0.3, 0.4, 0.2], # 1 is in top 2
29+
[0.2, 0.4, 0.3], # 2 is in top 2
30+
[0.7, 0.2, 0.1], # 2 isn't in top 2
31+
]
32+
)
33+
mapk.update(preds, target)
34+
print(mapk.compute()) # tensor(0.6250)
3235
"""
3336

34-
def __init__(self, k=3, dist_sync_on_step=False, device=None):
35-
super().__init__(dist_sync_on_step=dist_sync_on_step, device=device)
37+
def __init__(self, k=10, dist_sync_on_step=False):
38+
super().__init__(dist_sync_on_step=dist_sync_on_step)
3639
self.k = k
37-
self.add_state("actual", default=[], dist_reduce_fx="cat")
38-
self.add_state("predicted", default=[], dist_reduce_fx="cat")
40+
self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum")
41+
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum")
42+
43+
def update(self, predicted: torch.Tensor, actual: torch.Tensor):
44+
"""
45+
Update the state variables with a new batch of data.
3946
40-
def update(self, y_pred: torch.Tensor, y: torch.Tensor):
41-
sorted_prediction_ids = np.argsort(-y_pred.cpu().numpy(), axis=1)
42-
top_k_prediction_ids = sorted_prediction_ids[:, : self.k]
43-
self.actual.append(y.cpu().numpy().reshape(-1, 1))
44-
self.predicted.append(top_k_prediction_ids)
47+
Args:
48+
predicted (torch.Tensor): A 2D tensor containing the predicted scores for each class.
49+
actual (torch.Tensor): A 1D tensor containing the ground truth labels.
50+
51+
52+
Raises:
53+
AssertionError: If `actual` is not a 1D tensor or if `predicted` is not a 2D tensor
54+
or if `actual` and `predicted` do not have the same number of elements.
55+
"""
56+
assert len(actual.shape) == 1, "actual must be a 1D tensor"
57+
assert len(predicted.shape) == 2, "predicted must be a 2D tensor"
58+
assert actual.shape[0] == predicted.shape[0], "actual and predicted must have the same number of elements"
59+
60+
# Convert actual to list of lists
61+
actual = actual.tolist()
62+
actual = [[a] for a in actual]
63+
64+
# Convert predicted to list of lists of indices sorted by confidence score
65+
_, predicted = predicted.topk(k=self.k, dim=1)
66+
predicted = predicted.tolist()
67+
68+
score = np.mean([self.apk(p, a, self.k) for p, a in zip(predicted, actual)])
69+
self.total += score
70+
self.count += 1
4571

4672
def compute(self):
47-
actual = np.concatenate(self.actual)
48-
predicted = np.concatenate(self.predicted)
49-
return self.mapk(actual, predicted)
73+
"""
74+
Compute the mean average precision at k.
75+
76+
Returns:
77+
float: The mean average precision at k.
78+
"""
79+
return self.total / self.count
5080

5181
@staticmethod
52-
def apk(actual, predicted, k=10):
82+
def apk(predicted, actual, k=10):
83+
"""
84+
Calculate the average precision at k for a single pair of actual and predicted labels.
85+
86+
Args:
87+
predicted (list): A list of predicted labels.
88+
actual (list): A list of ground truth labels.
89+
k (int): The number of top predictions to consider.
90+
91+
Returns:
92+
float: The average precision at k.
93+
"""
94+
if not actual:
95+
return 0.0
96+
5397
if len(predicted) > k:
5498
predicted = predicted[:k]
99+
55100
score = 0.0
56101
num_hits = 0.0
102+
57103
for i, p in enumerate(predicted):
58104
if p in actual and p not in predicted[:i]:
59105
num_hits += 1.0
60106
score += num_hits / (i + 1.0)
61-
if not actual:
62-
return 0.0
63-
return score / min(len(actual), k)
64107

65-
def mapk(self, actual, predicted):
66-
return np.mean([self.apk(a, p, self.k) for a, p in zip(actual, predicted)])
108+
return score / min(len(actual), k)

src/spotPython/torch/traintest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def evaluate_cv(
157157
metric_name = "Metric"
158158
if metric is not None:
159159
metric_name = type(metric).__name__
160-
print(f"{metric_name} value on hold-out data: {metric_values[fold]}")
160+
# print(f"{metric_name} value on hold-out data: {metric_values[fold]}")
161161
if writer is not None:
162162
writer.add_scalars(
163163
"evaluate_cv fold:" + str(fold + 1) + ". Train & Val Loss and Val Metric" + writerId,

0 commit comments

Comments
 (0)