@@ -63,7 +63,7 @@ def training_step(self, batch):
6363 # self.log("train_mapk", self.train_mapk, on_step=True, on_epoch=False)
6464 return loss
6565
66- def validation_step (self , batch , batch_idx ):
66+ def validation_step (self , batch , batch_idx , prog_bar = False ):
6767 x , y = batch
6868 logits = self (x )
6969 # compute cross entropy loss from logits and y
@@ -73,10 +73,10 @@ def validation_step(self, batch, batch_idx):
7373 acc = accuracy (preds , y , task = "multiclass" , num_classes = self ._L_out )
7474 self .valid_mapk (logits , y )
7575 self .log ("valid_mapk" , self .valid_mapk , on_step = False , on_epoch = True )
76- self .log ("val_loss" , loss , prog_bar = True )
77- self .log ("val_acc" , acc , prog_bar = True )
76+ self .log ("val_loss" , loss , prog_bar = prog_bar )
77+ self .log ("val_acc" , acc , prog_bar = prog_bar )
7878
79- def test_step (self , batch , batch_idx ):
79+ def test_step (self , batch , batch_idx , prog_bar = False ):
8080 x , y = batch
8181 logits = self (x )
8282 # compute cross entropy loss from logits and y
@@ -85,8 +85,8 @@ def test_step(self, batch, batch_idx):
8585 acc = accuracy (preds , y , task = "multiclass" , num_classes = self ._L_out )
8686 self .test_mapk (logits , y )
8787 self .log ("test_mapk" , self .test_mapk , on_step = True , on_epoch = True )
88- self .log ("val_loss" , loss , prog_bar = True )
89- self .log ("val_acc" , acc , prog_bar = True )
88+ self .log ("val_loss" , loss , prog_bar = prog_bar )
89+ self .log ("val_acc" , acc , prog_bar = prog_bar )
9090 return loss , acc
9191
9292 def configure_optimizers (self ):
0 commit comments