Skip to content

Commit 3a37f9c

Browse files
0.10.22
visualize nn
1 parent 740aca0 commit 3a37f9c

3 files changed

Lines changed: 383 additions & 62 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2435,6 +2435,66 @@
24352435
"spot_tuner.run()"
24362436
]
24372437
},
2438+
{
2439+
"cell_type": "code",
2440+
"execution_count": 1,
2441+
"metadata": {},
2442+
"outputs": [],
2443+
"source": [
2444+
"import pandas as pd\n",
2445+
"import pytest\n",
2446+
"import torch\n",
2447+
"from pyhcf.data.loadHcfData import build_df, load_hcf_data\n",
2448+
"from torch.utils.data import DataLoader"
2449+
]
2450+
},
2451+
{
2452+
"cell_type": "code",
2453+
"execution_count": 6,
2454+
"metadata": {},
2455+
"outputs": [
2456+
{
2457+
"name": "stdout",
2458+
"output_type": "stream",
2459+
"text": [
2460+
"Batch Size: 5\n",
2461+
"Inputs Shape: 3\n",
2462+
"P List: ['L', 'AQ', 'AS', 'T']\n",
2463+
"P List Length: 4\n",
2464+
"Targets Shape: 5\n"
2465+
]
2466+
}
2467+
],
2468+
"source": [
2469+
"p_list=[\"L\", \"AQ\", \"AS\"]\n",
2470+
"dataset = load_hcf_data(param_list=p_list, target=\"T\",\n",
2471+
" rmNA=True, rmMF=True,\n",
2472+
" load_all_features=False,\n",
2473+
" load_thermo_features=False,\n",
2474+
" scale_data=True,\n",
2475+
" return_X_y=False)\n",
2476+
"assert isinstance(dataset, torch.utils.data.TensorDataset)\n",
2477+
"assert len(dataset) > 0\n",
2478+
"# Set batch size for DataLoader\n",
2479+
"batch_size = 5\n",
2480+
"# Create DataLoader \n",
2481+
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n",
2482+
"# Iterate over the data in the DataLoader\n",
2483+
"for batch in dataloader:\n",
2484+
" inputs, targets = batch\n",
2485+
" print(f\"Batch Size: {inputs.size(0)}\")\n",
2486+
" assert inputs.size(0) == batch_size\n",
2487+
" print(f\"Inputs Shape: {inputs.shape[1]}\")\n",
2488+
" print(f\"P List: {p_list}\")\n",
2489+
" print(f\"P List Length: {len(p_list)}\")\n",
2490+
" # input is p_list + 1 (for target)\n",
2491+
" # p_list = [\"L\", \"AQ\", \"AS\"] plus target \"N\"\n",
2492+
" assert inputs.shape[1] + 1 == len(p_list)\n",
2493+
" print(f\"Targets Shape: {targets.shape[0]}\")\n",
2494+
" assert targets.shape[0] == batch_size\n",
2495+
" break"
2496+
]
2497+
},
24382498
{
24392499
"cell_type": "code",
24402500
"execution_count": null,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.10.20"
10+
version = "0.10.22"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

0 commit comments

Comments
 (0)