diff --git a/ccflow/base.py b/ccflow/base.py index e922f96..e37e8cd 100644 --- a/ccflow/base.py +++ b/ccflow/base.py @@ -86,8 +86,6 @@ def get_registered_names(self, include_orphaned: bool = False) -> List[str]: for registry_name in registry_names: full_names.append(REGISTRY_SEPARATOR.join([registry_name, name])) return full_names - elif self is ModelRegistry.root(): - return [""] elif include_orphaned and isinstance(self, ModelRegistry) and self._was_registered and self.name: # Orphaned sub-registry: de-registered from its parent but .name is still valid. # Reconstruct path as if registered directly under root. @@ -663,6 +661,17 @@ def _debug_name(self) -> str: """Returns the "full name" of the registry. Since registries can have multiple names""" return "RootModelRegistry" + def get_registered_names(self, include_orphaned: bool = False) -> List[str]: + """The root registry is always the empty-prefix anchor of every path. + + Overrides the base implementation so that a deserialized copy of the + root (e.g. after being pickled into a Ray worker) still returns [""] + rather than [] — the base class falls through to [] because the + identity check ``self is ModelRegistry.root()`` fails for a deserialized + instance. + """ + return [""] + _REGISTRY_ROOT = RootModelRegistry.model_construct() diff --git a/ccflow/tests/test_base_registry.py b/ccflow/tests/test_base_registry.py index e75fd8c..657a416 100644 --- a/ccflow/tests/test_base_registry.py +++ b/ccflow/tests/test_base_registry.py @@ -1,6 +1,7 @@ import collections.abc import json import os +import pickle import sys from typing import Dict, List from unittest import TestCase @@ -353,6 +354,37 @@ def test_orphaned_subregistry_default_returns_empty(self): # Explicit True → reconstructs path self.assertListEqual(m.get_registered_names(include_orphaned=True), ["/foo/model_name"]) + def test_get_registered_names_survives_pickling(self): + """Registered names must round-trip through pickle. + + When a model is pickled (e.g., to be sent to a Ray worker) and then + unpickled in a fresh process context, get_registered_names() must still + return the correct full path. The bug: the deserialized RootModelRegistry + is a new object and fails the ``self is ModelRegistry.root()`` identity + check, causing the chain traversal to bottom out with [] instead of [""]. + """ + r = ModelRegistry.root() + foo = ModelRegistry(name="foo") + bar = ModelRegistry(name="bar") + m = MyTestModel(a="x", b=0.0) + + r.add("foo", foo) + foo.add("bar", bar) + bar.add("baz", m) + + self.assertListEqual(m.get_registered_names(), ["/foo/bar/baz"]) + + # Simulate what happens when a model is sent to a Ray worker: it is + # pickled and unpickled in the same process but as a standalone blob, + # severing the live registry identity chain. + restored = pickle.loads(pickle.dumps(m)) + + self.assertListEqual( + restored.get_registered_names(), + ["/foo/bar/baz"], + "get_registered_names() must return the correct path after pickling", + ) + class TestRegistryLoading(TestCase): def setUp(self) -> None: