Skip to content

Commit 225dc84

Browse files
mapk bug fixed
1 parent 571a45c commit 225dc84

3 files changed

Lines changed: 11 additions & 10 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.2.43"
10+
version = "0.2.44"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/light/csvmodel.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def training_step(self, batch):
6363
# self.log("train_mapk", self.train_mapk, on_step=True, on_epoch=False)
6464
return loss
6565

66-
def validation_step(self, batch, batch_idx):
66+
def validation_step(self, batch, batch_idx, prog_bar=False):
6767
x, y = batch
6868
logits = self(x)
6969
# compute cross entropy loss from logits and y
@@ -73,10 +73,10 @@ def validation_step(self, batch, batch_idx):
7373
acc = accuracy(preds, y, task="multiclass", num_classes=self._L_out)
7474
self.valid_mapk(logits, y)
7575
self.log("valid_mapk", self.valid_mapk, on_step=False, on_epoch=True)
76-
self.log("val_loss", loss, prog_bar=True)
77-
self.log("val_acc", acc, prog_bar=True)
76+
self.log("val_loss", loss, prog_bar=prog_bar)
77+
self.log("val_acc", acc, prog_bar=prog_bar)
7878

79-
def test_step(self, batch, batch_idx):
79+
def test_step(self, batch, batch_idx, prog_bar=False):
8080
x, y = batch
8181
logits = self(x)
8282
# compute cross entropy loss from logits and y
@@ -85,8 +85,8 @@ def test_step(self, batch, batch_idx):
8585
acc = accuracy(preds, y, task="multiclass", num_classes=self._L_out)
8686
self.test_mapk(logits, y)
8787
self.log("test_mapk", self.test_mapk, on_step=True, on_epoch=True)
88-
self.log("val_loss", loss, prog_bar=True)
89-
self.log("val_acc", acc, prog_bar=True)
88+
self.log("val_loss", loss, prog_bar=prog_bar)
89+
self.log("val_acc", acc, prog_bar=prog_bar)
9090
return loss, acc
9191

9292
def configure_optimizers(self):

src/spotPython/torch/mapk.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,11 @@ def update(self, predicted: torch.Tensor, actual: torch.Tensor):
6464
# Convert predicted to list of lists of indices sorted by confidence score
6565
_, predicted = predicted.topk(k=self.k, dim=1)
6666
predicted = predicted.tolist()
67-
67+
# Code modified according to: "Inplace update to inference tensor outside InferenceMode
68+
# is not allowed. You can make a clone to get a normal tensor before doing inplace update."
6869
score = np.mean([self.apk(p, a, self.k) for p, a in zip(predicted, actual)])
69-
self.total += score
70-
self.count += 1
70+
self.total = self.total + score
71+
self.count = self.count + 1
7172

7273
def compute(self):
7374
"""

0 commit comments

Comments
 (0)