Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions ccflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def get_registered_names(self, include_orphaned: bool = False) -> List[str]:
for registry_name in registry_names:
full_names.append(REGISTRY_SEPARATOR.join([registry_name, name]))
return full_names
elif self is ModelRegistry.root():
return [""]
elif include_orphaned and isinstance(self, ModelRegistry) and self._was_registered and self.name:
# Orphaned sub-registry: de-registered from its parent but .name is still valid.
# Reconstruct path as if registered directly under root.
Expand Down Expand Up @@ -663,6 +661,17 @@ def _debug_name(self) -> str:
"""Returns the "full name" of the registry. Since registries can have multiple names"""
return "RootModelRegistry"

def get_registered_names(self, include_orphaned: bool = False) -> List[str]:
"""The root registry is always the empty-prefix anchor of every path.

Overrides the base implementation so that a deserialized copy of the
root (e.g. after being pickled into a Ray worker) still returns [""]
rather than [] — the base class falls through to [] because the
identity check ``self is ModelRegistry.root()`` fails for a deserialized
instance.
"""
return [""]


_REGISTRY_ROOT = RootModelRegistry.model_construct()

Expand Down
32 changes: 32 additions & 0 deletions ccflow/tests/test_base_registry.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import collections.abc
import json
import os
import pickle
import sys
from typing import Dict, List
from unittest import TestCase
Expand Down Expand Up @@ -353,6 +354,37 @@ def test_orphaned_subregistry_default_returns_empty(self):
# Explicit True → reconstructs path
self.assertListEqual(m.get_registered_names(include_orphaned=True), ["/foo/model_name"])

def test_get_registered_names_survives_pickling(self):
"""Registered names must round-trip through pickle.

When a model is pickled (e.g., to be sent to a Ray worker) and then
unpickled in a fresh process context, get_registered_names() must still
return the correct full path. The bug: the deserialized RootModelRegistry
is a new object and fails the ``self is ModelRegistry.root()`` identity
check, causing the chain traversal to bottom out with [] instead of [""].
"""
r = ModelRegistry.root()
foo = ModelRegistry(name="foo")
bar = ModelRegistry(name="bar")
m = MyTestModel(a="x", b=0.0)

r.add("foo", foo)
foo.add("bar", bar)
bar.add("baz", m)

self.assertListEqual(m.get_registered_names(), ["/foo/bar/baz"])

# Simulate what happens when a model is sent to a Ray worker: it is
# pickled and unpickled in the same process but as a standalone blob,
# severing the live registry identity chain.
restored = pickle.loads(pickle.dumps(m))

self.assertListEqual(
restored.get_registered_names(),
["/foo/bar/baz"],
"get_registered_names() must return the correct path after pickling",
)


class TestRegistryLoading(TestCase):
def setUp(self) -> None:
Expand Down
Loading