@@ -29,8 +29,8 @@ def __init__(self, batch_size: int, data_dir: str = "./data", num_workers: int =
2929 def prepare_data (self ) -> None :
3030 """Prepares the data for use."""
3131 # download
32- CIFAR10 (self .data_dir , train = True , download = True )
33- CIFAR10 (self .data_dir , train = False , download = True )
32+ CIFAR10 (root = self .data_dir , train = True , download = True )
33+ CIFAR10 (root = self .data_dir , train = False , download = True )
3434
3535 def setup (self , stage : Optional [str ] = None ) -> None :
3636 """
@@ -42,14 +42,21 @@ def setup(self, stage: Optional[str] = None) -> None:
4242 """
4343 # Assign train/val datasets for use in dataloaders
4444 if stage == "fit" or stage is None :
45- transform = transforms .Compose ([transforms .ToTensor ()])
46- cifar_full = CIFAR10 (self .data_dir , train = True , transform = transform )
47- self .data_train , self .data_val = random_split (cifar_full , [45000 , 5000 ])
45+ transform = transforms .Compose (
46+ [transforms .ToTensor (), transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]
47+ )
48+ data_full = CIFAR10 (root = self .data_dir , train = True , transform = transform )
49+ # self.data_train, self.data_val = random_split(daata_full, [45000, 5000])
50+ test_abs = int (len (data_full ) * 0.6 )
51+ print ("test_abs" , test_abs )
52+ self .data_train , self .data_val = random_split (data_full , [test_abs , len (data_full ) - test_abs ])
4853
4954 # Assign test dataset for use in dataloader(s)
5055 if stage == "test" or stage is None :
51- transform = transforms .Compose ([transforms .ToTensor ()])
52- self .data_test = CIFAR10 (self .data_dir , train = False , transform = transform )
56+ transform = transforms .Compose (
57+ [transforms .ToTensor (), transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]
58+ )
59+ self .data_test = CIFAR10 (root = self .data_dir , train = False , transform = transform )
5360
5461 def train_dataloader (self ) -> DataLoader :
5562 """
@@ -59,7 +66,8 @@ def train_dataloader(self) -> DataLoader:
5966 DataLoader: The training dataloader.
6067
6168 """
62- return DataLoader (self .data_train , batch_size = self .batch_size , num_workers = self .num_workers )
69+ print ("self.batch_size" , self .batch_size )
70+ return DataLoader (self .data_train , batch_size = self .batch_size , shuffle = True , num_workers = self .num_workers )
6371
6472 def val_dataloader (self ) -> DataLoader :
6573 """
@@ -70,7 +78,7 @@ def val_dataloader(self) -> DataLoader:
7078
7179
7280 """
73- return DataLoader (self .data_val , batch_size = self .batch_size , num_workers = self .num_workers )
81+ return DataLoader (self .data_val , batch_size = self .batch_size , shuffle = False , num_workers = self .num_workers )
7482
7583 def test_dataloader (self ) -> DataLoader :
7684 """
@@ -81,4 +89,4 @@ def test_dataloader(self) -> DataLoader:
8189
8290
8391 """
84- return DataLoader (self .data_test , batch_size = self .batch_size , num_workers = self .num_workers )
92+ return DataLoader (self .data_test , batch_size = self .batch_size , shuffle = False , num_workers = self .num_workers )
0 commit comments