@@ -316,32 +316,6 @@ def __init__(
316316 if self .surrogate_control ["n_theta" ] > 1 :
317317 surrogate_control .update ({"n_theta" : self .k })
318318
319- # # If no surrogate model is specified, use the internal
320- # # spotpython kriging surrogate:
321- # if self.surrogate is None:
322- # # Call kriging with surrogate_control parameters:
323- # self.surrogate = Kriging(
324- # name="kriging",
325- # noise=self.surrogate_control["noise"],
326- # model_optimizer=self.surrogate_control["model_optimizer"],
327- # model_fun_evals=self.surrogate_control["model_fun_evals"],
328- # seed=self.surrogate_control["seed"],
329- # log_level=self.log_level,
330- # min_theta=self.surrogate_control["min_theta"],
331- # max_theta=self.surrogate_control["max_theta"],
332- # metric_factorial=self.surrogate_control["metric_factorial"],
333- # n_theta=self.surrogate_control["n_theta"],
334- # theta_init_zero=self.surrogate_control["theta_init_zero"],
335- # p_val=self.surrogate_control["p_val"],
336- # n_p=self.surrogate_control["n_p"],
337- # optim_p=self.surrogate_control["optim_p"],
338- # min_Lambda=self.surrogate_control["min_Lambda"],
339- # max_Lambda=self.surrogate_control["max_Lambda"],
340- # var_type=self.surrogate_control["var_type"],
341- # spot_writer=self.spot_writer,
342- # counter=self.design_control["init_size"] * self.design_control["repeats"] - 1,
343- # )
344-
345319 # Internal attributes:
346320 self .X = None
347321 self .y = None
@@ -355,6 +329,40 @@ def __init__(
355329 self .mean_y = None
356330 self .var_y = None
357331
332+ # save experiment must be called before the spot_writer is initialized
333+ if self .fun_control .get ("save_experiment" ):
334+ filename = self .fun_control .get ("PREFIX" ) + "_exp.pkl"
335+ self .save_experiment (filename = filename , verbosity = self .verbosity )
336+
337+ # Tensorboard must be initialized before the surrogate model:
338+ self .init_spot_writer ()
339+
340+ # If no surrogate model is specified, use the internal
341+ # spotpython kriging surrogate:
342+ if self .surrogate is None :
343+ # Call kriging with surrogate_control parameters:
344+ self .surrogate = Kriging (
345+ name = "kriging" ,
346+ noise = self .surrogate_control ["noise" ],
347+ model_optimizer = self .surrogate_control ["model_optimizer" ],
348+ model_fun_evals = self .surrogate_control ["model_fun_evals" ],
349+ seed = self .surrogate_control ["seed" ],
350+ log_level = self .log_level ,
351+ min_theta = self .surrogate_control ["min_theta" ],
352+ max_theta = self .surrogate_control ["max_theta" ],
353+ metric_factorial = self .surrogate_control ["metric_factorial" ],
354+ n_theta = self .surrogate_control ["n_theta" ],
355+ theta_init_zero = self .surrogate_control ["theta_init_zero" ],
356+ p_val = self .surrogate_control ["p_val" ],
357+ n_p = self .surrogate_control ["n_p" ],
358+ optim_p = self .surrogate_control ["optim_p" ],
359+ min_Lambda = self .surrogate_control ["min_Lambda" ],
360+ max_Lambda = self .surrogate_control ["max_Lambda" ],
361+ var_type = self .surrogate_control ["var_type" ],
362+ spot_writer = self .spot_writer ,
363+ counter = self .design_control ["init_size" ] * self .design_control ["repeats" ] - 1 ,
364+ )
365+
358366 logger .setLevel (self .log_level )
359367 logger .info (f"Starting the logger at level { self .log_level } for module { __name__ } :" )
360368 logger .debug ("In Spot() init(): fun_control: %s" , self .fun_control )
@@ -791,35 +799,6 @@ def run(self, X_start: np.ndarray = None) -> Spot:
791799 3.7179535332164810e-04])
792800
793801 """
794- # Tensorboard:
795- self .init_spot_writer ()
796-
797- # If no surrogate model is specified, use the internal
798- # spotpython kriging surrogate:
799- if self .surrogate is None :
800- # Call kriging with surrogate_control parameters:
801- self .surrogate = Kriging (
802- name = "kriging" ,
803- noise = self .surrogate_control ["noise" ],
804- model_optimizer = self .surrogate_control ["model_optimizer" ],
805- model_fun_evals = self .surrogate_control ["model_fun_evals" ],
806- seed = self .surrogate_control ["seed" ],
807- log_level = self .log_level ,
808- min_theta = self .surrogate_control ["min_theta" ],
809- max_theta = self .surrogate_control ["max_theta" ],
810- metric_factorial = self .surrogate_control ["metric_factorial" ],
811- n_theta = self .surrogate_control ["n_theta" ],
812- theta_init_zero = self .surrogate_control ["theta_init_zero" ],
813- p_val = self .surrogate_control ["p_val" ],
814- n_p = self .surrogate_control ["n_p" ],
815- optim_p = self .surrogate_control ["optim_p" ],
816- min_Lambda = self .surrogate_control ["min_Lambda" ],
817- max_Lambda = self .surrogate_control ["max_Lambda" ],
818- var_type = self .surrogate_control ["var_type" ],
819- spot_writer = self .spot_writer ,
820- counter = self .design_control ["init_size" ] * self .design_control ["repeats" ] - 1 ,
821- )
822-
823802 self .initialize_design (X_start )
824803 self .update_stats ()
825804 self .fit_surrogate ()
@@ -837,8 +816,9 @@ def run(self, X_start: np.ndarray = None) -> Spot:
837816 if self .fun_control .get ("db_dict_name" ) is not None :
838817 self ._write_db_dict ()
839818
840- if self .fun_control .get ("save_experiment" ):
841- self .save_experiment ()
819+ if self .fun_control .get ("save_result" ):
820+ filename = self .fun_control .get ("PREFIX" ) + "_res.pkl"
821+ self .save_experiment (filename = filename , verbosity = self .verbosity )
842822 return self
843823
844824 def initialize_design (self , X_start = None ) -> None :
0 commit comments