From cda93009d00a7f3d59a66b01c09011851f6d3bcb Mon Sep 17 00:00:00 2001 From: Zhang Shuo <52872288+fuyou4546@users.noreply.github.com> Date: Tue, 16 Jun 2026 08:41:45 +0000 Subject: [PATCH] test(gemm): use rand_strided instead of randn_strided --- tests/test_gemm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 224390d15..aadcaef7d 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -2,7 +2,7 @@ import pytest import torch -from tests.utils import Payload, get_stream, randn_strided +from tests.utils import Payload, get_stream, rand_strided @pytest.mark.auto_act_and_assert @@ -82,8 +82,8 @@ def test_gemm( "instantiated specialization" ) - a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) - b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) + a = rand_strided(a_shape, a_strides, dtype=dtype, device=device) + b = rand_strided(b_shape, b_strides, dtype=dtype, device=device) if trans_a: a = a.transpose(-2, -1) @@ -91,7 +91,7 @@ def test_gemm( if trans_b: b = b.transpose(-2, -1) - c = randn_strided(c_shape, c_strides, dtype=dtype, device=device) + c = rand_strided(c_shape, c_strides, dtype=dtype, device=device) use_portable_ref = implementation_index == 2 and not ( device == "cpu" or (