Skip to content

Commit ca60bea

Browse files
0.14.11
save_experiment
1 parent b8da597 commit ca60bea

3 files changed

Lines changed: 13 additions & 8 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.14.10"
10+
version = "0.14.11"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/utils/file.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# from torch.utils.tensorboard import SummaryWriter
1010

1111

12-
def load_data(data_dir="./data"):
12+
def load_cifar10_data(data_dir="./data"):
1313
"""Loads the CIFAR10 dataset.
1414
1515
Args:
@@ -19,16 +19,13 @@ def load_data(data_dir="./data"):
1919
trainset (torchvision.datasets.CIFAR10): Training dataset.
2020
2121
Examples:
22-
>>> from spotPython.utils.file import load_data
23-
>>> trainset = load_data(data_dir="./data")
22+
>>> from spotPython.utils.file import load_cifar10_data
23+
>>> trainset = load_cifar10_data(data_dir="./data")
2424
2525
"""
2626
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
27-
2827
trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
29-
3028
testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
31-
3229
return trainset, testset
3330

3431

@@ -106,6 +103,9 @@ def load_experiment(PKL_NAME):
106103
surrogate_control (dict): The surrogate control dictionary.
107104
optimizer_control (dict): The optimizer control dictionary.
108105
106+
Notes:
107+
The corresponding save_experiment function is part of the class spot.
108+
109109
"""
110110
with open(PKL_NAME, "rb") as handle:
111111
experiment = pickle.load(handle)

src/spotPython/utils/init.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def fun_control_init(
5151
prep_model=None,
5252
prep_model_name=None,
5353
progress_file=None,
54+
scenario=None,
5455
seed=123,
5556
show_models=False,
5657
show_progress=True,
@@ -175,6 +176,8 @@ def fun_control_init(
175176
The name of the preprocessing model. Default is None.
176177
progress_file (str):
177178
The name of the progress file. Default is None.
179+
scenario (str):
180+
The scenario to use. Default is None. Can be "river", "sklearn", or "lightning".
178181
seed (int):
179182
The seed to use for the random number generator. Default is 123.
180183
sigma (float):
@@ -275,8 +278,9 @@ def fun_control_init(
275278
'optimizer': None,
276279
'path': None,
277280
'prep_model': None,
278-
prep_model_name': None,
281+
'prep_model_name': None,
279282
'save_model': False,
283+
'scenario': "lightning",
280284
'seed': 1234,
281285
'show_batch_interval': 1000000,
282286
'shuffle': None,
@@ -383,6 +387,7 @@ def fun_control_init(
383387
"prep_model_name": prep_model_name,
384388
"progress_file": progress_file,
385389
"save_model": False,
390+
"scenario": scenario,
386391
"seed": seed,
387392
"show_batch_interval": 1_000_000,
388393
"show_models": show_models,

0 commit comments

Comments
 (0)