@@ -126,6 +126,94 @@ def prepare_data(self) -> None:
126126 # download
127127 pass
128128
129+ def _setup_full_data_provided (self , stage ) -> None :
130+ full_size = len (self .data_full )
131+ test_size = self .test_size
132+
133+ # consider the case when test_size is a float
134+ if isinstance (self .test_size , float ):
135+ full_train_size = 1.0 - self .test_size
136+ val_size = full_train_size * self .test_size
137+ train_size = full_train_size - val_size
138+ else :
139+ # test_size is an int, training size calculation directly based on it
140+ full_train_size = full_size - self .test_size
141+ val_size = floor (full_train_size * self .test_size / full_size )
142+ train_size = full_size - val_size - test_size
143+
144+ # Assign train/val datasets for use in dataloaders
145+ if stage == "fit" or stage is None :
146+ if self .verbosity > 0 :
147+ print (f"train_size: { train_size } , val_size: { val_size } used for train & val data." )
148+ generator_fit = torch .Generator ().manual_seed (self .test_seed )
149+ self .data_train , self .data_val , _ = random_split (self .data_full , [train_size , val_size , test_size ], generator = generator_fit )
150+ # Handle scaling and transformation if scaler is provided
151+ if self .scaler is not None :
152+ self .handle_scaling_and_transform ()
153+
154+ # Assign test dataset for use in dataloader(s)
155+ if stage == "test" or stage is None :
156+ if self .verbosity > 0 :
157+ print (f"test_size: { test_size } used for test dataset." )
158+ generator_test = torch .Generator ().manual_seed (self .test_seed )
159+ self .data_test , _ , _ = random_split (self .data_full , [test_size , train_size , val_size ], generator = generator_test )
160+ if self .scaler is not None :
161+ # Transform the test data
162+ self .data_test = self .transform_dataset (self .data_test )
163+
164+ # Assign pred dataset for use in dataloader(s)
165+ if stage == "predict" or stage is None :
166+ if self .verbosity > 0 :
167+ print (f"test_size: { test_size } used for predict dataset." )
168+ generator_predict = torch .Generator ().manual_seed (self .test_seed )
169+ self .data_predict , _ , _ = random_split (self .data_full , [test_size , train_size , val_size ], generator = generator_predict )
170+ if self .scaler is not None :
171+ # Transform the predict data
172+ self .data_predict = self .transform_dataset (self .data_predict )
173+
174+ def _setup_test_data_provided (self , stage ) -> None :
175+ # New functionality with separate full_train and test datasets. Use these datasets directly.
176+ full_train_size = len (self .data_full_train )
177+ test_size = self .test_size
178+ # consider the case when test_size is a float
179+ if isinstance (self .test_size , float ):
180+ val_size = self .test_size
181+ train_size = 1 - self .test_size
182+ else :
183+ # test_size is an int, training size calculation directly based on it
184+ full_size = len (self .data_full_train ) + len (self .data_test )
185+ full_train_size = len (self .data_full_train )
186+ val_size = floor (full_train_size * self .test_size / full_size )
187+ train_size = full_train_size - val_size
188+
189+ # Assign train/val datasets for use in dataloaders
190+ if stage == "fit" or stage is None :
191+ if self .verbosity > 0 :
192+ print (f"train_size: { train_size } , val_size: { val_size } used for train & val data." )
193+ generator_fit = torch .Generator ().manual_seed (self .test_seed )
194+ self .data_train , self .data_val = random_split (self .data_full_train , [train_size , val_size ], generator = generator_fit )
195+ # Handle scaling and transformation if scaler is provided
196+ if self .scaler is not None :
197+ self .handle_scaling_and_transform ()
198+
199+ # Assign test dataset for use in dataloader(s)
200+ if stage == "test" or stage is None :
201+ if self .verbosity > 0 :
202+ print (f"test_size: { test_size } used for test dataset." )
203+ self .data_test = self .data_test
204+ if self .scaler is not None :
205+ # Transform the test data
206+ self .data_test = self .transform_dataset (self .data_test )
207+
208+ # Assign pred dataset for use in dataloader(s)
209+ if stage == "predict" or stage is None :
210+ if self .verbosity > 0 :
211+ print (f"test_size: { test_size } used for predict dataset." )
212+ self .data_predict = self .data_test
213+ if self .scaler is not None :
214+ # Transform the predict data
215+ self .data_predict = self .transform_dataset (self .data_predict )
216+
129217 def setup (self , stage : Optional [str ] = None ) -> None :
130218 """
131219 Splits the data for use in training, validation, and testing.
@@ -151,91 +239,9 @@ def setup(self, stage: Optional[str] = None) -> None:
151239
152240 """
153241 if self .data_full is not None :
154- full_size = len (self .data_full )
155- test_size = self .test_size
156-
157- # consider the case when test_size is a float
158- if isinstance (self .test_size , float ):
159- full_train_size = 1.0 - self .test_size
160- val_size = full_train_size * self .test_size
161- train_size = full_train_size - val_size
162- else :
163- # test_size is an int, training size calculation directly based on it
164- full_train_size = full_size - self .test_size
165- val_size = floor (full_train_size * self .test_size / full_size )
166- train_size = full_size - val_size - test_size
167-
168- # Assign train/val datasets for use in dataloaders
169- if stage == "fit" or stage is None :
170- if self .verbosity > 0 :
171- print (f"train_size: { train_size } , val_size: { val_size } used for train & val data." )
172- generator_fit = torch .Generator ().manual_seed (self .test_seed )
173- self .data_train , self .data_val , _ = random_split (self .data_full , [train_size , val_size , test_size ], generator = generator_fit )
174- # Handle scaling and transformation if scaler is provided
175- if self .scaler is not None :
176- self .handle_scaling_and_transform ()
177-
178- # Assign test dataset for use in dataloader(s)
179- if stage == "test" or stage is None :
180- if self .verbosity > 0 :
181- print (f"test_size: { test_size } used for test dataset." )
182- generator_test = torch .Generator ().manual_seed (self .test_seed )
183- self .data_test , _ , _ = random_split (self .data_full , [test_size , train_size , val_size ], generator = generator_test )
184- if self .scaler is not None :
185- # Transform the test data
186- self .data_test = self .transform_dataset (self .data_test )
187-
188- # Assign pred dataset for use in dataloader(s)
189- if stage == "predict" or stage is None :
190- if self .verbosity > 0 :
191- print (f"test_size: { test_size } used for predict dataset." )
192- generator_predict = torch .Generator ().manual_seed (self .test_seed )
193- self .data_predict , _ , _ = random_split (self .data_full , [test_size , train_size , val_size ], generator = generator_predict )
194- if self .scaler is not None :
195- # Transform the predict data
196- self .data_predict = self .transform_dataset (self .data_predict )
242+ self ._setup_full_data_provided (stage )
197243 else :
198- # New functionality with separate full_train and test datasets. Use these datasets directly.
199- full_train_size = len (self .data_full_train )
200- test_size = self .test_size
201- # consider the case when test_size is a float
202- if isinstance (self .test_size , float ):
203- val_size = self .test_size
204- train_size = 1 - self .test_size
205- else :
206- # test_size is an int, training size calculation directly based on it
207- full_size = len (self .data_full_train ) + len (self .data_test )
208- full_train_size = len (self .data_full_train )
209- val_size = floor (full_train_size * self .test_size / full_size )
210- train_size = full_train_size - val_size
211-
212- # Assign train/val datasets for use in dataloaders
213- if stage == "fit" or stage is None :
214- if self .verbosity > 0 :
215- print (f"train_size: { train_size } , val_size: { val_size } used for train & val data." )
216- generator_fit = torch .Generator ().manual_seed (self .test_seed )
217- self .data_train , self .data_val = random_split (self .data_full_train , [train_size , val_size ], generator = generator_fit )
218- # Handle scaling and transformation if scaler is provided
219- if self .scaler is not None :
220- self .handle_scaling_and_transform ()
221-
222- # Assign test dataset for use in dataloader(s)
223- if stage == "test" or stage is None :
224- if self .verbosity > 0 :
225- print (f"test_size: { test_size } used for test dataset." )
226- self .data_test = self .data_test
227- if self .scaler is not None :
228- # Transform the test data
229- self .data_test = self .transform_dataset (self .data_test )
230-
231- # Assign pred dataset for use in dataloader(s)
232- if stage == "predict" or stage is None :
233- if self .verbosity > 0 :
234- print (f"test_size: { test_size } used for predict dataset." )
235- self .data_predict = self .data_test
236- if self .scaler is not None :
237- # Transform the predict data
238- self .data_predict = self .transform_dataset (self .data_predict )
244+ self ._setup_test_data_provided (stage )
239245
240246 def train_dataloader (self ) -> DataLoader :
241247 """
0 commit comments