|
1527 | 1527 | }, |
1528 | 1528 | { |
1529 | 1529 | "cell_type": "code", |
1530 | | - "execution_count": 16, |
| 1530 | + "execution_count": 1, |
1531 | 1531 | "metadata": {}, |
1532 | 1532 | "outputs": [ |
1533 | 1533 | { |
|
1680 | 1680 | " break" |
1681 | 1681 | ] |
1682 | 1682 | }, |
| 1683 | + { |
| 1684 | + "cell_type": "markdown", |
| 1685 | + "metadata": {}, |
| 1686 | + "source": [ |
| 1687 | + "## Test HyperLight" |
| 1688 | + ] |
| 1689 | + }, |
1683 | 1690 | { |
1684 | 1691 | "cell_type": "code", |
1685 | | - "execution_count": null, |
| 1692 | + "execution_count": 3, |
1686 | 1693 | "metadata": {}, |
1687 | | - "outputs": [], |
1688 | | - "source": [] |
| 1694 | + "outputs": [ |
| 1695 | + { |
| 1696 | + "name": "stderr", |
| 1697 | + "output_type": "stream", |
| 1698 | + "text": [ |
| 1699 | + "Seed set to 1234\n" |
| 1700 | + ] |
| 1701 | + }, |
| 1702 | + { |
| 1703 | + "data": { |
| 1704 | + "text/plain": [ |
| 1705 | + "array([[ True, True, True, True, True, True, True, True, True],\n", |
| 1706 | + " [ True, True, True, True, True, True, True, True, True]])" |
| 1707 | + ] |
| 1708 | + }, |
| 1709 | + "execution_count": 3, |
| 1710 | + "metadata": {}, |
| 1711 | + "output_type": "execute_result" |
| 1712 | + } |
| 1713 | + ], |
| 1714 | + "source": [ |
| 1715 | + "import numpy as np\n", |
| 1716 | + "from spotPython.utils.init import fun_control_init\n", |
| 1717 | + "from spotPython.light.regression.netlightregression import NetLightRegression\n", |
| 1718 | + "from spotPython.hyperdict.light_hyper_dict import LightHyperDict\n", |
| 1719 | + "from spotPython.hyperparameters.values import add_core_model_to_fun_control\n", |
| 1720 | + "from spotPython.fun.hyperlight import HyperLight\n", |
| 1721 | + "from spotPython.hyperparameters.values import get_var_name\n", |
| 1722 | + "fun_control = fun_control_init()\n", |
| 1723 | + "add_core_model_to_fun_control(core_model=NetLightRegression,\n", |
| 1724 | + " fun_control=fun_control,\n", |
| 1725 | + " hyper_dict=LightHyperDict)\n", |
| 1726 | + "hyper_light = HyperLight(seed=126, log_level=50)\n", |
| 1727 | + "n_hyperparams = len(get_var_name(fun_control))\n", |
| 1728 | + "# generate a random np.array X with shape (2, n_hyperparams)\n", |
| 1729 | + "X = np.random.rand(2, n_hyperparams)\n", |
| 1730 | + "X == hyper_light.check_X_shape(X, fun_control)" |
| 1731 | + ] |
| 1732 | + }, |
| 1733 | + { |
| 1734 | + "cell_type": "code", |
| 1735 | + "execution_count": 7, |
| 1736 | + "metadata": {}, |
| 1737 | + "outputs": [ |
| 1738 | + { |
| 1739 | + "name": "stderr", |
| 1740 | + "output_type": "stream", |
| 1741 | + "text": [ |
| 1742 | + "Seed set to 1234\n", |
| 1743 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'act_fn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['act_fn'])`.\n", |
| 1744 | + "GPU available: True (mps), used: True\n", |
| 1745 | + "TPU available: False, using: 0 TPU cores\n", |
| 1746 | + "IPU available: False, using: 0 IPUs\n", |
| 1747 | + "HPU available: False, using: 0 HPUs\n", |
| 1748 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:639: Checkpoint directory runs/lightning_logs/-2621019871347415348/checkpoints exists and is not empty.\n", |
| 1749 | + "\n", |
| 1750 | + " | Name | Type | Params | In sizes | Out sizes\n", |
| 1751 | + "-------------------------------------------------------------\n", |
| 1752 | + "0 | layers | Sequential | 157 | [16, 10] | [16, 1] \n", |
| 1753 | + "-------------------------------------------------------------\n", |
| 1754 | + "157 Trainable params\n", |
| 1755 | + "0 Non-trainable params\n", |
| 1756 | + "157 Total params\n", |
| 1757 | + "0.001 Total estimated model params size (MB)\n", |
| 1758 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n", |
| 1759 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n", |
| 1760 | + "/Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:293: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n" |
| 1761 | + ] |
| 1762 | + }, |
| 1763 | + { |
| 1764 | + "name": "stdout", |
| 1765 | + "output_type": "stream", |
| 1766 | + "text": [ |
| 1767 | + "Train_model(): Test set size: 266\n" |
| 1768 | + ] |
| 1769 | + }, |
| 1770 | + { |
| 1771 | + "data": { |
| 1772 | + "text/html": [ |
| 1773 | + "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
| 1774 | + "┃<span style=\"font-weight: bold\"> Validate metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n", |
| 1775 | + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
| 1776 | + "│<span style=\"color: #008080; text-decoration-color: #008080\"> hp_metric </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 27259.2421875 </span>│\n", |
| 1777 | + "│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 27259.2421875 </span>│\n", |
| 1778 | + "└───────────────────────────┴───────────────────────────┘\n", |
| 1779 | + "</pre>\n" |
| 1780 | + ], |
| 1781 | + "text/plain": [ |
| 1782 | + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
| 1783 | + "┃\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", |
| 1784 | + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
| 1785 | + "│\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 27259.2421875 \u001b[0m\u001b[35m \u001b[0m│\n", |
| 1786 | + "│\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 27259.2421875 \u001b[0m\u001b[35m \u001b[0m│\n", |
| 1787 | + "└───────────────────────────┴───────────────────────────┘\n" |
| 1788 | + ] |
| 1789 | + }, |
| 1790 | + "metadata": {}, |
| 1791 | + "output_type": "display_data" |
| 1792 | + }, |
| 1793 | + { |
| 1794 | + "name": "stderr", |
| 1795 | + "output_type": "stream", |
| 1796 | + "text": [ |
| 1797 | + "GPU available: True (mps), used: True\n", |
| 1798 | + "TPU available: False, using: 0 TPU cores\n", |
| 1799 | + "IPU available: False, using: 0 IPUs\n", |
| 1800 | + "HPU available: False, using: 0 HPUs\n", |
| 1801 | + "\n", |
| 1802 | + " | Name | Type | Params | In sizes | Out sizes\n", |
| 1803 | + "-------------------------------------------------------------\n", |
| 1804 | + "0 | layers | Sequential | 157 | [16, 10] | [16, 1] \n", |
| 1805 | + "-------------------------------------------------------------\n", |
| 1806 | + "157 Trainable params\n", |
| 1807 | + "0 Non-trainable params\n", |
| 1808 | + "157 Total params\n", |
| 1809 | + "0.001 Total estimated model params size (MB)\n" |
| 1810 | + ] |
| 1811 | + }, |
| 1812 | + { |
| 1813 | + "name": "stdout", |
| 1814 | + "output_type": "stream", |
| 1815 | + "text": [ |
| 1816 | + "train_model result: {'val_loss': 27259.2421875, 'hp_metric': 27259.2421875}\n", |
| 1817 | + "Train_model(): Test set size: 266\n" |
| 1818 | + ] |
| 1819 | + }, |
| 1820 | + { |
| 1821 | + "name": "stderr", |
| 1822 | + "output_type": "stream", |
| 1823 | + "text": [ |
| 1824 | + "`Trainer.fit` stopped: `max_epochs=16` reached.\n" |
| 1825 | + ] |
| 1826 | + }, |
| 1827 | + { |
| 1828 | + "data": { |
| 1829 | + "text/html": [ |
| 1830 | + "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
| 1831 | + "┃<span style=\"font-weight: bold\"> Validate metric </span>┃<span style=\"font-weight: bold\"> DataLoader 0 </span>┃\n", |
| 1832 | + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
| 1833 | + "│<span style=\"color: #008080; text-decoration-color: #008080\"> hp_metric </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 24221.275390625 </span>│\n", |
| 1834 | + "│<span style=\"color: #008080; text-decoration-color: #008080\"> val_loss </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> 24221.275390625 </span>│\n", |
| 1835 | + "└───────────────────────────┴───────────────────────────┘\n", |
| 1836 | + "</pre>\n" |
| 1837 | + ], |
| 1838 | + "text/plain": [ |
| 1839 | + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", |
| 1840 | + "┃\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", |
| 1841 | + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", |
| 1842 | + "│\u001b[36m \u001b[0m\u001b[36m hp_metric \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 24221.275390625 \u001b[0m\u001b[35m \u001b[0m│\n", |
| 1843 | + "│\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 24221.275390625 \u001b[0m\u001b[35m \u001b[0m│\n", |
| 1844 | + "└───────────────────────────┴───────────────────────────┘\n" |
| 1845 | + ] |
| 1846 | + }, |
| 1847 | + "metadata": {}, |
| 1848 | + "output_type": "display_data" |
| 1849 | + }, |
| 1850 | + { |
| 1851 | + "name": "stdout", |
| 1852 | + "output_type": "stream", |
| 1853 | + "text": [ |
| 1854 | + "train_model result: {'val_loss': 24221.275390625, 'hp_metric': 24221.275390625}\n" |
| 1855 | + ] |
| 1856 | + }, |
| 1857 | + { |
| 1858 | + "data": { |
| 1859 | + "text/plain": [ |
| 1860 | + "array([27259.2421875 , 24221.27539062])" |
| 1861 | + ] |
| 1862 | + }, |
| 1863 | + "execution_count": 7, |
| 1864 | + "metadata": {}, |
| 1865 | + "output_type": "execute_result" |
| 1866 | + } |
| 1867 | + ], |
| 1868 | + "source": [ |
| 1869 | + "from spotPython.utils.init import fun_control_init\n", |
| 1870 | + "from spotPython.light.regression.netlightregression import NetLightRegression\n", |
| 1871 | + "from spotPython.hyperdict.light_hyper_dict import LightHyperDict\n", |
| 1872 | + "from spotPython.hyperparameters.values import (add_core_model_to_fun_control,\n", |
| 1873 | + " get_default_hyperparameters_as_array)\n", |
| 1874 | + "from spotPython.fun.hyperlight import HyperLight\n", |
| 1875 | + "from spotPython.data.diabetes import Diabetes\n", |
| 1876 | + "from spotPython.hyperparameters.values import set_data_set\n", |
| 1877 | + "import numpy as np\n", |
| 1878 | + "fun_control = fun_control_init(\n", |
| 1879 | + " _L_in=10,\n", |
| 1880 | + " _L_out=1,)\n", |
| 1881 | + "dataset = Diabetes()\n", |
| 1882 | + "set_data_set(fun_control=fun_control,\n", |
| 1883 | + " data_set=dataset)\n", |
| 1884 | + "add_core_model_to_fun_control(core_model=NetLightRegression,\n", |
| 1885 | + " fun_control=fun_control,\n", |
| 1886 | + " hyper_dict=LightHyperDict)\n", |
| 1887 | + "hyper_light = HyperLight(seed=126, log_level=50)\n", |
| 1888 | + "X = get_default_hyperparameters_as_array(fun_control)\n", |
| 1889 | + "# combine X and X to a np.array with shape (2, n_hyperparams)\n", |
| 1890 | + "# so that two values are returned\n", |
| 1891 | + "X = np.vstack((X, X))\n", |
| 1892 | + "hyper_light.fun(X, fun_control)\n" |
| 1893 | + ] |
1689 | 1894 | }, |
1690 | 1895 | { |
1691 | 1896 | "cell_type": "code", |
|
0 commit comments