77
88
99def train_model (config , fun_control ):
10+ if fun_control ["enable_progress_bar" ] is None :
11+ enable_progress_bar = False
12+ else :
13+ enable_progress_bar = fun_control ["enable_progress_bar" ]
1014 config_id = generate_config_id (config )
1115 # Init DataModule
1216 dm = CSVDataModule (
@@ -21,8 +25,11 @@ def train_model(config, fun_control):
2125 max_epochs = model .epochs ,
2226 accelerator = "auto" ,
2327 devices = 1 ,
24- logger = TensorBoardLogger (save_dir = fun_control ["tensorboard_path" ], version = config_id ),
25- callbacks = [EarlyStopping (monitor = "val_loss" , patience = 3 , mode = "min" , strict = False , verbose = False )],
28+ logger = TensorBoardLogger (save_dir = fun_control ["tensorboard_path" ], version = config_id , default_hp_metric = True ),
29+ callbacks = [
30+ EarlyStopping (monitor = "val_loss" , patience = config ["patience" ], mode = "min" , strict = False , verbose = False )
31+ ],
32+ enable_progress_bar = enable_progress_bar ,
2633 )
2734 # Pass the datamodule as arg to trainer.fit to override model hooks :)
2835 trainer .fit (model = model , datamodule = dm )
@@ -36,6 +43,10 @@ def train_model(config, fun_control):
3643
3744
3845def test_model (config , fun_control ):
46+ if fun_control ["enable_progress_bar" ] is None :
47+ enable_progress_bar = False
48+ else :
49+ enable_progress_bar = fun_control ["enable_progress_bar" ]
3950 config_id = generate_config_id (config )
4051 # Init DataModule
4152 dm = CSVDataModule (
@@ -50,7 +61,11 @@ def test_model(config, fun_control):
5061 max_epochs = model .epochs ,
5162 accelerator = "auto" ,
5263 devices = 1 ,
53- logger = TensorBoardLogger (save_dir = fun_control ["tensorboard_path" ], version = config_id ),
64+ logger = TensorBoardLogger (save_dir = fun_control ["tensorboard_path" ], version = config_id , default_hp_metric = True ),
65+ callbacks = [
66+ EarlyStopping (monitor = "val_loss" , patience = config ["patience" ], mode = "min" , strict = False , verbose = False )
67+ ],
68+ enable_progress_bar = enable_progress_bar ,
5469 )
5570 # Pass the datamodule as arg to trainer.fit to override model hooks :)
5671 trainer .fit (model = model , datamodule = dm )
@@ -61,7 +76,11 @@ def test_model(config, fun_control):
6176
6277
6378def cv_model (config , fun_control ):
64- # config_id = generate_config_id(config)
79+ config_id = generate_config_id (config )
80+ if fun_control ["enable_progress_bar" ] is None :
81+ enable_progress_bar = False
82+ else :
83+ enable_progress_bar = fun_control ["enable_progress_bar" ]
6584 results = []
6685 num_folds = 10
6786 split_seed = 12345
@@ -87,7 +106,13 @@ def cv_model(config, fun_control):
87106 max_epochs = model .epochs ,
88107 accelerator = "auto" ,
89108 devices = 1 ,
90- # logger=TensorBoardLogger(save_dir=fun_control["tensorboard_path"], version=config_id),
109+ logger = TensorBoardLogger (
110+ save_dir = fun_control ["tensorboard_path" ], version = config_id , default_hp_metric = True
111+ ),
112+ callbacks = [
113+ EarlyStopping (monitor = "val_loss" , patience = config ["patience" ], mode = "min" , strict = False , verbose = False )
114+ ],
115+ enable_progress_bar = enable_progress_bar ,
91116 )
92117 # Pass the datamodule as arg to trainer.fit to override model hooks :)
93118 trainer .fit (model = model , datamodule = dm )
0 commit comments