Skip to content

Add: dep_gen overflow chain — capture submits with arbitrary explicit dep count#859

Open
ChaoWao wants to merge 1 commit into
hw-native-sys:mainfrom
ChaoWao:feat/dep-gen-overflow-chain
Open

Add: dep_gen overflow chain — capture submits with arbitrary explicit dep count#859
ChaoWao wants to merge 1 commit into
hw-native-sys:mainfrom
ChaoWao:feat/dep-gen-overflow-chain

Conversation

@ChaoWao
Copy link
Copy Markdown
Collaborator

@ChaoWao ChaoWao commented May 26, 2026

Summary

DepGenRecord::explicit_deps[] held at most 16 entries, so submits with >16 explicit deps were silently truncated in deps.json — leaving big-fanin barriers (many-to-one barriers built with Arg::set_dependencies) underspecified in the captured graph. Runtime correctness was always fine; this was a diagnostic blind spot, but a load-bearing one for swimlane viewers.

This PR:

  • Bumps the inline base capacity from 16 → 64. Covers all realistic fan-in at +17% buffer footprint (record 2240 → 2624 B, tensors[] offset 192 → 576). Added exact-size + exact-offset static_assert guards so future layout drift fails at compile time.
  • Adds a wire-format overflow chain for submits past 64. DepGenOverflowRecord reinterprets a buffer slot (same size + alignment) as { task_id, flags, dep_count, deps[326] }, distinguished by DEP_GEN_FLAG_OVERFLOW. Base records set HAS_OVERFLOW; the final overflow sets LAST_OVERFLOW. Chain slots share the base task_id and are contiguous within one DepGenBuffer — writer reserves the full chain up front (switching buffer if needed) and publishes via one buf->count store, so the host sees either old count (chain invisible) or new count with everything committed.
  • Replay reads the chain back by skipping OVERFLOW slots in the main scan and assembling base + chain into a single deps_data buffer before driving compute_task_fanin and the annot mirror. count_outputs() also skips overflow slots so their reinterpreted bytes don't get misread as tensor_count/arg_types.
  • Fast path unchanged: dc ≤ 64 still writes exactly one record with no chain bookkeeping; replay points straight at rec.explicit_deps.

Chain that won't fit even in a fresh buffer is truncated to the largest dc that fits, logged via LOG_ERROR. Theoretical max per submit ≈ 64 + 31 × 326 = 10170 deps.

Reconciliation fix (caught by the new test)

reconcile_counters() checked collected + dropped == total_record_count, but total_collected_ counts physical buffer slots while total_record_count increments once per submit_task. Every chained submit over-counted the LHS and tripped the mismatch warning, and the host runner gates deps.json emission on a clean reconcile — so the chain feature silently produced no output for any submit with >64 deps before this fix.

Added DepGenBufferState::total_overflow_record_count (sized into the existing _pad[] so the 192-byte struct invariant holds). New reconciliation invariant:

collected + dropped == total_record_count + total_overflow_record_count

Test plan

  • Local pip install -e . builds clean on a2a3sim (darwin-arm64).
  • All static_asserts in dep_gen.h hold (DepGenRecord: 2624 B, tensors offset: 576; DepGenOverflowRecord size + alignment match base).
  • New tests/st/a2a3/tensormap_and_ringbuffer/dfx/dep_gen/test_dep_gen_chain.py passes on a2a3sim for n ∈ {64, 65, 200, 391}: barrier task has exactly N explicit predecessors in deps.json for every case.
  • No reconcile warnings emitted for any chain case (verified in run log).
  • CI a2a3 onboard run (please trigger).

Cross-referenced changes

  • src/a2a3/platform/include/common/dep_gen.h — capacity bump, chain wire format, helpers, total_overflow_record_count.
  • src/a2a3/platform/{include,src}/aicpu/dep_gen_collector_aicpu.{h,cpp} — writer rewrite to emit chain.
  • src/a2a3/platform/src/host/dep_gen_collector.cpp — reconciliation invariant.
  • src/a2a3/runtime/tensormap_and_ringbuffer/host/dep_gen_replay.cpp — chain join in replay scan + count_outputs() skip.
  • docs/dfx/dep_gen.md — new §7 documenting chain shape, atomicity, truncation tail.
  • tests/st/.../dfx/dep_gen/ — new test + reusable orchestration kernel.

…y explicit dep count

Before this change `DepGenRecord::explicit_deps[]` held at most 16 deps,
and submits with more were silently truncated in `deps.json` — leaving
big-fanin barriers underspecified in the captured graph. The runtime
itself (`Arg::set_dependencies`) has no such cap, so this was a
diagnostic blind spot only, but a load-bearing one for swimlane viewers.

Bumped the inline base capacity to 64 (covers all realistic fan-in
patterns at +17% buffer footprint: record 2240 → 2624 B, tensors[]
offset 192 → 576). Added exact-size / exact-offset static_asserts so
future layout drift trips at compile time.

For submits past 64, added a wire-format chain:

- `DepGenOverflowRecord` reinterprets a normal record slot (same size +
  alignment) as { task_id, flags, dep_count, deps[326] }, distinguished
  by `DEP_GEN_FLAG_OVERFLOW`. Base records set `HAS_OVERFLOW`; the last
  overflow sets `LAST_OVERFLOW`. Chain slots share the base's task_id
  and are always contiguous within one DepGenBuffer.
- Writer reserves all chain slots up front (switching buffer if needed).
  `buf->count` is published with one trailing store, so the host either
  sees the old count (chain invisible) or the new count with the full
  chain committed.
- A chain that would not fit even in a fresh buffer is truncated to the
  largest dc that fits, logged via `LOG_ERROR`. Runtime correctness is
  unaffected — `Arg::set_dependencies` keeps the full dep list; only the
  diagnostic replay drops the tail.

Replay (`dep_gen_replay_emit_deps_json`) now skips overflow slots in the
main scan, assembles base + chain into a single deps_data buffer before
driving `compute_task_fanin` and the annot mirror. `count_outputs()`
also skips overflow slots so their reinterpreted bytes don't get
misread as `tensor_count`/`arg_types`.

Fast path (dc ≤ 64) is bit-identical to the pre-chain version: one
record, no chain bookkeeping, replay points straight at
`rec.explicit_deps`.

### Reconciliation

Adding the test surfaced a real bug in the chain path: `reconcile_counters()`
checks `collected + dropped == total_record_count`, but `total_collected_`
counts physical buffer slots while `total_record_count` increments once
per `submit_task`. Every chained submit over-counted the LHS and tripped
the mismatch warning, and the host runner gates `deps.json` emission on
a clean reconcile — so the chain feature silently produced no output
for any submit with >64 deps.

Split into a separate counter:
`DepGenBufferState::total_overflow_record_count` tracks the extra slots
committed by chained submits, sized into the existing `_pad[11]` so the
192-byte struct invariant still holds. New reconciliation invariant:

  collected + dropped == total_record_count + total_overflow_record_count

### Test

New `tests/st/a2a3/tensormap_and_ringbuffer/dfx/dep_gen/test_dep_gen_chain.py`
+ `chain_barrier_orch.cpp` exercise four N values across the chain
boundaries:

- n=64: base only, no chain (baseline)
- n=65: base + 1 overflow record (1 dep in overflow)
- n=200: base + 1 overflow (136 deps in overflow)
- n=391: base + 2 overflow (326 + 1 deps across two overflows)

Each case asserts the barrier task has exactly N explicit predecessors
in `deps.json` — the round-trip check that pre-chain code would have
failed by truncating at index 63. `_post_validate` asserts `deps.json`
existence rather than returning silently, so any future reconciliation
regression surfaces as a hard failure with a pointer to the log.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces an overflow chain mechanism to handle "Big-Fanin Submits" with more than 64 explicit dependencies, allowing excess dependencies to spill into contiguous DepGenOverflowRecord slots overlaying the buffer. It updates the shared-memory layout, the AICPU writer, the host collector, and the replay engine to support this chained wire format, and adds comprehensive regression tests. The review feedback highlights several critical security and robustness improvements: addressing potential Time-of-Check to Time-of-Use (TOCTOU) vulnerabilities by caching buf->count locally, defensively validating buffer capacity to prevent out-of-bounds writes, clamping dependency counts during replay to avoid out-of-bounds reads on corrupted data, and preventing integer overflow in the record-calculation helper.

// Reserve the whole chain up front. If it won't fit in the current
// buffer, switch first; if it still won't fit (chain larger than the
// buffer), cap dc to what the new buffer can hold and log truncation.
if (buf->count + static_cast<uint32_t>(needed) > static_cast<uint32_t>(PLATFORM_DEP_GEN_RECORDS_PER_BUFFER)) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Since buf->count resides in volatile shared memory, reading it multiple times can lead to Time-of-Check to Time-of-Use (TOCTOU) vulnerabilities or inconsistent state if the shared memory is concurrently modified or corrupted. We should read buf->count once into a local variable local_count and use it for the capacity check and subsequent writes.

Suggested change
if (buf->count + static_cast<uint32_t>(needed) > static_cast<uint32_t>(PLATFORM_DEP_GEN_RECORDS_PER_BUFFER)) {
uint32_t local_count = buf->count;
if (local_count + static_cast<uint32_t>(needed) > static_cast<uint32_t>(PLATFORM_DEP_GEN_RECORDS_PER_BUFFER)) {

Comment on lines +202 to +203
const int capacity = PLATFORM_DEP_GEN_RECORDS_PER_BUFFER - static_cast<int>(buf->count);
if (needed > capacity) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

After switching the buffer, we must refresh our local copy of buf->count from the new buffer. Additionally, we must defensively validate that capacity is strictly positive. If capacity <= 0 (which can happen if the shared memory buf->count is corrupted or manipulated to be >= PLATFORM_DEP_GEN_RECORDS_PER_BUFFER), we must drop the record and return early to prevent a critical out-of-bounds write to buf->records[idx].

Suggested change
const int capacity = PLATFORM_DEP_GEN_RECORDS_PER_BUFFER - static_cast<int>(buf->count);
if (needed > capacity) {
local_count = buf->count;
const int capacity = PLATFORM_DEP_GEN_RECORDS_PER_BUFFER - static_cast<int>(local_count);
if (capacity <= 0) {
LOG_ERROR("dep_gen: invalid capacity %d (buf->count=%u), dropping record", capacity, local_count);
s_dep_gen_state->dropped_record_count += 1;
wmb();
return;
}
if (needed > capacity) {
References
  1. Defensively clamp counts or sizes read from shared memory to the maximum capacity of local buffers to prevent stack overflows caused by potential memory corruption or builder bypasses.

);
dc = DEP_GEN_MAX_EXPLICIT_DEPS;
// ---- Write base record ----
uint32_t idx = buf->count;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Use the validated local variable local_count instead of re-reading buf->count from volatile shared memory to avoid TOCTOU issues.

Suggested change
uint32_t idx = buf->count;
uint32_t idx = local_count;

Comment on lines +548 to 589
if (rec.flags & DEP_GEN_FLAG_HAS_OVERFLOW) {
full_deps_buf.clear();
full_deps_buf.reserve(static_cast<size_t>(rec.explicit_dep_count) + DEP_GEN_OVERFLOW_DEPS_PER_RECORD);
full_deps_buf.insert(full_deps_buf.end(), rec.explicit_deps, rec.explicit_deps + rec.explicit_dep_count);
bool chain_complete = false;
for (size_t j = rec_i + 1; j < num_records; j++) {
const DepGenRecord &maybe = records[j];
if (!(maybe.flags & DEP_GEN_FLAG_OVERFLOW)) {
LOG_ERROR(
"dep_gen replay: unterminated overflow chain at rec_idx=%zu (task_id=%" PRIu64 ")", rec_i,
rec.task_id
);
break;
}
if (maybe.task_id != rec.task_id) {
LOG_ERROR(
"dep_gen replay: orphan overflow at rec_idx=%zu (expected task_id=%" PRIu64 ", found %" PRIu64
")",
j, rec.task_id, maybe.task_id
);
break;
}
const auto *over = reinterpret_cast<const DepGenOverflowRecord *>(&maybe);
full_deps_buf.insert(full_deps_buf.end(), over->deps, over->deps + over->dep_count);
if (over->flags & DEP_GEN_FLAG_LAST_OVERFLOW) {
chain_complete = true;
break;
}
}
if (!chain_complete) {
LOG_ERROR(
"dep_gen replay: chain for task_id=%" PRIu64 " missing LAST_OVERFLOW marker — "
"using partial dep list (%zu deps)",
rec.task_id, full_deps_buf.size()
);
}
deps_data = full_deps_buf.data();
dc = static_cast<int32_t>(full_deps_buf.size());
} else {
deps_data = rec.explicit_deps;
dc = static_cast<int32_t>(rec.explicit_dep_count);
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

The values of rec.explicit_dep_count and over->dep_count are read directly from recorded data (which can be corrupted or manipulated). If they exceed their respective array bounds (DEP_GEN_MAX_EXPLICIT_DEPS and DEP_GEN_OVERFLOW_DEPS_PER_RECORD), the replay scan will perform out-of-bounds reads on rec.explicit_deps and over->deps, potentially causing segmentation faults or memory disclosure. We must defensively clamp these counts before using them as loop or copy bounds.

        if (rec.flags & DEP_GEN_FLAG_HAS_OVERFLOW) {
            full_deps_buf.clear();
            uint16_t base_dc = rec.explicit_dep_count;
            if (base_dc > DEP_GEN_MAX_EXPLICIT_DEPS) {
                base_dc = DEP_GEN_MAX_EXPLICIT_DEPS;
            }
            full_deps_buf.reserve(static_cast<size_t>(base_dc) + DEP_GEN_OVERFLOW_DEPS_PER_RECORD);
            full_deps_buf.insert(full_deps_buf.end(), rec.explicit_deps, rec.explicit_deps + base_dc);
            bool chain_complete = false;
            for (size_t j = rec_i + 1; j < num_records; j++) {
                const DepGenRecord &maybe = records[j];
                if (!(maybe.flags & DEP_GEN_FLAG_OVERFLOW)) {
                    LOG_ERROR(
                        "dep_gen replay: unterminated overflow chain at rec_idx=%zu (task_id=%" PRIu64 ")", rec_i,
                        rec.task_id
                    );
                    break;
                }
                if (maybe.task_id != rec.task_id) {
                    LOG_ERROR(
                        "dep_gen replay: orphan overflow at rec_idx=%zu (expected task_id=%" PRIu64 ", found %" PRIu64
                        ")",
                        j, rec.task_id, maybe.task_id
                    );
                    break;
                }
                const auto *over = reinterpret_cast<const DepGenOverflowRecord *>(&maybe);
                uint16_t over_dc = over->dep_count;
                if (over_dc > DEP_GEN_OVERFLOW_DEPS_PER_RECORD) {
                    over_dc = DEP_GEN_OVERFLOW_DEPS_PER_RECORD;
                }
                full_deps_buf.insert(full_deps_buf.end(), over->deps, over->deps + over_dc);
                if (over->flags & DEP_GEN_FLAG_LAST_OVERFLOW) {
                    chain_complete = true;
                    break;
                }
            }
            if (!chain_complete) {
                LOG_ERROR(
                    "dep_gen replay: chain for task_id=%" PRIu64 " missing LAST_OVERFLOW marker — "
                    "using partial dep list (%zu deps)",
                    rec.task_id, full_deps_buf.size()
                );
            }
            deps_data = full_deps_buf.data();
            dc = static_cast<int32_t>(full_deps_buf.size());
        } else {
            deps_data = rec.explicit_deps;
            dc = static_cast<int32_t>(rec.explicit_dep_count);
            if (dc > DEP_GEN_MAX_EXPLICIT_DEPS) {
                dc = DEP_GEN_MAX_EXPLICIT_DEPS;
            }
        }
References
  1. Defensively clamp counts or sizes read from shared memory to the maximum capacity of local buffers to prevent stack overflows caused by potential memory corruption or builder bypasses.

Comment on lines +167 to +171
inline int dep_gen_records_needed_for(int dc) {
if (dc <= DEP_GEN_MAX_EXPLICIT_DEPS) return 1;
int spill = dc - DEP_GEN_MAX_EXPLICIT_DEPS;
return 1 + (spill + DEP_GEN_OVERFLOW_DEPS_PER_RECORD - 1) / DEP_GEN_OVERFLOW_DEPS_PER_RECORD;
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If dc is extremely large (e.g., close to INT_MAX), the expression spill + DEP_GEN_OVERFLOW_DEPS_PER_RECORD - 1 can overflow int, leading to undefined behavior or incorrect slot calculations. We should perform the arithmetic using int64_t to prevent any potential integer overflow.

inline int dep_gen_records_needed_for(int dc) {
    if (dc <= DEP_GEN_MAX_EXPLICIT_DEPS) return 1;
    int64_t spill = static_cast<int64_t>(dc) - DEP_GEN_MAX_EXPLICIT_DEPS;
    return static_cast<int>(1 + (spill + DEP_GEN_OVERFLOW_DEPS_PER_RECORD - 1) / DEP_GEN_OVERFLOW_DEPS_PER_RECORD);
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant