@@ -104,12 +104,123 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
104104 # print(f"train_model(): Batch size: {config['batch_size']}")
105105
106106 # Callbacks
107+ #
108+ # EarlyStopping:
109+ # Stop training when a monitored quantity has stopped improving.
110+ # The EarlyStopping callback runs at the end of every validation epoch by default.
111+ # However, the frequency of validation can be modified by setting various parameters
112+ # in the Trainer, for example check_val_every_n_epoch and val_check_interval.
113+ # It must be noted that the patience parameter counts the number of validation checks
114+ # with no improvement, and not the number of training epochs.
115+ # Therefore, with parameters check_val_every_n_epoch=10 and patience=3,
116+ # the trainer will perform at least 40 training epochs before being stopped.
117+ # Args:
118+ # - monitor:
119+ # Quantity to be monitored. Default: 'val_loss'.
120+ # - patience:
121+ # Number of validation checks with no improvement after which training will be stopped.
122+ # In spotpython, this is a hyperparameter.
123+ # - mode (str):
124+ # one of {min, max}. If save_top_k != 0, the decision to overwrite the current save file
125+ # is made based on either the maximization or the minimization of the monitored quantity.
126+ # For 'val_acc', this should be 'max', for 'val_loss' this should be 'min', etc.
127+ # - strict:
128+ # Set to False.
129+ # - verbose:
130+ # If True, prints a message to the logger.
131+ #
132+ # ModelCheckpoint:
133+ # Save the model periodically by monitoring a quantity.
134+ # Every metric logged with log() or log_dict() is a candidate for the monitor key.
135+ # spotpython uses ModelCheckpoint if timestamp is set to False. In this case, the
136+ # config_id has no timestamp and ends with the unique string "_TRAIN". This
137+ # enables loading the model from a checkpoint, because the config_id is unique.
138+ # Args:
139+ # - dirpath:
140+ # Path to the directory where the checkpoints will be saved.
141+ # - monitor (str):
142+ # Quantity to monitor.
143+ # By default it is None which saves a checkpoint only for the last epoch.
144+ # - verbose (bool):
145+ # If True, prints a message to the logger.
146+ # - save_last (Union[bool, Literal['link'], None]):
147+ # When True, saves a last.ckpt copy whenever a checkpoint file gets saved.
148+ # Can be set to 'link' on a local filesystem to create a symbolic link.
149+ # This allows accessing the latest checkpoint in a deterministic manner.
150+ # Default: None.
151+
107152 callbacks = [EarlyStopping (monitor = "val_loss" , patience = config ["patience" ], mode = "min" , strict = False , verbose = False )]
108153 if not timestamp :
109154 # add ModelCheckpoint only if timestamp is False
110- callbacks .append (ModelCheckpoint (dirpath = os .path .join (fun_control ["CHECKPOINT_PATH" ], config_id ), save_last = True )) # Save the last checkpoint
155+ dirpath = os .path .join (fun_control ["CHECKPOINT_PATH" ], config_id )
156+ callbacks .append (ModelCheckpoint (dirpath = dirpath , monitor = None , verbose = False , save_last = True )) # Save the last checkpoint
157+
158+ # Tensorboard logger. The tensorboard is passed to the trainer.
159+ # See: https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.TensorBoardLogger.html
160+ # It uses the following arguments:
161+ # Args:
162+ # - save_dir:
163+ # Where to save logs. Can be specified via fun_control["TENSORBOARD_PATH"]
164+ # - name:
165+ # Experiment name. Defaults to 'default'.
166+ # If it is the empty string then no per-experiment subdirectory is used.
167+ # Changed in spotpython 0.17.2 to the empty string.
168+ # - version:
169+ # Experiment version. If version is not specified the logger inspects the save directory
170+ # for existing versions, then automatically assigns the next available version.
171+ # If it is a string then it is used as the run-specific subdirectory name,
172+ # otherwise 'version_${version}' is used. spotpython uses the config_id as version.
173+ # - log_graph (bool):
174+ # Adds the computational graph to tensorboard.
175+ # This requires that the user has defined the self.example_input_array
176+ # attribute in their model. Set in spotpython to fun_control["log_graph"].
177+ # - default_hp_metric (bool):
178+ # Enables a placeholder metric with key hp_metric when log_hyperparams is called
179+ # without a metric (otherwise calls to log_hyperparams without a metric are ignored).
180+ # spotpython sets this to True.
181+
182+ # Init trainer. See: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.trainer.trainer.Trainer.html#lightning.pytorch.trainer.trainer.Trainer
183+ # Args used by spotpython (there are more):
184+ # - default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
185+ # Default: os.getcwd(). Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’
186+ # - max_epochs: Stop training once this number of epochs is reached.
187+ # Disabled by default (None).
188+ # If both max_epochs and max_steps are not specified, defaults to max_epochs = 1000.
189+ # To enable infinite training, set max_epochs = -1.
190+ # - accelerator: Supports passing different accelerator types
191+ # (“cpu”, “gpu”, “tpu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances.
192+ # - devices: The devices to use. Can be set to a positive number (int or str),
193+ # a sequence of device indices (list or str), the value -1 to indicate all available devices
194+ # should be used, or "auto" for automatic selection based on the chosen accelerator.
195+ # Default: "auto".
196+ # - strategy: Supports different training strategies with aliases as well custom strategies.
197+ # Default: "auto".
198+ # - num_nodes: Number of GPU nodes for distributed training. Default: 1.
199+ # - precision: Double precision (64, ‘64’ or ‘64-true’), full precision (32, ‘32’ or ‘32-true’),
200+ # 16bit mixed precision (16, ‘16’, ‘16-mixed’) or bfloat16 mixed precision (‘bf16’, ‘bf16-mixed’).
201+ # Can be used on CPU, GPU, TPUs, or HPUs. Default: '32-true'.
202+ # - logger: Logger (or iterable collection of loggers) for experiment tracking.
203+ # A True value uses the default TensorBoardLogger if it is installed, otherwise CSVLogger.
204+ # False will disable logging. If multiple loggers are provided, local files (checkpoints,
205+ # profiler traces, etc.) are saved in the log_dir of the first logger. Default: True.
206+ # - callbacks: List of callbacks to enable during training.Default: None.
207+ # - enable_progress_bar: If True, enables the progress bar.
208+ # Whether to enable to progress bar by default. Default: True.
209+ # - num_sanity_val_steps:
210+ # Sanity check runs n validation batches before starting the training routine.
211+ # Set it to -1 to run all batches in all validation dataloaders. Default: 2.
212+ # - log_every_n_steps:
213+ # How often to log within steps. Default: 50.
214+ # - gradient_clip_val:
215+ # The value at which to clip gradients. Passing gradient_clip_val=None
216+ # disables gradient clipping. If using Automatic Mixed Precision (AMP),
217+ # the gradients will be unscaled before. Default: None.
218+ # - gradient_clip_algorithm (str):
219+ # The gradient clipping algorithm to use.
220+ # Pass gradient_clip_algorithm="value" to clip by value,
221+ # and gradient_clip_algorithm="norm" to clip by norm.
222+ # By default it will be set to "norm".
111223
112- # Init trainer
113224 trainer = L .Trainer (
114225 # Where to save models
115226 default_root_dir = os .path .join (fun_control ["CHECKPOINT_PATH" ], config_id ),
@@ -119,21 +230,46 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
119230 strategy = fun_control ["strategy" ],
120231 num_nodes = fun_control ["num_nodes" ],
121232 precision = fun_control ["precision" ],
122- logger = TensorBoardLogger (
123- save_dir = fun_control ["TENSORBOARD_PATH" ],
124- version = config_id ,
125- default_hp_metric = True ,
126- log_graph = fun_control ["log_graph" ],
127- ),
233+ logger = TensorBoardLogger (save_dir = fun_control ["TENSORBOARD_PATH" ], version = config_id , default_hp_metric = True , log_graph = fun_control ["log_graph" ], name = "" ),
128234 callbacks = callbacks ,
129235 enable_progress_bar = enable_progress_bar ,
236+ num_sanity_val_steps = fun_control ["num_sanity_val_steps" ],
237+ log_every_n_steps = fun_control ["log_every_n_steps" ],
238+ gradient_clip_val = None ,
239+ gradient_clip_algorithm = "norm" ,
130240 )
131- # Pass the datamodule as arg to trainer.fit to override model hooks :)
132- trainer .fit (model = model , datamodule = dm )
241+
242+ # Fit the model
243+ # Args:
244+ # - model: Model to fit
245+ # - datamodule: A LightningDataModule that defines the train_dataloader
246+ # hook. Pass the datamodule as arg to trainer.fit to override model hooks # :)
247+ # - ckpt_path: Path/URL of the checkpoint from which training is resumed.
248+ # Could also be one of two special keywords "last" and "hpc".
249+ # If there is no checkpoint file at the path, an exception is raised.
250+ trainer .fit (model = model , datamodule = dm , ckpt_path = None )
251+
133252 # Test best model on validation and test set
134- # result = trainer.validate(model=model, datamodule=dm, ckpt_path="last")
253+
135254 verbose = fun_control ["verbosity" ] > 0
136- result = trainer .validate (model = model , datamodule = dm , verbose = verbose )
255+
256+ # Validate the model
257+ # Perform one evaluation epoch over the validation set.
258+ # Args:
259+ # - model: The model to validate.
260+ # - datamodule: A LightningDataModule that defines the val_dataloader hook.
261+ # - verbose: If True, prints the validation results.
262+ # - ckpt_path: Path to a specific checkpoint to load for validation.
263+ # Either "best", "last", "hpc" or path to the checkpoint you wish to validate.
264+ # If None and the model instance was passed, use the current weights.
265+ # Otherwise, the best model checkpoint from the previous trainer.fit call will
266+ # be loaded if a checkpoint callback is configured.
267+ # Returns:
268+ # - List of dictionaries with metrics logged during the validation phase,
269+ # e.g., in model- or callback hooks like validation_step() etc.
270+ # The length of the list corresponds to the number of validation dataloaders used.
271+ result = trainer .validate (model = model , datamodule = dm , ckpt_path = None , verbose = verbose )
272+
137273 # unlist the result (from a list of one dict)
138274 result = result [0 ]
139275 print (f"train_model result: { result } " )
0 commit comments