diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 224390d15..aadcaef7d 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, get_stream, randn_strided +from tests.utils import Payload, get_stream, rand_strided @pytest.mark.auto_act_and_assert @@ -82,8 +82,8 @@ def test_gemm( "instantiated specialization" ) - a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) - b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) + a = rand_strided(a_shape, a_strides, dtype=dtype, device=device) + b = rand_strided(b_shape, b_strides, dtype=dtype, device=device) if trans_a: a = a.transpose(-2, -1) @@ -91,7 +91,7 @@ def test_gemm( if trans_b: b = b.transpose(-2, -1) - c = randn_strided(c_shape, c_strides, dtype=dtype, device=device) + c = rand_strided(c_shape, c_strides, dtype=dtype, device=device) use_portable_ref = implementation_index == 2 and not ( device == "cpu" or (