From 78a18282be1b289ad57ba2438d35bb9ff1ad5a51 Mon Sep 17 00:00:00 2001 From: David Langerman Date: Wed, 17 Jun 2026 06:01:50 -0400 Subject: [PATCH] fix innocuous tracer warnings in onnx export --- dltype/_lib/_core.py | 9 +- dltype/_lib/_dltype_context.py | 76 ++++++++------ dltype/_lib/_tensor_type_base.py | 2 +- dltype/tests/dltype_test.py | 173 ++++++++++++++++++++----------- pyproject.toml | 2 +- uv.lock | 2 +- 6 files changed, 161 insertions(+), 103 deletions(-) diff --git a/dltype/_lib/_core.py b/dltype/_lib/_core.py index 6326d3c..01e6140 100644 --- a/dltype/_lib/_core.py +++ b/dltype/_lib/_core.py @@ -6,7 +6,7 @@ import itertools import warnings from copy import copy -from functools import lru_cache, wraps +from functools import wraps from types import EllipsisType from typing import ( TYPE_CHECKING, @@ -39,12 +39,11 @@ if TYPE_CHECKING: from collections.abc import Callable - -_logger: Final = _log_utils.get_logger(__name__) - P = ParamSpec("P") R = TypeVar("R") +_logger: Final = _log_utils.get_logger(__name__) + class DLTypeAnnotation(NamedTuple): """A class representing a type annotation for a tensor.""" @@ -130,7 +129,6 @@ def from_hint( # noqa: PLR0911 return (cls(tensor_type_hint=tensor_type, dltype_annotation=dltype_hint),) -@lru_cache() def _resolve_types( annotations: tuple[DLTypeAnnotation | None, ...] | None, ) -> tuple[_tensor_type_base.TensorTypeBase | None, ...] | None: @@ -165,7 +163,6 @@ def _maybe_get_type_hints( return None -@lru_cache() def _maybe_get_signature( existing: inspect.Signature | None, func: Callable[P, R], diff --git a/dltype/_lib/_dltype_context.py b/dltype/_lib/_dltype_context.py index 32f7e3e..c9f8433 100644 --- a/dltype/_lib/_dltype_context.py +++ b/dltype/_lib/_dltype_context.py @@ -8,7 +8,16 @@ from typing import Any, Final, NamedTuple, TypeAlias, cast from dltype._lib import _constants, _dtypes, _errors, _log_utils, _parser, _tensor_type_base +from dltype._lib import _dependency_utilities as _deps +if _deps.is_torch_available(): + from torch.jit import TracerWarning # pyright: ignore[reportPrivateImportUsage] +else: + + class _NullWarning(Warning): + pass + + TracerWarning = _NullWarning _logger: Final = _log_utils.get_logger(__name__) EvaluatedDimensionT: TypeAlias = dict[str, int] @@ -123,43 +132,48 @@ def assert_context(self) -> None: """Considering the current context, check if all tensors match their expected types.""" __tracebackhide__ = not _constants.DEBUG_MODE - start_t = time.perf_counter_ns() + with warnings.catch_warnings(): + warnings.simplefilter(category=TracerWarning, action="ignore") - try: - while self._hinted_tensors: - tensor_context = self._hinted_tensors.popleft() - # first check if the tensor could possibly have the right shape - tensor_context.dltype_annotation.check( - tensor_context.tensor, - tensor_name=tensor_context.tensor_arg_name, - ) + start_t = time.perf_counter_ns() - if tensor_context.tensor_arg_name in self.registered_tensor_dtypes: - raise _errors.DLTypeDuplicateError( + try: + while self._hinted_tensors: + tensor_context = self._hinted_tensors.popleft() + # first check if the tensor could possibly have the right shape + tensor_context.dltype_annotation.check( + tensor_context.tensor, tensor_name=tensor_context.tensor_arg_name, ) - self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = tensor_context.tensor.dtype - expected_shape = tensor_context.get_expected_shape( - tensor_context.tensor, - ) - self._assert_tensor_shape( - tensor_context.tensor_arg_name, - expected_shape, - tensor_context.tensor, - ) + if tensor_context.tensor_arg_name in self.registered_tensor_dtypes: + raise _errors.DLTypeDuplicateError( + tensor_name=tensor_context.tensor_arg_name, + ) - finally: - end_t = time.perf_counter_ns() - runtime_ns = end_t - start_t - _logger.debug("Context evaluation took %d ns", runtime_ns) - if _maybe_warn_runtime(runtime_ns): - max_ms = _constants.MAX_ACCEPTABLE_EVALUATION_TIME_NS / 1e6 - warnings.warn( - f"Type checking took longer than expected {(runtime_ns) / 1e6:.2f}ms > {max_ms:.2f}ms", - UserWarning, - stacklevel=2, - ) + self.registered_tensor_dtypes[tensor_context.tensor_arg_name] = ( + tensor_context.tensor.dtype + ) + expected_shape = tensor_context.get_expected_shape( + tensor_context.tensor, + ) + self._assert_tensor_shape( + tensor_context.tensor_arg_name, + expected_shape, + tensor_context.tensor, + ) + + finally: + end_t = time.perf_counter_ns() + runtime_ns = end_t - start_t + _logger.debug("Context evaluation took %d ns", runtime_ns) + if _maybe_warn_runtime(runtime_ns): + max_ms = _constants.MAX_ACCEPTABLE_EVALUATION_TIME_NS / 1e6 + warnings.warn( + f"Type checking took longer than expected {(runtime_ns) / 1e6:.2f}ms > {max_ms:.2f}ms", + UserWarning, + stacklevel=2, + ) def _assert_tensor_shape( self, diff --git a/dltype/_lib/_tensor_type_base.py b/dltype/_lib/_tensor_type_base.py index f120e70..57d8ce5 100644 --- a/dltype/_lib/_tensor_type_base.py +++ b/dltype/_lib/_tensor_type_base.py @@ -153,7 +153,6 @@ def validate_tensor( return core_schema.with_info_after_validator_function( validate_tensor, schema=core_schema.is_instance_schema(source_type), - field_name=handler.field_name, ) def check( @@ -164,6 +163,7 @@ def check( """Check if the tensor matches this type.""" # Basic validation for multi-axis dimensions __tracebackhide__ = not _constants.DEBUG_MODE + if self.multiaxis_index is not None: # Min required dimensions = expected shape length + extra dimensions - 1 (the multi-axis placeholder) min_required_dims = len(self.expected_shape) - 1 diff --git a/dltype/tests/dltype_test.py b/dltype/tests/dltype_test.py index 3200185..9d8ba99 100644 --- a/dltype/tests/dltype_test.py +++ b/dltype/tests/dltype_test.py @@ -18,7 +18,6 @@ import pytest import torch from pydantic import BaseModel -from torch.jit import TracerWarning # pyright: ignore[reportPrivateImportUsage] from typing_extensions import Self import dltype @@ -26,6 +25,9 @@ if TYPE_CHECKING: from collections.abc import Callable +# turns all warnings into errors for this module +pytestmark = pytest.mark.filterwarnings("error") + np_rand = np.random.RandomState(42).rand NPFloatArrayT: TypeAlias = npt.NDArray[np.float32 | np.float64] NPIntArrayT: TypeAlias = npt.NDArray[np.int32 | np.uint16 | np.uint32 | np.uint8] @@ -58,20 +60,24 @@ def good_function( return tensor.permute(2, 3, 0, 1) -@dltype.dltyped() -def incomplete_annotated_function( - tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], -) -> torch.Tensor: - """A function that takes a tensor and returns a tensor.""" - return tensor +with pytest.warns(UserWarning, match="missing a DLType hint"): + @dltype.dltyped() + def incomplete_annotated_function( + tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], + ) -> torch.Tensor: + """A function that takes a tensor and returns a tensor.""" + return tensor -@dltype.dltyped() -def incomplete_return_function( - tensor: torch.Tensor, -) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: - """A function that takes a tensor and returns a tensor.""" - return tensor.permute(2, 3, 0, 1) + +with pytest.warns(UserWarning, match="missing a DLType hint"): + + @dltype.dltyped() + def incomplete_return_function( + tensor: torch.Tensor, + ) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: + """A function that takes a tensor and returns a tensor.""" + return tensor.permute(2, 3, 0, 1) @dltype.dltyped() @@ -535,9 +541,8 @@ def forward( ) -> Annotated[torch.Tensor, dltype.FloatTensor("b c h w")]: return torch.multiply(x, 2) - with NamedTemporaryFile() as f, warnings.catch_warnings(): + with NamedTemporaryFile() as f, warnings.catch_warnings(record=True): warnings.simplefilter(category=DeprecationWarning, action="ignore") - warnings.simplefilter(category=TracerWarning, action="ignore") torch.onnx.export( _DummyModule(), (torch.rand(1, 2, 3, 4),), @@ -563,36 +568,61 @@ def forward( ) +class _DummyModule(torch.nn.Module): + @dltype.dltyped() + def forward( + self, + x: Annotated[torch.Tensor, dltype.FloatTensor("b c h w")], + ) -> Annotated[torch.Tensor, dltype.FloatTensor("b c h w")]: + return torch.multiply(x, 2) + + +@dltype.dltyped_namedtuple() +class Container(NamedTuple): + arg: Annotated[torch.Tensor, dltype.IntTensor[None]] + + +class _DummyModuleNamedtuple(torch.nn.Module): + def forward(self, x: Container) -> None: + assert x is not None + + def test_torch_compile() -> None: - class _DummyModule(torch.nn.Module): - @dltype.dltyped() - def forward( - self, - x: Annotated[torch.Tensor, dltype.FloatTensor("b c h w")], - ) -> Annotated[torch.Tensor, dltype.FloatTensor("b c h w")]: - return torch.multiply(x, 2) _DummyModule().forward(torch.rand(1, 2, 3, 4)) with pytest.raises(dltype.DLTypeNDimsError): _DummyModule().forward(torch.rand(1, 2, 3)) - module = torch.compile(_DummyModule()) + regular_mod = _DummyModule() + with warnings.catch_warnings(): + warnings.simplefilter(category=DeprecationWarning, action="ignore") + module = torch.compile(regular_mod) + module2 = torch.compile(_DummyModuleNamedtuple()) + + arg = torch.rand(1, 2, 3, 4) + torch.testing.assert_close(module(arg), regular_mod(arg)) - module(torch.rand(1, 2, 3, 4)) + arg2 = Container(arg=torch.tensor(1)) + module2(arg2) with pytest.raises(dltype.DLTypeNDimsError): module(torch.rand(1, 2, 3)) + +def test_jit_trace() -> None: with warnings.catch_warnings(): - warnings.simplefilter("ignore", TracerWarning) + warnings.simplefilter(category=DeprecationWarning, action="ignore") torch.jit.trace(_DummyModule(), torch.rand(1, 2, 3, 4)) + +def test_jit_script() -> None: if sys.version_info.minor >= 14: - # torch doesn't support script in 3.14 - return + pytest.skip("torch doesn't support script in 3.14") - scripted_module = torch.jit.script(_DummyModule()) + with warnings.catch_warnings(): + warnings.simplefilter(category=DeprecationWarning, action="ignore") + scripted_module = torch.jit.script(_DummyModule()) scripted_module(torch.rand(1, 2, 3, 4)) @@ -600,17 +630,19 @@ def forward( scripted_module(torch.rand(1, 2, 3)) -@dltype.dltyped() -def mixed_func( # noqa: PLR0913 - tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], - array: Annotated[NPFloatArrayT, dltype.TensorTypeBase["b c h w"]], - number: int, - other_tensor: torch.Tensor, - other_array: NPFloatArrayT, - other_number: float, - other_annotated_tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["c c c"]], -) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: - return tensor.permute(2, 3, 0, 1) +with pytest.warns(UserWarning, match="missing a DLType hint"): + + @dltype.dltyped() + def mixed_func( # noqa: PLR0913 + tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["b c h w"]], + array: Annotated[NPFloatArrayT, dltype.TensorTypeBase["b c h w"]], + number: int, + other_tensor: torch.Tensor, + other_array: NPFloatArrayT, + other_number: float, + other_annotated_tensor: Annotated[torch.Tensor, dltype.TensorTypeBase["c c c"]], + ) -> Annotated[torch.Tensor, dltype.TensorTypeBase["h w b c"]]: + return tensor.permute(2, 3, 0, 1) def test_mixed_typing() -> None: @@ -1054,14 +1086,16 @@ def bad_function( # pyright: ignore[reportUnusedFunction] def test_dimension_name_with_underscores() -> None: - @dltype.dltyped() - def good_function( # pyright: ignore[reportUnusedFunction] - tensor: Annotated[ - torch.Tensor, - dltype.IntTensor["batch channels_in channels_out"], - ], - ) -> torch.Tensor: - return tensor + with pytest.warns(UserWarning, match="missing a DLType hint"): + + @dltype.dltyped() + def good_function( # pyright: ignore[reportUnusedFunction] + tensor: Annotated[ + torch.Tensor, + dltype.IntTensor["batch channels_in channels_out"], + ], + ) -> torch.Tensor: + return tensor def test_dimension_with_external_scope() -> None: @@ -1092,11 +1126,11 @@ def good_function( ) -> torch.Tensor: return tensor - with pytest.WarningsRecorder() as rec: + with warnings.catch_warnings(record=True): good_function(torch.ones(1, 3, 4).int()) - with pytest.WarningsRecorder() as rec: + + with warnings.catch_warnings(record=True): good_function(torch.ones(4, 3, 4).int()) - assert len(rec.list) == 0 with pytest.raises(dltype.DLTypeShapeError): good_function(torch.ones(1, 3, 5).int()) @@ -1236,13 +1270,16 @@ def test_annotated_dataclass() -> None: """Test that dltyped correctly handles Annotated dataclasses.""" # Test with a function with annotated dataclass - @dltype.dltyped_dataclass() - @dataclass(frozen=True, kw_only=True, slots=True) - class AnnotatedDataclass: - tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] - tensor_2: Annotated[torch.Tensor, dltype.IntTensor["b c h w"]] - other_thing: int - un_annotated_tensor: torch.Tensor + + with pytest.warns(UserWarning, match="missing a DLType hint"): + + @dltype.dltyped_dataclass() + @dataclass(frozen=True, kw_only=True, slots=True) + class AnnotatedDataclass: + tensor: Annotated[torch.Tensor, dltype.FloatTensor["b c h w"]] + tensor_2: Annotated[torch.Tensor, dltype.IntTensor["b c h w"]] + other_thing: int + un_annotated_tensor: torch.Tensor AnnotatedDataclass( tensor=torch.rand(1, 2, 3, 4), @@ -1362,8 +1399,12 @@ def create( def test_warning_if_decorator_has_no_annotations_to_check() -> None: - with pytest.warns( - UserWarning, match="No DLType hints found for Function: no_annotations, skipping type checking" + with ( + pytest.warns( + UserWarning, match="No DLType hints found for Function: no_annotations, skipping type checking" + ), + pytest.warns(UserWarning, match=re.escape("[tensor] is missing a DLType hint")), + pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")), ): @dltype.dltyped() @@ -1385,7 +1426,11 @@ def some_annotations( some_annotations(torch.rand(1, 2, 3)) - with pytest.warns(UserWarning, match=re.escape("[tensor] has an invalid DLType hint")): + with ( + pytest.warns(UserWarning, match=re.escape("[tensor] has an invalid DLType hint")), + pytest.warns(UserWarning, match="No DLType hints"), + pytest.warns(UserWarning, match=re.escape("[return] is missing a DLType hint")), + ): @dltype.dltyped() def some_annotations( @@ -1735,8 +1780,10 @@ def func(non_optional_tensor: SomeTensorT, optional_tensor: SomeTensorT | None = def test_weird_type() -> None: - @dltype.dltyped() - def func(arg: type) -> None: - pass + with pytest.warns(UserWarning, match="No DLType hints"): + + @dltype.dltyped() + def func(arg: type) -> None: + pass func(int) diff --git a/pyproject.toml b/pyproject.toml index 783fc05..79c5d0b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ license-files = ["LICENSE"] name = "dltype" readme = "README.md" requires-python = ">=3.10" -version = "0.14.0" +version = "0.15.0" [project.optional-dependencies] jax = ["jax>=0.6.2"] diff --git a/uv.lock b/uv.lock index a1b3091..ca58a63 100644 --- a/uv.lock +++ b/uv.lock @@ -222,7 +222,7 @@ nvtx = [ [[package]] name = "dltype" -version = "0.14.0" +version = "0.15.0" source = { virtual = "." } dependencies = [ { name = "pydantic" },