|
7192 | 7192 | }, |
7193 | 7193 | { |
7194 | 7194 | "cell_type": "code", |
7195 | | - "execution_count": null, |
| 7195 | + "execution_count": 1, |
7196 | 7196 | "metadata": {}, |
7197 | | - "outputs": [], |
| 7197 | + "outputs": [ |
| 7198 | + { |
| 7199 | + "name": "stderr", |
| 7200 | + "output_type": "stream", |
| 7201 | + "text": [ |
| 7202 | + "Seed set to 123\n", |
| 7203 | + "Seed set to 123\n" |
| 7204 | + ] |
| 7205 | + }, |
| 7206 | + { |
| 7207 | + "name": "stdout", |
| 7208 | + "output_type": "stream", |
| 7209 | + "text": [ |
| 7210 | + "S.X: [[ 0. 0. ]\n", |
| 7211 | + " [ 0. 1. ]\n", |
| 7212 | + " [ 1. 0. ]\n", |
| 7213 | + " [ 1. 1. ]\n", |
| 7214 | + " [-0.90924339 -0.15823458]\n", |
| 7215 | + " [-0.20581711 -0.48124909]\n", |
| 7216 | + " [ 0.94974117 -0.94631272]\n", |
| 7217 | + " [-0.12095571 0.06383589]\n", |
| 7218 | + " [-0.66278702 0.17431637]\n", |
| 7219 | + " [ 0.28200844 0.93001011]\n", |
| 7220 | + " [ 0.47878812 0.65321058]]\n", |
| 7221 | + "S.y: [0. 1. 1. 2. 0.85176172 0.27396137\n", |
| 7222 | + " 1.79751605 0.01870531 0.46967283 0.94444757 0.65592212]\n" |
| 7223 | + ] |
| 7224 | + } |
| 7225 | + ], |
7198 | 7226 | "source": [ |
7199 | 7227 | "import numpy as np\n", |
7200 | 7228 | "from spotpython.fun.objectivefunctions import Analytical\n", |
|
7228 | 7256 | }, |
7229 | 7257 | { |
7230 | 7258 | "cell_type": "code", |
7231 | | - "execution_count": null, |
| 7259 | + "execution_count": 2, |
7232 | 7260 | "metadata": {}, |
7233 | | - "outputs": [], |
| 7261 | + "outputs": [ |
| 7262 | + { |
| 7263 | + "name": "stderr", |
| 7264 | + "output_type": "stream", |
| 7265 | + "text": [ |
| 7266 | + "Seed set to 123\n" |
| 7267 | + ] |
| 7268 | + }, |
| 7269 | + { |
| 7270 | + "name": "stdout", |
| 7271 | + "output_type": "stream", |
| 7272 | + "text": [ |
| 7273 | + "Moving TENSORBOARD_PATH: runs/ to TENSORBOARD_PATH_OLD: runs_OLD/runs_2025_01_12_10_59_57\n", |
| 7274 | + "Created spot_tensorboard_path: runs/spot_logs/00_p040025_2025-01-12_10-59-57 for SummaryWriter()\n" |
| 7275 | + ] |
| 7276 | + } |
| 7277 | + ], |
7234 | 7278 | "source": [ |
7235 | 7279 | "import numpy as np\n", |
7236 | 7280 | "from spotpython.fun import Analytical\n", |
|
7260 | 7304 | }, |
7261 | 7305 | { |
7262 | 7306 | "cell_type": "code", |
7263 | | - "execution_count": 2, |
| 7307 | + "execution_count": 3, |
7264 | 7308 | "metadata": {}, |
7265 | 7309 | "outputs": [ |
7266 | 7310 | { |
|
7317 | 7361 | }, |
7318 | 7362 | { |
7319 | 7363 | "cell_type": "code", |
7320 | | - "execution_count": 1, |
| 7364 | + "execution_count": 4, |
7321 | 7365 | "metadata": {}, |
7322 | 7366 | "outputs": [ |
7323 | 7367 | { |
7324 | 7368 | "name": "stderr", |
7325 | 7369 | "output_type": "stream", |
7326 | 7370 | "text": [ |
7327 | | - "Seed set to 123\n", |
7328 | 7371 | "Seed set to 123\n" |
7329 | 7372 | ] |
7330 | 7373 | }, |
|
7369 | 7412 | "print(f\"S.X: {S.X}\")\n", |
7370 | 7413 | "print(f\"S.y: {S.y}\")" |
7371 | 7414 | ] |
| 7415 | + }, |
| 7416 | + { |
| 7417 | + "cell_type": "markdown", |
| 7418 | + "metadata": {}, |
| 7419 | + "source": [ |
| 7420 | + "## save experiment" |
| 7421 | + ] |
| 7422 | + }, |
| 7423 | + { |
| 7424 | + "cell_type": "code", |
| 7425 | + "execution_count": null, |
| 7426 | + "metadata": {}, |
| 7427 | + "outputs": [], |
| 7428 | + "source": [ |
| 7429 | + "import copy\n", |
| 7430 | + "import os\n", |
| 7431 | + "import pickle\n", |
| 7432 | + "import logging\n", |
| 7433 | + "from torch.utils.tensorboard import SummaryWriter\n", |
| 7434 | + "\n", |
| 7435 | + "class Experiment:\n", |
| 7436 | + " def save_experiment(self, filename=None, path=None, overwrite=True) -> None:\n", |
| 7437 | + " \"\"\"\n", |
| 7438 | + " Save the experiment to a file.\n", |
| 7439 | + "\n", |
| 7440 | + " Args:\n", |
| 7441 | + " filename (str): The filename of the experiment file.\n", |
| 7442 | + " path (str): The path to the experiment file.\n", |
| 7443 | + " overwrite (bool): If `True`, the file will be overwritten if it already exists. Default is `True`.\n", |
| 7444 | + "\n", |
| 7445 | + " Returns:\n", |
| 7446 | + " None\n", |
| 7447 | + " \"\"\"\n", |
| 7448 | + " # Ensure we don't accidentally try to pickle unpicklable components\n", |
| 7449 | + " self.close_and_del_spot_writer()\n", |
| 7450 | + " self.remove_logger_handlers()\n", |
| 7451 | + "\n", |
| 7452 | + " # Create deep copies of control dictionaries\n", |
| 7453 | + " fun_control = copy.deepcopy(self.fun_control)\n", |
| 7454 | + " optimizer_control = copy.deepcopy(self.optimizer_control)\n", |
| 7455 | + " surrogate_control = copy.deepcopy(self.surrogate_control)\n", |
| 7456 | + " design_control = copy.deepcopy(self.design_control)\n", |
| 7457 | + "\n", |
| 7458 | + " # Prepare an experiment dictionary excluding any explicitly unpickable components\n", |
| 7459 | + " experiment = {\n", |
| 7460 | + " \"design_control\": design_control,\n", |
| 7461 | + " \"fun_control\": fun_control,\n", |
| 7462 | + " \"optimizer_control\": optimizer_control,\n", |
| 7463 | + " \"spot_tuner\": self._get_pickle_safe_spot_tuner(),\n", |
| 7464 | + " \"surrogate_control\": surrogate_control,\n", |
| 7465 | + " }\n", |
| 7466 | + "\n", |
| 7467 | + " # Determine the filename based on PREFIX if not provided\n", |
| 7468 | + " PREFIX = fun_control.get(\"PREFIX\", \"experiment\")\n", |
| 7469 | + " if filename is None:\n", |
| 7470 | + " filename = self.get_experiment_filename(PREFIX)\n", |
| 7471 | + "\n", |
| 7472 | + " if path is not None:\n", |
| 7473 | + " filename = os.path.join(path, filename)\n", |
| 7474 | + " if not os.path.exists(path):\n", |
| 7475 | + " os.makedirs(path)\n", |
| 7476 | + "\n", |
| 7477 | + " # Check if the file already exists\n", |
| 7478 | + " if filename is not None and os.path.exists(filename) and not overwrite:\n", |
| 7479 | + " print(f\"Error: File {filename} already exists. Use overwrite=True to overwrite the file.\")\n", |
| 7480 | + " return\n", |
| 7481 | + "\n", |
| 7482 | + " # Serialize the experiment dictionary to the pickle file\n", |
| 7483 | + " if filename is not None:\n", |
| 7484 | + " with open(filename, \"wb\") as handle:\n", |
| 7485 | + " try:\n", |
| 7486 | + " pickle.dump(experiment, handle, protocol=pickle.HIGHEST_PROTOCOL)\n", |
| 7487 | + " except Exception as e:\n", |
| 7488 | + " print(f\"Error during pickling: {e}\")\n", |
| 7489 | + " raise e\n", |
| 7490 | + " print(f\"Experiment saved to {filename}\")\n", |
| 7491 | + "\n", |
| 7492 | + " def remove_logger_handlers(self) -> None:\n", |
| 7493 | + " \"\"\"\n", |
| 7494 | + " Remove handlers from the logger to avoid pickling issues.\n", |
| 7495 | + " \"\"\"\n", |
| 7496 | + " logger = logging.getLogger(__name__)\n", |
| 7497 | + " for handler in list(logger.handlers): # Copy the list to avoid modification during iteration\n", |
| 7498 | + " logger.removeHandler(handler)\n", |
| 7499 | + "\n", |
| 7500 | + " def close_and_del_spot_writer(self) -> None:\n", |
| 7501 | + " \"\"\"\n", |
| 7502 | + " Delete the spot_writer attribute from the object\n", |
| 7503 | + " if it exists and close the writer.\n", |
| 7504 | + " \"\"\"\n", |
| 7505 | + " if hasattr(self, \"spot_writer\") and self.spot_writer is not None:\n", |
| 7506 | + " self.spot_writer.flush()\n", |
| 7507 | + " self.spot_writer.close()\n", |
| 7508 | + " del self.spot_writer\n", |
| 7509 | + "\n", |
| 7510 | + " def _get_pickle_safe_spot_tuner(self):\n", |
| 7511 | + " \"\"\"\n", |
| 7512 | + " Create a copy of self excluding unpickleable components for safe pickling.\n", |
| 7513 | + " This ensures no unpicklable components are passed to pickle.dump().\n", |
| 7514 | + " \"\"\"\n", |
| 7515 | + " # Make a deepcopy and manually remove unpickleable components\n", |
| 7516 | + " spot_tuner = copy.deepcopy(self)\n", |
| 7517 | + " for attr in ['spot_writer']:\n", |
| 7518 | + " if hasattr(spot_tuner, attr):\n", |
| 7519 | + " delattr(spot_tuner, attr)\n", |
| 7520 | + " return spot_tuner\n", |
| 7521 | + "\n", |
| 7522 | + " def get_experiment_filename(self, prefix):\n", |
| 7523 | + " \"\"\"\n", |
| 7524 | + " Generate a filename based on a given prefix with additional unique identifiers or timestamps.\n", |
| 7525 | + " \"\"\"\n", |
| 7526 | + " # Implement the logic to generate a filename\n", |
| 7527 | + " return f\"{prefix}_experiment.pkl\"" |
| 7528 | + ] |
7372 | 7529 | } |
7373 | 7530 | ], |
7374 | 7531 | "metadata": { |
|
0 commit comments