@@ -72,7 +72,7 @@ def validation_step(self, batch, batch_idx, prog_bar=False):
7272 preds = torch .argmax (logits , dim = 1 )
7373 acc = accuracy (preds , y , task = "multiclass" , num_classes = self ._L_out )
7474 self .valid_mapk (logits , y )
75- self .log ("valid_mapk" , self .valid_mapk , on_step = False , on_epoch = True )
75+ self .log ("valid_mapk" , self .valid_mapk , on_step = False , on_epoch = True , prog_bar = prog_bar )
7676 self .log ("val_loss" , loss , prog_bar = prog_bar )
7777 self .log ("val_acc" , acc , prog_bar = prog_bar )
7878
@@ -84,7 +84,7 @@ def test_step(self, batch, batch_idx, prog_bar=False):
8484 preds = torch .argmax (logits , dim = 1 )
8585 acc = accuracy (preds , y , task = "multiclass" , num_classes = self ._L_out )
8686 self .test_mapk (logits , y )
87- self .log ("test_mapk" , self .test_mapk , on_step = True , on_epoch = True )
87+ self .log ("test_mapk" , self .test_mapk , on_step = True , on_epoch = True , prog_bar = prog_bar )
8888 self .log ("val_loss" , loss , prog_bar = prog_bar )
8989 self .log ("val_acc" , acc , prog_bar = prog_bar )
9090 return loss , acc
0 commit comments