From 4771380746b1684348ccf8d4832a59e8d0fe9634 Mon Sep 17 00:00:00 2001 From: Parker Schuh Date: Tue, 28 Apr 2026 18:48:12 -0700 Subject: [PATCH] Automated Code Change PiperOrigin-RevId: 907289393 --- pathwaysutils/jax/__init__.py | 2 +- pathwaysutils/profiling.py | 8 +++++++- .../test/experimental/reshard_test.py | 18 +++++++++++++++--- pathwaysutils/test/reshard_test.py | 12 ++++++++++-- 4 files changed, 33 insertions(+), 7 deletions(-) diff --git a/pathwaysutils/jax/__init__.py b/pathwaysutils/jax/__init__.py index 8608704..7ae663d 100644 --- a/pathwaysutils/jax/__init__.py +++ b/pathwaysutils/jax/__init__.py @@ -104,7 +104,7 @@ def ifrt_reshard_available() -> bool: transfer_to_shardings( [jax.numpy.array([0])], - [jax.sharding.SingleDeviceSharding(jax.devices()[0])], + [jax.sharding.make_single_device_sharding(jax.devices()[0])], ) except (ImportError, NameError, jax.errors.JaxRuntimeError): diff --git a/pathwaysutils/profiling.py b/pathwaysutils/profiling.py index b4f4378..6378d2b 100644 --- a/pathwaysutils/profiling.py +++ b/pathwaysutils/profiling.py @@ -250,7 +250,13 @@ def stop_trace() -> None: and "xprofTraceOptions" in _profile_state.profile_request ): out_avals = [jax.core.ShapedArray((1,), jnp.object_)] - out_shardings = [jax.sharding.SingleDeviceSharding(jax.devices()[0])] + out_shardings = [ + getattr( + jax.sharding, + "make_single_device_sharding", + lambda x: jax.sharding.SingleDeviceSharding(x), + )(jax.devices()[0]) + ] else: out_avals = () out_shardings = () diff --git a/pathwaysutils/test/experimental/reshard_test.py b/pathwaysutils/test/experimental/reshard_test.py index 79c0434..5754513 100644 --- a/pathwaysutils/test/experimental/reshard_test.py +++ b/pathwaysutils/test/experimental/reshard_test.py @@ -37,7 +37,11 @@ def test_sidechannel_reshard_donate( ): x = jnp.array([1, 2]) devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) + sharding = getattr( + jax.sharding, + "make_single_device_sharding", + lambda x: jax.sharding.SingleDeviceSharding(x), + )(devices[0]) mock_pe = self.enter_context( mock.patch.object(plugin_executable, "PluginExecutable", autospec=True) @@ -64,7 +68,11 @@ def test_sidechannel_reshard_cache_resharding_plans( ): x = jnp.array([1, 2]) devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) + sharding = getattr( + jax.sharding, + "make_single_device_sharding", + lambda x: jax.sharding.SingleDeviceSharding(x), + )(devices[0]) mock_pe = self.enter_context( mock.patch.object(plugin_executable, "PluginExecutable") @@ -92,7 +100,11 @@ def test_sidechannel_reshard_cache_resharding_plans( def test_sidechannel_reshard_pytree(self): x = {"a": jnp.array([1]), "b": [jnp.array([2])]} devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) + sharding = getattr( + jax.sharding, + "make_single_device_sharding", + lambda x: jax.sharding.SingleDeviceSharding(x), + )(devices[0]) # Tree prefix sharding tree_sharding = {"a": sharding, "b": [sharding]} diff --git a/pathwaysutils/test/reshard_test.py b/pathwaysutils/test/reshard_test.py index 6c80221..0cc6c3f 100644 --- a/pathwaysutils/test/reshard_test.py +++ b/pathwaysutils/test/reshard_test.py @@ -41,7 +41,11 @@ def test_ifrt_reshard_donate( ): x = jnp.array([1, 2]) devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) + sharding = getattr( + jax.sharding, + "make_single_device_sharding", + lambda x: jax.sharding.SingleDeviceSharding(x), + )(devices[0]) mock_transfer = self.enter_context( mock.patch.object(pw_jax, "transfer_to_shardings", autospec=True) @@ -60,7 +64,11 @@ def test_ifrt_reshard_donate( def test_ifrt_reshard_pytree(self): x = {"a": jnp.array([1]), "b": [jnp.array([2])]} devices = jax.devices() - sharding = jax.sharding.SingleDeviceSharding(devices[0]) + sharding = getattr( + jax.sharding, + "make_single_device_sharding", + lambda x: jax.sharding.SingleDeviceSharding(x), + )(devices[0]) # Tree prefix sharding tree_sharding = {"a": sharding, "b": [sharding]}