|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
| 3 | +import pprint |
| 4 | +import os |
| 5 | +import copy |
| 6 | +import json |
3 | 7 | from numpy.random import default_rng |
4 | 8 | from spotPython.design.spacefilling import spacefilling |
5 | 9 | from spotPython.build.kriging import Kriging |
|
35 | 39 | ) |
36 | 40 | import plotly.graph_objects as go |
37 | 41 | from typing import Callable |
| 42 | +from spotPython.utils.numpy2json import NumpyEncoder |
38 | 43 |
|
39 | 44 |
|
40 | 45 | logger = logging.getLogger(__name__) |
@@ -612,6 +617,88 @@ def get_new_X0(self) -> np.array: |
612 | 617 | logger.warning("No new XO found on surrogate. Generate new solution %s", X0) |
613 | 618 | return X0 |
614 | 619 |
|
| 620 | + def write_db_dict(self) -> None: |
| 621 | + """Writes a dictionary with the experiment parameters to the json file spotPython_db.json. |
| 622 | +
|
| 623 | + Args: |
| 624 | + self (object): Spot object |
| 625 | +
|
| 626 | + Returns: |
| 627 | + (NoneType): None |
| 628 | +
|
| 629 | + """ |
| 630 | + # get the time in seconds from 1.1.1970 and convert the time to a string |
| 631 | + t_str = str(time.time()) |
| 632 | + ident = str(self.fun_control["PREFIX"]) + "_" + t_str |
| 633 | + |
| 634 | + spot_tuner = copy.deepcopy(self) |
| 635 | + spot_tuner_control = vars(spot_tuner) |
| 636 | + |
| 637 | + fun_control = copy.deepcopy(spot_tuner_control["fun_control"]) |
| 638 | + design_control = copy.deepcopy(spot_tuner_control["design_control"]) |
| 639 | + optimizer_control = copy.deepcopy(spot_tuner_control["optimizer_control"]) |
| 640 | + surrogate_control = copy.deepcopy(spot_tuner_control["surrogate_control"]) |
| 641 | + |
| 642 | + # remove keys from the dictionaries: |
| 643 | + spot_tuner_control.pop("fun_control", None) |
| 644 | + spot_tuner_control.pop("design_control", None) |
| 645 | + spot_tuner_control.pop("optimizer_control", None) |
| 646 | + spot_tuner_control.pop("surrogate_control", None) |
| 647 | + spot_tuner_control.pop("spot_writer", None) |
| 648 | + spot_tuner_control.pop("design", None) |
| 649 | + spot_tuner_control.pop("fun", None) |
| 650 | + spot_tuner_control.pop("optimizer", None) |
| 651 | + spot_tuner_control.pop("rng", None) |
| 652 | + spot_tuner_control.pop("surrogate", None) |
| 653 | + |
| 654 | + fun_control.pop("core_model", None) |
| 655 | + fun_control.pop("metric_river", None) |
| 656 | + fun_control.pop("metric_sklearn", None) |
| 657 | + fun_control.pop("metric_torch", None) |
| 658 | + fun_control.pop("prep_model", None) |
| 659 | + fun_control.pop("spot_writer", None) |
| 660 | + fun_control.pop("test", None) |
| 661 | + fun_control.pop("train", None) |
| 662 | + |
| 663 | + surrogate_control.pop("model_optimizer", None) |
| 664 | + surrogate_control.pop("surrogate", None) |
| 665 | + |
| 666 | + print("\n**********************") |
| 667 | + print("The following dictionaries are written to the json file spotPython_db.json:") |
| 668 | + print("fun_control:") |
| 669 | + pprint.pprint(fun_control) |
| 670 | + print("design_control:") |
| 671 | + pprint.pprint(design_control) |
| 672 | + print("optimizer_control:") |
| 673 | + pprint.pprint(optimizer_control) |
| 674 | + print("surrogate_control:") |
| 675 | + pprint.pprint(surrogate_control) |
| 676 | + print("spot_tuner_control:") |
| 677 | + pprint.pprint(spot_tuner_control) |
| 678 | + db_dict = { |
| 679 | + str(ident): { |
| 680 | + "fun_control": fun_control, |
| 681 | + "design_control": design_control, |
| 682 | + "surrogate_control": surrogate_control, |
| 683 | + "optimizer_control": optimizer_control, |
| 684 | + "spot_tuner_control": spot_tuner_control, |
| 685 | + } |
| 686 | + } |
| 687 | + |
| 688 | + # check if the directory "db_dicts" exists. |
| 689 | + if not os.path.exists("db_dicts"): |
| 690 | + try: |
| 691 | + os.makedirs("db_dicts") |
| 692 | + except OSError as e: |
| 693 | + raise Exception(f"Error creating directory: {e}") |
| 694 | + if os.path.exists("db_dicts"): |
| 695 | + try: |
| 696 | + with open("db_dicts/" + self.fun_control["db_dict_name"], "a") as f: |
| 697 | + json.dump(db_dict, f, indent=4, cls=NumpyEncoder) |
| 698 | + f.close() |
| 699 | + except OSError as e: |
| 700 | + raise Exception(f"Error writing to file: {e}") |
| 701 | + |
615 | 702 | def run(self, X_start=None) -> Spot: |
616 | 703 | self.initialize_design(X_start) |
617 | 704 | # New: self.update_stats() moved here: |
@@ -640,6 +727,9 @@ def run(self, X_start=None) -> Spot: |
640 | 727 | if self.spot_writer is not None: |
641 | 728 | writer = self.spot_writer |
642 | 729 | writer.close() |
| 730 | + pprint.pprint(self.fun_control) |
| 731 | + if self.fun_control["db_dict_name"] is not None: |
| 732 | + self.write_db_dict() |
643 | 733 | return self |
644 | 734 |
|
645 | 735 | def initialize_design(self, X_start=None) -> None: |
|
0 commit comments