Skip to content

Add: AllGather and ReduceScatter distributed L3 examples#842

Open
georgebisbas wants to merge 1 commit into
hw-native-sys:mainfrom
georgebisbas:feat/allgather-reduce-scatter-examples
Open

Add: AllGather and ReduceScatter distributed L3 examples#842
georgebisbas wants to merge 1 commit into
hw-native-sys:mainfrom
georgebisbas:feat/allgather-reduce-scatter-examples

Conversation

@georgebisbas
Copy link
Copy Markdown
Contributor

Summary

Adds two new symmetric collective communication examples following the
allreduce_distributed pattern.

New files

examples/workers/l3/allgather_distributed/

3-phase symmetric AllGather on the HCCL window scratch pattern:

Phase Description
1 stage-in input[0..C) → my scratch slot
2 barrier TNOTIFY/TWAIT pairwise signal matrix
3 gather for r in 0..N-1: TLOAD(peer scratch)TSTORE(output[r*C])
  • COUNT_PER_RANK = 64, window = 4 KB
  • Output is identical on every rank (rank-ordered concatenation)
  • Golden: output[r*C + i] = r*100 + i (closed-form)
  • CommRemotePtr<T> defined inline (same pattern as allreduce_distributed)

examples/workers/l3/reduce_scatter_distributed/

4-phase symmetric ReduceScatter on the HCCL window scratch pattern:

Phase Description
1 stage-in all N chunks → scratch (nranks × C floats)
2 barrier TNOTIFY/TWAIT pairwise signal matrix
3 reduce acc = own scratch[my_rank*C]; TADD all peers via CommRemotePtr
4 stage-out TSTORE acc → output
  • COUNT_PER_RANK = 64; scratch window is nranks-dependent (computed at runtime)
  • Each rank produces its own shard; golden per dest: nranks*(dest*C+j) + 100*nranks*(nranks-1)/2
  • CommRemotePtr<T> defined inline (same pattern as allreduce_distributed)

Testing

Both examples mirror test_allreduce.py:

  • Platforms: a2a3sim, a2a3, a5sim
  • Runtime: tensormap_and_ringbuffer
  • Parametrized: 2 and 4 devices
pytest examples/workers/l3/allgather_distributed/ -v
pytest examples/workers/l3/reduce_scatter_distributed/ -v
# or run standalone:
python examples/workers/l3/allgather_distributed/main.py -p a2a3sim -d 0-1
python examples/workers/l3/reduce_scatter_distributed/main.py -p a2a3sim -d 0-3

Add two new symmetric collective communication examples modelled on
allreduce_distributed, plus a shared comm_utils.h header.

New files:
  examples/workers/l3/common/comm_utils.h
    Shared CommRemotePtr<T> template extracted from the allreduce pattern.
    New kernels include it as "common/comm_utils.h"; existing kernels are
    not modified.

  examples/workers/l3/allgather_distributed/
    3-phase kernel: stage-in → barrier → gather.
    Input: COUNT_PER_RANK=64 floats/rank.
    Output: nranks*64 floats (rank-ordered concatenation, same on every rank).
    Golden: output[r*C+i] = r*100 + i  (closed-form, no reference run).

  examples/workers/l3/reduce_scatter_distributed/
    4-phase kernel: stage-in N chunks → barrier → reduce my chunk → stage-out.
    Input: nranks*64 floats/rank.
    Output: 64 floats/rank (rank-specific shard).
    Golden per dest: nranks*(dest*C+j) + 100*nranks*(nranks-1)/2.
    Scratch window is nranks-dependent; computed at runtime in run().

Both examples follow the orch.allocate_domain() API, the same
Worker/orch_fn/TaskArgs structure as allreduce_distributed, and include
pytest fixtures mirroring test_allreduce.py (a2a3sim/a2a3/a5sim,
n_devices 2 and 4).
@georgebisbas georgebisbas force-pushed the feat/allgather-reduce-scatter-examples branch from d1d4bcb to a9e7e16 Compare May 21, 2026 15:38
Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces end-to-end distributed implementations for AllGather and ReduceScatter operations, including AIV kernels, orchestration shims, and Python execution scripts. The feedback identifies critical synchronization risks in the barrier implementation, recommending a sense-reversing or phase-based mechanism to prevent race conditions when memory is reused. Additionally, there are suggestions to improve the clarity of the buffer specifications in the Python scripts by ensuring the element count accurately reflects the total allocated memory, including signal areas.

Comment on lines +109 to +120
for (int peer = 0; peer < nranks; ++peer) {
if (peer == my_rank) continue;
__gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer);
pto::comm::Signal sig(remote_signal);
pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd);
}
for (int peer = 0; peer < nranks; ++peer) {
if (peer == my_rank) continue;
pto::comm::Signal sig(signal_base + peer);
pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE);
}
pipe_barrier(PIPE_ALL);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The barrier implementation should use a sense-reversing or phase-based mechanism to prevent race conditions in distributed environments where memory is reused. Each participant must snapshot the current phase before incrementing the arrival counter. This ensures that fast participants do not interfere with the current barrier generation by looping back too quickly, which is a risk if the signal area is not managed correctly.

References
  1. When implementing a reusable barrier with atomic counters, use a sense-reversing or phase-based mechanism. Each participant should snapshot the current phase before incrementing the arrival counter to prevent race conditions where fast participants loop back and interfere with the current barrier generation.

Comment on lines +117 to +128
for (int peer = 0; peer < nranks; ++peer) {
if (peer == my_rank) continue;
__gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer);
pto::comm::Signal sig(remote_signal);
pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd);
}
for (int peer = 0; peer < nranks; ++peer) {
if (peer == my_rank) continue;
pto::comm::Signal sig(signal_base + peer);
pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE);
}
pipe_barrier(PIPE_ALL);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To ensure robust synchronization when scratch memory is reused, implement a sense-reversing or phase-based barrier. It is critical to snapshot the current phase before incrementing the arrival counter to prevent race conditions between successive barrier generations, ensuring that no rank proceeds based on stale signal values.

References
  1. When implementing a reusable barrier with atomic counters, use a sense-reversing or phase-based mechanism. Each participant should snapshot the current phase before incrementing the arrival counter to prevent race conditions where fast participants loop back and interfere with the current barrier generation.

window_size=window_size,
buffers=[CommBufferSpec(name="scratch", dtype="float32", count=COUNT_PER_RANK, nbytes=SCRATCH_NBYTES)],
) as handle:
for i in range(nranks):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The count parameter in CommBufferSpec is set to COUNT_PER_RANK (64), but the actual buffer size SCRATCH_NBYTES (320) includes additional space for signals. While nbytes is correctly provided, having a count that only reflects the data portion and not the full buffer capacity (including the signal tail) can be misleading for maintenance. Consider setting count to the total number of elements or ensuring it aligns with the intended usage of the buffer.

Comment on lines +165 to +171
count=nranks * COUNT_PER_RANK,
nbytes=scratch_nbytes,
)
],
) as handle:
for i in range(nranks):
domain = handle[i]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The count in CommBufferSpec for the scratch buffer only accounts for the float data (nranks * COUNT_PER_RANK), but the buffer also contains the signal area. It is better to define the buffer specification such that the count and dtype accurately represent the allocated memory or rely solely on nbytes if the buffer is heterogeneous.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant