|
1 | 1 | import torchvision |
2 | 2 | import torchvision.transforms as transforms |
| 3 | +import socket |
| 4 | +from datetime import datetime |
| 5 | +from dateutil.tz import tzlocal |
3 | 6 |
|
4 | 7 |
|
5 | 8 | def load_data(data_dir="./data"): |
| 9 | + """Loads the CIFAR10 dataset. |
| 10 | + Args: |
| 11 | + data_dir (str, optional): Directory to save the data. Defaults to "./data". |
| 12 | + Returns: |
| 13 | + trainset (torchvision.datasets.CIFAR10): Training dataset. |
| 14 | + Examples: |
| 15 | + >>> from spotPython.utils.file import load_data |
| 16 | + >>> trainset = load_data(data_dir="./data") |
| 17 | +
|
| 18 | + """ |
6 | 19 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) |
7 | 20 |
|
8 | 21 | trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform) |
9 | 22 |
|
10 | 23 | testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform) |
11 | 24 |
|
12 | 25 | return trainset, testset |
| 26 | + |
| 27 | + |
| 28 | +def get_experiment_name(prefix: str = "00") -> str: |
| 29 | + """Returns a unique experiment name with a given prefix. |
| 30 | + Args: |
| 31 | + prefix (str, optional): Prefix for the experiment name. Defaults to "00". |
| 32 | + Returns: |
| 33 | + str: Unique experiment name. |
| 34 | + Examples: |
| 35 | + >>> from spotPython.utils.file import get_experiment_name |
| 36 | + >>> get_experiment_name(prefix="00") |
| 37 | + 00_ubuntu_2021-08-31_14-30-00 |
| 38 | + """ |
| 39 | + start_time = datetime.now(tzlocal()) |
| 40 | + HOSTNAME = socket.gethostname().split(".")[0] |
| 41 | + experiment_name = prefix + "_" + HOSTNAME + "_" + str(start_time).split(".", 1)[0].replace(" ", "_") |
| 42 | + experiment_name = experiment_name.replace(":", "-") |
| 43 | + return experiment_name |
0 commit comments