Add: AllGather and ReduceScatter distributed L3 examples#842
Add: AllGather and ReduceScatter distributed L3 examples#842georgebisbas wants to merge 1 commit into
Conversation
Add two new symmetric collective communication examples modelled on
allreduce_distributed, plus a shared comm_utils.h header.
New files:
examples/workers/l3/common/comm_utils.h
Shared CommRemotePtr<T> template extracted from the allreduce pattern.
New kernels include it as "common/comm_utils.h"; existing kernels are
not modified.
examples/workers/l3/allgather_distributed/
3-phase kernel: stage-in → barrier → gather.
Input: COUNT_PER_RANK=64 floats/rank.
Output: nranks*64 floats (rank-ordered concatenation, same on every rank).
Golden: output[r*C+i] = r*100 + i (closed-form, no reference run).
examples/workers/l3/reduce_scatter_distributed/
4-phase kernel: stage-in N chunks → barrier → reduce my chunk → stage-out.
Input: nranks*64 floats/rank.
Output: 64 floats/rank (rank-specific shard).
Golden per dest: nranks*(dest*C+j) + 100*nranks*(nranks-1)/2.
Scratch window is nranks-dependent; computed at runtime in run().
Both examples follow the orch.allocate_domain() API, the same
Worker/orch_fn/TaskArgs structure as allreduce_distributed, and include
pytest fixtures mirroring test_allreduce.py (a2a3sim/a2a3/a5sim,
n_devices 2 and 4).
d1d4bcb to
a9e7e16
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces end-to-end distributed implementations for AllGather and ReduceScatter operations, including AIV kernels, orchestration shims, and Python execution scripts. The feedback identifies critical synchronization risks in the barrier implementation, recommending a sense-reversing or phase-based mechanism to prevent race conditions when memory is reused. Additionally, there are suggestions to improve the clarity of the buffer specifications in the Python scripts by ensuring the element count accurately reflects the total allocated memory, including signal areas.
| for (int peer = 0; peer < nranks; ++peer) { | ||
| if (peer == my_rank) continue; | ||
| __gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer); | ||
| pto::comm::Signal sig(remote_signal); | ||
| pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd); | ||
| } | ||
| for (int peer = 0; peer < nranks; ++peer) { | ||
| if (peer == my_rank) continue; | ||
| pto::comm::Signal sig(signal_base + peer); | ||
| pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE); | ||
| } | ||
| pipe_barrier(PIPE_ALL); |
There was a problem hiding this comment.
The barrier implementation should use a sense-reversing or phase-based mechanism to prevent race conditions in distributed environments where memory is reused. Each participant must snapshot the current phase before incrementing the arrival counter. This ensures that fast participants do not interfere with the current barrier generation by looping back too quickly, which is a risk if the signal area is not managed correctly.
References
- When implementing a reusable barrier with atomic counters, use a sense-reversing or phase-based mechanism. Each participant should snapshot the current phase before incrementing the arrival counter to prevent race conditions where fast participants loop back and interfere with the current barrier generation.
| for (int peer = 0; peer < nranks; ++peer) { | ||
| if (peer == my_rank) continue; | ||
| __gm__ int32_t *remote_signal = CommRemotePtr(commCtx, signal_base + my_rank, peer); | ||
| pto::comm::Signal sig(remote_signal); | ||
| pto::comm::TNOTIFY(sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd); | ||
| } | ||
| for (int peer = 0; peer < nranks; ++peer) { | ||
| if (peer == my_rank) continue; | ||
| pto::comm::Signal sig(signal_base + peer); | ||
| pto::comm::TWAIT(sig, (int32_t)1, pto::comm::WaitCmp::GE); | ||
| } | ||
| pipe_barrier(PIPE_ALL); |
There was a problem hiding this comment.
To ensure robust synchronization when scratch memory is reused, implement a sense-reversing or phase-based barrier. It is critical to snapshot the current phase before incrementing the arrival counter to prevent race conditions between successive barrier generations, ensuring that no rank proceeds based on stale signal values.
References
- When implementing a reusable barrier with atomic counters, use a sense-reversing or phase-based mechanism. Each participant should snapshot the current phase before incrementing the arrival counter to prevent race conditions where fast participants loop back and interfere with the current barrier generation.
| window_size=window_size, | ||
| buffers=[CommBufferSpec(name="scratch", dtype="float32", count=COUNT_PER_RANK, nbytes=SCRATCH_NBYTES)], | ||
| ) as handle: | ||
| for i in range(nranks): |
There was a problem hiding this comment.
The count parameter in CommBufferSpec is set to COUNT_PER_RANK (64), but the actual buffer size SCRATCH_NBYTES (320) includes additional space for signals. While nbytes is correctly provided, having a count that only reflects the data portion and not the full buffer capacity (including the signal tail) can be misleading for maintenance. Consider setting count to the total number of elements or ensuring it aligns with the intended usage of the buffer.
| count=nranks * COUNT_PER_RANK, | ||
| nbytes=scratch_nbytes, | ||
| ) | ||
| ], | ||
| ) as handle: | ||
| for i in range(nranks): | ||
| domain = handle[i] |
There was a problem hiding this comment.
The count in CommBufferSpec for the scratch buffer only accounts for the float data (nranks * COUNT_PER_RANK), but the buffer also contains the signal area. It is better to define the buffer specification such that the count and dtype accurately represent the allocated memory or rely solely on nbytes if the buffer is heterogeneous.
Summary
Adds two new symmetric collective communication examples following the
allreduce_distributedpattern.New files
examples/workers/l3/allgather_distributed/3-phase symmetric AllGather on the HCCL window scratch pattern:
input[0..C)→ my scratch slotTNOTIFY/TWAITpairwise signal matrixTLOAD(peer scratch)→TSTORE(output[r*C])COUNT_PER_RANK = 64, window = 4 KBoutput[r*C + i] = r*100 + i(closed-form)CommRemotePtr<T>defined inline (same pattern asallreduce_distributed)examples/workers/l3/reduce_scatter_distributed/4-phase symmetric ReduceScatter on the HCCL window scratch pattern:
TNOTIFY/TWAITpairwise signal matrixscratch[my_rank*C];TADDall peers viaCommRemotePtrTSTOREacc → outputCOUNT_PER_RANK = 64; scratch window is nranks-dependent (computed at runtime)dest:nranks*(dest*C+j) + 100*nranks*(nranks-1)/2CommRemotePtr<T>defined inline (same pattern asallreduce_distributed)Testing
Both examples mirror
test_allreduce.py:a2a3sim,a2a3,a5simtensormap_and_ringbufferpytest examples/workers/l3/allgather_distributed/ -v pytest examples/workers/l3/reduce_scatter_distributed/ -v # or run standalone: python examples/workers/l3/allgather_distributed/main.py -p a2a3sim -d 0-1 python examples/workers/l3/reduce_scatter_distributed/main.py -p a2a3sim -d 0-3