|
2435 | 2435 | "spot_tuner.run()" |
2436 | 2436 | ] |
2437 | 2437 | }, |
| 2438 | + { |
| 2439 | + "cell_type": "code", |
| 2440 | + "execution_count": 1, |
| 2441 | + "metadata": {}, |
| 2442 | + "outputs": [], |
| 2443 | + "source": [ |
| 2444 | + "import pandas as pd\n", |
| 2445 | + "import pytest\n", |
| 2446 | + "import torch\n", |
| 2447 | + "from pyhcf.data.loadHcfData import build_df, load_hcf_data\n", |
| 2448 | + "from torch.utils.data import DataLoader" |
| 2449 | + ] |
| 2450 | + }, |
| 2451 | + { |
| 2452 | + "cell_type": "code", |
| 2453 | + "execution_count": 6, |
| 2454 | + "metadata": {}, |
| 2455 | + "outputs": [ |
| 2456 | + { |
| 2457 | + "name": "stdout", |
| 2458 | + "output_type": "stream", |
| 2459 | + "text": [ |
| 2460 | + "Batch Size: 5\n", |
| 2461 | + "Inputs Shape: 3\n", |
| 2462 | + "P List: ['L', 'AQ', 'AS', 'T']\n", |
| 2463 | + "P List Length: 4\n", |
| 2464 | + "Targets Shape: 5\n" |
| 2465 | + ] |
| 2466 | + } |
| 2467 | + ], |
| 2468 | + "source": [ |
| 2469 | + "p_list=[\"L\", \"AQ\", \"AS\"]\n", |
| 2470 | + "dataset = load_hcf_data(param_list=p_list, target=\"T\",\n", |
| 2471 | + " rmNA=True, rmMF=True,\n", |
| 2472 | + " load_all_features=False,\n", |
| 2473 | + " load_thermo_features=False,\n", |
| 2474 | + " scale_data=True,\n", |
| 2475 | + " return_X_y=False)\n", |
| 2476 | + "assert isinstance(dataset, torch.utils.data.TensorDataset)\n", |
| 2477 | + "assert len(dataset) > 0\n", |
| 2478 | + "# Set batch size for DataLoader\n", |
| 2479 | + "batch_size = 5\n", |
| 2480 | + "# Create DataLoader \n", |
| 2481 | + "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n", |
| 2482 | + "# Iterate over the data in the DataLoader\n", |
| 2483 | + "for batch in dataloader:\n", |
| 2484 | + " inputs, targets = batch\n", |
| 2485 | + " print(f\"Batch Size: {inputs.size(0)}\")\n", |
| 2486 | + " assert inputs.size(0) == batch_size\n", |
| 2487 | + " print(f\"Inputs Shape: {inputs.shape[1]}\")\n", |
| 2488 | + " print(f\"P List: {p_list}\")\n", |
| 2489 | + " print(f\"P List Length: {len(p_list)}\")\n", |
| 2490 | + " # input is p_list + 1 (for target)\n", |
| 2491 | + " # p_list = [\"L\", \"AQ\", \"AS\"] plus target \"N\"\n", |
| 2492 | + " assert inputs.shape[1] + 1 == len(p_list)\n", |
| 2493 | + " print(f\"Targets Shape: {targets.shape[0]}\")\n", |
| 2494 | + " assert targets.shape[0] == batch_size\n", |
| 2495 | + " break" |
| 2496 | + ] |
| 2497 | + }, |
2438 | 2498 | { |
2439 | 2499 | "cell_type": "code", |
2440 | 2500 | "execution_count": null, |
|
0 commit comments