|
3793 | 3793 | }, |
3794 | 3794 | { |
3795 | 3795 | "cell_type": "code", |
3796 | | - "execution_count": null, |
| 3796 | + "execution_count": 1, |
3797 | 3797 | "metadata": {}, |
3798 | | - "outputs": [], |
| 3798 | + "outputs": [ |
| 3799 | + { |
| 3800 | + "name": "stdout", |
| 3801 | + "output_type": "stream", |
| 3802 | + "text": [ |
| 3803 | + "LightDataModule.setup(): stage: None\n", |
| 3804 | + "train_size: 0.25, val_size: 0.25 used for train & val data.\n", |
| 3805 | + "test_size: 0.5 used for test dataset.\n", |
| 3806 | + "test_size: 0.5 used for predict dataset.\n", |
| 3807 | + "Training set size: 3\n" |
| 3808 | + ] |
| 3809 | + } |
| 3810 | + ], |
3799 | 3811 | "source": [ |
3800 | 3812 | "from spotPython.data.lightdatamodule import LightDataModule\n", |
3801 | 3813 | "from spotPython.data.csvdataset import CSVDataset\n", |
3802 | 3814 | "from spotPython.data.pkldataset import PKLDataset\n", |
| 3815 | + "from spotPython.utils.scaler import TorchStandardScaler\n", |
3803 | 3816 | "import torch\n", |
3804 | | - "dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n", |
3805 | | - "data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)\n", |
| 3817 | + "\n", |
| 3818 | + "scaler=TorchStandardScaler()\n", |
| 3819 | + "\n", |
| 3820 | + "dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.float64)\n", |
| 3821 | + "data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5, scaler=scaler)\n", |
3806 | 3822 | "data_module.setup()\n", |
3807 | 3823 | "print(f\"Training set size: {len(data_module.data_train)}\")\n" |
3808 | 3824 | ] |
3809 | 3825 | }, |
3810 | 3826 | { |
3811 | 3827 | "cell_type": "code", |
3812 | | - "execution_count": null, |
| 3828 | + "execution_count": 2, |
3813 | 3829 | "metadata": {}, |
3814 | | - "outputs": [], |
| 3830 | + "outputs": [ |
| 3831 | + { |
| 3832 | + "data": { |
| 3833 | + "text/plain": [ |
| 3834 | + "0.19878798965729408" |
| 3835 | + ] |
| 3836 | + }, |
| 3837 | + "execution_count": 2, |
| 3838 | + "metadata": {}, |
| 3839 | + "output_type": "execute_result" |
| 3840 | + } |
| 3841 | + ], |
3815 | 3842 | "source": [ |
3816 | 3843 | "from sklearn.datasets import load_diabetes\n", |
3817 | 3844 | "diabetes = load_diabetes()\n", |
|
3821 | 3848 | }, |
3822 | 3849 | { |
3823 | 3850 | "cell_type": "code", |
3824 | | - "execution_count": null, |
| 3851 | + "execution_count": 4, |
3825 | 3852 | "metadata": {}, |
3826 | | - "outputs": [], |
| 3853 | + "outputs": [ |
| 3854 | + { |
| 3855 | + "name": "stdout", |
| 3856 | + "output_type": "stream", |
| 3857 | + "text": [ |
| 3858 | + "Batch Size: 1\n", |
| 3859 | + "---------------\n", |
| 3860 | + "Inputs: tensor([[ 0.0381, 0.0507, 0.0617, 0.0219, -0.0442, -0.0348, -0.0434, -0.0026,\n", |
| 3861 | + " 0.0199, -0.0176]])\n", |
| 3862 | + "Targets: tensor([151.])\n" |
| 3863 | + ] |
| 3864 | + } |
| 3865 | + ], |
3827 | 3866 | "source": [ |
3828 | 3867 | "from spotPython.data.lightdatamodule import LightDataModule\n", |
3829 | 3868 | "from spotPython.data.csvdataset import CSVDataset\n", |
|
3908 | 3947 | }, |
3909 | 3948 | { |
3910 | 3949 | "cell_type": "code", |
3911 | | - "execution_count": 1, |
| 3950 | + "execution_count": 6, |
3912 | 3951 | "metadata": {}, |
3913 | 3952 | "outputs": [ |
3914 | 3953 | { |
|
3923 | 3962 | "Validation set size: 5160\n", |
3924 | 3963 | "Test set size: 10320\n", |
3925 | 3964 | "LightDataModule.train_dataloader(). data_train size: 5160\n", |
3926 | | - "[tensor([[ 5.6063e+00, 1.6000e+01, 6.4174e+00, 9.6957e-01, 1.5250e+03,\n", |
3927 | | - " 3.3152e+00, 3.7450e+01, -1.2190e+02],\n", |
3928 | | - " [ 3.3462e+00, 3.4000e+01, 3.9503e+00, 9.8619e-01, 8.0500e+02,\n", |
3929 | | - " 2.2238e+00, 3.4020e+01, -1.1841e+02]]), tensor([3.2050, 3.0700])]\n", |
| 3965 | + "[tensor([[-0.2677, -0.2508, -0.2664, -0.2752, 2.1991, -0.2714, -0.2160, -0.4747],\n", |
| 3966 | + " [-0.2714, -0.2216, -0.2704, -0.2752, 1.0301, -0.2732, -0.2216, -0.4690]],\n", |
| 3967 | + " grad_fn=<StackBackward0>), tensor([3.2050, 3.0700])]\n", |
3930 | 3968 | "LightDataModule.train_dataloader(). data_train size: 5160\n", |
3931 | | - "[[ 5.6062999e+00 1.6000000e+01 6.4173913e+00 9.6956521e-01\n", |
3932 | | - " 1.5250000e+03 3.3152175e+00 3.7450001e+01 -1.2190000e+02]\n", |
3933 | | - " [ 3.3462000e+00 3.4000000e+01 3.9502761e+00 9.8618782e-01\n", |
3934 | | - " 8.0500000e+02 2.2237568e+00 3.4020000e+01 -1.1841000e+02]]\n" |
| 3969 | + "[[-0.267703 -0.25082865 -0.26638618 -0.2752308 2.1990557 -0.2714226\n", |
| 3970 | + " -0.21600425 -0.47471142]\n", |
| 3971 | + " [-0.2713723 -0.22160538 -0.27039158 -0.2752038 1.0301248 -0.27319458\n", |
| 3972 | + " -0.2215729 -0.46904534]]\n" |
3935 | 3973 | ] |
3936 | 3974 | } |
3937 | 3975 | ], |
|
3940 | 3978 | "from spotPython.data.california_housing import CaliforniaHousing\n", |
3941 | 3979 | "import torch\n", |
3942 | 3980 | "dataset = CaliforniaHousing(feature_type=torch.float32, target_type=torch.float32)\n", |
3943 | | - "data_module = LightDataModule(dataset=dataset, batch_size=2, test_size=0.5)\n", |
| 3981 | + "data_module = LightDataModule(dataset=dataset, batch_size=2, test_size=0.5, scaler=scaler)\n", |
3944 | 3982 | "data_module.setup()\n", |
3945 | 3983 | "print(f\"Training set size: {len(data_module.data_train)}\")\n", |
3946 | 3984 | "print(f\"Validation set size: {len(data_module.data_val)}\")\n", |
3947 | 3985 | "print(f\"Test set size: {len(data_module.data_test)}\")\n", |
3948 | 3986 | "# print the first batch of the training set from data_module.data_train\n", |
3949 | 3987 | "print(next(iter(data_module.train_dataloader())))\n", |
3950 | 3988 | "# print the first batch of the training set from data_module.data_train as a numpy array\n", |
3951 | | - "print(next(iter(data_module.train_dataloader()))[0].numpy())\n" |
| 3989 | + "print(next(iter(data_module.train_dataloader()))[0].detach().numpy())\n" |
3952 | 3990 | ] |
3953 | 3991 | }, |
3954 | 3992 | { |
|
4050 | 4088 | "name": "python", |
4051 | 4089 | "nbconvert_exporter": "python", |
4052 | 4090 | "pygments_lexer": "ipython3", |
4053 | | - "version": "3.11.7" |
| 4091 | + "version": "3.11.8" |
4054 | 4092 | } |
4055 | 4093 | }, |
4056 | 4094 | "nbformat": 4, |
|
0 commit comments