@@ -5948,7 +5948,18 @@ <h2 id="spotpython.light.trainmodel.train_model" class="doc doc-heading">
59485948< span class ="normal "> 367</ span >
59495949< span class ="normal "> 368</ span >
59505950< span class ="normal "> 369</ span >
5951- < span class ="normal "> 370</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="k "> def</ span > < span class ="nf "> train_model</ span > < span class ="p "> (</ span > < span class ="n "> config</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> fun_control</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> timestamp</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> float</ span > < span class ="p "> :</ span >
5951+ < span class ="normal "> 370</ span >
5952+ < span class ="normal "> 371</ span >
5953+ < span class ="normal "> 372</ span >
5954+ < span class ="normal "> 373</ span >
5955+ < span class ="normal "> 374</ span >
5956+ < span class ="normal "> 375</ span >
5957+ < span class ="normal "> 376</ span >
5958+ < span class ="normal "> 377</ span >
5959+ < span class ="normal "> 378</ span >
5960+ < span class ="normal "> 379</ span >
5961+ < span class ="normal "> 380</ span >
5962+ < span class ="normal "> 381</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="k "> def</ span > < span class ="nf "> train_model</ span > < span class ="p "> (</ span > < span class ="n "> config</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> fun_control</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> timestamp</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> float</ span > < span class ="p "> :</ span >
59525963< span class ="w "> </ span > < span class ="sd "> """</ span >
59535964< span class ="sd "> Trains a model using the given configuration and function control parameters.</ span >
59545965
@@ -6078,7 +6089,18 @@ <h2 id="spotpython.light.trainmodel.train_model" class="doc doc-heading">
60786089 < span class ="c1 "> # This allows accessing the latest checkpoint in a deterministic manner.</ span >
60796090 < span class ="c1 "> # Default: None.</ span >
60806091 < span class ="n "> config_id</ span > < span class ="o "> =</ span > < span class ="n "> generate_config_id_with_timestamp</ span > < span class ="p "> (</ span > < span class ="n "> config</ span > < span class ="o "> =</ span > < span class ="n "> config</ span > < span class ="p "> ,</ span > < span class ="n "> timestamp</ span > < span class ="o "> =</ span > < span class ="n "> timestamp</ span > < span class ="p "> )</ span >
6081- < span class ="n "> callbacks</ span > < span class ="o "> =</ span > < span class ="p "> [</ span > < span class ="n "> EarlyStopping</ span > < span class ="p "> (</ span > < span class ="n "> monitor</ span > < span class ="o "> =</ span > < span class ="s2 "> "val_loss"</ span > < span class ="p "> ,</ span > < span class ="n "> patience</ span > < span class ="o "> =</ span > < span class ="n "> config</ span > < span class ="p "> [</ span > < span class ="s2 "> "patience"</ span > < span class ="p "> ],</ span > < span class ="n "> mode</ span > < span class ="o "> =</ span > < span class ="s2 "> "min"</ span > < span class ="p "> ,</ span > < span class ="n "> strict</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="n "> verbose</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )]</ span >
6092+ < span class ="n "> callbacks</ span > < span class ="o "> =</ span > < span class ="p "> [</ span >
6093+ < span class ="n "> EarlyStopping</ span > < span class ="p "> (</ span >
6094+ < span class ="n "> monitor</ span > < span class ="o "> =</ span > < span class ="s2 "> "val_loss"</ span > < span class ="p "> ,</ span >
6095+ < span class ="n "> patience</ span > < span class ="o "> =</ span > < span class ="n "> config</ span > < span class ="p "> [</ span > < span class ="s2 "> "patience"</ span > < span class ="p "> ],</ span >
6096+ < span class ="n "> divergence_threshold</ span > < span class ="o "> =</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "divergence_threshold"</ span > < span class ="p "> ],</ span >
6097+ < span class ="n "> check_finite</ span > < span class ="o "> =</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "check_finite"</ span > < span class ="p "> ],</ span >
6098+ < span class ="n "> stopping_threshold</ span > < span class ="o "> =</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "stopping_threshold"</ span > < span class ="p "> ],</ span >
6099+ < span class ="n "> mode</ span > < span class ="o "> =</ span > < span class ="s2 "> "min"</ span > < span class ="p "> ,</ span >
6100+ < span class ="n "> strict</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span >
6101+ < span class ="n "> verbose</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span >
6102+ < span class ="p "> )</ span >
6103+ < span class ="p "> ]</ span >
60826104 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> timestamp</ span > < span class ="p "> :</ span >
60836105 < span class ="c1 "> # add ModelCheckpoint only if timestamp is False</ span >
60846106 < span class ="n "> dirpath</ span > < span class ="o "> =</ span > < span class ="n "> os</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="o "> .</ span > < span class ="n "> join</ span > < span class ="p "> (</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "CHECKPOINT_PATH"</ span > < span class ="p "> ],</ span > < span class ="n "> config_id</ span > < span class ="p "> )</ span >
@@ -6406,18 +6428,7 @@ <h2 id="spotpython.light.trainmodel.train_model_xai" class="doc doc-heading">
64066428
64076429 < details class ="quote ">
64086430 < summary > Source code in < code > spotpython/light/trainmodel.py</ code > </ summary >
6409- < div class ="highlight "> < table class ="highlighttable "> < tr > < td class ="linenos "> < div class ="linenodiv "> < pre > < span > </ span > < span class ="normal "> 373</ span >
6410- < span class ="normal "> 374</ span >
6411- < span class ="normal "> 375</ span >
6412- < span class ="normal "> 376</ span >
6413- < span class ="normal "> 377</ span >
6414- < span class ="normal "> 378</ span >
6415- < span class ="normal "> 379</ span >
6416- < span class ="normal "> 380</ span >
6417- < span class ="normal "> 381</ span >
6418- < span class ="normal "> 382</ span >
6419- < span class ="normal "> 383</ span >
6420- < span class ="normal "> 384</ span >
6431+ < div class ="highlight "> < table class ="highlighttable "> < tr > < td class ="linenos "> < div class ="linenodiv "> < pre > < span > </ span > < span class ="normal "> 384</ span >
64216432< span class ="normal "> 385</ span >
64226433< span class ="normal "> 386</ span >
64236434< span class ="normal "> 387</ span >
@@ -6765,7 +6776,29 @@ <h2 id="spotpython.light.trainmodel.train_model_xai" class="doc doc-heading">
67656776< span class ="normal "> 729</ span >
67666777< span class ="normal "> 730</ span >
67676778< span class ="normal "> 731</ span >
6768- < span class ="normal "> 732</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="k "> def</ span > < span class ="nf "> train_model_xai</ span > < span class ="p "> (</ span > < span class ="n "> config</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> fun_control</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> timestamp</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> float</ span > < span class ="p "> :</ span >
6779+ < span class ="normal "> 732</ span >
6780+ < span class ="normal "> 733</ span >
6781+ < span class ="normal "> 734</ span >
6782+ < span class ="normal "> 735</ span >
6783+ < span class ="normal "> 736</ span >
6784+ < span class ="normal "> 737</ span >
6785+ < span class ="normal "> 738</ span >
6786+ < span class ="normal "> 739</ span >
6787+ < span class ="normal "> 740</ span >
6788+ < span class ="normal "> 741</ span >
6789+ < span class ="normal "> 742</ span >
6790+ < span class ="normal "> 743</ span >
6791+ < span class ="normal "> 744</ span >
6792+ < span class ="normal "> 745</ span >
6793+ < span class ="normal "> 746</ span >
6794+ < span class ="normal "> 747</ span >
6795+ < span class ="normal "> 748</ span >
6796+ < span class ="normal "> 749</ span >
6797+ < span class ="normal "> 750</ span >
6798+ < span class ="normal "> 751</ span >
6799+ < span class ="normal "> 752</ span >
6800+ < span class ="normal "> 753</ span >
6801+ < span class ="normal "> 754</ span > </ pre > </ div > </ td > < td class ="code "> < div > < pre > < span > </ span > < code > < span class ="k "> def</ span > < span class ="nf "> train_model_xai</ span > < span class ="p "> (</ span > < span class ="n "> config</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> fun_control</ span > < span class ="p "> :</ span > < span class ="nb "> dict</ span > < span class ="p "> ,</ span > < span class ="n "> timestamp</ span > < span class ="p "> :</ span > < span class ="nb "> bool</ span > < span class ="o "> =</ span > < span class ="kc "> True</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="nb "> float</ span > < span class ="p "> :</ span >
67696802< span class ="w "> </ span > < span class ="sd "> """</ span >
67706803< span class ="sd "> Trains a model using the given configuration and function control parameters. Performs feature attribution analysis and calculates consistency of these methods.</ span >
67716804
@@ -6895,7 +6928,18 @@ <h2 id="spotpython.light.trainmodel.train_model_xai" class="doc doc-heading">
68956928 < span class ="c1 "> # This allows accessing the latest checkpoint in a deterministic manner.</ span >
68966929 < span class ="c1 "> # Default: None.</ span >
68976930 < span class ="n "> config_id</ span > < span class ="o "> =</ span > < span class ="n "> generate_config_id_with_timestamp</ span > < span class ="p "> (</ span > < span class ="n "> config</ span > < span class ="o "> =</ span > < span class ="n "> config</ span > < span class ="p "> ,</ span > < span class ="n "> timestamp</ span > < span class ="o "> =</ span > < span class ="n "> timestamp</ span > < span class ="p "> )</ span >
6898- < span class ="n "> callbacks</ span > < span class ="o "> =</ span > < span class ="p "> [</ span > < span class ="n "> EarlyStopping</ span > < span class ="p "> (</ span > < span class ="n "> monitor</ span > < span class ="o "> =</ span > < span class ="s2 "> "val_loss"</ span > < span class ="p "> ,</ span > < span class ="n "> patience</ span > < span class ="o "> =</ span > < span class ="n "> config</ span > < span class ="p "> [</ span > < span class ="s2 "> "patience"</ span > < span class ="p "> ],</ span > < span class ="n "> mode</ span > < span class ="o "> =</ span > < span class ="s2 "> "min"</ span > < span class ="p "> ,</ span > < span class ="n "> strict</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span > < span class ="n "> verbose</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> )]</ span >
6931+ < span class ="n "> callbacks</ span > < span class ="o "> =</ span > < span class ="p "> [</ span >
6932+ < span class ="n "> EarlyStopping</ span > < span class ="p "> (</ span >
6933+ < span class ="n "> monitor</ span > < span class ="o "> =</ span > < span class ="s2 "> "val_loss"</ span > < span class ="p "> ,</ span >
6934+ < span class ="n "> patience</ span > < span class ="o "> =</ span > < span class ="n "> config</ span > < span class ="p "> [</ span > < span class ="s2 "> "patience"</ span > < span class ="p "> ],</ span >
6935+ < span class ="n "> divergence_threshold</ span > < span class ="o "> =</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "divergence_threshold"</ span > < span class ="p "> ],</ span >
6936+ < span class ="n "> check_finite</ span > < span class ="o "> =</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "check_finite"</ span > < span class ="p "> ],</ span >
6937+ < span class ="n "> stopping_threshold</ span > < span class ="o "> =</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "stopping_threshold"</ span > < span class ="p "> ],</ span >
6938+ < span class ="n "> mode</ span > < span class ="o "> =</ span > < span class ="s2 "> "min"</ span > < span class ="p "> ,</ span >
6939+ < span class ="n "> strict</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span >
6940+ < span class ="n "> verbose</ span > < span class ="o "> =</ span > < span class ="kc "> False</ span > < span class ="p "> ,</ span >
6941+ < span class ="p "> )</ span >
6942+ < span class ="p "> ]</ span >
68996943 < span class ="k "> if</ span > < span class ="ow "> not</ span > < span class ="n "> timestamp</ span > < span class ="p "> :</ span >
69006944 < span class ="c1 "> # add ModelCheckpoint only if timestamp is False</ span >
69016945 < span class ="n "> dirpath</ span > < span class ="o "> =</ span > < span class ="n "> os</ span > < span class ="o "> .</ span > < span class ="n "> path</ span > < span class ="o "> .</ span > < span class ="n "> join</ span > < span class ="p "> (</ span > < span class ="n "> fun_control</ span > < span class ="p "> [</ span > < span class ="s2 "> "CHECKPOINT_PATH"</ span > < span class ="p "> ],</ span > < span class ="n "> config_id</ span > < span class ="p "> )</ span >
0 commit comments