Skip to content

Commit 1552b02

Browse files
hyperdict path
1 parent 329491d commit 1552b02

3 files changed

Lines changed: 33 additions & 18 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.7.0"
10+
version = "0.7.1"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/hyperdict/sklearn_hyper_dict.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from spotPython.data import base
3+
import pathlib
34

45

56
class SklearnHyperDict(base.FileConfig):
@@ -11,15 +12,21 @@ class SklearnHyperDict(base.FileConfig):
1112
filename (str): The name of the file where the hyperparameters are stored.
1213
"""
1314

14-
def __init__(self):
15-
"""Initialize the SklearnHyperDict object.
16-
17-
Examples:
18-
>>> shd = SklearnHyperDict()
19-
"""
20-
super().__init__(
21-
filename="sklearn_hyper_dict.json",
22-
)
15+
def __init__(
16+
self,
17+
filename: str = "sklearn_hyper_dict.json",
18+
directory: None = None,
19+
) -> None:
20+
super().__init__(filename=filename, directory=directory)
21+
self.filename = filename
22+
self.directory = directory
23+
self.hyper_dict = self.load()
24+
25+
@property
26+
def path(self):
27+
if self.directory:
28+
return pathlib.Path(self.directory).joinpath(self.filename)
29+
return pathlib.Path(__file__).parent.joinpath(self.filename)
2330

2431
def load(self) -> dict:
2532
"""Load the hyperparameters from the file.

src/spotPython/hyperdict/torch_hyper_dict.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from spotPython.data import base
3+
import pathlib
34

45

56
class TorchHyperDict(base.FileConfig):
@@ -11,14 +12,21 @@ class TorchHyperDict(base.FileConfig):
1112
filename (str): The name of the file where the hyperparameters are stored.
1213
"""
1314

14-
def __init__(self):
15-
"""Initialize the TorchHyperDict object.
16-
Examples:
17-
>>> thd = TorchHyperDict()
18-
"""
19-
super().__init__(
20-
filename="torch_hyper_dict.json",
21-
)
15+
def __init__(
16+
self,
17+
filename: str = "torch_hyper_dict.json",
18+
directory: None = None,
19+
) -> None:
20+
super().__init__(filename=filename, directory=directory)
21+
self.filename = filename
22+
self.directory = directory
23+
self.hyper_dict = self.load()
24+
25+
@property
26+
def path(self):
27+
if self.directory:
28+
return pathlib.Path(self.directory).joinpath(self.filename)
29+
return pathlib.Path(__file__).parent.joinpath(self.filename)
2230

2331
def load(self) -> dict:
2432
"""Load the hyperparameters from the file.

0 commit comments

Comments
 (0)