Skip to content

Commit d919255

Browse files
v0.6.12
1 parent b41b4c3 commit d919255

6 files changed

Lines changed: 49 additions & 50 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ site/
99
Ignored/
1010
spotPython.code-workspace
1111
src/spotPython/_version.py
12+
src/spotPython/_version*.py
1213
Figures.d/*
1314
notebooks/runs/
1415
notebook/runs/*

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.11"
10+
version = "0.6.12"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/fun/hyperlightning.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,18 @@ def fun(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
105105
array containing the evaluation results.
106106
107107
Examples:
108-
>>> MAX_TIME = 1
109-
INIT_SIZE = 5
110-
WORKERS = 0
111-
PREFIX="TEST"
112-
from spotPython.utils.init import fun_control_init
108+
>>> from spotPython.utils.init import fun_control_init
113109
from spotPython.utils.file import get_experiment_name, get_spot_tensorboard_path
114110
from spotPython.utils.device import getDevice
111+
from spotPython.light.cnn.googlenet import GoogleNet
112+
from spotPython.data.lightning_hyper_dict import LightningHyperDict
113+
from spotPython.hyperparameters.values import add_core_model_to_fun_control
114+
from spotPython.fun.hyperlightning import HyperLightning
115+
from spotPython.hyperparameters.values import get_default_hyperparameters_as_array
116+
MAX_TIME = 1
117+
INIT_SIZE = 3
118+
WORKERS = 8
119+
PREFIX="TEST"
115120
experiment_name = get_experiment_name(prefix=PREFIX)
116121
fun_control = fun_control_init(
117122
spot_tensorboard_path=get_spot_tensorboard_path(experiment_name),
@@ -120,15 +125,10 @@ def fun(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
120125
_L_in=3,
121126
_L_out=10,
122127
TENSORBOARD_CLEAN=True)
123-
from spotPython.light.cnn.googlenet import GoogleNet
124-
from spotPython.data.lightning_hyper_dict import LightningHyperDict
125-
from spotPython.hyperparameters.values import add_core_model_to_fun_control
126128
add_core_model_to_fun_control(core_model=GoogleNet,
127129
fun_control=fun_control,
128130
hyper_dict= LightningHyperDict)
129-
from spotPython.hyperparameters.values import get_default_hyperparameters_as_array
130131
X_start = get_default_hyperparameters_as_array(fun_control)
131-
from spotPython.fun.hyperlightning import HyperLightning
132132
hyper_light = HyperLightning(seed=126, log_level=50)
133133
hyper_light.fun(X=X_start, fun_control=fun_control)
134134

src/spotPython/light/cnn/netcnnbase.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,21 @@
55
# from spotPython.light.utils import create_model
66
import torch.optim as optim
77

8-
# from spotPython.light.cnn.googlenet import GoogleNet
9-
import spotPython.light.cnn.googlenet
8+
from spotPython.light.cnn.googlenet import GoogleNet
9+
10+
# import spotPython.light.cnn.googlenet
1011

1112

1213
class NetCNNBase(L.LightningModule):
13-
def __init__(self, config, fun_control):
14+
def __init__(self, model_name, model_hparams, optimizer_name, optimizer_hparams):
1415
"""
1516
Initializes the CNN model.
1617
1718
Args:
18-
config (dict): dictionary containing the configuration for the hyperparameter tuning.
19-
fun_control (dict): dictionary containing control parameters for the hyperparameter tuning.
19+
model_name (str): name of the model.
20+
model_hparams (dict): dictionary containing the hyperparameters for the model.
21+
optimizer_name (str): name of the optimizer.
22+
optimizer_hparams (dict): dictionary containing the hyperparameters for the optimizer.
2023
2124
Returns:
2225
(object): model object.
@@ -26,38 +29,23 @@ def __init__(self, config, fun_control):
2629
from spotPython.light.cnn.googlenet import GoogleNet
2730
import torch
2831
import torch.nn as nn
29-
config = {"c_in": 3, "c_out": 10, "act_fn": nn.ReLU, "optimizer_name": "Adam"}
32+
model_hparams = {"c_in": 3, "c_out": 10, "act_fn": nn.ReLU, "optimizer_name": "Adam"}
3033
fun_control = {"core_model": GoogleNet}
31-
model = NetCNNBase(config, fun_control)
34+
model = NetCNNBase(model_hparams, fun_control)
3235
x = torch.randn(1, 3, 32, 32)
3336
y = model(x)
3437
y.shape
3538
torch.Size([1, 10])
3639
3740
"""
38-
print("NetCNNBase: Starting")
39-
print(f"NetCNNBase: config: {config}")
40-
print(f"NetCNNBase: fun_control['core_model']: {fun_control['core_model']}")
41-
config = {
42-
"c_in": 3,
43-
"c_out": 10,
44-
"act_fn": nn.ReLU,
45-
"optimizer_name": "Adam",
46-
"optimizer_hparams": {"lr": 1e-3, "weight_decay": 1e-4},
47-
}
48-
print("fun_control['core_model']: ", fun_control["core_model"])
49-
print("fun_control['core_model'].type: ", fun_control["core_model"].type)
50-
# fun_control = {"core_model": GoogleNet}
51-
fun_control = {"core_model": spotPython.light.cnn.googlenet.GoogleNet}
5241
super().__init__()
5342
# Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
54-
self.save_hyperparameters() # "fun_control" is not a hyperparameter )
55-
print(f"config: {config}")
43+
self.save_hyperparameters()
44+
print(f"model_hparams: {model_hparams}")
45+
print(f"self.hparams: {self.hparams}")
5646
# Create model
57-
print("Creating model")
58-
# self.model = create_model(config, fun_control)
59-
self.model = fun_control["core_model"](**config)
60-
print("Model created")
47+
self.model = self.create_model(model_name, model_hparams)
48+
# self.model = fun_control["core_model"](**model_hparams)
6149
print(f"self.model: {self.model}")
6250
# Create loss module
6351
self.loss_module = nn.CrossEntropyLoss()
@@ -69,11 +57,8 @@ def forward(self, imgs):
6957
return self.model(imgs)
7058

7159
def configure_optimizers(self):
72-
# We will support Adam or SGD as optimizers.
73-
if self.hparams.config["optimizer_name"] == "Adam":
74-
# AdamW is Adam with a correct implementation of weight decay (see here
75-
# for details: https://arxiv.org/pdf/1711.05101.pdf)
76-
optimizer = optim.AdamW(self.parameters(), **self.hparams.config["optimizer_hparams"])
60+
if self.hparams.optimizer_name == "Adam":
61+
optimizer = optim.AdamW(self.parameters(), **self.hparams.optimizer_hparams)
7762
elif self.hparams.optimizer_name == "SGD":
7863
optimizer = optim.SGD(self.parameters(), **self.hparams.optimizer_hparams)
7964
else:
@@ -108,3 +93,13 @@ def test_step(self, batch, batch_idx):
10893
acc = (labels == preds).float().mean()
10994
# By default logs it per epoch (weighted average over batches), and returns it afterwards
11095
self.log("test_acc", acc)
96+
97+
def create_model(self, model_name, model_hparams):
98+
print("create_model: Starting")
99+
print(f"model_name: {model_name}")
100+
print(f"model_hparams: {model_hparams}")
101+
model_dict = {"GoogleNet": GoogleNet}
102+
if model_name in model_dict:
103+
return model_dict[model_name](**model_hparams)
104+
else:
105+
assert False, f'Unknown model name "{model_name}". Available models are: {str(model_dict.keys())}'

src/spotPython/light/trainmodel.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,10 @@ def train_model(config: dict, fun_control: dict):
4141
"""
4242
print("train_model: Starting")
4343
print(f"train_model: config: {config}")
44-
save_name = "saved_models"
44+
save_name = fun_control["core_model"].__name__
4545
# Create PyTorch Lightning data loaders
46-
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/ConvNets")
47-
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
48-
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")
49-
os.makedirs(DATASET_PATH, exist_ok=True)
46+
CHECKPOINT_PATH = fun_control["CHECKPOINT_PATH"]
47+
DATASET_PATH = fun_control["DATASET_PATH"]
5048

5149
# Create PyTorch Lightning data loaders
5250
# TODO: Replace this by data loaders external to train_model method:
@@ -115,7 +113,12 @@ def train_model(config: dict, fun_control: dict):
115113
else:
116114
L.seed_everything(42) # To be reproducable
117115
print("train_model: Creating model")
118-
model = NetCNNBase(config=config, fun_control=fun_control) # Create model
116+
model = NetCNNBase(
117+
model_name=fun_control["core_model"].__name__,
118+
model_hparams=config,
119+
optimizer_name="Adam",
120+
optimizer_hparams={"lr": 1e-3, "weight_decay": 1e-4},
121+
) # Create model
119122
trainer.fit(model, train_loader, val_loader)
120123
model = NetCNNBase.load_from_checkpoint(
121124
trainer.checkpoint_callback.best_model_path

src/spotPython/utils/init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def fun_control_init(
8282
L.seed_everything(42)
8383

8484
# Path to the folder where the pretrained models are saved
85-
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/")
85+
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "runs/saved_models/")
8686
os.makedirs(CHECKPOINT_PATH, exist_ok=True)
8787
# Path to the folder where the datasets are/should be downloaded (e.g. MNIST)
8888
DATASET_PATH = os.environ.get("PATH_DATASETS", "data/")

0 commit comments

Comments
 (0)