Skip to content

Commit f346e3b

Browse files
committed
Misc fixes on test configs
- add missing test mark - disable xla preallocate memory Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent ced0707 commit f346e3b

3 files changed

Lines changed: 16 additions & 1 deletion

File tree

pytest.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ markers =
1010
use_mlir: mark tests that depend on mlir which requires the "internal" extension
1111
env =
1212
CUDA_TILE_COMPILER_TIMEOUT_SEC=60
13+
XLA_PYTHON_CLIENT_PREALLOCATE=false

test/conftest.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
from functools import cache
1313

1414
from cuda.tile._bytecode.version import BytecodeVersion
15-
from cuda.tile._compile import _get_max_supported_bytecode_version, _SUPPORTED_VERSIONS
15+
from cuda.tile._compile import (
16+
_get_max_supported_bytecode_version,
17+
_SUPPORTED_VERSIONS,
18+
_find_compiler_bin)
1619
from cuda.tile._cext import dev_features_enabled
1720
from util import require_blackwell_or_newer, require_hopper_or_newer
1821

@@ -39,6 +42,16 @@ def strict_importorskip(modname, *args, **kwargs):
3942
pytest.importorskip = strict_importorskip
4043

4144

45+
def pytest_sessionstart(session):
46+
"""
47+
Called after the Session object has been created and
48+
before performing collection and entering the run test loop.
49+
"""
50+
print("Tile compiler path:", _find_compiler_bin().path)
51+
print("Dev features enabled:", dev_features_enabled())
52+
print("Bytecode version:", get_tileiras_version().as_string())
53+
54+
4255
@cache
4356
def get_tileiras_version():
4457
return _get_max_supported_bytecode_version(tempfile.gettempdir(),

test/test_token_order.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,6 +1402,7 @@ def compile_kernel(self, kernel):
14021402
Y = torch.arange(tile_size, device="cuda", dtype=torch.int32)
14031403
return get_bytecode(kernel, (X, Y, tile_size))
14041404

1405+
@requires_tileiras(BytecodeVersion.V_13_3)
14051406
@pytest.mark.parametrize("kernel, check_directive", make_cases(
14061407
(tv_atomic, TVAtomicCheckDirective),
14071408
))

0 commit comments

Comments
 (0)