From 89be62eea28fd0de710effc0aa84ea6aafb8e5d9 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Mon, 6 Apr 2026 14:59:09 -0400 Subject: [PATCH 1/4] feature: Torch dependency in sagameker-core to be made optional (5457) --- sagemaker-core/pyproject.toml | 7 +- .../src/sagemaker/core/deserializers/base.py | 5 +- .../src/sagemaker/core/serializers/base.py | 8 +- .../unit/test_optional_torch_dependency.py | 154 ++++++++++++++++++ 4 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 sagemaker-core/tests/unit/test_optional_torch_dependency.py diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index 2756ce0f1c..4134d50e34 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -32,7 +32,6 @@ dependencies = [ "smdebug_rulesconfig>=1.0.1", "schema>=0.7.5", "omegaconf>=2.1.0", - "torch>=1.9.0", "scipy>=1.5.0", # Remote function dependencies "cloudpickle>=2.0.0", @@ -51,6 +50,12 @@ classifiers = [ ] [project.optional-dependencies] +torch = [ + "torch>=1.9.0", +] +all = [ + "torch>=1.9.0", +] codegen = [ "black>=24.3.0, <25.0.0", "pandas>=2.0.0, <3.0.0", diff --git a/sagemaker-core/src/sagemaker/core/deserializers/base.py b/sagemaker-core/src/sagemaker/core/deserializers/base.py index 4faae7db74..1f7ec9ab06 100644 --- a/sagemaker-core/src/sagemaker/core/deserializers/base.py +++ b/sagemaker-core/src/sagemaker/core/deserializers/base.py @@ -366,7 +366,10 @@ def __init__(self, accept="tensor/pt"): self.convert_npy_to_tensor = from_numpy except ImportError: - raise Exception("Unable to import pytorch.") + raise ImportError( + "Unable to import torch. Please install torch to use TorchTensorDeserializer: " + "pip install 'sagemaker-core[torch]'" + ) def deserialize(self, stream, content_type="tensor/pt"): """Deserialize streamed data to TorchTensor diff --git a/sagemaker-core/src/sagemaker/core/serializers/base.py b/sagemaker-core/src/sagemaker/core/serializers/base.py index a4ecf7c1dc..4b3ba4fdba 100644 --- a/sagemaker-core/src/sagemaker/core/serializers/base.py +++ b/sagemaker-core/src/sagemaker/core/serializers/base.py @@ -443,7 +443,13 @@ class TorchTensorSerializer(SimpleBaseSerializer): def __init__(self, content_type="tensor/pt"): super(TorchTensorSerializer, self).__init__(content_type=content_type) - from torch import Tensor + try: + from torch import Tensor + except ImportError: + raise ImportError( + "Unable to import torch. Please install torch to use TorchTensorSerializer: " + "pip install 'sagemaker-core[torch]'" + ) self.torch_tensor = Tensor self.numpy_serializer = NumpySerializer() diff --git a/sagemaker-core/tests/unit/test_optional_torch_dependency.py b/sagemaker-core/tests/unit/test_optional_torch_dependency.py new file mode 100644 index 0000000000..f0eb2879fd --- /dev/null +++ b/sagemaker-core/tests/unit/test_optional_torch_dependency.py @@ -0,0 +1,154 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests to verify torch dependency is optional in sagemaker-core.""" +from __future__ import absolute_import + +import io +import sys +from unittest import mock + +import numpy as np +import pytest + + +def test_serializer_module_imports_without_torch(): + """Verify that importing serializers module succeeds without torch installed.""" + # The serializers module should be importable even without torch + # because TorchTensorSerializer uses lazy import in __init__ + from sagemaker.core.serializers.base import ( + CSVSerializer, + NumpySerializer, + JSONSerializer, + IdentitySerializer, + SparseMatrixSerializer, + JSONLinesSerializer, + LibSVMSerializer, + DataSerializer, + StringSerializer, + ) + # Verify non-torch serializers can be instantiated + assert CSVSerializer() is not None + assert NumpySerializer() is not None + assert JSONSerializer() is not None + assert IdentitySerializer() is not None + + +def test_deserializer_module_imports_without_torch(): + """Verify that importing deserializers module succeeds without torch installed.""" + from sagemaker.core.deserializers.base import ( + StringDeserializer, + BytesDeserializer, + CSVDeserializer, + StreamDeserializer, + NumpyDeserializer, + JSONDeserializer, + PandasDeserializer, + JSONLinesDeserializer, + ) + # Verify non-torch deserializers can be instantiated + assert StringDeserializer() is not None + assert BytesDeserializer() is not None + assert CSVDeserializer() is not None + assert NumpyDeserializer() is not None + assert JSONDeserializer() is not None + + +def test_torch_tensor_serializer_raises_import_error_without_torch(): + """Verify TorchTensorSerializer raises ImportError when torch is not installed.""" + import importlib + import sagemaker.core.serializers.base as ser_module + + # Save original torch module if present + original_torch = sys.modules.get('torch') + + try: + # Simulate torch not being installed + sys.modules['torch'] = None + # Need to also handle the case where torch submodules are cached + torch_keys = [key for key in sys.modules if key.startswith('torch.')] + saved = {key: sys.modules.pop(key) for key in torch_keys} + + with pytest.raises(ImportError, match="Unable to import torch"): + ser_module.TorchTensorSerializer() + finally: + # Restore original state + if original_torch is not None: + sys.modules['torch'] = original_torch + elif 'torch' in sys.modules: + del sys.modules['torch'] + for key, val in saved.items(): + sys.modules[key] = val + + +def test_torch_tensor_deserializer_raises_import_error_without_torch(): + """Verify TorchTensorDeserializer raises ImportError when torch is not installed.""" + import sagemaker.core.deserializers.base as deser_module + + # Save original torch module if present + original_torch = sys.modules.get('torch') + + try: + # Simulate torch not being installed + sys.modules['torch'] = None + torch_keys = [key for key in sys.modules if key.startswith('torch.')] + saved = {key: sys.modules.pop(key) for key in torch_keys} + + with pytest.raises(ImportError, match="Unable to import torch"): + deser_module.TorchTensorDeserializer() + finally: + # Restore original state + if original_torch is not None: + sys.modules['torch'] = original_torch + elif 'torch' in sys.modules: + del sys.modules['torch'] + for key, val in saved.items(): + sys.modules[key] = val + + +def test_torch_tensor_serializer_works_with_torch(): + """Verify TorchTensorSerializer works when torch is available.""" + try: + import torch + except ImportError: + pytest.skip("torch is not installed") + + from sagemaker.core.serializers.base import TorchTensorSerializer + + serializer = TorchTensorSerializer() + tensor = torch.tensor([1.0, 2.0, 3.0]) + result = serializer.serialize(tensor) + assert result is not None + # Verify the result can be loaded back as numpy + array = np.load(io.BytesIO(result)) + assert np.array_equal(array, np.array([1.0, 2.0, 3.0])) + + +def test_torch_tensor_deserializer_works_with_torch(): + """Verify TorchTensorDeserializer works when torch is available.""" + try: + import torch + except ImportError: + pytest.skip("torch is not installed") + + from sagemaker.core.deserializers.base import TorchTensorDeserializer + + deserializer = TorchTensorDeserializer() + # Create a numpy array, save it, and deserialize to tensor + array = np.array([1.0, 2.0, 3.0]) + buffer = io.BytesIO() + np.save(buffer, array) + buffer.seek(0) + + result = deserializer.deserialize(buffer, "tensor/pt") + assert isinstance(result, torch.Tensor) + assert torch.equal(result, torch.tensor([1.0, 2.0, 3.0])) From 57b5c02cb6d097312fc055fa4bf83559abbae941 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Mon, 6 Apr 2026 15:04:48 -0400 Subject: [PATCH 2/4] fix: address review comments (iteration #1) --- sagemaker-core/pyproject.toml | 2 +- .../src/sagemaker/core/deserializers/base.py | 4 +- .../src/sagemaker/core/serializers/base.py | 4 +- .../unit/test_optional_torch_dependency.py | 138 +++++++++--------- 4 files changed, 73 insertions(+), 75 deletions(-) diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index 4134d50e34..c0656ab16a 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -54,7 +54,7 @@ torch = [ "torch>=1.9.0", ] all = [ - "torch>=1.9.0", + "sagemaker-core[torch]", ] codegen = [ "black>=24.3.0, <25.0.0", diff --git a/sagemaker-core/src/sagemaker/core/deserializers/base.py b/sagemaker-core/src/sagemaker/core/deserializers/base.py index 1f7ec9ab06..03138ed577 100644 --- a/sagemaker-core/src/sagemaker/core/deserializers/base.py +++ b/sagemaker-core/src/sagemaker/core/deserializers/base.py @@ -365,11 +365,11 @@ def __init__(self, accept="tensor/pt"): from torch import from_numpy self.convert_npy_to_tensor = from_numpy - except ImportError: + except ImportError as e: raise ImportError( "Unable to import torch. Please install torch to use TorchTensorDeserializer: " "pip install 'sagemaker-core[torch]'" - ) + ) from e def deserialize(self, stream, content_type="tensor/pt"): """Deserialize streamed data to TorchTensor diff --git a/sagemaker-core/src/sagemaker/core/serializers/base.py b/sagemaker-core/src/sagemaker/core/serializers/base.py index 4b3ba4fdba..e8862b66f3 100644 --- a/sagemaker-core/src/sagemaker/core/serializers/base.py +++ b/sagemaker-core/src/sagemaker/core/serializers/base.py @@ -445,11 +445,11 @@ def __init__(self, content_type="tensor/pt"): super(TorchTensorSerializer, self).__init__(content_type=content_type) try: from torch import Tensor - except ImportError: + except ImportError as e: raise ImportError( "Unable to import torch. Please install torch to use TorchTensorSerializer: " "pip install 'sagemaker-core[torch]'" - ) + ) from e self.torch_tensor = Tensor self.numpy_serializer = NumpySerializer() diff --git a/sagemaker-core/tests/unit/test_optional_torch_dependency.py b/sagemaker-core/tests/unit/test_optional_torch_dependency.py index f0eb2879fd..5008244e27 100644 --- a/sagemaker-core/tests/unit/test_optional_torch_dependency.py +++ b/sagemaker-core/tests/unit/test_optional_torch_dependency.py @@ -11,108 +11,106 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Tests to verify torch dependency is optional in sagemaker-core.""" -from __future__ import absolute_import +from __future__ import annotations +import importlib import io import sys -from unittest import mock import numpy as np import pytest +def _block_torch(): + """Block torch imports by setting sys.modules['torch'] to None. + + Returns a dict of saved torch submodule entries so they can be restored. + """ + saved = {} + torch_keys = [key for key in sys.modules if key.startswith("torch.")] + saved = {key: sys.modules.pop(key) for key in torch_keys} + saved["torch"] = sys.modules.get("torch") + sys.modules["torch"] = None + return saved + + +def _restore_torch(saved): + """Restore torch modules from saved dict.""" + original_torch = saved.pop("torch", None) + if original_torch is not None: + sys.modules["torch"] = original_torch + elif "torch" in sys.modules: + del sys.modules["torch"] + for key, val in saved.items(): + sys.modules[key] = val + + def test_serializer_module_imports_without_torch(): - """Verify that importing serializers module succeeds without torch installed.""" - # The serializers module should be importable even without torch - # because TorchTensorSerializer uses lazy import in __init__ - from sagemaker.core.serializers.base import ( - CSVSerializer, - NumpySerializer, - JSONSerializer, - IdentitySerializer, - SparseMatrixSerializer, - JSONLinesSerializer, - LibSVMSerializer, - DataSerializer, - StringSerializer, - ) - # Verify non-torch serializers can be instantiated - assert CSVSerializer() is not None - assert NumpySerializer() is not None - assert JSONSerializer() is not None - assert IdentitySerializer() is not None + """Verify that importing non-torch serializers succeeds without torch installed.""" + saved = {} + try: + saved = _block_torch() + + # Reload the module so it re-evaluates imports with torch blocked + import sagemaker.core.serializers.base as ser_module + + importlib.reload(ser_module) + + # Verify non-torch serializers can be instantiated + assert ser_module.CSVSerializer() is not None + assert ser_module.NumpySerializer() is not None + assert ser_module.JSONSerializer() is not None + assert ser_module.IdentitySerializer() is not None + finally: + _restore_torch(saved) def test_deserializer_module_imports_without_torch(): - """Verify that importing deserializers module succeeds without torch installed.""" - from sagemaker.core.deserializers.base import ( - StringDeserializer, - BytesDeserializer, - CSVDeserializer, - StreamDeserializer, - NumpyDeserializer, - JSONDeserializer, - PandasDeserializer, - JSONLinesDeserializer, - ) - # Verify non-torch deserializers can be instantiated - assert StringDeserializer() is not None - assert BytesDeserializer() is not None - assert CSVDeserializer() is not None - assert NumpyDeserializer() is not None - assert JSONDeserializer() is not None + """Verify that importing non-torch deserializers succeeds without torch installed.""" + saved = {} + try: + saved = _block_torch() + + import sagemaker.core.deserializers.base as deser_module + + importlib.reload(deser_module) + + # Verify non-torch deserializers can be instantiated + assert deser_module.StringDeserializer() is not None + assert deser_module.BytesDeserializer() is not None + assert deser_module.CSVDeserializer() is not None + assert deser_module.NumpyDeserializer() is not None + assert deser_module.JSONDeserializer() is not None + finally: + _restore_torch(saved) def test_torch_tensor_serializer_raises_import_error_without_torch(): """Verify TorchTensorSerializer raises ImportError when torch is not installed.""" - import importlib import sagemaker.core.serializers.base as ser_module - # Save original torch module if present - original_torch = sys.modules.get('torch') - + saved = {} try: - # Simulate torch not being installed - sys.modules['torch'] = None - # Need to also handle the case where torch submodules are cached - torch_keys = [key for key in sys.modules if key.startswith('torch.')] - saved = {key: sys.modules.pop(key) for key in torch_keys} - + saved = _block_torch() + with pytest.raises(ImportError, match="Unable to import torch"): ser_module.TorchTensorSerializer() finally: - # Restore original state - if original_torch is not None: - sys.modules['torch'] = original_torch - elif 'torch' in sys.modules: - del sys.modules['torch'] - for key, val in saved.items(): - sys.modules[key] = val + _restore_torch(saved) def test_torch_tensor_deserializer_raises_import_error_without_torch(): """Verify TorchTensorDeserializer raises ImportError when torch is not installed.""" import sagemaker.core.deserializers.base as deser_module - # Save original torch module if present - original_torch = sys.modules.get('torch') - + saved = {} try: - # Simulate torch not being installed - sys.modules['torch'] = None - torch_keys = [key for key in sys.modules if key.startswith('torch.')] - saved = {key: sys.modules.pop(key) for key in torch_keys} - + saved = _block_torch() + with pytest.raises(ImportError, match="Unable to import torch"): deser_module.TorchTensorDeserializer() finally: - # Restore original state - if original_torch is not None: - sys.modules['torch'] = original_torch - elif 'torch' in sys.modules: - del sys.modules['torch'] - for key, val in saved.items(): - sys.modules[key] = val + _restore_torch(saved) def test_torch_tensor_serializer_works_with_torch(): From 0ed06b2afb0480bdb86131f1c04aeb800cf3f9ca Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 14 Apr 2026 15:34:58 -0700 Subject: [PATCH 3/4] fix: address review comments (iteration #1) --- sagemaker-core/pyproject.toml | 1 + .../src/sagemaker/core/serializers/base.py | 3 ++- .../unit/test_serializer_implementations.py | 24 ++++++++++++++++++- sagemaker-core/tox.ini | 1 - 4 files changed, 26 insertions(+), 3 deletions(-) diff --git a/sagemaker-core/pyproject.toml b/sagemaker-core/pyproject.toml index c0656ab16a..4b3f3f1739 100644 --- a/sagemaker-core/pyproject.toml +++ b/sagemaker-core/pyproject.toml @@ -66,6 +66,7 @@ test = [ "pytest>=8.0.0, <9.0.0", "pytest-cov>=4.0.0", "pytest-xdist>=3.0.0", + "sagemaker-core[torch]", ] [project.urls] diff --git a/sagemaker-core/src/sagemaker/core/serializers/base.py b/sagemaker-core/src/sagemaker/core/serializers/base.py index e8862b66f3..84b9832c63 100644 --- a/sagemaker-core/src/sagemaker/core/serializers/base.py +++ b/sagemaker-core/src/sagemaker/core/serializers/base.py @@ -445,13 +445,14 @@ def __init__(self, content_type="tensor/pt"): super(TorchTensorSerializer, self).__init__(content_type=content_type) try: from torch import Tensor + + self.torch_tensor = Tensor except ImportError as e: raise ImportError( "Unable to import torch. Please install torch to use TorchTensorSerializer: " "pip install 'sagemaker-core[torch]'" ) from e - self.torch_tensor = Tensor self.numpy_serializer = NumpySerializer() def serialize(self, data): diff --git a/sagemaker-core/tests/unit/test_serializer_implementations.py b/sagemaker-core/tests/unit/test_serializer_implementations.py index 60d7d62b0b..b12471cbf2 100644 --- a/sagemaker-core/tests/unit/test_serializer_implementations.py +++ b/sagemaker-core/tests/unit/test_serializer_implementations.py @@ -11,7 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Unit tests for sagemaker.core.serializers.implementations module.""" -from __future__ import absolute_import +from __future__ import annotations import pytest from unittest.mock import Mock, patch @@ -162,3 +162,25 @@ def test_numpy_serializer_import(self): def test_record_serializer_deprecated(self): """Test that numpy_to_record_serializer is available as deprecated.""" assert hasattr(implementations, "numpy_to_record_serializer") + + +class TestTorchSerializerWithOptionalDependency: + """Test torch serializer/deserializer with optional torch dependency.""" + + def test_torch_tensor_serializer_instantiation(self): + """Test that TorchTensorSerializer can be instantiated when torch is available.""" + torch = pytest.importorskip("torch") + from sagemaker.core.serializers.base import TorchTensorSerializer + + serializer = TorchTensorSerializer() + assert serializer is not None + assert serializer.content_type == "tensor/pt" + + def test_torch_tensor_deserializer_instantiation(self): + """Test that TorchTensorDeserializer can be instantiated when torch is available.""" + torch = pytest.importorskip("torch") + from sagemaker.core.deserializers.base import TorchTensorDeserializer + + deserializer = TorchTensorDeserializer() + assert deserializer is not None + assert deserializer.accept == "tensor/pt" diff --git a/sagemaker-core/tox.ini b/sagemaker-core/tox.ini index 136b99c69c..0e31a74d80 100644 --- a/sagemaker-core/tox.ini +++ b/sagemaker-core/tox.ini @@ -6,7 +6,6 @@ [tox] isolated_build = true envlist = black-format,flake8,pylint,docstyle,sphinx,doc8,twine,py39,py310,py311,py312 - skip_missing_interpreters = False [flake8] From 528080325ae8b2ff31fad68164c0f74bd96c1bdd Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Wed, 15 Apr 2026 10:12:43 -0700 Subject: [PATCH 4/4] fix: use subprocess instead of importlib.reload to avoid breaking six.with_metaclass super() --- .../unit/test_optional_torch_dependency.py | 108 ++++++++++++------ .../unit/test_serializer_implementations.py | 2 +- 2 files changed, 73 insertions(+), 37 deletions(-) diff --git a/sagemaker-core/tests/unit/test_optional_torch_dependency.py b/sagemaker-core/tests/unit/test_optional_torch_dependency.py index 5008244e27..2b7efbc227 100644 --- a/sagemaker-core/tests/unit/test_optional_torch_dependency.py +++ b/sagemaker-core/tests/unit/test_optional_torch_dependency.py @@ -10,12 +10,21 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Tests to verify torch dependency is optional in sagemaker-core.""" -from __future__ import annotations +"""Tests to verify torch dependency is optional in sagemaker-core. + +The "module imports without torch" tests use subprocess instead of +importlib.reload to avoid poisoning the class hierarchy in the current +process. six.with_metaclass + old-style super() breaks when a module +is reloaded because the class identity changes, causing +``TypeError: super(type, obj): obj must be an instance or subtype of type`` +in subsequent tests that instantiate serializers/deserializers. +""" +from __future__ import absolute_import -import importlib import io +import subprocess import sys +import textwrap import numpy as np import pytest @@ -26,7 +35,6 @@ def _block_torch(): Returns a dict of saved torch submodule entries so they can be restored. """ - saved = {} torch_keys = [key for key in sys.modules if key.startswith("torch.")] saved = {key: sys.modules.pop(key) for key in torch_keys} saved["torch"] = sys.modules.get("torch") @@ -46,43 +54,71 @@ def _restore_torch(saved): def test_serializer_module_imports_without_torch(): - """Verify that importing non-torch serializers succeeds without torch installed.""" - saved = {} - try: - saved = _block_torch() + """Verify that non-torch serializers can be imported and instantiated without torch. - # Reload the module so it re-evaluates imports with torch blocked - import sagemaker.core.serializers.base as ser_module - - importlib.reload(ser_module) - - # Verify non-torch serializers can be instantiated - assert ser_module.CSVSerializer() is not None - assert ser_module.NumpySerializer() is not None - assert ser_module.JSONSerializer() is not None - assert ser_module.IdentitySerializer() is not None - finally: - _restore_torch(saved) + Runs in a subprocess to avoid polluting the current process's class + hierarchy via importlib.reload (which breaks six.with_metaclass). + """ + code = textwrap.dedent("""\ + import sys + # Block torch before any sagemaker imports + sys.modules["torch"] = None + + from sagemaker.core.serializers.base import ( + CSVSerializer, + NumpySerializer, + JSONSerializer, + IdentitySerializer, + ) + + assert CSVSerializer() is not None + assert NumpySerializer() is not None + assert JSONSerializer() is not None + assert IdentitySerializer() is not None + print("OK") + """) + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True, + ) + assert result.returncode == 0, ( + f"Subprocess failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + ) def test_deserializer_module_imports_without_torch(): - """Verify that importing non-torch deserializers succeeds without torch installed.""" - saved = {} - try: - saved = _block_torch() - - import sagemaker.core.deserializers.base as deser_module + """Verify that non-torch deserializers can be imported and instantiated without torch. - importlib.reload(deser_module) - - # Verify non-torch deserializers can be instantiated - assert deser_module.StringDeserializer() is not None - assert deser_module.BytesDeserializer() is not None - assert deser_module.CSVDeserializer() is not None - assert deser_module.NumpyDeserializer() is not None - assert deser_module.JSONDeserializer() is not None - finally: - _restore_torch(saved) + Runs in a subprocess for the same reason as the serializer test above. + """ + code = textwrap.dedent("""\ + import sys + sys.modules["torch"] = None + + from sagemaker.core.deserializers.base import ( + StringDeserializer, + BytesDeserializer, + CSVDeserializer, + NumpyDeserializer, + JSONDeserializer, + ) + + assert StringDeserializer() is not None + assert BytesDeserializer() is not None + assert CSVDeserializer() is not None + assert NumpyDeserializer() is not None + assert JSONDeserializer() is not None + print("OK") + """) + result = subprocess.run( + [sys.executable, "-c", code], + capture_output=True, + text=True, + ) + assert result.returncode == 0, ( + f"Subprocess failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + ) def test_torch_tensor_serializer_raises_import_error_without_torch(): diff --git a/sagemaker-core/tests/unit/test_serializer_implementations.py b/sagemaker-core/tests/unit/test_serializer_implementations.py index b12471cbf2..82d5e074b1 100644 --- a/sagemaker-core/tests/unit/test_serializer_implementations.py +++ b/sagemaker-core/tests/unit/test_serializer_implementations.py @@ -11,7 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Unit tests for sagemaker.core.serializers.implementations module.""" -from __future__ import annotations +from __future__ import absolute_import import pytest from unittest.mock import Mock, patch