Skip to content

feat(runtime): TRT-RTX runtime controls via context managers#4330

Open
tp5uiuc wants to merge 36 commits into
pytorch:mainfrom
tp5uiuc:feat/trtrtx-runtime-ctx-managers-upstream
Open

feat(runtime): TRT-RTX runtime controls via context managers#4330
tp5uiuc wants to merge 36 commits into
pytorch:mainfrom
tp5uiuc:feat/trtrtx-runtime-ctx-managers-upstream

Conversation

@tp5uiuc

@tp5uiuc tp5uiuc commented Jun 9, 2026

Copy link
Copy Markdown
Collaborator

Description

Refs #4310, design discussion at #4323.

Moves cuda_graph_strategy, dynamic_shapes_kernel_specialization_strategy, and runtime_cache off CompilationSettings onto runtime context managers — toggle them without recompiling.

Type of change

  • New feature (non-breaking change which adds functionality)
  • Breaking change (the three CompilationSettings fields above move to RuntimeSettings / runtime CMs; old compile-time kwargs are no longer accepted)
  • This change requires a documentation update

Public API

  • torch_tensorrt.runtime.runtime_config(target, **overrides) — pool CM applying any RuntimeSettings field to every TRT submodule under target.
  • torch_tensorrt.runtime.runtime_cache(target, path_or_stream) — shared IRuntimeCache across one or more modules. Accepts str, os.PathLike, or a file-like (io.BytesIO, opened handles).
  • torch_tensorrt.runtime.enable_cudagraphs(target, *, cuda_graph_strategy=...) — RTX cuda-graph strategy + outer cudagraph wrap in one CM (the strategy is RTX-only; non-RTX builds raise on kwarg use).
  • torch_tensorrt.runtime.set_dynamic_shapes_kernel_strategy(target, strategy) — sugar wrapper for the dynamic-shapes field.
  • module.runtime_settings = RuntimeSettings(...) — direct assignment after compile.

User guide: docsrc/user_guide/runtime_performance/runtime_settings.rst.

Examples

Post-compile setter + cudagraph capture with strategy in one CM:

import torch_tensorrt as torchtrt
from torch_tensorrt.runtime import RuntimeSettings, enable_cudagraphs

mod = torchtrt.compile(model, inputs=inputs)
mod.runtime_settings = RuntimeSettings(runtime_cache="/var/cache/jit.bin")

with enable_cudagraphs(mod, cuda_graph_strategy="whole_graph_capture") as wrapped:
    out = wrapped(x)

Putting it all together — shared kernel cache across two modules, dynamic-shapes override + cudagraph capture on the first, the second consuming the first's output under the same cache:

from torch_tensorrt.runtime import (
    runtime_cache,
    runtime_config,
    enable_cudagraphs,
)

with runtime_cache([mod1, mod2], "/var/cache/jit.bin") as rc:
    with (
        runtime_config(
            mod1,
            runtime_cache=rc,
            dynamic_shapes_kernel_specialization_strategy="eager",
        ) as modr,
        enable_cudagraphs(modr, cuda_graph_strategy="whole_graph_capture") as cg,
    ):
        outputs = cg(*inputs)
    mod2(*outputs)

For stream-backed caches (io.BytesIO, opened files), caller-owned RuntimeCache lifetimes, sharing one cache across many modules, and other advanced patterns, see the Runtime Settings user guide at docsrc/user_guide/runtime_performance/runtime_settings.rst.

Architecture

flowchart TB
    classDef api fill:#cce5ff,stroke:#004085,color:#004085
    classDef settings fill:#fff3cd,stroke:#856404,color:#856404
    classDef module fill:#d4edda,stroke:#155724,color:#155724
    classDef py fill:#e7d6f7,stroke:#553375,color:#553375
    classDef cpp fill:#ffe0b3,stroke:#a14400,color:#7a3500
    classDef facade fill:#f5e6cc,stroke:#6b4423,color:#3a2200,stroke-width:3px

    %% Layer 1 — Public API
    A1["runtime_cache CM"]:::api
    A2["runtime_config CM"]:::api
    A3["enable_cudagraphs<br/>(cuda_graph_strategy=...)"]:::api
    A4["mod.runtime_settings = rs"]:::api

    %% Layer 2 — Data model
    RS["RuntimeSettings dataclass<br/>cuda_graph_strategy<br/>dynamic_shapes_kernel_specialization_strategy<br/>runtime_cache : None | str | RuntimeCache"]:::settings
    A1 --> RS
    A2 --> RS
    A3 --> RS
    A4 --> RS

    %% Layer 3 — Module (owner of implicit handle)
    MOD["TorchTensorRTModule<br/>_implicit_cache_handle : RuntimeCache<br/>_resolve_runtime_cache: builds + warm-loads disk → pending<br/>_send_to_engine"]:::module
    RS --> MOD

    %% Layer 3.5 — User-facing facade (sits BETWEEN module and the runtime split)
    RC{{"⭐ RuntimeCache &mdash; USER-FACING FACADE<br/>py/torch_tensorrt/runtime/_runtime_cache.py<br/>path / autosave_on_del<br/>load · save · load_from_stream · save_to_stream<br/>has_cache · is_cpp_runtime · ensure_cache<br/><i>same API regardless of runtime — forwards to ._handle</i>"}}:::facade
    MOD -. "owns" .-> RC

    %% Layer 4 — Runtime branch
    BR{cpp runtime<br/>available?}:::module
    MOD --> BR

    %% Layer 5/6/7 — Side-by-side language columns, each with engine → shim → inner handle
    subgraph PY ["Python runtime path"]
        direction TB
        PYENG["_TRTEngine<br/>.context (lazy @property)<br/>.update_runtime_settings(rs)"]:::py
        PYTRC["TRTRuntimeConfig (Python shim)<br/>_runtime_config.py<br/>owns trt.IRuntimeConfig (lazy)"]:::py
        PYINNER["<b>_RuntimeCacheHandle</b><br/>(python-rt inner — port of cpp class)<br/>_cache : trt.IRuntimeCache<br/>_pending_warm_bytes (drained on first ensure_materialized)<br/>_lock (mirrors cpp state_mu_)"]:::py
        PYENG --> PYTRC
        PYTRC -. "ensure_cache → setRuntimeCache" .-> PYINNER
    end

    subgraph CPP ["C++ runtime path"]
        direction TB
        CPPENG["torch.classes.tensorrt.Engine<br/>.update_runtime_settings(int, int, cache)"]:::cpp
        CPPTRC["TRTRuntimeConfig (C++ struct)<br/>core/runtime/TRTRuntimeConfig.{h,cpp}<br/>owns nvinfer1::IRuntimeConfig"]:::cpp
        CPPINNER["<b>torch.classes.tensorrt.RuntimeCacheHandle</b><br/>(cpp-rt inner — torchbind class, used directly, no wrapper)<br/>core/runtime/RuntimeSettings.{h,cpp}<br/>trt_handle_ : shared_ptr&lt;IRuntimeCache&gt;<br/>pending_warm_bytes_ (drained on ensure_materialized)<br/>state_mu_ (mirrors python _lock)"]:::cpp
        CPPENG --> CPPTRC
        CPPTRC -. "ensure_materialized → setRuntimeCache" .-> CPPINNER
    end

    BR -- "No" --> PYENG
    BR -- "Yes" --> CPPENG

    %% The facade uniformly exposes whichever inner is appropriate — same API surface either side.
    RC == "_handle (python rt)" ==> PYINNER
    RC == "_handle (cpp rt)" ==> CPPINNER
Loading

Color key

  • 🟦 Blue — Public API entry points
  • 🟨 Amber — RuntimeSettings dataclass (data model)
  • 🟩 Green — TorchTensorRTModule orchestration (owns the implicit handle)
  • 🟫 Tan ⭐ — RuntimeCache user-facing facade (thick border; runtime-agnostic API; the only handle users touch)
  • 🟪 Purple — Python runtime path: _TRTEngine → Python shim → _RuntimeCacheHandle (inner)
  • 🟧 Orange — C++ runtime path: torchbind engine → C++ struct → torch.classes.tensorrt.RuntimeCacheHandle (inner)

How to read the diagram. Settings flow top → down. The two language columns are mirror images of each other (engine → shim → inner cache handle); the bold inner-handle nodes (_RuntimeCacheHandle in purple, torch.classes.tensorrt.RuntimeCacheHandle in orange) are 1:1 ports — same public surface (serialize / deserialize / has_cache / ensure_materialized), both with a pending-bytes stash and their own lock guarding the GIL-releasing create-cache race. The RuntimeCache facade (the thick-bordered ⭐ node) is the only handle users touch; it forwards every call (load, save, has_cache, etc.) uniformly to whichever inner ._handle references — bold double-arrows on either side show the dispatch.

Implementation

  • RuntimeSettings dataclass + TRTRuntimeConfig shim (Python + a mirroring C++ struct in core/runtime/) own the live IRuntimeConfig. All ENABLED_FEATURES.tensorrt_rtx gates live inside the shim.
  • RuntimeCache is a facade wrapping either _RuntimeCacheHandle (Python-rt port of the cpp class) or torch.classes.tensorrt.RuntimeCacheHandle (cpp torchbind, used directly). Both implement a common _RuntimeCacheHandleProtocol; the facade forwards without isinstance branching.
  • Deferred materialization on both inners: deserialize stashes bytes into a pending buffer if the underlying IRuntimeCache is not yet created; the first ensure_materialized call (driven by the python or cpp _apply_settings) creates the cache and drains the pending bytes atomically. Disk bytes for engine-implicit handles are pre-loaded into the pending buffer at handle construction time (_TorchTensorRTModule._resolve_runtime_cache) — one warm-load callsite covers both runtimes.
  • Filelocked atomic-rename disk persistence (load / save) plus load_from_stream / save_to_stream primitives. Engine-implicit handles autosave on __del__.
  • TorchTensorRTModule._implicit_cache_handle is the canonical owner; RuntimeCache.is_cpp_runtime() lets external callers detect which inner is in use.
  • IExecutionContext is strictly lazy on both runtimes. Python exposes engine.context as a write-protected @property; the C++ engine exposes a single exec_ctx() getter. Runtime knobs are NOT serialized into the engine tuple.

Tests

Added tests/py/dynamo/runtime/: test_000_runtime_cache.py, test_001_cuda_graph_strategy.py, test_001_dynamic_shapes_kernel_strategy.py, test_004_runtime_settings.py. The build's selected runtime determines whether the cpp or Python inner path runs; whitebox introspection tests skip on the other side.

Verified locally on TRT-RTX 1.5.0.114 (A100): cpp-rt 41 passed / 20 skipped / 0 failed; python-rt (libs hidden to force the python path) 58 passed / 3 skipped / 0 failed.

Notes

The cudagraphs wrapper's warm_up() materializes the engine's context with whatever settings are in effect at that moment. enable_cudagraphs(target, cuda_graph_strategy=...) applies the strategy before the wrapper's warm-up, preserving the "one createExecutionContext per setup" invariant.

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

Refs pytorch#4310. Moves ``cuda_graph_strategy``,
``dynamic_shapes_kernel_specialization_strategy``, and ``runtime_cache``
off ``CompilationSettings`` and onto runtime-mode context managers, so
callers toggle them without recompiling.

Public API
----------
- ``torch_tensorrt.runtime.runtime_config(target, **overrides)`` -- pool CM
  that applies any ``RuntimeSettings`` field to every TRT submodule.
- ``torch_tensorrt.runtime.runtime_cache(target, path_or_stream)`` --
  attaches a shared ``IRuntimeCache`` across one or more modules. Accepts
  ``str``, ``os.PathLike``, or a file-like (``io.BytesIO``, opened file
  handles, etc.).
- Sugar: ``set_cuda_graph_strategy`` /
  ``set_dynamic_shapes_kernel_strategy``.
- ``module.runtime_settings = RuntimeSettings(...)`` for direct assignment.
- Compile-time hint via ``torchtrt.compile(..., runtime_settings=...)``
  primes the engine without an extra CM enter/exit.

Implementation
--------------
- ``RuntimeSettings`` dataclass + ``TRTRuntimeConfig`` shim (Python + a
  mirroring C++ struct in ``core/runtime/``) own the live
  ``IRuntimeConfig`` and apply settings. All
  ``ENABLED_FEATURES.tensorrt_rtx`` gates live inside the shim; callers
  in ``_TRTEngine`` and ``_TorchTensorRTModule`` stay uniform.
- ``RuntimeCacheHandle`` (Python wrapper + C++ torchbind sibling) owns the
  per-engine ``IRuntimeCache`` plus filelocked atomic-rename disk
  persistence. Three construction modes: engine-implicit
  (``autosave_on_del=True``), runtime CM (``autosave_on_del=False``,
  explicit save on ``__exit__``), and user-built (default
  ``autosave_on_del=False``).
- Stream support: ``load_from_stream`` / ``save_to_stream`` are the byte
  primitives; the path-mode ``load`` / ``save`` delegate to them.
- ``TorchTensorRTModule._implicit_cache_handle`` is the single owner
  across Python and cpp runtimes; ``TRTRuntimeConfig`` is a pure-execution
  shim.
- C++ strategy fields are typed ``enum class : int32_t`` mirroring the
  ``nvinfer1`` enum integers; ``int32_t`` crosses the torchbind boundary
  for ABI stability, with reverse-lookup helpers for logging.
- Lazy ``IExecutionContext`` creation in ``TRTEngine``; runtime knobs
  are NOT serialized into the engine tuple (per the issue contract).

Tests
-----
65 new tests under ``tests/py/dynamo/runtime/``, parameterized over
python and cpp runtimes where applicable. Covers compile-time hints,
CM enter/exit, settings-swap save semantics, file-handle and
``BytesIO`` round-trip for the shared cache, and the lazy-context
regression.
@meta-cla meta-cla Bot added the cla signed label Jun 9, 2026
@github-actions github-actions Bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jun 9, 2026
@github-actions github-actions Bot requested a review from cehongwang June 9, 2026 22:43
@tp5uiuc tp5uiuc marked this pull request as draft June 9, 2026 23:00
Comment thread core/runtime/BUILD Outdated
@tp5uiuc tp5uiuc self-assigned this Jun 9, 2026
tp5uiuc added 2 commits June 9, 2026 22:59
…pile-time hint

Python (``_TRTEngine``):
- ``self.context`` becomes a ``@property`` backed by a private
  ``self._context`` field. Reads lazily materialize on first access;
  the property has no setter (write raises ``AttributeError``) so external
  code cannot stash an arbitrary context.
- Add ``invalidate_context()`` (drops the cached context; next read
  rebuilds) and ``has_context()`` (probes without triggering creation).
- ``_setup_engine`` no longer creates the context. Distributed engines
  still materialize eagerly (mirrors cpp) by reading ``self.context``
  before the NCCL barrier.
- All five recreate sites (``update_runtime_settings``, device-memory
  budget setter, ``use_dynamically_allocated_resources``,
  ``disable_profiling``, internal) collapse to ``invalidate_context()``.
- All read sites (forward, ``infer_outputs``, ``enable_profiling``,
  ``setup_nccl_comm``, ``_is_monolithic_capturable``) are unchanged --
  the property's lazy semantics absorb the laziness.

C++ (``TRTEngine``):
- ``exec_ctx`` field moves from public to private (renamed ``exec_ctx_``).
- Single public getter ``exec_ctx()`` returns a raw pointer, lazy-creating
  via the existing private ``recreate_execution_context()``. Drop public
  ``ensure_execution_context()`` -- the getter IS the ensure.
- Rename ``invalidate_execution_context()`` to ``invalidate_exec_ctx()``;
  add ``has_exec_ctx()`` for null-safe introspection.
- All 5-6 call sites in ``TRTEngine.cpp`` collapse to ``exec_ctx()->...``;
  ``execute_engine.cpp`` swaps ``->exec_ctx->`` for ``->exec_ctx()->``.

Drop the compile-time ``runtime_settings`` kwarg:
- The kwarg existed to dodge an old 2-create regression on cpp; with both
  runtimes strictly lazy, that motivation is gone. Users apply settings via
  ``mod.runtime_settings = rs`` after compile, or use a runtime CM.
- Removed from ``compile``, ``compile_module``, ``convert_module``,
  ``TorchTensorRTModule.__init__``, ``_TRTEngine.__init__``.
- Documented composition contract on ``set_cuda_graph_strategy`` and
  ``enable_cudagraphs`` docstrings: nest ``with runtime_config(...) as m:``
  outside ``with enable_cudagraphs(m) as w:`` so settings are applied
  state-only before the wrapper's warm-up materializes the context.

Tests:
- Two assertions for ``engine.context is not None`` flip to
  ``engine.has_context()`` so they probe the lazy field without forcing
  materialization.
- Tests that passed ``runtime_settings=...`` to ``torchtrt.compile``
  switch to a small ``_apply_runtime_settings(compiled, rs)`` helper
  that walks the compiled module and assigns ``mod.runtime_settings = rs``
  per inner ``TorchTensorRTModule``.
- ``test_one_context_create_with_default_settings`` now expects 0
  contexts at setup on both runtimes (was 0 cpp / 1 python).
- ``test_one_context_create_with_compile_time_settings`` was redundant
  once the hint is gone; replaced with
  ``test_post_compile_settings_then_execute_is_one_create``.

All 33 runtime tests pass on TRT-RTX 1.5.0.103 (20 skips unchanged).
Two stale references on the Python-runtime branch of TorchTensorRTModule
were left over from earlier in this PR and broke every test that runs a
compiled module on the Python runtime:

- ``setup_engine`` and ``set_extra_state`` were calling
  ``torch.ops.tensorrt.execute_engine_python``, but the custom op is
  registered as ``tensorrt::execute_engine`` (in _TRTEngine.py). The
  ``_python`` suffix was an intermediate name during PR pytorch#4222's dev
  cycle and never made it to main. Fixed both call sites to use the
  single shared op name, matching the C++-runtime branch and the
  docstring at the top of the class.

- ``set_extra_state`` was still passing
  ``runtime_settings=self._runtime_settings`` to ``TRTEngine.__init__``,
  but the previous follow-up commit dropped that kwarg. Engine now
  constructs with default settings (matching what the caller assigned
  to ``self._runtime_settings`` two lines above) and applies any
  non-default settings via the post-load setter, same as the live
  ``setup_engine`` path.
Comment thread core/runtime/RuntimeSettings.cpp Outdated
Comment thread py/torch_tensorrt/runtime/_cudagraphs.py

@tp5uiuc tp5uiuc left a comment

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technical note : why runtime_cache() needs module arguments

runtime_cache(...) exists to attach a shared IRuntimeCache to one or more engines for the duration of a with block. The "shared" + "for the duration of" parts both require knowing which engines.

A standalone "just give me a cache" API, also exists : we construct RuntimeCache(path="...") directly. Then pass it into mod.runtime_settings = RuntimeSettings(runtime_cache=handle) (single module). The CM is the convenience wrapper that does all three things — construct + attach + auto-load/save — in one block.

Three reasons it can't be standalone:

  1. The cache has to be wired to engines, and not just constructed like the example above
  2. The "shared across modules" semantic only exists with multiple targets listed in the first argument.
  3. Bootstrap depends on a module's engine. This is the most annoying but technically still a valid reason. We can't create a standalone runtime cache today with TRT-RTX APIs, but we need an engine to bootstrap it. In this case tThe CM walks target.named_modules() to find a TorchTensorRTModule whose engine it can use to query runtime properties (engine.runtime_config, the cpp IRuntimeConfig, etc.) and create a runtime cache. Having modules to attach the cache to makes this much easier to manage.

Comment thread py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Comment thread py/torch_tensorrt/runtime/_runtime_cache.py Outdated
Comment thread py/torch_tensorrt/dynamo/runtime/_TRTEngine.py
Comment thread py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py Outdated
Comment thread py/torch_tensorrt/dynamo/_defaults.py
Comment thread core/runtime/register_jit_hooks.cpp
Comment thread py/torch_tensorrt/runtime/_runtime_cache.py Outdated
Comment thread py/torch_tensorrt/runtime/_cudagraphs.py Outdated
Comment thread tests/py/dynamo/runtime/test_004_runtime_settings.py
Comment thread py/torch_tensorrt/runtime/_runtime_config.py Outdated
tp5uiuc added 3 commits June 10, 2026 01:39
Round of fixes off the latest PR review feedback:

Blocking
- Initialize ``_implicit_cache_handle`` in ``TorchTensorRTModule.__init__``
  so the slot exists on every construction path, not just ``setup_engine``.
  Drops a regression where ``set_extra_state`` (post-load) skipped the
  init and any subsequent ``mod.runtime_settings = ...`` raised
  AttributeError. Removes the matching ``# type: ignore[has-type]``.

High-priority
- ``_to_torchbind_handle`` now rejects a mixed-runtime case loudly:
  a Python ``RuntimeCacheHandle`` with a live pybind ``IRuntimeCache``
  but no torchbind sibling crossing into the C++ runtime path would
  silently orphan the cache.
- Also gates the str -> torchbind path on a truthy-string check
  (mirrors ``_materialize_implicit_handle``) so ``""`` doesn't
  construct a no-op torchbind handle.
- ``runtime_settings.setter`` now reconciles the handle that
  ``_dispatch_runtime_settings_to_engine`` substituted in, matching the
  ``setup_engine`` post-condition (``self._runtime_settings`` agrees
  with what the engine actually saw).

Medium
- ``runtime`` Bazel cc_library now exports ``TRTRuntimeConfig.h``
  alongside ``TRTEngine.h`` for symmetry with ``runtime_base`` and the
  ``include_files`` filegroup.
- ``_RuntimeCacheContextManager.__exit__`` synchronizes CUDA before
  save (avoids the detach-then-save race against a concurrent execute),
  and wraps save in try/except+warn so a transient filesystem failure
  on exit can't mask the with-block's actual exception.
- ``to_*_strategy`` now take ``int64_t`` and bounds-check before
  narrowing, so an out-of-range Python caller can't slip past the check
  via a silent ``int32_t`` overflow.

Low
- Drop the explicit local_defines select on ``ENABLE_FEATURE_DISABLE_RUNTIME_ALLOCATION``
  in ``core/runtime/BUILD`` (no longer needed on TRT-RTX 1.5+).
- ``.size()`` -> ``std::size(...)`` on the std::array uses in
  ``RuntimeSettings.cpp`` for consistency.
- Trim docstrings on ``enable_cudagraphs`` and ``set_cuda_graph_strategy``
  -- composition guidance moves to the runtime-settings doc page.
- Drop ``_NON_RTX_WARNING_EMITTED`` module global; emit unconditionally
  so test isolation can assert on it. Users wanting once-per-process
  behavior filter via ``warnings.simplefilter("once", UserWarning)``.
- Document multi-process lost-update semantics of the default
  ``runtime_cache`` path in the dataclass docstring.
- Skip the ``num_execution_contexts_created == 0`` assertion for NCCL
  engines in ``test_one_context_create_with_default_settings``: NCCL
  engines eagerly bind the comm at setup, which materializes the
  context. Comment in ``_TRTEngine`` flags the same divergence.
Originally this PR hard-removed ``cuda_graph_strategy``,
``dynamic_shapes_kernel_specialization_strategy``, and ``runtime_cache_path``
from ``torch_tensorrt.compile()``. That left two issues:

- **B1**: the model-suite tests
  (``test_cuda_graph_strategy_models.py``,
  ``test_dynamic_shapes_kernel_strategy_models.py``,
  ``test_runtime_cache_models.py``) still passed those kwargs to
  ``compile()``; they were silently dropped via ``**kwargs`` and the
  test intent was lost.
- **B3**: downstream callers of ``torch_tensorrt.compile(model, ...,
  cuda_graph_strategy=...)`` would see the same silent drop — no
  deprecation warning, no error.

Fix combines both halves:

1. Re-accept the three kwargs via ``**kwargs`` in ``compile()`` with a
   single ``DeprecationWarning``, and route them through
   ``mod.runtime_settings = RuntimeSettings(...)`` after ``compile_module``
   returns. ``runtime_cache_path`` → ``runtime_cache`` rename is applied
   inside the shim.

2. Port the three model-suite test files to the new pattern (post-compile
   ``_apply_runtime_settings(compiled, RuntimeSettings(...))``), matching
   the runtime-suite tests. This keeps them clean of deprecation warnings
   while serving as the canonical example of the new API.

3. **L1**: drop the now-stale ``runtime_cache_path`` parameter line from
   ``MutableTorchTensorRTModule.__init__``'s docstring. Touched up a
   parallel docstring in ``_TRTEngine`` (``Set cuda_graph_strategy at
   compile time`` -> ``Apply RuntimeSettings(...) via the runtime_config
   CM or mod.runtime_settings setter``).
Reverts the ``compile()`` deprecation shim from the previous commit.
Tests in the model and runtime suites have all been ported to the
post-compile ``mod.runtime_settings = RuntimeSettings(...)`` pattern;
no in-tree caller of ``compile()`` still passes the old kwargs. Going
straight to hard-removal is cleaner than carrying the warning machinery.

@tp5uiuc tp5uiuc left a comment

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Placeholder.

Comment thread py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Comment thread py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py Outdated
tp5uiuc added 2 commits June 10, 2026 10:20
CI fix
- ``RuntimeSettings.cpp`` was relying on ``ATen/core/Tensor.h`` transitively
  pulling in factory functions; torch_l4t (Jetpack) only ships the minimal
  ``Tensor.h``, so ``at::empty`` failed to resolve. Add ``ATen/ATen.h``
  explicitly. Unblocks the Jetpack lane on this PR.

Follow-up review items
- **F1**: ``set_extra_state`` now also resets ``_implicit_cache_handle`` to
  ``None`` alongside the ``_runtime_settings`` reset. Otherwise a stale
  wrapper from prior use survives ``load_state_dict`` and the next setter
  could silently write the fresh engine's cache bytes to the old path.
- **F2**: added the requested invariant comment above the conditional
  reconciliation in ``runtime_settings.setter`` so a future reviewer
  doesn't mistake the ``if`` for an unconditional merge.

Tests pinning the fixes
- ``TestSaveLoadRuntimeSettingsRoundTrip``: save -> load -> setter must not
  raise. Catches a regression of B2 if the ``_implicit_cache_handle = None``
  init ever moves back out of ``__init__``.
- ``TestNestedRuntimeConfigCudagraphs``: asserts the central perf invariant
  of the runtime CM + cudagraphs composition -- nested form (runtime CM
  outside) yields one ``createExecutionContext`` call, inverted form
  yields two. Pins the contract documented in the PR description.
- ``TestToTorchbindHandleOrphanGuard``: exercises the H1 raise path so a
  future change to the silent-orphan fallback gets caught.
…ytorch#5)

The class-level skipIf on ``TestRuntimeCacheStreamSupport`` only fenced
the round-trip tests (which legitimately can't run on cpp because the
``IRuntimeCache`` materializes lazily on context creation and bytes
loaded before that don't survive). The first-run flavor -- CM enter ->
forward -> exit-with-bytes-in-buf -- works the same on both runtimes
and exercises the cpp dispatch glue (handle construction, attach to
torchbind engine, save-on-exit).

No content assertion on the saved bytes -- workload-dependent on cpp.
The non-raising exit is the contract the test protects.
Comment thread tests/py/dynamo/runtime/test_004_runtime_settings.py
The prior ``test_setter_after_save_load_does_not_raise`` used
``torch.save`` / ``torch.load`` and so didn't actually exercise the B2
bug path. Pickle goes through ``nn.Module.__setstate__`` (wholesale
``__dict__`` restore); ``set_extra_state`` is never called. So the
test would have passed even with the B2 fix reverted, because the slot
was preserved as a regular ``__dict__`` entry on the saved side.

Swap to the ``state_dict`` / ``load_state_dict`` round-trip, which is
the path that actually goes through ``set_extra_state``. Renamed class
to ``TestStateDictRoundTripRuntimeSettings`` and updated the docstring
to be honest about what it pins down.
Comment thread core/runtime/BUILD
Comment thread core/runtime/BUILD
Comment thread core/runtime/execute_engine.cpp Outdated
Comment thread core/runtime/execute_engine.cpp Outdated
Comment thread core/runtime/register_jit_hooks.cpp Outdated
@tp5uiuc tp5uiuc requested a review from narendasan June 12, 2026 16:30
Comment thread core/runtime/RuntimeSettings.cpp Outdated

// Reverse-lookup tables. Indices match the enum integer values (which mirror
// the nvinfer1 enums). Out-of-range -> "<unknown>".
constexpr std::array<std::string_view, 3> kDsStrategyNames = {"lazy", "eager", "none"};

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there like a TRT-RTX enum or something that defines these names?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not I would prefer to create one enum type for our library (kind of rust style) and then have methods to go between strings / ints etc. This is a common pattern in the library if you search around. Things like dtype core/util or the similar pattern in python as well https://github.com/pytorch/TensorRT/blob/main/core/util/trt_util.h

enum Type { kITensor, kIValue, kNone };

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Naren, yes. Since we don't have public headers (TensorRT-OSS strips RTX symbols), you can see them here:
https://docs.nvidia.com/deeplearning/tensorrt-rtx/latest/inference-library/work-with-dynamic-shapes.html#setting-the-kernel-specialization-strategy

@narendasan narendasan left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went mostly through the C++ so far. Also would it be possible to add examples using the apis?

Comment thread core/runtime/RuntimeSettings.cpp Outdated

// Reverse-lookup tables. Indices match the enum integer values (which mirror
// the nvinfer1 enums). Out-of-range -> "<unknown>".
constexpr std::array<std::string_view, 3> kDsStrategyNames = {"lazy", "eager", "none"};

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not I would prefer to create one enum type for our library (kind of rust style) and then have methods to go between strings / ints etc. This is a common pattern in the library if you search around. Things like dtype core/util or the similar pattern in python as well https://github.com/pytorch/TensorRT/blob/main/core/util/trt_util.h

enum Type { kITensor, kIValue, kNone };

std::memcpy(tensor.data_ptr(), host_mem->data(), host_mem->size());
return tensor;
#else
return empty();

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the right error handling in this case?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent seems wrong, maybe a LOG_WARN?

auto const* p = static_cast<uint8_t const*>(contig.data_ptr());
pending_warm_bytes_.assign(p, p + contig.numel());
}
#endif

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

}

#ifdef TRT_MAJOR_RTX
std::shared_ptr<nvinfer1::IRuntimeCache> RuntimeCacheHandle::ensure_materialized(nvinfer1::IRuntimeConfig* config) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this method removed but the others are defined but just noop?

Comment thread core/runtime/RuntimeSettings.h Outdated
// through the data model -- only the ``static_cast`` to the nvinfer1 type
// (inside ``TRTRuntimeConfig::ensure_initialized``) is RTX-only. Integer
// values must stay in sync with the nvinfer1 enums.
enum class DynamicShapesKernelSpecializationStrategy : int32_t {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re the comments above, lets use these. I think in general we shouldnt have semantic values passed around except at the API boundary. We should use typed values


// Returns true iff any of the listed input bindings (including shape tensors) has a
// dynamic dimension.
[[nodiscard]] bool engine_has_dynamic_inputs(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need to be in a anonymous namespace? Seems like a useful enough function to have. @cehongwang might help your optimization profile work

@cehongwang cehongwang Jun 12, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the other PR I put this function in TRTInterpreter to minimize the runtime overhead:

@tp5uiuc can we reuse num_optimization_profile after that PR goes in? Basically 0 means static and >= 1 means dynamic


TORCHTRT_CHECK(exec_ctx.get() != nullptr, "Cannot bind NCCL communicator: execution context is null");
exec_ctx->setCommunicator(reinterpret_cast<void*>(comm_ptr));
// Distributed engines must hold a live IExecutionContext at bind time.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did this drop the check?

Fixes the L1 dynamo compile regression:

  RuntimeError: Tried to deepcopy object
  __torch__.torch.classes.tensorrt.RuntimeCacheHandle which does not have a
  __getstate__ method defined!

surfaced by ``tests/py/dynamo/models/test_export_serde.py`` and
``test_export_kwargs_serde.py``. Before this PR the cpp torchbind
``RuntimeCacheHandle`` was constructed lazily inside the cpp engine and
never lived in Python's attribute graph. With the new design,
``_TorchTensorRTModule._implicit_cache_handle`` is a Python ``RuntimeCache``
whose ``_handle`` field IS the torchbind class on cpp-rt builds, so
``torch.export.save``/``deepcopy`` walks hit it.

Registers ``__getstate__`` / ``__setstate__`` via ``def_pickle`` that
serialize only the ``path`` string. The underlying ``IRuntimeCache`` is
GPU-side state that can't cross a process boundary anyway, and the new
process's first ``_resolve_runtime_cache`` warm-loads the disk file
through the standard ``load`` -> ``pending_warm_bytes`` flow.

Verified locally: ``test_export_serde.py`` / ``test_export_kwargs_serde.py``
pass after the cpp rebuild; the 4 runtime test files still pass (41/20/0
cpp-rt, 58/3/0 python-rt) -- no regressions.

@cehongwang cehongwang left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments

invalidate_exec_ctx();
// Existing recreate sites set runtime_states.context_changed for cudagraph
// re-record; do the same here so a settings flip inside an active CM forces
// the next enqueue to re-record any captured graph.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to reset the nccl state?

Comment thread core/runtime/execute_engine.cpp Outdated

auto expected_type =
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx->getEngine().getTensorDataType(name.c_str()));
util::TRTDataTypeToScalarType(compiled_engine->exec_ctx()->getEngine().getTensorDataType(name.c_str()));

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we do

ctx = compiled_engine->exec_ctx()
ctx->getEngine()
ctx->setTensorAddress
...

to avoid checking a million times whether it is null

tp5uiuc added 5 commits June 12, 2026 18:02
Addresses Naren's review feedback on the strategy types
(``RuntimeSettings.cpp:23`` + ``RuntimeSettings.h:78``). Replaces the
two ``enum class : int32_t`` types with "newtype" wrapping classes that
hold a nested ``enum Value``, mirroring the ``Var`` / ``Var::Type``
pattern in ``core/conversion/var/Var.h``. All conversions
(``to_string`` / ``to_underlying`` / ``from_underlying`` / ``from_string``)
live on the type as member or static functions; the four free helpers
(``to_dynamic_shapes_kernel_strategy`` / ``to_cuda_graph_strategy`` /
``ds_strategy_name`` / ``cg_strategy_name``) are removed.

The implicit ``operator Value() const noexcept`` unwrap keeps existing
usage compiling unchanged:

  - ``switch (s) { case CudaGraphStrategy::kDISABLED: ... }`` -- unchanged
  - ``static_cast<nvinfer1::CudaGraphStrategy>(s)`` -- unchanged via unwrap
  - ``s == CudaGraphStrategy::kDISABLED`` -- unchanged via implicit ctor
  - ``RuntimeSettings`` field defaults (``... = Strategy::kFOO``) -- unchanged

Net cpp delta: ~70 lines (class defs in header) + 4 callsite updates
(``register_jit_hooks.cpp`` boundary cast goes through ``::from_underlying``,
``RuntimeSettings::to_str`` and ``TRTRuntimeConfig::ensure_initialized``
log via ``.to_string()``). No behavior change.
Addresses Naren's review: calling ``RuntimeCacheHandle::serialize()`` or
``::deserialize()`` on a non-RTX build is a logical error (no IRuntimeCache
exists; nothing to read or write). Today both methods silently return an
empty tensor / discard the bytes, which can mask configuration mistakes.

Emit a ``LOG_WARNING`` in each ``#else`` branch so the call is visible in
dev logs without escalating to a hard error (the methods are still callable
from Python via the torchbind binding on RTX builds, and unconditionally
callable from C++).
Addresses cehongwang's review on ``TRTEngine::runtime_settings``: the NCCL
communicator is bound onto the ``IExecutionContext`` via
``setCommunicator``. When ``runtime_settings`` invalidates the context, the
communicator binding goes with it, but ``nccl_initialized`` stayed true --
so the next ``execute_engine`` skipped the re-bind, leaving the new
context without a communicator.

Reset ``nccl_initialized = false`` alongside ``invalidate_exec_ctx()`` so
the next forward triggers ``bind_nccl_comm`` on the freshly materialized
context. Guarded by ``ENABLE_TRT_NCCL_COLLECTIVES`` to match the field's
own gate.
Addresses cehongwang's review: the TRT API path in ``execute_engine``
called ``compiled_engine->exec_ctx()`` ~20 times across the function and
its three helpers (``setup_input_tensors``, ``create_output_tensors``,
``create_output_allocator``), each repeating the null-check + lazy-create
branch. The lock on ``compiled_engine->mu`` makes the pointer stable for
the duration of the call, so a single materialization suffices.

Capture ``auto* ctx = compiled_engine->exec_ctx()`` once at the top of
``execute_engine`` (after the lazy NCCL bind) and thread it as a parameter
through the three file-static helpers. All ``compiled_engine->exec_ctx()
->X`` callsites collapse to ``ctx->X``. No behavior change.
Addresses cehongwang's review suggestion (verbatim, minor typo fix):
collapse the 8-line warm-load comment to a 2-line statement of intent.
The mechanics it described (python vs cpp pending-bytes buffer, prior
``_apply_settings``-only path) are now adequately covered by the cpp
class and ``RuntimeCache`` docstrings.
Comment thread core/runtime/register_jit_hooks.cpp Outdated
Comment thread core/runtime/TRTRuntimeConfig.cpp
Comment thread core/runtime/RuntimeSettings.cpp Outdated
Comment thread core/runtime/RuntimeSettings.cpp Outdated
// caught here -- casting to int32_t first would silently wrap.
// ---- DynamicShapesKernelSpecializationStrategy -----------------------------

std::string_view DynamicShapesKernelSpecializationStrategy::to_string() const noexcept {

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be constexpr

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 13472dd29.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should revert this, I don't want to expose the inline variables outside the RuntimeSettings.cpp TU.

Comment thread core/runtime/RuntimeSettings.cpp
Comment thread core/runtime/RuntimeSettings.cpp Outdated
Comment thread core/runtime/RuntimeSettings.cpp
Comment thread core/runtime/RuntimeSettings.cpp
Follow-up to the LOG_WARN pass on non-RTX paths. The TRT-RTX branch of
``RuntimeCacheHandle::serialize()`` had one more silent failure: when
``trt_handle_->serialize()`` returns ``nullptr`` (TRT-level allocation or
internal failure), we returned an empty tensor without surfacing the
error.

Add a ``LOG_WARNING`` so the host-memory allocation failure is visible.
The pre-materialize ``!trt_handle_`` branch above stays silent on
purpose: it fires on normal lifecycle states (autosave-on-del before any
forward, CM exit pre-execute) and would be noise as a warning.
``deserialize()``'s ``data.numel() == 0`` early-return likewise stays
silent: the Python wrapper already filters empty inputs upstream.
Comment thread core/runtime/execute_engine.cpp Outdated
tp5uiuc added 7 commits June 13, 2026 15:43
Follow-up to the prior LOG_WARN pass. Per reviewer follow-up, the two
remaining silent branches surface as warnings too:

- ``RuntimeCacheHandle::serialize()`` with ``!trt_handle_`` (wrapper exists
  but the underlying ``IRuntimeCache`` was never materialized -- e.g.
  autosave-on-del before any forward, CM exit pre-execute). Previously
  returned empty silently; users had no signal that the saved file was
  empty by design vs by bug.
- ``RuntimeCacheHandle::deserialize()`` with an empty input tensor.
  Reachable via direct torchbind calls that bypass the Python
  ``load_from_stream`` filter; useful for catching accidental empty
  loads.
Addresses reviewer follow-up: the ``def_pickle`` comment on the
torchbind ``RuntimeCacheHandle`` claimed the underlying ``IRuntimeCache``
is GPU-side state. It is CPU-side -- the cache holds host-memory
kernel-compilation metadata, not device buffers. Correct the wording so
the rationale for persisting only the ``path`` (no in-memory bytes)
matches reality.
Addresses reviewer comment ("Make both code-paths similar"). The
dynamic-shapes-kernel-specialization strategy path emits a ``LOG_DEBUG``
on every successful set; the cuda_graph_strategy path only warned on
failure with no success-side debug log, so the two paths read
asymmetrically.

Add the success-branch ``LOG_DEBUG`` to ``setCudaGraphStrategy`` so both
strategy attachments produce a uniform "X set to <value>" debug trail.
The failure ``LOG_WARNING`` stays -- ``setCudaGraphStrategy`` returns
bool unlike its DS counterpart, so the genuine failure signal is
preserved.
Per reviewer comment: both ``DynamicShapesKernelSpecializationStrategy``
and ``CudaGraphStrategy`` already have an ``operator<<`` overload that
forwards to ``to_string()``. Use the streaming overload directly in
``to_str()`` instead of the explicit ``.to_string()`` calls -- shorter
and avoids the redundant conversion at the print site.
Per reviewer comment ("can be constexpr"). The ``to_string()`` methods
on ``DynamicShapesKernelSpecializationStrategy`` and ``CudaGraphStrategy``
are pure index lookups over compile-time-known arrays, so they should be
constexpr; this lets a future constant-folded ``static_assert(s.to_string() == "lazy")``
work and avoids a function call at print sites.

The reverse-lookup arrays (``kDsStrategyNames`` / ``kCgStrategyNames``)
move from the .cpp anonymous namespace into the header as ``inline
constexpr`` so the inline ``to_string()`` definitions can see them at
compile time. ``from_underlying`` / ``from_string`` in the .cpp still
reference the same arrays via the header.

No behavior change at runtime; the change is purely "values become
usable in constant expressions".
Per reviewer suggestion: the four validator call sites
(``DynamicShapesKernelSpecializationStrategy`` / ``CudaGraphStrategy``,
each with ``from_underlying`` + ``from_string``) repeated the
"|"-joined name tail (``"lazy|eager|none"``, ``"disabled|whole_graph_capture"``)
in literal form. Adding or renaming a strategy required touching two
strings per type.

Introduce two ``constexpr`` -fold-friendly templates in the anonymous
namespace:

- ``join_string_views(sep, parts)`` for ``"a|b|c"``.
- ``format_expected_strategy(names)`` for ``"(expected 0..N-1 mapping to a|b|c)"``.
- ``format_expected_name(names)`` for ``"(expected a|b|c)"`` (the name-only
  variant used by ``from_string``).

Validator messages now render from ``kDsStrategyNames`` / ``kCgStrategyNames``
directly, so a new strategy value requires only one array edit.
Reviewer suggestion: instead of threading a ``nvinfer1::IExecutionContext*``
parameter through ``setup_input_tensors`` / ``create_output_tensors`` /
``create_output_allocator``, keep the original signatures and hoist
``auto* ctx = compiled_engine->exec_ctx();`` at the top of each helper.

Lower diff to the call sites; pays one extra ``exec_ctx()`` call per
helper invocation (still vastly fewer than the ~20 per-call rate the
original code paid). The top-level hoist inside ``execute_engine`` is
unchanged.
Comment thread core/runtime/TRTRuntimeConfig.cpp
Comment on lines +23 to +33
template <size_t N>
std::string join_string_views(std::string_view sep, std::array<std::string_view, N> const& parts) {
std::ostringstream os;
for (auto it = parts.begin(); it != parts.end(); ++it) {
if (it != parts.begin()) {
os << sep;
}
os << *it;
}
return os.str();
}

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

template <size_t N>
std::string join_string_views(std::string_view sep, std::array<std::string_view, N> const& parts) {
  if (N == 0) {
    return {};
  }
  std::ostringstream os;
  os << *std::cbegin(parts);
  for (auto it = std::next(std::cbegin(parts)); it != std::cend(parts); ++it) {
    os << sep << *it;
  }
  return os.str();
}

a but more cleaner

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests documentation Improvements or additions to documentation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants