55# from spotPython.light.utils import create_model
66import torch .optim as optim
77
8- # from spotPython.light.cnn.googlenet import GoogleNet
9- import spotPython .light .cnn .googlenet
8+ from spotPython .light .cnn .googlenet import GoogleNet
9+
10+ # import spotPython.light.cnn.googlenet
1011
1112
1213class NetCNNBase (L .LightningModule ):
13- def __init__ (self , config , fun_control ):
14+ def __init__ (self , model_name , model_hparams , optimizer_name , optimizer_hparams ):
1415 """
1516 Initializes the CNN model.
1617
1718 Args:
18- config (dict): dictionary containing the configuration for the hyperparameter tuning.
19- fun_control (dict): dictionary containing control parameters for the hyperparameter tuning.
19+ model_name (str): name of the model.
20+ model_hparams (dict): dictionary containing the hyperparameters for the model.
21+ optimizer_name (str): name of the optimizer.
22+ optimizer_hparams (dict): dictionary containing the hyperparameters for the optimizer.
2023
2124 Returns:
2225 (object): model object.
@@ -26,38 +29,23 @@ def __init__(self, config, fun_control):
2629 from spotPython.light.cnn.googlenet import GoogleNet
2730 import torch
2831 import torch.nn as nn
29- config = {"c_in": 3, "c_out": 10, "act_fn": nn.ReLU, "optimizer_name": "Adam"}
32+ model_hparams = {"c_in": 3, "c_out": 10, "act_fn": nn.ReLU, "optimizer_name": "Adam"}
3033 fun_control = {"core_model": GoogleNet}
31- model = NetCNNBase(config , fun_control)
34+ model = NetCNNBase(model_hparams , fun_control)
3235 x = torch.randn(1, 3, 32, 32)
3336 y = model(x)
3437 y.shape
3538 torch.Size([1, 10])
3639
3740 """
38- print ("NetCNNBase: Starting" )
39- print (f"NetCNNBase: config: { config } " )
40- print (f"NetCNNBase: fun_control['core_model']: { fun_control ['core_model' ]} " )
41- config = {
42- "c_in" : 3 ,
43- "c_out" : 10 ,
44- "act_fn" : nn .ReLU ,
45- "optimizer_name" : "Adam" ,
46- "optimizer_hparams" : {"lr" : 1e-3 , "weight_decay" : 1e-4 },
47- }
48- print ("fun_control['core_model']: " , fun_control ["core_model" ])
49- print ("fun_control['core_model'].type: " , fun_control ["core_model" ].type )
50- # fun_control = {"core_model": GoogleNet}
51- fun_control = {"core_model" : spotPython .light .cnn .googlenet .GoogleNet }
5241 super ().__init__ ()
5342 # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
54- self .save_hyperparameters () # "fun_control" is not a hyperparameter )
55- print (f"config: { config } " )
43+ self .save_hyperparameters ()
44+ print (f"model_hparams: { model_hparams } " )
45+ print (f"self.hparams: { self .hparams } " )
5646 # Create model
57- print ("Creating model" )
58- # self.model = create_model(config, fun_control)
59- self .model = fun_control ["core_model" ](** config )
60- print ("Model created" )
47+ self .model = self .create_model (model_name , model_hparams )
48+ # self.model = fun_control["core_model"](**model_hparams)
6149 print (f"self.model: { self .model } " )
6250 # Create loss module
6351 self .loss_module = nn .CrossEntropyLoss ()
@@ -69,11 +57,8 @@ def forward(self, imgs):
6957 return self .model (imgs )
7058
7159 def configure_optimizers (self ):
72- # We will support Adam or SGD as optimizers.
73- if self .hparams .config ["optimizer_name" ] == "Adam" :
74- # AdamW is Adam with a correct implementation of weight decay (see here
75- # for details: https://arxiv.org/pdf/1711.05101.pdf)
76- optimizer = optim .AdamW (self .parameters (), ** self .hparams .config ["optimizer_hparams" ])
60+ if self .hparams .optimizer_name == "Adam" :
61+ optimizer = optim .AdamW (self .parameters (), ** self .hparams .optimizer_hparams )
7762 elif self .hparams .optimizer_name == "SGD" :
7863 optimizer = optim .SGD (self .parameters (), ** self .hparams .optimizer_hparams )
7964 else :
@@ -108,3 +93,13 @@ def test_step(self, batch, batch_idx):
10893 acc = (labels == preds ).float ().mean ()
10994 # By default logs it per epoch (weighted average over batches), and returns it afterwards
11095 self .log ("test_acc" , acc )
96+
97+ def create_model (self , model_name , model_hparams ):
98+ print ("create_model: Starting" )
99+ print (f"model_name: { model_name } " )
100+ print (f"model_hparams: { model_hparams } " )
101+ model_dict = {"GoogleNet" : GoogleNet }
102+ if model_name in model_dict :
103+ return model_dict [model_name ](** model_hparams )
104+ else :
105+ assert False , f'Unknown model name "{ model_name } ". Available models are: { str (model_dict .keys ())} '
0 commit comments