Skip to content

Commit dbf6e4a

Browse files
0.17.1
improved get_device()
1 parent 8cc5ce0 commit dbf6e4a

4 files changed

Lines changed: 79 additions & 5 deletions

File tree

RELEASE_NOTES.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
spotpython-0.17.1
2+
3+
- improved get_device function
4+
15
spotpython-0.17.0
26

37
- designs.py and spacefilling.py:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.17.0"
10+
version = "0.17.1"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/utils/device.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,43 @@
22

33

44
def getDevice(device=None):
5-
"""Get cpu, gpu or mps device for training.
5+
"""Get CPU, GPU (CUDA), or MPS device for training.
6+
67
Args:
78
device (str):
8-
Device for training. If None or "auto" the device is selected automatically.
9+
Device for training. If None or "auto", the device is selected automatically based on availability.
910
1011
Returns:
1112
device (str):
1213
Device for training.
1314
15+
Raises:
16+
ValueError: If the requested device is not recognized or available.
17+
1418
Examples:
1519
>>> from spotpython.utils.device import getDevice
16-
>>> getDevice()
17-
'cuda:0'
20+
getDevice()
21+
'cuda:0'
1822
"""
1923
if device is None or device == "auto":
24+
# Automatically select device
2025
device = "cpu"
2126
if torch.cuda.is_available():
2227
device = "cuda:0"
2328
elif torch.backends.mps.is_available():
2429
device = "mps"
30+
return device
31+
32+
# Check the explicit device request
33+
if device.startswith("cuda"):
34+
if not torch.cuda.is_available():
35+
raise ValueError("CUDA device requested but no CUDA device is available.")
36+
elif device == "mps":
37+
if not torch.backends.mps.is_available():
38+
raise ValueError("MPS device requested but MPS is not available.")
39+
elif device == "cpu":
40+
return "cpu"
41+
else:
42+
raise ValueError(f"Unrecognized device: {device}. Valid options are 'cpu', 'cuda:x', or 'mps'.")
43+
2544
return device

test/test_get_device.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import pytest
2+
from unittest.mock import patch
3+
import torch
4+
from spotpython.utils.device import getDevice
5+
6+
def test_get_device_auto_cpu():
7+
with patch('torch.cuda.is_available') as mock_cuda, patch('torch.backends.mps.is_available') as mock_mps:
8+
mock_cuda.return_value = False
9+
mock_mps.return_value = False
10+
assert getDevice() == 'cpu'
11+
12+
def test_get_device_auto_cuda():
13+
with patch('torch.cuda.is_available') as mock_cuda, patch('torch.backends.mps.is_available') as mock_mps:
14+
mock_cuda.return_value = True
15+
mock_mps.return_value = False
16+
assert getDevice() == 'cuda:0'
17+
18+
def test_get_device_auto_mps():
19+
with patch('torch.cuda.is_available') as mock_cuda, patch('torch.backends.mps.is_available') as mock_mps:
20+
mock_cuda.return_value = False
21+
mock_mps.return_value = True
22+
assert getDevice() == 'mps'
23+
24+
def test_get_device_explicit_cpu():
25+
assert getDevice('cpu') == 'cpu'
26+
27+
def test_get_device_explicit_cuda_available():
28+
with patch('torch.cuda.is_available') as mock_cuda:
29+
mock_cuda.return_value = True
30+
assert getDevice('cuda:0') == 'cuda:0'
31+
32+
def test_get_device_explicit_cuda_unavailable():
33+
with patch('torch.cuda.is_available') as mock_cuda:
34+
mock_cuda.return_value = False
35+
with pytest.raises(ValueError, match="CUDA device requested but no CUDA device is available."):
36+
getDevice('cuda:0')
37+
38+
def test_get_device_explicit_mps_available():
39+
with patch('torch.backends.mps.is_available') as mock_mps:
40+
mock_mps.return_value = True
41+
assert getDevice('mps') == 'mps'
42+
43+
def test_get_device_explicit_mps_unavailable():
44+
with patch('torch.backends.mps.is_available') as mock_mps:
45+
mock_mps.return_value = False
46+
with pytest.raises(ValueError, match="MPS device requested but MPS is not available."):
47+
getDevice('mps')
48+
49+
def test_get_device_invalid():
50+
with pytest.raises(ValueError, match="Unrecognized device: invalid_device. Valid options are 'cpu', 'cuda:x', or 'mps'."):
51+
getDevice('invalid_device')

0 commit comments

Comments
 (0)