Skip to content

Commit 54accad

Browse files
v0.2.25
evaluate_hold_out and evaluate_cv set metric.to(device)
1 parent be28da2 commit 54accad

2 files changed

Lines changed: 3 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.2.24"
10+
version = "0.2.25"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/torch/traintest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def evaluate_cv(
113113
print("We will use", torch.cuda.device_count(), "GPUs!")
114114
net = nn.DataParallel(net)
115115
net.to(device)
116+
metric.to(device)
116117
optimizer = optimizer_handler(
117118
optimizer_name=optimizer_instance,
118119
params=net.parameters(),
@@ -226,6 +227,7 @@ def evaluate_hold_out(
226227
print("We will use", torch.cuda.device_count(), "GPUs!")
227228
net = nn.DataParallel(net)
228229
net.to(device)
230+
metric.to(device)
229231
# loss_function = nn.CrossEntropyLoss()
230232
# TODO: optimizer = optim.Adam(net.parameters(), lr=lr)
231233
# optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)

0 commit comments

Comments
 (0)