1616from math import isfinite
1717import matplotlib .pyplot as plt
1818from numpy import argmin
19+ from torch .utils .tensorboard import SummaryWriter
1920
2021from numpy import repeat
2122from numpy import sqrt
@@ -257,9 +258,8 @@ def __init__(
257258 self .max_surrogate_points = self .fun_control ["max_surrogate_points" ]
258259 self .progress_file = self .fun_control ["progress_file" ]
259260
260- # if the key "spot_writer" is not in the dictionary fun_control,
261- # set self.spot_writer to None else to the value of the key "spot_writer"
262- self .spot_writer = self .fun_control .get ("spot_writer" , None )
261+ # Tensorboard:
262+ self .init_spot_writer ()
263263
264264 # Bounds are internal, because they are functions of self.lower and self.upper
265265 # and used by the optimizer:
@@ -846,8 +846,8 @@ def run(self, X_start=None) -> Spot:
846846 # progress bar:
847847 self .show_progress_if_needed (timeout_start )
848848 if self .spot_writer is not None :
849- writer = self .spot_writer
850- writer .close ()
849+ self .spot_writer . flush ()
850+ self . spot_writer .close ()
851851 if self .fun_control ["db_dict_name" ] is not None :
852852 self .write_db_dict ()
853853 self .save_experiment ()
@@ -929,16 +929,18 @@ def initialize_design(self, X_start=None) -> None:
929929 #
930930 self .counter = self .y .size
931931 if self .spot_writer is not None :
932- writer = self .spot_writer
932+ # writer = self.spot_writer
933933 # range goes to init_size -1 because the last value is added by update_stats(),
934934 # which always adds the last value.
935935 # Changed in 0.5.9:
936936 for j in range (len (self .y )):
937937 X_j = self .X [j ].copy ()
938938 y_j = self .y [j ].copy ()
939939 config = {self .var_name [i ]: X_j [i ] for i in range (self .k )}
940- writer .add_hparams (config , {"spot_y" : y_j })
941- writer .flush ()
940+ # see: https://github.com/pytorch/pytorch/issues/32651
941+ # self.spot_writer.add_hparams(config, {"spot_y": y_j}, run_name=self.spot_tensorboard_path)
942+ self .spot_writer .add_hparams (config , {"spot_y" : y_j })
943+ self .spot_writer .flush ()
942944 #
943945 self .X , self .y = remove_nan (self .X , self .y , stop_on_zero_return = True )
944946 logger .debug ("In Spot() initialize_design(), final X val, after remove nan: self.X: %s" , self .X )
@@ -1257,17 +1259,17 @@ def update_stats(self) -> None:
12571259
12581260 def update_writer (self ) -> None :
12591261 if self .spot_writer is not None :
1260- writer = self .spot_writer
1262+ # writer = self.spot_writer
12611263 # get the last y value:
12621264 y_last = self .y [- 1 ].copy ()
12631265 if self .noise is False :
12641266 y_min = self .min_y .copy ()
12651267 X_min = self .min_X .copy ()
12661268 # y_min: best y value so far
12671269 # y_last: last y value, can be worse than y_min
1268- writer .add_scalars ("spot_y" , {"min" : y_min , "last" : y_last }, self .counter )
1270+ self . spot_writer .add_scalars ("spot_y" , {"min" : y_min , "last" : y_last }, self .counter )
12691271 # X_min: X value of the best y value so far
1270- writer .add_scalars ("spot_X" , {f"X_{ i } " : X_min [i ] for i in range (self .k )}, self .counter )
1272+ self . spot_writer .add_scalars ("spot_X" , {f"X_{ i } " : X_min [i ] for i in range (self .k )}, self .counter )
12711273 else :
12721274 # get the last n y values:
12731275 y_last_n = self .y [- self .fun_repeats :].copy ()
@@ -1277,23 +1279,25 @@ def update_writer(self) -> None:
12771279 X_min_mean = self .min_mean_X .copy ()
12781280 # y_min_var: variance of the min y value so far
12791281 y_min_var = self .min_var_y .copy ()
1280- writer .add_scalar ("spot_y_min_var" , y_min_var , self .counter )
1282+ self . spot_writer .add_scalar ("spot_y_min_var" , y_min_var , self .counter )
12811283 # y_min_mean: best mean y value so far (see above)
1282- writer .add_scalar ("spot_y" , y_min_mean , self .counter )
1284+ self . spot_writer .add_scalar ("spot_y" , y_min_mean , self .counter )
12831285 # last n y values (noisy):
1284- writer .add_scalars (
1286+ self . spot_writer .add_scalars (
12851287 "spot_y" , {f"y_last_n{ i } " : y_last_n [i ] for i in range (self .fun_repeats )}, self .counter
12861288 )
12871289 # X_min_mean: X value of the best mean y value so far (see above)
1288- writer .add_scalars (
1290+ self . spot_writer .add_scalars (
12891291 "spot_X_noise" , {f"X_min_mean{ i } " : X_min_mean [i ] for i in range (self .k )}, self .counter
12901292 )
12911293 # get last value of self.X and convert to dict. take the values from self.var_name as keys:
12921294 X_last = self .X [- 1 ].copy ()
12931295 config = {self .var_name [i ]: X_last [i ] for i in range (self .k )}
12941296 # hyperparameters X and value y of the last configuration:
1295- writer .add_hparams (config , {"spot_y" : y_last })
1296- writer .flush ()
1297+ # see: https://github.com/pytorch/pytorch/issues/32651
1298+ # self.spot_writer.add_hparams(config, {"spot_y": y_last}, run_name=self.spot_tensorboard_path)
1299+ self .spot_writer .add_hparams (config , {"spot_y" : y_last })
1300+ self .spot_writer .flush ()
12971301
12981302 def suggest_new_X (self ) -> np .array :
12991303 """
@@ -2185,3 +2189,13 @@ def save_experiment(self, filename=None) -> None:
21852189 pprint .pprint (spot_tuner )
21862190
21872191 raise e
2192+
2193+ def init_spot_writer (self ) -> None :
2194+ """
2195+ Initialize the spot_writer for the current experiment.
2196+ """
2197+ self .spot_tensorboard_path = self .fun_control ["spot_tensorboard_path" ]
2198+ if self .spot_tensorboard_path is not None :
2199+ self .spot_writer = SummaryWriter (log_dir = self .spot_tensorboard_path )
2200+ else :
2201+ self .spot_writer = None
0 commit comments