From 4a50788e3e0924525f1eb7bc4d64917a59f5acfd Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Wed, 25 Dec 2024 16:45:28 +0100 Subject: [PATCH 1/3] Update __init__.py --- torchvision/tv_tensors/__init__.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 1ba47f60a36..82f57c74aac 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,3 +1,5 @@ +from typing import TypeVar + import torch from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat @@ -7,12 +9,14 @@ from ._tv_tensor import TVTensor from ._video import Video +TVTensorLike = TypeVar("TVTensorLike", TVTensor, BoundingBoxes, Image, Mask, Video) + # TODO: Fix this. We skip this method as it leads to # RecursionError: maximum recursion depth exceeded while calling a Python object # Until `disable` is removed, there will be graph breaks after all calls to functional transforms @torch.compiler.disable -def wrap(wrappee, *, like, **kwargs): +def wrap(wrappee: torch.Tensor, *, like: TVTensorLike, **kwargs) -> TVTensorLike: """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of From 9e83b32c4b57905f531db4ac33d69c5bdc3ffdfa Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Thu, 20 Feb 2025 17:39:19 +0000 Subject: [PATCH 2/3] Fix type hinting issue by ignoring mypy false-positive expectation --- torchvision/tv_tensors/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 82f57c74aac..6fe8c45da8c 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -30,7 +30,7 @@ def wrap(wrappee: torch.Tensor, *, like: TVTensorLike, **kwargs) -> TVTensorLike Ignored otherwise. """ if isinstance(like, BoundingBoxes): - return BoundingBoxes._wrap( + return BoundingBoxes._wrap( # type: ignore wrappee, format=kwargs.get("format", like.format), canvas_size=kwargs.get("canvas_size", like.canvas_size), From 9c97e778dcf9069993d46680b0e492dd93bc7f50 Mon Sep 17 00:00:00 2001 From: twsl <45483159+twsl@users.noreply.github.com> Date: Thu, 20 Feb 2025 22:45:06 +0000 Subject: [PATCH 3/3] Use bound with covariant instead --- torchvision/tv_tensors/__init__.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torchvision/tv_tensors/__init__.py b/torchvision/tv_tensors/__init__.py index 6fe8c45da8c..2fd9b9fbc7d 100644 --- a/torchvision/tv_tensors/__init__.py +++ b/torchvision/tv_tensors/__init__.py @@ -1,4 +1,4 @@ -from typing import TypeVar +from typing import TypeVar, cast import torch @@ -9,14 +9,14 @@ from ._tv_tensor import TVTensor from ._video import Video -TVTensorLike = TypeVar("TVTensorLike", TVTensor, BoundingBoxes, Image, Mask, Video) +TVTensorLike = TypeVar("TVTensorLike", bound=TVTensor, covariant=True) # TODO: Fix this. We skip this method as it leads to # RecursionError: maximum recursion depth exceeded while calling a Python object # Until `disable` is removed, there will be graph breaks after all calls to functional transforms @torch.compiler.disable -def wrap(wrappee: torch.Tensor, *, like: TVTensorLike, **kwargs) -> TVTensorLike: +def wrap(wrappee: torch.Tensor, *, like: TVTensorLike, **kwargs) -> TVTensorLike: # type: ignore """Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of @@ -30,10 +30,10 @@ def wrap(wrappee: torch.Tensor, *, like: TVTensorLike, **kwargs) -> TVTensorLike Ignored otherwise. """ if isinstance(like, BoundingBoxes): - return BoundingBoxes._wrap( # type: ignore + return cast(TVTensorLike, BoundingBoxes._wrap( wrappee, format=kwargs.get("format", like.format), canvas_size=kwargs.get("canvas_size", like.canvas_size), - ) + )) else: return wrappee.as_subclass(type(like))