From 30a5572e6c351fa38b192c2664ce4f463c872abf Mon Sep 17 00:00:00 2001 From: lmoresi Date: Tue, 19 May 2026 16:31:36 +1000 Subject: [PATCH 1/3] =?UTF-8?q?feat:=20Model.tracker=20=E2=80=94=20snapsho?= =?UTF-8?q?t-managed,=20user-extensible=20run=20state?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Model.tracker is the authoritative *record* of where a run is — time, step, dt, plus any quantity the user parks on it — and it is automatically captured by Model.snapshot() and reverted by Model.restore(). A loose Python variable (model_time = 0.0 in a script) is not reverted; the same value on the tracker is. That contrast is the whole point. Design (per Louis's intent, 2026-05-19): - Authoritative as a *record*, NOT a dependency. Solvers and DDt are untouched; using the tracker is optional. It sits alongside DDt's own _dt_history (captured independently), it does not subsume it. - User-extensible by plain attribute assignment: model.tracker.foo = ... registers foo as managed state — no dataclass authoring, no special status in solvers. - time/step/dt are ordinary pre-seeded managed entries (0.0/0/None), not privileged fields — consistent with "user-added quantities are first-class". - git-stash semantics: restore replaces the managed map wholesale, so a quantity created after the snapshot is dropped on restore. Implementation: - src/underworld3/checkpoint/tracker.py: ModelTracker (uw_object subclass for instance_number) + TrackerState(SnapshottableState) carrying an open `managed` dict. Attribute routing: underscore names are real attributes, public names are managed entries. __setattr__ respects class-level data descriptors so the `state` property setter is honoured (without this guard restore would silently no-op — caught by the test suite; `state` is therefore a reserved name). .state getter deep-copies for isolation. - Model: PrivateAttr _tracker, instantiated and auto-registered as a state-bearer in __init__; exposed via the `tracker` property. Zero new snapshot plumbing — the existing _state_bearers path picks it up. Tests: tests/test_0009_model_tracker.py (9, tier_a level_1) — defaults, builtins revert, user-quantity reverts, numpy-by-value deep-copy, post-snapshot quantity dropped on restore, the loose-var-vs-tracker contrast, bit-identical state roundtrip, and a realistic stepping-loop continuation. Drive-by: test_symbolic_ddt_snapshot_is_deep_copy assumed state_bearers[0] was the DDt; with a tracker now always registered (WeakSet, unordered) it now finds the DDt state by type. Pre-existing fragility, exposed not introduced. 60 tests pass (24 snapshot + 3 real-solver + 9 tracker + 24 regression); parallel ptest still PASS at np 4 with the tracker auto-registered and snapshot/restored alongside everything else. Stacked on feature/in-memory-checkpoint (depends on its Snapshottable/_state_bearers); PRs to development after #195 lands. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- src/underworld3/checkpoint/__init__.py | 3 + src/underworld3/checkpoint/tracker.py | 140 +++++++++++++++++++ src/underworld3/model.py | 26 ++++ tests/test_0007_snapshot_inmemory.py | 9 +- tests/test_0009_model_tracker.py | 177 +++++++++++++++++++++++++ 5 files changed, 354 insertions(+), 1 deletion(-) create mode 100644 src/underworld3/checkpoint/tracker.py create mode 100644 tests/test_0009_model_tracker.py diff --git a/src/underworld3/checkpoint/__init__.py b/src/underworld3/checkpoint/__init__.py index fcde6412..c601b1ff 100644 --- a/src/underworld3/checkpoint/__init__.py +++ b/src/underworld3/checkpoint/__init__.py @@ -26,6 +26,7 @@ restore, ) from .state import Snapshottable, SnapshottableState +from .tracker import ModelTracker, TrackerState __all__ = [ "CheckpointBackend", @@ -37,4 +38,6 @@ "restore", "Snapshottable", "SnapshottableState", + "ModelTracker", + "TrackerState", ] diff --git a/src/underworld3/checkpoint/tracker.py b/src/underworld3/checkpoint/tracker.py new file mode 100644 index 00000000..76f55df4 --- /dev/null +++ b/src/underworld3/checkpoint/tracker.py @@ -0,0 +1,140 @@ +"""Model-dwelling tracker — snapshot-managed evolving state. + +``Model.tracker`` is the authoritative *record* of where a run is: +simulation time, step, dt, plus any user-registered quantities. It is +deliberately NOT something solvers depend on — solvers and DDt are +untouched, and a user need not use the tracker at all. Its one +superpower: everything living in the tracker is automatically +captured by ``Model.snapshot()`` and reverted by ``Model.restore()``, +whereas a loose Python variable (``model_time = 0.0`` in a script) is +not. + +Add managed quantities by plain attribute assignment:: + + model.tracker.time = 0.0 + model.tracker.step = 0 + model.tracker.my_diagnostic = np.zeros(3) + +Any attribute set on the tracker whose name does not start with an +underscore is a managed state variable: part of every snapshot, +restored exactly on rollback, with no special status in solvers and +no dataclass authoring required. Underscore-prefixed names are +internal and not managed. + +``time``, ``step`` and ``dt`` are ordinary managed entries pre-seeded +with sensible defaults (``0.0`` / ``0`` / ``None``). They are +conventions, not privileged fields — consistent with the design +intent that user-added quantities are first-class. +""" + +from __future__ import annotations + +import copy +from dataclasses import dataclass, field + +from underworld3.utilities._api_tools import uw_object + +from .state import SnapshottableState + + +@dataclass +class TrackerState(SnapshottableState): + """Snapshot of a :class:`ModelTracker`. + + The tracker is extensible, so the State carries an open mapping + rather than fixed fields. ``time`` / ``step`` / ``dt`` are + ordinary entries in ``managed``. + """ + + _schema_version: int = 1 + managed: dict = field(default_factory=dict) + + +class ModelTracker(uw_object): + """One per :class:`underworld3.Model`, auto-registered as a + :class:`~underworld3.checkpoint.Snapshottable` state-bearer so the + snapshot machinery captures and restores it with no extra + plumbing. See the module docstring for the user-facing contract. + """ + + def __init__(self): + # _managed must exist before any public attribute assignment + # routes through __setattr__. + object.__setattr__( + self, "_managed", {"time": 0.0, "step": 0, "dt": None} + ) + super().__init__() # uw_object: sets self._uw_id (underscore) + + # --- attribute routing: public -> managed, underscore -> real --- + + def __setattr__(self, name, value): + if name.startswith("_"): + object.__setattr__(self, name, value) + return + # Respect class-level data descriptors — notably the `state` + # property. Without this guard, `tracker.state = ...` (done by + # the snapshot machinery on restore) would be captured as a + # managed quantity instead of invoking the property setter, + # and restore would silently no-op. `state` is therefore a + # reserved name and cannot be a user-managed quantity. + cls_attr = getattr(type(self), name, None) + if hasattr(cls_attr, "__set__") or hasattr(cls_attr, "__get__"): + object.__setattr__(self, name, value) + return + self._managed[name] = value + + def __getattr__(self, name): + # __getattr__ only fires when normal lookup fails, so it never + # shadows real attributes or class properties (state, + # instance_number, ...). + if name.startswith("_"): + raise AttributeError(name) + managed = object.__getattribute__(self, "_managed") + if name in managed: + return managed[name] + raise AttributeError( + f"ModelTracker has no managed quantity {name!r}; assign " + f"model.tracker.{name} = ... to create it" + ) + + def __delattr__(self, name): + if name.startswith("_"): + object.__delattr__(self, name) + elif name in self._managed: + del self._managed[name] + else: + raise AttributeError(name) + + # --- convenience --- + + def __contains__(self, name): + return name in self._managed + + def keys(self): + """Names of all managed quantities (including time/step/dt).""" + return list(self._managed.keys()) + + def __repr__(self): + items = ", ".join(f"{k}={v!r}" for k, v in self._managed.items()) + return f"ModelTracker({items})" + + # --- Snapshottable contract --- + + @property + def state(self) -> TrackerState: + # Deep-copy on read so a held .state is isolated from later + # mutation even if not routed through the snapshot machinery. + return TrackerState(managed=copy.deepcopy(self._managed)) + + @state.setter + def state(self, s: TrackerState) -> None: + if s._schema_version != TrackerState._schema_version: + raise ValueError( + f"TrackerState schema version mismatch: snapshot " + f"{s._schema_version} vs current " + f"{TrackerState._schema_version}" + ) + # Replace wholesale: restore returns to exactly the captured + # point, so a quantity added *after* the snapshot is dropped + # on restore (git-stash semantics). + object.__setattr__(self, "_managed", copy.deepcopy(s.managed)) diff --git a/src/underworld3/model.py b/src/underworld3/model.py index 248aa745..547e690a 100644 --- a/src/underworld3/model.py +++ b/src/underworld3/model.py @@ -134,6 +134,13 @@ class Model(PintNativeModelMixin, BaseModel): # other than checkpoint may also walk it. _state_bearers: Any = PrivateAttr(default_factory=weakref.WeakSet) + # Model-dwelling tracker: the snapshot-managed record of where a + # run is (time, step, dt) plus any user-registered quantities. + # Auto-registered as a state-bearer in __init__ so snapshot / + # restore manage it automatically. See + # src/underworld3/checkpoint/tracker.py. + _tracker: Any = PrivateAttr(default=None) + def __init__(self, name: Optional[str] = None, **kwargs): """ Initialize a new Model instance. @@ -151,6 +158,13 @@ def __init__(self, name: Optional[str] = None, **kwargs): super().__init__(**kwargs) + # Model-dwelling tracker, auto-registered so snapshot/restore + # manage time/step/dt and any user-added quantities for free. + from underworld3.checkpoint.tracker import ModelTracker + + self._tracker = ModelTracker() + self._register_state_bearer(self._tracker) + # Set initial state if not provided if self.state == ModelState.CONFIGURED: # Transition through initializing to configured @@ -573,6 +587,18 @@ def get_solver(self, name: str): """Get a solver by name from the model registry""" return self._solvers.get(name) + @property + def tracker(self): + """Snapshot-managed record of where this run is. + + Holds ``time`` / ``step`` / ``dt`` (pre-seeded conventions) + plus any quantity you assign — ``model.tracker.foo = ...``. + Everything on the tracker is captured by :meth:`snapshot` and + reverted by :meth:`restore`; loose Python variables are not. + Solvers do not depend on it; using it is optional. + """ + return self._tracker + def _register_state_bearer(self, obj) -> None: """Register a Snapshottable object with this model. diff --git a/tests/test_0007_snapshot_inmemory.py b/tests/test_0007_snapshot_inmemory.py index e7e5bb0d..1af2af6f 100644 --- a/tests/test_0007_snapshot_inmemory.py +++ b/tests/test_0007_snapshot_inmemory.py @@ -345,7 +345,14 @@ def test_symbolic_ddt_snapshot_is_deep_copy(): ddt.update_post_solve(dt=0.1) snap = model.snapshot() - captured_state = snap.state_bearers[0][1] # (key, state) + # Find the DDt's captured state by type — state_bearers is + # unordered and now also contains the model tracker. + from underworld3.systems.ddt import DDtSymbolicState + + captured_state = next( + st for _key, st in snap.state_bearers + if isinstance(st, DDtSymbolicState) + ) captured_dt_history = list(captured_state.dt_history) # Scribble the live DDt's internal state — must not leak into snapshot. diff --git a/tests/test_0009_model_tracker.py b/tests/test_0009_model_tracker.py new file mode 100644 index 00000000..35403654 --- /dev/null +++ b/tests/test_0009_model_tracker.py @@ -0,0 +1,177 @@ +import pytest + +pytestmark = [pytest.mark.level_1, pytest.mark.tier_a] + +import numpy as np + + +def _fresh_model(): + import underworld3 as uw + + uw.reset_default_model() + return uw, uw.get_default_model() + + +def test_tracker_exists_with_default_conventions(): + """A fresh model has a tracker pre-seeded with time/step/dt.""" + uw, model = _fresh_model() + assert model.tracker.time == 0.0 + assert model.tracker.step == 0 + assert model.tracker.dt is None + assert set(model.tracker.keys()) == {"time", "step", "dt"} + + +def test_tracker_is_registered_state_bearer(): + """The tracker auto-registers so snapshot/restore see it.""" + uw, model = _fresh_model() + assert model.tracker in model._state_bearers + + +def test_tracker_builtins_revert_on_restore(): + """time/step/dt are managed entries — they roll back.""" + uw, model = _fresh_model() + model.tracker.time = 3.14 + model.tracker.step = 7 + model.tracker.dt = 0.05 + + snap = model.snapshot() + + model.tracker.time = 99.0 + model.tracker.step = 999 + model.tracker.dt = 1.0 + + model.restore(snap) + + assert model.tracker.time == 3.14 + assert model.tracker.step == 7 + assert model.tracker.dt == 0.05 + + +def test_tracker_user_quantity_reverts(): + """A user-added scalar is managed automatically — no dataclass, + no special status — and reverts on restore.""" + uw, model = _fresh_model() + model.tracker.my_diagnostic = 42.0 + + snap = model.snapshot() + model.tracker.my_diagnostic = -1.0 + model.restore(snap) + + assert model.tracker.my_diagnostic == 42.0 + + +def test_tracker_numpy_quantity_reverts_by_value(): + """A numpy array on the tracker is deep-copied into the snapshot, + so post-snapshot in-place mutation doesn't leak, and it reverts.""" + uw, model = _fresh_model() + arr = np.array([1.0, 2.0, 3.0]) + model.tracker.history = arr + + snap = model.snapshot() + model.tracker.history[:] = -9.0 # in-place mutation + assert np.allclose(model.tracker.history, -9.0) + + model.restore(snap) + assert np.allclose(model.tracker.history, [1.0, 2.0, 3.0]) + + +def test_tracker_quantity_added_after_snapshot_is_dropped_on_restore(): + """git-stash semantics: restore returns to exactly the captured + point, so a quantity created after the snapshot disappears.""" + uw, model = _fresh_model() + model.tracker.a = 1.0 + + snap = model.snapshot() + model.tracker.b = 2.0 # created after snapshot + assert "b" in model.tracker + + model.restore(snap) + assert "a" in model.tracker + assert "b" not in model.tracker + + +def test_tracker_is_what_makes_state_revertible(): + """The contrast that motivates the tracker: a loose Python + variable is NOT reverted by restore; the same value parked on the + tracker IS. This is the whole point.""" + uw, model = _fresh_model() + + loose_time = 0.0 + model.tracker.time = 0.0 + + snap = model.snapshot() + + # Advance both the loose variable and the tracked one. + loose_time = 5.0 + model.tracker.time = 5.0 + + model.restore(snap) + + # The loose variable is untouched by restore (the language can't + # know about it); the tracked one rolled back. + assert loose_time == 5.0 # NOT reverted + assert model.tracker.time == 0.0 # reverted automatically + + +def test_tracker_state_roundtrip_is_bit_identical(): + """snapshot S -> mutate -> restore: tracker.state equals the + captured state exactly (dataclass equality).""" + uw, model = _fresh_model() + model.tracker.time = 1.0 + model.tracker.step = 2 + model.tracker.payload = np.arange(5).astype(float) + state_pre = model.tracker.state + + snap = model.snapshot() + model.tracker.time = 12345.0 + model.tracker.payload[:] = 0.0 + model.restore(snap) + + state_post = model.tracker.state + assert state_post.managed["time"] == state_pre.managed["time"] + assert state_post.managed["step"] == state_pre.managed["step"] + assert np.array_equal( + state_post.managed["payload"], state_pre.managed["payload"] + ) + + +def test_tracker_continuation_with_solver_loop(): + """Realistic: drive time/step on the tracker through a stepping + loop, snapshot mid-run, take a regretted step, restore, continue; + the tracker is exactly back and continues correctly.""" + uw, model = _fresh_model() + import sympy + + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 6.0 + ) + x, y = mesh.X + V_fn = sympy.Matrix([[x - 0.5, y - 0.5]]).T + swarm = uw.swarm.Swarm(mesh) + swarm.populate(fill_param=2) + + def do_step(dt): + swarm.advection(V_fn, delta_t=dt, step_limit=False) + model.tracker.time = model.tracker.time + dt + model.tracker.step = model.tracker.step + 1 + + for _ in range(3): + do_step(0.05) + + snap = model.snapshot() + t_snap, s_snap = model.tracker.time, model.tracker.step + + # Regretted big step. + do_step(0.5) + assert model.tracker.step == s_snap + 1 + assert model.tracker.time != t_snap + + model.restore(snap) + assert model.tracker.time == t_snap + assert model.tracker.step == s_snap + + # Continue cleanly. + for _ in range(2): + do_step(0.05) + assert model.tracker.step == s_snap + 2 + assert abs(model.tracker.time - (t_snap + 0.10)) < 1e-12 From 01ae188d35b32fab1ec919bdaa778b63cf4e7218 Mon Sep 17 00:00:00 2001 From: lmoresi Date: Tue, 19 May 2026 16:39:19 +1000 Subject: [PATCH 2/3] docs: user guide for snapshot/restore + Model.tracker (readthedocs) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit User-facing advanced guide covering Model.snapshot() / Model.restore() and Model.tracker, in the Sphinx/MyST form that builds into the readthedocs site. Distinct from the developer state-as-dataclass guide (that one is for people *extending* the mechanism; this is for people *using* it). Contents: the "stash for timesteps" mental model and when to use it (backtrack, adaptive Δt, predictor-corrector, RK staging); the API; what is captured automatically; the loose-variable-vs-tracker trap and how Model.tracker solves it (with the reserved-name and git-stash-semantics caveats); a worked adaptive-Δt CFL backtracking loop; and an explicit guarantees/limitations section (bit-exact discard incl. parallel and through real solvers; in-memory only; fixed rank count; mesh-adapt refused; within-tolerance vs a never-snapshotted solver run). Wired into docs/advanced/index.md prose listing and the hidden toctree. `pixi run -e amr-dev docs-build` succeeds; the page renders to docs/_build/html/advanced/snapshot-restore.html with no page-specific warnings and the toctree link resolves. On feature/model-tracker because the guide documents both the snapshot toolkit (#195) and the tracker, so it can be complete and build against working code; lands with the tracker PR after #195. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- docs/advanced/index.md | 8 ++ docs/advanced/snapshot-restore.md | 206 ++++++++++++++++++++++++++++++ 2 files changed, 214 insertions(+) create mode 100644 docs/advanced/snapshot-restore.md diff --git a/docs/advanced/index.md b/docs/advanced/index.md index 94826f71..48fab6fa 100644 --- a/docs/advanced/index.md +++ b/docs/advanced/index.md @@ -49,6 +49,13 @@ Darcy flow, Richards equation, and variably-saturated groundwater modelling. **[→ Porous Media Flow](porous-flow.md)** +### State Snapshots & Restore +A "stash for timesteps": snapshot the full model state, try a step, +restore exactly if you don't like it. For backtracking, adaptive Δt, +and predictor–corrector workflows. + +**[→ State Snapshots & Restore](snapshot-restore.md)** + ### Troubleshooting Common issues, debugging strategies, and solutions. @@ -85,6 +92,7 @@ custom-meshes curved-boundary-conditions mesh-adaptation porous-flow +snapshot-restore troubleshooting api-patterns SWARM-INTEGRATION-STATISTICS diff --git a/docs/advanced/snapshot-restore.md b/docs/advanced/snapshot-restore.md new file mode 100644 index 00000000..2acaf276 --- /dev/null +++ b/docs/advanced/snapshot-restore.md @@ -0,0 +1,206 @@ +--- +title: "State Snapshots & Restore" +--- + +# State Snapshots & Restore + +## Overview + +`Model.snapshot()` and `Model.restore()` are a *stash for timesteps* — +a quick "hold that thought, I might need to come back" mechanism for +time-stepping code. Take a snapshot, try a step, and if you don't like +the result, restore and try again. The system is put back exactly as it +was, as if the discarded step never happened. + +Typical uses: + +- **Backtrack past an instability** — a step blows up; restore and + continue with a smaller Δt or a different scheme. +- **Adaptive Δt with an error / CFL check** — take a step, measure it, + restore and retry if it violated your criterion. +- **Predictor–corrector probing** — try a predictor, inspect the + corrector, fall back if it isn't converging. +- **Multi-stage time integration** (RK-style) — restore to the start + of a step between stages. + +This is intentionally *not* archival checkpointing. It is fast, +in-memory, and meant to be used freely within a run. For long-term, +on-disk restart files, use the existing `mesh.write_timestep()` / +`read_timestep()` path, which is unchanged and serves a different +purpose. + +## The API + +```python +import underworld3 as uw + +model = uw.get_default_model() + +# ... set up mesh, variables, swarm, solvers, step a few times ... + +token = model.snapshot() # capture everything, return a token + +# ... take a speculative step you might regret ... + +model.restore(token) # put everything back exactly +``` + +`snapshot()` returns a plain in-memory token. You can hold several at +once and restore any of them. `restore()` returns the model to the +exact state at the moment that token was captured. + +## What is captured + +You do not enumerate anything — `snapshot()` captures the full state +of the model automatically: + +- mesh coordinates, +- all mesh-variable values, +- all swarm particle positions and swarm-variable values, +- solver-internal time-integration history (the `DDt` operators that + drive `AdvDiffusion`, viscoelastic stress history, etc.), +- everything on the model tracker (see below). + +Restore rebuilds swarm populations from the snapshot, so it is correct +even if particles migrated, were added, or were lost between snapshot +and restore — that is exactly the situation restore exists for. + +## The model tracker: time, step, and your own quantities + +A subtle trap in time-stepping scripts: your loop counter and +simulation time usually live in plain Python variables, and +`restore()` has no way to know about them. + +```python +model_time = 0.0 +token = model.snapshot() +model_time = 5.0 # advance +model.restore(token) +# model_time is still 5.0 — restore cannot reach a local variable +``` + +`Model.tracker` solves this. It is a model-dwelling record of where the +run is — and anything you put on it is automatically captured and +restored. + +```python +model.tracker.time = 0.0 +model.tracker.step = 0 + +token = model.snapshot() + +model.tracker.time = 5.0 +model.tracker.step = 100 + +model.restore(token) + +model.tracker.time # 0.0 — reverted automatically +model.tracker.step # 0 — reverted automatically +``` + +`time`, `step` and `dt` come pre-seeded as conventions, but they have +no special status. Any attribute you assign becomes managed state: + +```python +model.tracker.peak_velocity = 0.0 +model.tracker.energy_history = np.zeros(3) +``` + +These now travel with every snapshot and revert on every restore — no +extra code, no special handling in your solvers. Using the tracker is +optional; solvers do not depend on it. It is simply the place to keep +the things you want `restore()` to manage. + +```{note} Reserved name +`state` is reserved on the tracker (it is the snapshot mechanism's own +hook). Do not use `model.tracker.state` for your own quantity. +``` + +```{note} git-stash semantics +Restore returns to *exactly* the captured point. A quantity you add to +the tracker *after* taking a snapshot is removed by a restore of that +snapshot — the same way `git stash pop` does not keep work you started +afterwards. +``` + +## Worked example: adaptive-Δt backtracking + +A canonical CFL-controlled stepping loop. The speculative step is +taken, checked, and either kept or discarded: + +```python +import numpy as np +import underworld3 as uw + +model = uw.get_default_model() +# ... mesh, swarm, velocity field V_fn, solvers set up ... + +cfl_limit = mesh.get_min_radius() +dt = 0.5 + +while model.tracker.time < t_end: + token = model.snapshot() + coords_before = swarm._particle_coordinates.data.copy() + + # Speculative step at the current Δt. + swarm.advection(V_fn, delta_t=dt) + # ... your solves for this step ... + + # CFL check. + moved = np.linalg.norm( + swarm._particle_coordinates.data - coords_before, axis=1 + ).max() + + if moved > cfl_limit: + # Too big — discard and retry with a smaller Δt. + model.restore(token) + dt *= 0.5 + continue + + # Good step — commit. + model.tracker.time += dt + model.tracker.step += 1 + dt = min(dt * 1.1, dt_max) # let Δt grow again +``` + +Because the swarm, fields, solver history *and* the tracker's `time` / +`step` are all captured, the `continue` path leaves no trace: the next +attempt starts from precisely where the failed one began. + +## Guarantees and scope + +```{note} What is guaranteed +- **Discarding a step leaves no trace.** A snapshot → speculative + step → restore → continue reproduces a run that never took the + speculative step *bit-for-bit*, including across MPI ranks and + through real PETSc solves. +- **Parallel-correct.** Works under MPI at any (fixed) rank count. + Restore recovers the exact global state even if the discarded step + migrated or lost particles across ranks. +``` + +```{warning} Limitations +- **In-memory only.** Snapshots live in process memory and are not + written to disk; they do not survive the process exiting. They are + also a full copy of model state — holding many large snapshots at + once costs memory. +- **Same rank count.** A snapshot taken on *N* MPI ranks is restored + on *N* ranks. Changing the rank count is not supported by this + mechanism (use the `write_timestep` restart path for that). +- **No mesh adaptation across a snapshot.** If the mesh is adapted + between snapshot and restore, restore refuses with a clear error + rather than corrupting state. +- **Recovery vs. a never-snapshotted run** is bit-exact for the + *discarded-step* guarantee above. Continuing after a restore that + ran a real solver may differ from a run that never snapshotted by a + small amount within solver tolerance — restore resyncs solver + fields rather than reproducing their exact internal buffers. This + does not affect the correctness of backtracking. +``` + +## Related + +- [Parallel-Safe Scripting](parallel-computing.md) — MPI patterns; + snapshot/restore is parallel-correct at fixed rank count. +- Developer reference: the state-as-dataclass contract for adding new + snapshot-managed solver helpers lives in the developer guide. From 8922ea1f032d61d5df2346f829cd36228303203a Mon Sep 17 00:00:00 2001 From: lmoresi Date: Tue, 19 May 2026 16:42:37 +1000 Subject: [PATCH 3/3] demo: snapshot/restore back-stepping visualisations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two standalone runnable demos (tests/run_*.py convention), companions to tests/test_0007's back-stepping test and the new advanced user guide: - run_snapshot_backstepping_demo.py: CFL-ratio time series. Two overlapping segments in the snap-back zone — the abandoned big step (dashed, CFL spike) and the kept substep trajectory — making "time is multi-valued where you stashed" visible at a glance. - run_snapshot_backstepping_spatial.py: 2x2 spatial panels (initial / after bad step / after restore / after substep recovery). Top-left and bottom-left are visually identical — the snap-back proof. Each writes a PNG to the cwd; the PNGs themselves are regenerable output and are intentionally not committed. Underworld development team with AI support from Claude Code (https://claude.com/claude-code) --- tests/run_snapshot_backstepping_demo.py | 222 +++++++++++++++++++++ tests/run_snapshot_backstepping_spatial.py | 186 +++++++++++++++++ 2 files changed, 408 insertions(+) create mode 100644 tests/run_snapshot_backstepping_demo.py create mode 100644 tests/run_snapshot_backstepping_spatial.py diff --git a/tests/run_snapshot_backstepping_demo.py b/tests/run_snapshot_backstepping_demo.py new file mode 100644 index 00000000..5a014548 --- /dev/null +++ b/tests/run_snapshot_backstepping_demo.py @@ -0,0 +1,222 @@ +"""Snapshot toolkit demonstration: time-series view of back-stepping. + +A small adaptive-Δt drama on one axis. The y-axis is the canonical +adaptive-Δt diagnostic: max per-step particle displacement compared +to the mesh cell radius (CFL ratio). The story: + + - timestep forward at small Δt for a while (CFL well under 1), + - take a snapshot, + - try one too-large Δt (CFL spikes far above 1), + - detect the bad step, call ``model.restore(snap)``, + - replay the same time interval with many small steps (CFL stays small), + - continue past the speculative end-time. + +The plot shows two overlapping segments in the snap-back zone: + + - the abandoned big step (dashed red X — single tall spike, CFL ≫ 1), + - the kept substep trajectory (solid blue dots — each well under 1). + +At ``t = t_speculative_end`` both an abandoned and a recovered value +exist. The time axis is genuinely multi-valued there — that's the +visual point of the figure. + +Run: + pixi run -e amr-dev python tests/run_snapshot_backstepping_demo.py + +Output: + snapshot_backstepping_demo.png in the current working directory. + +Companion to ``tests/test_0007_snapshot_inmemory.py``'s +``test_backstepping_cfl_recovery_end_to_end``. +""" + +import numpy as np +import sympy +import matplotlib.pyplot as plt + +import underworld3 as uw + + +def _max_step_displacement(coords_now: np.ndarray, coords_before: np.ndarray) -> float: + """Largest distance any local particle moved during the last step.""" + return float(np.max(np.linalg.norm(coords_now - coords_before, axis=1))) + + +def main(out_path: str = "snapshot_backstepping_demo.png"): + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 8.0 + ) + + x_sym, y_sym = mesh.X + V_fn = sympy.Matrix([[x_sym - 0.5, y_sym - 0.5]]).T + + swarm = uw.swarm.Swarm(mesh) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + + cfl_threshold = mesh.get_min_radius() + small_dt = 0.05 + candidate_dt = 0.5 + n_substeps = int(round(candidate_dt / small_dt)) + + # Time series we'll plot. (t_end_of_step, max_step_displacement) + # for every step we keep. + times_kept = [] + cfl_kept = [] + + def take_step(dt: float): + before = swarm._particle_coordinates.data.copy() + swarm.advection(V_fn, delta_t=dt, step_limit=False) + after = swarm._particle_coordinates.data + return _max_step_displacement(after, before) + + # --- Phase 1: normal stepping --- + n_phase1 = 5 + t = 0.0 + for _ in range(n_phase1): + disp = take_step(small_dt) + t += small_dt + times_kept.append(t) + cfl_kept.append(disp / cfl_threshold) + + t_snap = t + snap = model.snapshot() + + # --- Phase 2: speculative big step --- + disp_bad = take_step(candidate_dt) + t_bad_end = t_snap + candidate_dt + cfl_bad = disp_bad / cfl_threshold + + # --- CFL violated → restore --- + model.restore(snap) + + # --- Phase 3: substep replay --- + times_recovered = [] + cfl_recovered = [] + for _ in range(n_substeps): + disp = take_step(small_dt) + t += small_dt + times_recovered.append(t) + cfl_recovered.append(disp / cfl_threshold) + + # --- Phase 4: continue past the speculative endpoint --- + n_phase4 = 5 + times_post = [] + cfl_post = [] + for _ in range(n_phase4): + disp = take_step(small_dt) + t += small_dt + times_post.append(t) + cfl_post.append(disp / cfl_threshold) + + # --- Plot --- + fig, ax = plt.subplots(figsize=(11, 5.5)) + + # Shaded snap-back zone. + ax.axvspan(t_snap, t_bad_end, color="0.94", zorder=0) + + # CFL = 1 reference. + ax.axhline(1.0, color="0.6", linestyle="-", linewidth=0.8) + ax.text( + 0.005, 1.04, + "CFL = 1 (one cell radius per step)", + fontsize=9, color="0.4", transform=ax.get_yaxis_transform(), + va="bottom", + ) + + # Phase 1: pre-snapshot trajectory. + ax.plot( + times_kept, cfl_kept, + marker="o", markersize=5, linewidth=1.5, color="C0", + label="Time-stepping at Δt = {:.2f}".format(small_dt), + ) + + # Snapshot marker. + ax.axvline(t_snap, color="0.6", linestyle=":", linewidth=1) + ax.annotate( + "snapshot taken", + xy=(t_snap, 0.02), xycoords=("data", "axes fraction"), + xytext=(-4, 4), textcoords="offset points", + ha="right", va="bottom", fontsize=9, color="0.3", + ) + + # Speculative bad step: dashed from snapshot horizontal-ish to (t_bad_end, cfl_bad). + ax.plot( + [t_snap, t_bad_end], [cfl_kept[-1], cfl_bad], + linestyle="--", linewidth=1.5, color="C3", alpha=0.7, + label="Speculative Δt = {:.2f}".format(candidate_dt), + ) + ax.scatter( + [t_bad_end], [cfl_bad], marker="X", s=130, color="C3", zorder=5, + ) + ax.annotate( + "abandoned: CFL = {:.1f}".format(cfl_bad), + xy=(t_bad_end, cfl_bad), + xytext=(8, -2), textcoords="offset points", + ha="left", va="center", fontsize=10, color="C3", fontweight="bold", + ) + + # Snap-back arrow. + ax.annotate( + "", + xy=(t_snap + 0.003, 0.18), + xytext=(t_bad_end - 0.003, max(cfl_bad - 0.5, 1.5)), + arrowprops=dict( + arrowstyle="->", color="0.45", + connectionstyle="arc3,rad=-0.35", linewidth=1.4, + ), + ) + ax.text( + 0.5 * (t_snap + t_bad_end), + 0.4 * cfl_bad, + "model.restore(snap)", + ha="center", va="center", fontsize=10, color="0.35", + style="italic", + bbox=dict(facecolor="white", edgecolor="0.7", boxstyle="round,pad=0.25"), + ) + + # Phase 3: recovered substeps. + ax.plot( + times_recovered, cfl_recovered, + marker="o", markersize=5, linewidth=1.5, color="C0", + ) + + # Phase 4: continuation. + ax.plot( + times_post, cfl_post, + marker="o", markersize=5, linewidth=1.5, color="C0", + ) + + # Snap-back zone label (above the axes). + ax.text( + 0.5 * (t_snap + t_bad_end), + 1.015, "snap-back zone — t is multi-valued", + ha="center", va="bottom", fontsize=10, color="0.4", + transform=ax.get_xaxis_transform(), + ) + + ax.set_xlabel("simulation time t") + ax.set_ylabel("CFL ratio = max per-step displacement / cell radius") + ax.set_title( + "Adaptive-Δt back-stepping • model.snapshot() / model.restore()", + pad=22, + ) + ax.legend(loc="upper right", frameon=False) + ax.grid(True, axis="y", color="0.92", linewidth=0.6) + ax.set_xlim(-0.02, t + 0.02) + ax.set_ylim(-0.3, cfl_bad * 1.12) + + fig.tight_layout() + fig.savefig(out_path, dpi=120) + print(f"Wrote {out_path}") + print(f" t_snap = {t_snap:.3f}") + print(f" speculative Δt = {candidate_dt:.3f}") + print(f" CFL ratio (bad step) = {cfl_bad:.2f}") + print(f" substeps to recover: {n_substeps} × Δt = {small_dt:.3f}") + print(f" max CFL on substeps: {max(cfl_recovered):.3f}") + + +if __name__ == "__main__": + main() diff --git a/tests/run_snapshot_backstepping_spatial.py b/tests/run_snapshot_backstepping_spatial.py new file mode 100644 index 00000000..fb155747 --- /dev/null +++ b/tests/run_snapshot_backstepping_spatial.py @@ -0,0 +1,186 @@ +"""Snapshot toolkit demonstration: spatial view of back-stepping. + +Companion to ``run_snapshot_backstepping_demo.py``. That script answers +"when" via a CFL-ratio time series; this one answers "what" via 2×2 +spatial panels at the four moments that matter: + + [initial state (snapshot taken here)] [after speculative bad step] + [after model.restore(snap)] [after substep recovery to same t] + +Each panel shows the swarm particles coloured by their carried +material value (initial radial position), with the domain boundary +drawn as context. The diagonal pairs tell two stories: + + - top-left vs. bottom-left should be **visually identical**. That's + the proof that model.restore(snap) put the captured state back + exactly. If the figure ever stops showing two identical panels + in that diagonal, the snapshot mechanism has broken. + + - top-right vs. bottom-right are the same simulation time reached + by two different paths: a single too-large Δt step (corner-clumping, + over-stretched, CFL violated) vs. ten substeps at sub-CFL Δt. + +Run: + pixi run -e amr-dev python tests/run_snapshot_backstepping_spatial.py + +Output: + snapshot_backstepping_spatial.png in the current working directory. +""" + +import numpy as np +import sympy +import matplotlib.pyplot as plt + +import underworld3 as uw + + +def _capture(swarm, material): + """Snapshot the swarm spatial state for plotting (positions + material).""" + coords = swarm._particle_coordinates.data.copy() + mat = np.asarray(material.data).copy() + return coords, mat + + +def main(out_path: str = "snapshot_backstepping_spatial.png"): + uw.reset_default_model() + model = uw.get_default_model() + mesh = uw.meshing.UnstructuredSimplexBox( + minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0), cellSize=1.0 / 8.0 + ) + + x_sym, y_sym = mesh.X + V_fn = sympy.Matrix([[x_sym - 0.5, y_sym - 0.5]]).T + + swarm = uw.swarm.Swarm(mesh) + material = swarm.add_variable("material", 1, dtype=float) + swarm.populate(fill_param=2) + + cfl_threshold = mesh.get_min_radius() + small_dt = 0.05 + candidate_dt = 0.5 + n_substeps = int(round(candidate_dt / small_dt)) + + # Colour each particle by its initial radial distance from centre. + coords_initial = swarm._particle_coordinates.data.copy() + material.data[:, 0] = np.linalg.norm(coords_initial - 0.5, axis=1) + + # Initial state. + state_initial = _capture(swarm, material) + + # Take the snapshot — this is the state that bottom-left will + # have to match after restore. + snap = model.snapshot() + + # --- Speculative big step --- + swarm.advection(V_fn, delta_t=candidate_dt, step_limit=False) + state_after_bad = _capture(swarm, material) + max_disp_bad = np.max( + np.linalg.norm(state_after_bad[0] - state_initial[0], axis=1) + ) + cfl_bad = max_disp_bad / cfl_threshold + + # --- model.restore(snap) --- + model.restore(snap) + state_after_restore = _capture(swarm, material) + + # --- Substep recovery to the same target time --- + for _ in range(n_substeps): + swarm.advection(V_fn, delta_t=small_dt, step_limit=False) + state_after_recovery = _capture(swarm, material) + max_disp_recovery = np.max( + np.linalg.norm(state_after_recovery[0] - state_initial[0], axis=1) + ) + cfl_recovery_per_step = ( + np.max( + np.linalg.norm( + state_after_recovery[0] - state_initial[0], axis=1 + ) + ) + / n_substeps + / cfl_threshold + ) + + # --- Plot --- + fig, axes = plt.subplots(2, 2, figsize=(11.5, 11), constrained_layout=True) + + panels = [ + ( + axes[0, 0], + state_initial, + "Initial state", + "t = 0.00 • snapshot taken here", + ), + ( + axes[0, 1], + state_after_bad, + "After speculative Δt = {:.2f}".format(candidate_dt), + "t = {:.2f} • CFL = {:.1f} × threshold".format( + candidate_dt, cfl_bad + ), + ), + ( + axes[1, 0], + state_after_restore, + "After model.restore(snap)", + "t = 0.00 • visually identical to top-left", + ), + ( + axes[1, 1], + state_after_recovery, + "After {} substeps at Δt = {:.3f}".format(n_substeps, small_dt), + "t = {:.2f} • same time as top-right, CFL safe".format( + candidate_dt + ), + ), + ] + + # Common colour scale across all four panels so colours mean the + # same thing everywhere. + all_mat = np.concatenate( + [s[1][:, 0] for s in (state_initial, state_after_bad, + state_after_restore, state_after_recovery)] + ) + vmin, vmax = float(all_mat.min()), float(all_mat.max()) + + last_sc = None + for ax, (coords, mat), title, subtitle in panels: + sc = ax.scatter( + coords[:, 0], coords[:, 1], + c=mat[:, 0], s=8, cmap="viridis", + vmin=vmin, vmax=vmax, + ) + last_sc = sc + # Domain boundary. + ax.plot([0, 1, 1, 0, 0], [0, 0, 1, 1, 0], color="0.5", linewidth=0.9) + # Generous limits so the bad-step overshoot is visible if any + # particles strayed past the boundary. + ax.set_xlim(-0.15, 1.15) + ax.set_ylim(-0.15, 1.15) + ax.set_aspect("equal") + ax.set_title(f"{title}\n{subtitle}", fontsize=10) + ax.set_xticks([]) + ax.set_yticks([]) + + # Single colourbar on the right. + cbar = fig.colorbar( + last_sc, ax=axes.ravel().tolist(), + shrink=0.55, pad=0.02, aspect=30, + ) + cbar.set_label("material = initial radial distance from centre", + fontsize=9) + + fig.suptitle( + "Adaptive-Δt back-stepping • spatial view\n" + "Top-left ↔ Bottom-left identical (snap-back). Top-right ↔ Bottom-right same simulation time, different path.", + fontsize=11, + ) + fig.savefig(out_path, dpi=120, bbox_inches="tight") + print(f"Wrote {out_path}") + print(f" CFL ratio (bad single step): {cfl_bad:.2f}") + print(f" CFL ratio per substep (mean): {cfl_recovery_per_step:.3f}") + print(f" max disp bad path: {max_disp_bad:.4f}") + print(f" max disp recovery path (cumul): {max_disp_recovery:.4f}") + + +if __name__ == "__main__": + main()