@@ -27,12 +27,14 @@ def fun_control_init(
2727 PREFIX = None ,
2828 TENSORBOARD_CLEAN = False ,
2929 accelerator = "auto" ,
30+ check_finite = True ,
3031 collate_fn_name = None ,
3132 converters = None ,
3233 core_model = None ,
3334 core_model_name = None ,
3435 data = None ,
3536 data_full_train = None ,
37+ divergence_threshold = None ,
3638 hacky = False , # !TODO: Documentation
3739 data_val = None ,
3840 data_dir = "./data" ,
@@ -90,6 +92,7 @@ def fun_control_init(
9092 shuffle_val = False ,
9193 shuffle_test = False ,
9294 sigma = 0.0 ,
95+ stopping_threshold = None ,
9396 strategy = "auto" ,
9497 surrogate = None ,
9598 target_column = None ,
@@ -129,6 +132,9 @@ def fun_control_init(
129132 The accelerator to be used by the Lighting Trainer.
130133 It can be either "auto", "dp", "ddp", "ddp2", "ddp_spawn", "ddp_cpu", "gpu", "tpu".
131134 Default is "auto".
135+ check_finite (bool):
136+ When set True, stops training when the monitor becomes NaN or infinite.
137+ Default is True.
132138 collate_fn_name (str):
133139 The name of the collate function. Default is None.
134140 converters (dict):
@@ -164,6 +170,9 @@ def fun_control_init(
164170 Default is 1. Can be "auto" or an integer.
165171 design (object):
166172 The experimental design object. Default is None.
173+ divergence_threshold (float):
174+ Stop training as soon as the monitored quantity becomes worse than this threshold.
175+ Default is None.
167176 enable_progress_bar (bool):
168177 Whether to enable the progress bar or not.
169178 eval (str):
@@ -284,6 +293,9 @@ def fun_control_init(
284293 Whether the test data were shuffled or not. Default is False.
285294 surrogate (object):
286295 The surrogate model object. Default is None.
296+ stopping_threshold (float):
297+ Stop training immediately once the monitored quantity reaches this threshold.
298+ Default is None.
287299 strategy (str):
288300 The strategy to use. Default is "auto".
289301 target_column (str):
@@ -355,13 +367,15 @@ def fun_control_init(
355367 '_L_out': 11,
356368 '_L_cond': None,
357369 'accelerator': "auto",
370+ 'check_finite': True,
358371 'core_model': None,
359372 'core_model_name': None,
360373 'data': None,
361374 'data_dir': './data',
362375 'db_dict_name': None,
363376 'device': None,
364377 'devices': "auto",
378+ 'divergence_threshold': None,
365379 'enable_progress_bar': False,
366380 'eval': None,
367381 'horizon': 7,
@@ -391,6 +405,7 @@ def fun_control_init(
391405 'show_batch_interval': 1000000,
392406 'shuffle': None,
393407 'sigma': 0.0,
408+ 'stopping_threshold': None,
394409 'target_column': None,
395410 'target_type': None,
396411 'train': None,
@@ -425,6 +440,7 @@ def fun_control_init(
425440 "_L_cond" : _L_cond ,
426441 "_torchmetric" : _torchmetric ,
427442 "accelerator" : accelerator ,
443+ "check_finite" : check_finite ,
428444 "collate_fn_name" : collate_fn_name ,
429445 "converters" : converters ,
430446 "core_model" : core_model ,
@@ -433,6 +449,7 @@ def fun_control_init(
433449 "data" : data ,
434450 "data_dir" : data_dir ,
435451 "data_full_train" : data_full_train ,
452+ "divergence_threshold" : divergence_threshold ,
436453 "hacky" : hacky ,
437454 "data_module" : data_module ,
438455 "data_set" : data_set ,
@@ -497,6 +514,7 @@ def fun_control_init(
497514 "shuffle_val" : shuffle_val ,
498515 "shuffle_test" : shuffle_test ,
499516 "sigma" : sigma ,
517+ "stopping_threshold" : stopping_threshold ,
500518 "spot_tensorboard_path" : spot_tensorboard_path ,
501519 "strategy" : strategy ,
502520 "target_column" : target_column ,
0 commit comments