1+ import pytest
2+ from unittest .mock import patch
3+ import torch
4+ from spotpython .utils .device import getDevice
5+
6+ def test_get_device_auto_cpu ():
7+ with patch ('torch.cuda.is_available' ) as mock_cuda , patch ('torch.backends.mps.is_available' ) as mock_mps :
8+ mock_cuda .return_value = False
9+ mock_mps .return_value = False
10+ assert getDevice () == 'cpu'
11+
12+ def test_get_device_auto_cuda ():
13+ with patch ('torch.cuda.is_available' ) as mock_cuda , patch ('torch.backends.mps.is_available' ) as mock_mps :
14+ mock_cuda .return_value = True
15+ mock_mps .return_value = False
16+ assert getDevice () == 'cuda:0'
17+
18+ def test_get_device_auto_mps ():
19+ with patch ('torch.cuda.is_available' ) as mock_cuda , patch ('torch.backends.mps.is_available' ) as mock_mps :
20+ mock_cuda .return_value = False
21+ mock_mps .return_value = True
22+ assert getDevice () == 'mps'
23+
24+ def test_get_device_explicit_cpu ():
25+ assert getDevice ('cpu' ) == 'cpu'
26+
27+ def test_get_device_explicit_cuda_available ():
28+ with patch ('torch.cuda.is_available' ) as mock_cuda :
29+ mock_cuda .return_value = True
30+ assert getDevice ('cuda:0' ) == 'cuda:0'
31+
32+ def test_get_device_explicit_cuda_unavailable ():
33+ with patch ('torch.cuda.is_available' ) as mock_cuda :
34+ mock_cuda .return_value = False
35+ with pytest .raises (ValueError , match = "CUDA device requested but no CUDA device is available." ):
36+ getDevice ('cuda:0' )
37+
38+ def test_get_device_explicit_mps_available ():
39+ with patch ('torch.backends.mps.is_available' ) as mock_mps :
40+ mock_mps .return_value = True
41+ assert getDevice ('mps' ) == 'mps'
42+
43+ def test_get_device_explicit_mps_unavailable ():
44+ with patch ('torch.backends.mps.is_available' ) as mock_mps :
45+ mock_mps .return_value = False
46+ with pytest .raises (ValueError , match = "MPS device requested but MPS is not available." ):
47+ getDevice ('mps' )
48+
49+ def test_get_device_invalid ():
50+ with pytest .raises (ValueError , match = "Unrecognized device: invalid_device. Valid options are 'cpu', 'cuda:x', or 'mps'." ):
51+ getDevice ('invalid_device' )
0 commit comments