Skip to content

Commit bbf557f

Browse files
v0.0.28
mapk can handle devices
1 parent 9bda142 commit bbf557f

3 files changed

Lines changed: 4 additions & 6 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.2.26"
10+
version = "0.2.28"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/torch/mapk.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class MAPK(Metric):
88
Args:
99
k: Number of predictions to consider
1010
dist_sync_on_step: Whether to sync the output across all GPUs
11+
device: Device to use for the computation
1112
Example:
1213
>>> from torchmetrics import MAPK
1314
>>> target = torch.tensor([0, 1, 2, 3])
@@ -30,8 +31,8 @@ class MAPK(Metric):
3031
>>> print(result) # tensor(0.37500)
3132
"""
3233

33-
def __init__(self, k=3, dist_sync_on_step=False):
34-
super().__init__(dist_sync_on_step=dist_sync_on_step)
34+
def __init__(self, k=3, dist_sync_on_step=False, device=None):
35+
super().__init__(dist_sync_on_step=dist_sync_on_step, device=device)
3536
self.k = k
3637
self.add_state("actual", default=[], dist_reduce_fx="cat")
3738
self.add_state("predicted", default=[], dist_reduce_fx="cat")

src/spotPython/torch/traintest.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,6 @@ def evaluate_cv(
113113
print("We will use", torch.cuda.device_count(), "GPUs!")
114114
net = nn.DataParallel(net)
115115
net.to(device)
116-
metric.to(device)
117116
optimizer = optimizer_handler(
118117
optimizer_name=optimizer_instance,
119118
params=net.parameters(),
@@ -220,7 +219,6 @@ def evaluate_hold_out(
220219
print("We will use", torch.cuda.device_count(), "GPUs!")
221220
net = nn.DataParallel(net)
222221
net.to(device)
223-
metric.to(device)
224222
# loss_function = nn.CrossEntropyLoss()
225223
# TODO: optimizer = optim.Adam(net.parameters(), lr=lr)
226224
# optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)
@@ -266,7 +264,6 @@ def evaluate_hold_out(
266264
metric_name = "Metric"
267265
if metric is not None:
268266
metric_name = type(metric).__name__
269-
print(f"{metric_name} value on hold-out data: {metric_val}")
270267
if writer is not None:
271268
writer.add_scalars(
272269
"evaluate_hold_out: Train & Val Loss and Val Metric" + writerId,

0 commit comments

Comments
 (0)