Skip to content

Commit cb31fec

Browse files
0.21.4
modified save experiment
1 parent 469126a commit cb31fec

5 files changed

Lines changed: 485 additions & 188 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 164 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7192,9 +7192,37 @@
71927192
},
71937193
{
71947194
"cell_type": "code",
7195-
"execution_count": null,
7195+
"execution_count": 1,
71967196
"metadata": {},
7197-
"outputs": [],
7197+
"outputs": [
7198+
{
7199+
"name": "stderr",
7200+
"output_type": "stream",
7201+
"text": [
7202+
"Seed set to 123\n",
7203+
"Seed set to 123\n"
7204+
]
7205+
},
7206+
{
7207+
"name": "stdout",
7208+
"output_type": "stream",
7209+
"text": [
7210+
"S.X: [[ 0. 0. ]\n",
7211+
" [ 0. 1. ]\n",
7212+
" [ 1. 0. ]\n",
7213+
" [ 1. 1. ]\n",
7214+
" [-0.90924339 -0.15823458]\n",
7215+
" [-0.20581711 -0.48124909]\n",
7216+
" [ 0.94974117 -0.94631272]\n",
7217+
" [-0.12095571 0.06383589]\n",
7218+
" [-0.66278702 0.17431637]\n",
7219+
" [ 0.28200844 0.93001011]\n",
7220+
" [ 0.47878812 0.65321058]]\n",
7221+
"S.y: [0. 1. 1. 2. 0.85176172 0.27396137\n",
7222+
" 1.79751605 0.01870531 0.46967283 0.94444757 0.65592212]\n"
7223+
]
7224+
}
7225+
],
71987226
"source": [
71997227
"import numpy as np\n",
72007228
"from spotpython.fun.objectivefunctions import Analytical\n",
@@ -7228,9 +7256,25 @@
72287256
},
72297257
{
72307258
"cell_type": "code",
7231-
"execution_count": null,
7259+
"execution_count": 2,
72327260
"metadata": {},
7233-
"outputs": [],
7261+
"outputs": [
7262+
{
7263+
"name": "stderr",
7264+
"output_type": "stream",
7265+
"text": [
7266+
"Seed set to 123\n"
7267+
]
7268+
},
7269+
{
7270+
"name": "stdout",
7271+
"output_type": "stream",
7272+
"text": [
7273+
"Moving TENSORBOARD_PATH: runs/ to TENSORBOARD_PATH_OLD: runs_OLD/runs_2025_01_12_10_59_57\n",
7274+
"Created spot_tensorboard_path: runs/spot_logs/00_p040025_2025-01-12_10-59-57 for SummaryWriter()\n"
7275+
]
7276+
}
7277+
],
72347278
"source": [
72357279
"import numpy as np\n",
72367280
"from spotpython.fun import Analytical\n",
@@ -7260,7 +7304,7 @@
72607304
},
72617305
{
72627306
"cell_type": "code",
7263-
"execution_count": 2,
7307+
"execution_count": 3,
72647308
"metadata": {},
72657309
"outputs": [
72667310
{
@@ -7317,14 +7361,13 @@
73177361
},
73187362
{
73197363
"cell_type": "code",
7320-
"execution_count": 1,
7364+
"execution_count": 4,
73217365
"metadata": {},
73227366
"outputs": [
73237367
{
73247368
"name": "stderr",
73257369
"output_type": "stream",
73267370
"text": [
7327-
"Seed set to 123\n",
73287371
"Seed set to 123\n"
73297372
]
73307373
},
@@ -7369,6 +7412,120 @@
73697412
"print(f\"S.X: {S.X}\")\n",
73707413
"print(f\"S.y: {S.y}\")"
73717414
]
7415+
},
7416+
{
7417+
"cell_type": "markdown",
7418+
"metadata": {},
7419+
"source": [
7420+
"## save experiment"
7421+
]
7422+
},
7423+
{
7424+
"cell_type": "code",
7425+
"execution_count": null,
7426+
"metadata": {},
7427+
"outputs": [],
7428+
"source": [
7429+
"import copy\n",
7430+
"import os\n",
7431+
"import pickle\n",
7432+
"import logging\n",
7433+
"from torch.utils.tensorboard import SummaryWriter\n",
7434+
"\n",
7435+
"class Experiment:\n",
7436+
" def save_experiment(self, filename=None, path=None, overwrite=True) -> None:\n",
7437+
" \"\"\"\n",
7438+
" Save the experiment to a file.\n",
7439+
"\n",
7440+
" Args:\n",
7441+
" filename (str): The filename of the experiment file.\n",
7442+
" path (str): The path to the experiment file.\n",
7443+
" overwrite (bool): If `True`, the file will be overwritten if it already exists. Default is `True`.\n",
7444+
"\n",
7445+
" Returns:\n",
7446+
" None\n",
7447+
" \"\"\"\n",
7448+
" # Ensure we don't accidentally try to pickle unpicklable components\n",
7449+
" self.close_and_del_spot_writer()\n",
7450+
" self.remove_logger_handlers()\n",
7451+
"\n",
7452+
" # Create deep copies of control dictionaries\n",
7453+
" fun_control = copy.deepcopy(self.fun_control)\n",
7454+
" optimizer_control = copy.deepcopy(self.optimizer_control)\n",
7455+
" surrogate_control = copy.deepcopy(self.surrogate_control)\n",
7456+
" design_control = copy.deepcopy(self.design_control)\n",
7457+
"\n",
7458+
" # Prepare an experiment dictionary excluding any explicitly unpickable components\n",
7459+
" experiment = {\n",
7460+
" \"design_control\": design_control,\n",
7461+
" \"fun_control\": fun_control,\n",
7462+
" \"optimizer_control\": optimizer_control,\n",
7463+
" \"spot_tuner\": self._get_pickle_safe_spot_tuner(),\n",
7464+
" \"surrogate_control\": surrogate_control,\n",
7465+
" }\n",
7466+
"\n",
7467+
" # Determine the filename based on PREFIX if not provided\n",
7468+
" PREFIX = fun_control.get(\"PREFIX\", \"experiment\")\n",
7469+
" if filename is None:\n",
7470+
" filename = self.get_experiment_filename(PREFIX)\n",
7471+
"\n",
7472+
" if path is not None:\n",
7473+
" filename = os.path.join(path, filename)\n",
7474+
" if not os.path.exists(path):\n",
7475+
" os.makedirs(path)\n",
7476+
"\n",
7477+
" # Check if the file already exists\n",
7478+
" if filename is not None and os.path.exists(filename) and not overwrite:\n",
7479+
" print(f\"Error: File {filename} already exists. Use overwrite=True to overwrite the file.\")\n",
7480+
" return\n",
7481+
"\n",
7482+
" # Serialize the experiment dictionary to the pickle file\n",
7483+
" if filename is not None:\n",
7484+
" with open(filename, \"wb\") as handle:\n",
7485+
" try:\n",
7486+
" pickle.dump(experiment, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
7487+
" except Exception as e:\n",
7488+
" print(f\"Error during pickling: {e}\")\n",
7489+
" raise e\n",
7490+
" print(f\"Experiment saved to {filename}\")\n",
7491+
"\n",
7492+
" def remove_logger_handlers(self) -> None:\n",
7493+
" \"\"\"\n",
7494+
" Remove handlers from the logger to avoid pickling issues.\n",
7495+
" \"\"\"\n",
7496+
" logger = logging.getLogger(__name__)\n",
7497+
" for handler in list(logger.handlers): # Copy the list to avoid modification during iteration\n",
7498+
" logger.removeHandler(handler)\n",
7499+
"\n",
7500+
" def close_and_del_spot_writer(self) -> None:\n",
7501+
" \"\"\"\n",
7502+
" Delete the spot_writer attribute from the object\n",
7503+
" if it exists and close the writer.\n",
7504+
" \"\"\"\n",
7505+
" if hasattr(self, \"spot_writer\") and self.spot_writer is not None:\n",
7506+
" self.spot_writer.flush()\n",
7507+
" self.spot_writer.close()\n",
7508+
" del self.spot_writer\n",
7509+
"\n",
7510+
" def _get_pickle_safe_spot_tuner(self):\n",
7511+
" \"\"\"\n",
7512+
" Create a copy of self excluding unpickleable components for safe pickling.\n",
7513+
" This ensures no unpicklable components are passed to pickle.dump().\n",
7514+
" \"\"\"\n",
7515+
" # Make a deepcopy and manually remove unpickleable components\n",
7516+
" spot_tuner = copy.deepcopy(self)\n",
7517+
" for attr in ['spot_writer']:\n",
7518+
" if hasattr(spot_tuner, attr):\n",
7519+
" delattr(spot_tuner, attr)\n",
7520+
" return spot_tuner\n",
7521+
"\n",
7522+
" def get_experiment_filename(self, prefix):\n",
7523+
" \"\"\"\n",
7524+
" Generate a filename based on a given prefix with additional unique identifiers or timestamps.\n",
7525+
" \"\"\"\n",
7526+
" # Implement the logic to generate a filename\n",
7527+
" return f\"{prefix}_experiment.pkl\""
7528+
]
73727529
}
73737530
],
73747531
"metadata": {

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.21.3"
10+
version = "0.21.4"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

0 commit comments

Comments
 (0)