diff --git a/tests/fsdp_state_dict_save.py b/tests/fsdp_state_dict_save.py index 2e56c1c03..4a2e08c2b 100644 --- a/tests/fsdp_state_dict_save.py +++ b/tests/fsdp_state_dict_save.py @@ -20,6 +20,34 @@ import bitsandbytes as bnb +def _current_accelerator_type(): + if hasattr(torch, "accelerator") and torch.accelerator.is_available(): + return str(torch.accelerator.current_accelerator()) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return "xpu" + if torch.cuda.is_available(): + return "cuda" + return "cpu" + + +def _set_device_index(index: int, device_type: str): + if hasattr(torch, "accelerator"): + torch.accelerator.set_device_index(index) + return + if device_type == "cuda": + torch.cuda.set_device(index) + elif device_type == "xpu" and hasattr(torch, "xpu") and hasattr(torch.xpu, "set_device"): + torch.xpu.set_device(index) + + +def _get_device_and_backend(): + """Auto-detect accelerator device and distributed backend.""" + device_type = _current_accelerator_type() + backend_map = {"cuda": "nccl", "xpu": "xccl"} + backend = backend_map.get(device_type, "gloo") + return device_type, backend + + class SimpleQLoRAModel(nn.Module): """Minimal model with a frozen 4-bit base layer and a trainable adapter.""" @@ -33,15 +61,16 @@ def forward(self, x): def main(): - dist.init_process_group(backend="nccl") + device_type, backend = _get_device_and_backend() + dist.init_process_group(backend=backend) rank = dist.get_rank() - torch.cuda.set_device(rank) + _set_device_index(rank, device_type) errors = [] for quant_type in ("nf4", "fp4"): model = SimpleQLoRAModel(quant_type=quant_type) - model = model.to("cuda") + model = model.to(device_type) # Freeze quantized base weights (as in real QLoRA) for p in model.base.parameters(): diff --git a/tests/test_functional.py b/tests/test_functional.py index 0e2e5e1eb..a2127c58a 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -98,10 +98,10 @@ class Test8BitBlockwiseQuantizeFunctional: def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): iters = 100 - if device != "cuda": + if device not in ["cuda", "xpu"]: iters = 10 - # This test is slow in our non-CUDA implementations, so avoid atypical use cases. + # This test is slow in our non-cuda/non-xpu implementations, so avoid atypical use cases. if nested: pytest.skip("Not a typical use case.") if blocksize != 256: diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py index d9a25c90e..f2335e5ea 100644 --- a/tests/test_linear4bit.py +++ b/tests/test_linear4bit.py @@ -569,11 +569,8 @@ def test_params4bit_quant_state_attr_access(device, quant_type, compress_statist assert w.bnb_quantized is True -@pytest.mark.skipif(not torch.cuda.is_available(), reason="FSDP requires CUDA") -@pytest.mark.skipif( - not getattr(torch.distributed, "is_nccl_available", lambda: False)(), - reason="FSDP test requires NCCL backend", -) +@pytest.mark.skipif(platform.system() == "Windows", reason="FSDP is not supported on Windows") +@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="FSDP requires an accelerator device") def test_fsdp_state_dict_save_4bit(): """Integration test: FSDP get_model_state_dict with cpu_offload on a 4-bit model (#1405). diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 410961e0b..a9d75b92e 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -172,8 +172,9 @@ def test_linear_serialization( assert torch.allclose(x_first.grad, x_third.grad, atol=1e-5) -@pytest.fixture -def linear8bit(requires_cuda): +@pytest.fixture(params=get_available_devices(no_cpu=True)) +def linear8bit(request): + device = request.param linear = torch.nn.Linear(32, 96) linear_custom = Linear8bitLt( linear.in_features, @@ -188,7 +189,7 @@ def linear8bit(requires_cuda): has_fp16_weights=False, ) linear_custom.bias = linear.bias - linear_custom = linear_custom.cuda() + linear_custom = linear_custom.to(device) return linear_custom diff --git a/tests/test_modules.py b/tests/test_modules.py index 8c4d666d3..95f78b6d3 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -448,34 +448,36 @@ def test_4bit_embedding_warnings(device, caplog): assert any("inference" in msg for msg in caplog.messages) -def test_4bit_embedding_weight_fsdp_fix(requires_cuda): +@pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) +def test_4bit_embedding_weight_fsdp_fix(device): num_embeddings = 64 embedding_dim = 32 module = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=embedding_dim) - module.cuda() + module.to(device) module.weight.quant_state = None - input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda") + input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device=device) module(input_tokens) assert module.weight.quant_state is not None -def test_4bit_linear_weight_fsdp_fix(requires_cuda): +@pytest.mark.parametrize("device", get_available_devices(no_cpu=True)) +def test_4bit_linear_weight_fsdp_fix(device): inp_size = 64 out_size = 32 module = bnb.nn.Linear4bit(inp_size, out_size) - module.cuda() + module.to(device) module.weight.quant_state = None - input_tensor = torch.randn((1, inp_size), device="cuda") + input_tensor = torch.randn((1, inp_size), device=device) module(input_tensor)