77from torch .utils .data import random_split
88
99
10- class Net_Core_CV (nn .Module ):
10+ class Net_Core (nn .Module ):
1111 def __init__ (self , lr , batch_size , epochs , k_folds ):
12- super (Net_Core_CV , self ).__init__ ()
12+ super (Net_Core , self ).__init__ ()
1313 self .lr = lr
1414 self .batch_size = batch_size
1515 self .epochs = epochs
1616 self .k_folds = k_folds
1717 self .results = {}
1818
19- # def evaluate_cv_old(self, dataset, shuffle=False):
20- # try:
21- # device = getDevice()
22- # self.to(device)
23- # if torch.cuda.device_count() > 1:
24- # self = nn.DataParallel(self)
25- # criterion = nn.CrossEntropyLoss()
26- # optimizer = optim.SGD(self.parameters(), lr=self.lr, momentum=0.9)
27- # # TODO:
28- # # if checkpoint_dir:
29- # # model_state, optimizer_state = torch.load(os.path.join(checkpoint_dir, "checkpoint"))
30- # # model.load_state_dict(model_state)
31- # # optimizer.load_state_dict(optimizer_state)
32- # # TODO:
33- # # trainset, testset = load_data(data_dir)
34- # # dataset = fun_control["train"]
35- # kfold = KFold(n_splits=self.k_folds, shuffle=shuffle)
36-
37- # # test_abs = int(len(dataset) * 0.6)
38- # # train_subset, val_subset = random_split(dataset, [test_abs, len(dataset) - test_abs])
39- # for fold, (train_ids, val_ids) in enumerate(kfold.split(dataset)):
40- # print(f"Fold {fold}")
41- # # Sample elements randomly from a given list of ids, no replacement.
42- # train_subsampler = torch.utils.data.SubsetRandomSampler(train_ids)
43- # val_subsampler = torch.utils.data.SubsetRandomSampler(val_ids)
44- # # Define data loaders for training and testing data in this fold
45- # trainloader = torch.utils.data.DataLoader(dataset,
46- # batch_size=self.batch_size, sampler=train_subsampler)
47- # valloader = torch.utils.data.DataLoader(dataset,
48- # batch_size=self.batch_size, sampler=val_subsampler)
49- # self.reset_weights()
50- # # Define best_score, counter, and patience for early stopping:
51- # best_score = None
52- # counter = 0
53- # patience = 10
54- # # path = os.path.join(".", "checkpoint")
55- # for epoch in range(self.epochs): # loop over the dataset multiple times
56- # running_loss = 0.0
57- # epoch_steps = 0
58- # for i, data in enumerate(trainloader, 0):
59- # # get the inputs; data is a list of [inputs, labels]
60- # inputs, labels = data
61- # inputs, labels = inputs.to(device), labels.to(device)
62-
63- # # zero the parameter gradients
64- # optimizer.zero_grad()
65-
66- # # forward + backward + optimize
67- # outputs = self(inputs)
68- # loss = criterion(outputs, labels)
69- # loss.backward()
70- # optimizer.step()
71-
72- # # print statistics
73- # running_loss += loss.item()
74- # epoch_steps += 1
75- # if i % 2000 == 1999: # print every 2000 mini-batches
76- # print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / epoch_steps))
77- # running_loss = 0.0
78-
79- # # Validation loss
80- # val_loss = 0.0
81- # val_steps = 0
82- # total = 0
83- # correct = 0
84- # for i, data in enumerate(valloader, 0):
85- # with torch.no_grad():
86- # inputs, labels = data
87- # inputs, labels = inputs.to(device), labels.to(device)
88-
89- # outputs = self(inputs)
90- # _, predicted = torch.max(outputs.data, 1)
91- # total += labels.size(0)
92- # correct += (predicted == labels).sum().item()
93- # loss = criterion(outputs, labels)
94- # val_loss += loss.cpu().numpy()
95- # val_steps += 1
96- # # Print accuracy
97- # print("Accuracy for fold %d: %d %%" % (fold, 100.0 * correct / total))
98- # print("--------------------------------")
99- # self.results[fold] = 100.0 * (correct / total)
100- # # early stopping:
101- # # https://stackoverflow.com/questions/60200088/
102- # how-to-make-early-stopping-in-image-classification-pytorch
103- # if best_score is None:
104- # best_score = val_loss
105- # else:
106- # # Check if val_loss improves or not.
107- # if val_loss < best_score:
108- # # val_loss improves, we update the latest best_score,
109- # # and save the current model
110- # best_score = val_loss
111- # # TODO:
112- # # torch.save({'state_dict':self.state_dict()}, path)
113- # else:
114- # # val_loss does not improve, we increase the counter,
115- # # stop training if it exceeds the amount of patience
116- # counter += 1
117- # if counter >= patience:
118- # break
119- # # TODO:
120- # # torch.save((self.state_dict(), optimizer.state_dict()), path)
121- # # Print fold results
122- # print(f"k-fold CV results for {self.k_folds} folds")
123- # print("--------------------------------")
124- # sum = 0.0
125- # for key, value in self.results.items():
126- # print(f"Fold {key}: {value} %")
127- # sum += value
128- # avg = sum / len(self.results.items())
129- # print(f"Average: {avg} %")
130- # df_eval = avg
131- # df_preds = np.nan
132- # except Exception as err:
133- # print(f"Error in Net_Core. Call to evaluate() failed. {err=}, {type(err)=}")
134- # df_eval = np.nan
135- # df_preds = np.nan
136- # return df_eval, df_preds
137-
13819 def reset_weights (self ):
13920 for layer in self .children ():
14021 if hasattr (layer , "reset_parameters" ):
@@ -195,7 +76,7 @@ def evaluate_cv(self, dataset, shuffle=False):
19576 df_eval = sum (self .results .values ()) / len (self .results .values ())
19677 df_preds = np .nan
19778 except Exception as err :
198- print (f"Error in Net_Core_CV . Call to evaluate_cv() failed. { err = } , { type (err )= } " )
79+ print (f"Error in Net_Core . Call to evaluate_cv() failed. { err = } , { type (err )= } " )
19980 df_eval = np .nan
20081 df_preds = np .nan
20182 return df_eval , df_preds
@@ -213,16 +94,30 @@ def evaluate_hold_out(self, dataset, shuffle, test_dataset=None):
21394 else :
21495 trainloader , valloader = self .create_train_test_data_loaders (dataset , shuffle , test_dataset )
21596 scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = 30 , gamma = 0.1 )
97+ # Early stopping parameters
98+ patience = 5
99+ best_val_loss = float ("inf" )
100+ counter = 0
216101 for epoch in range (epochs ):
217102 self .train_hold_out (trainloader , criterion , optimizer , device = device , epoch = epoch )
218103 scheduler .step ()
219- val_accuracy , val_loss = self .validate_hold_out (valloader = valloader , criterion = criterion , device = device )
220- df_eval = val_loss
104+ # Early stopping check
105+ val_accuracy , val_loss = self .validate_hold_out (valloader = valloader , criterion = criterion , device = device )
106+ if val_loss < best_val_loss :
107+ best_val_loss = val_loss
108+ counter = 0
109+ else :
110+ counter += 1
111+ if counter >= patience :
112+ print (f"Early stopping at epoch { epoch } " )
113+ break
114+ df_eval = best_val_loss
221115 df_preds = np .nan
222116 except Exception as err :
223- print (f"Error in Net_Core_CV . Call to evaluate_hold_out() failed. { err = } , { type (err )= } " )
117+ print (f"Error in Net_Core . Call to evaluate_hold_out() failed. { err = } , { type (err )= } " )
224118 df_eval = np .nan
225119 df_preds = np .nan
120+ print (f"Returned to Spot: Best validation loss: { df_eval } " )
226121 return df_eval , df_preds
227122
228123 def create_train_val_data_loaders (self , dataset , shuffle ):
@@ -260,8 +155,11 @@ def train_hold_out(self, trainloader, criterion, optimizer, device, epoch):
260155 # print statistics
261156 running_loss += loss .item ()
262157 epoch_steps += 1
263- if i % 2000 == 1999 : # print every 2000 mini-batches
264- print ("[%d, %5d] loss: %.3f" % (epoch + 1 , i + 1 , running_loss / epoch_steps ))
158+ if i % 1000 == 999 : # print every 1000 mini-batches
159+ print (
160+ "Epoch: %d, Batch: %5d. Batch Size: %d. Training Loss: %.3f"
161+ % (epoch + 1 , i + 1 , int (self .batch_size ), running_loss / epoch_steps )
162+ )
265163 running_loss = 0.0
266164
267165 def validate_hold_out (self , valloader , criterion , device ):
@@ -284,6 +182,6 @@ def validate_hold_out(self, valloader, criterion, device):
284182 val_steps += 1
285183 accuracy = correct / total
286184 loss = val_loss / val_steps
287- print (f"Accuracy on hold-out set: { accuracy } " )
288185 print (f"Loss on hold-out set: { loss } " )
186+ print (f"Accuracy on hold-out set: { accuracy } " )
289187 return accuracy , loss
0 commit comments