fix(sampler): equalize per-rank sample counts to avoid inter-epoch deadlock#25
Conversation
…adlock When num_samples % comm_size != 0, dlio_sampler used math.ceil(N/size) and clamped the last rank to fewer samples than its peers. With drop_last=True the last rank produced fewer batches per epoch, so the per-step and end-of-epoch barriers in main._train() matched across iterations on the last rank — both are plain MPI_Barrier and MPI cannot distinguish call sites. Subsequent steps would then match an iteration-barrier on the older ranks against a reduce() from the next-epoch reconfigure() on the last rank, producing a permanent CPU-spinning deadlock. Replace ceil+clamp with floor division at the three matching call sites (torch_data_loader.dlio_sampler, config.build_sample_map_iter, config.get_global_map_index). Every rank now gets the same floor(N/comm_size) samples. Up to comm_size-1 trailing samples per epoch are dropped on purpose; rank 0 emits a warning telling the user this happened and recommending they pick total_samples as a multiple of comm_size to use every sample. Reported in mlcommons/storage#455 (uneven multi-rank checkpointing runs hang silently at the epoch boundary).
|
@russfellows @idevasena |
There was a problem hiding this comment.
One of the targeted regression test fails as below:
(.venv) smrc@dskbd029:~/DLIO_local_changes$ cat > tests/test_pr25_sampler_equalization.py <<'PY'
from dlio_benchmark.data_loader.torch_data_loader import dlio_sampler
def test_dlio_sampler_equalizes_uneven_rank_counts():
total = 100
size = 7
batch_size = 3
per_rank_samples = [
len(list(dlio_sampler(rank, size, total, epochs=1)))
for rank in range(size)
]
per_rank_batches = [n // batch_size for n in per_rank_samples]
assert per_rank_samples == [14] * 7
assert per_rank_batches == [4] * 7
assert total - sum(per_rank_samples) == 2
def test_dlio_sampler_len_matches_iterator_length():
total = 100
size = 7
for rank in range(size):
sampler = dlio_sampler(rank, size, total, epochs=1)
assert len(sampler) == len(list(iter(sampler)))
PY
python -m pytest tests/test_pr25_sampler_equalization.py -q
.F [100%]
============================================================================== FAILURES ==============================================================================
___________________________________________________________ test_dlio_sampler_len_matches_iterator_length ____________________________________________________________
def test_dlio_sampler_len_matches_iterator_length():
total = 100
size = 7
for rank in range(size):
sampler = dlio_sampler(rank, size, total, epochs=1)
> assert len(sampler) == len(list(iter(sampler)))
E assert 100 == 14
E + where 100 = len(<dlio_benchmark.data_loader.torch_data_loader.dlio_sampler object at 0x7c476a541280>)
E + and 14 = len([0, 1, 2, 3, 4, 5, ...])
E + where [0, 1, 2, 3, 4, 5, ...] = list(<generator object dlio_sampler.__iter__ at 0x7c476a0ed6c0>)
E + where <generator object dlio_sampler.__iter__ at 0x7c476a0ed6c0> = iter(<dlio_benchmark.data_loader.torch_data_loader.dlio_sampler object at 0x7c4
tests/test_pr25_sampler_equalization.py:24: AssertionError
------------------------------------------------------------------------- Captured log call --------------------------------------------------------------------------
WARNING DLIO:torch_data_loader.py:431 2026-06-17T23:26:12.651188 dlio_sampler: dropping 2 sample(s) — num_samples (100) is not a multiple of comm_size (7). Each rank
====================================================================== short test summary info =======================================================================
FAILED tests/test_pr25_sampler_equalization.py::test_dlio_sampler_len_matches_iterator_length - assert 100 == 14
1 failed, 1 passed in 1.43s
Reason: Because dlio_sampler.__len__() still returns self.num_samples, while __iter__() yields only the per-rank floored shard. The changed sampler code now builds self.indices from num_samples // size, but __len__ still returns the global sample count.
Fix:
Change:
return self.num_samples
to
return len(self.indices)
There was a problem hiding this comment.
Good catch, thank you — applied in c3a3508.
You're right that the right fix is return len(self.indices). Pre-existing latent bug: the old ceil+clamp also disagreed, just less sharply. Floor division makes it provable, which is what your test trips on.
I also confirmed it didn't matter for DLIO's main training loop (it bounds iteration with its own max_steps = floor(num_samples * num_files_train / batch_size / comm_size) at main.py:361, not len(loader)), but len(DataLoader) is derived from len(sampler), so any third-party consumer of that was over-reporting by a factor of comm_size.
Included your two test cases plus one for even division (N=100, size=10 → unchanged) as tests/test_dlio_sampler.py. All three pass locally.
Reviewer (@idevasena) caught that dlio_sampler.__len__ returns self.num_samples (global) while __iter__ yields only the per-rank shard (len(self.indices)). Pre-existing latent bug — the previous ceil+clamp also disagreed, just less obviously. With floor division it's sharp: a Sampler test of the form `len(sampler) == len(list(iter(sampler)))` trips it every time. DLIO's training loop bounds iteration with max_steps computed independently, so the loop didn't depend on len(sampler). But len(DataLoader) is derived from len(sampler), so any consumer of that (progress bars, third-party tools) sees an over-report by a factor of comm_size. Switch __len__ to return len(self.indices) and add three regression tests: - uneven counts equalized (N=100, size=7 → [14]*7, drops 2) - __len__ matches iterated length (the reviewer's test) - even division unchanged (N=100, size=10 → [10]*10)
|
@FileSystemGuy There was one more issue i.e. related to Thus, I updated the expected sums to use |
Summary
Fixes the inter-epoch deadlock reported in mlcommons/storage#455.
When
num_samples % comm_size != 0,dlio_samplerusedmath.ceil(N/size)and clamped the last rank to fewer samples. Withdrop_last=Truethe last rank produced fewer batches per epoch, and the per-step / end-of-epoch barriers inmain._train()matched across iterations on the last rank — both are plainMPI_Barrierand MPI cannot distinguish call sites. The next step then matched an iteration-barrier on the older ranks against areduce()from the next-epochreconfigure()on the last rank, producing a permanent CPU-spinning deadlock with no diagnostic.Change
Replaced
ceil+clampwith floor division at the three matching call sites:dlio_benchmark/data_loader/torch_data_loader.py—dlio_samplerdlio_benchmark/utils/config.py—build_sample_map_iterdlio_benchmark/utils/config.py—get_global_map_indexEvery rank now gets exactly
total_samples // comm_sizesamples. Up tocomm_size - 1trailing samples per epoch are dropped on purpose; rank 0 emits a warning telling the user this happened and that they should picktotal_samplesas a multiple ofcomm_sizeto use every sample.Verification (issue example, N=100, size=7)
[15,15,15,15,15,15,10][14,14,14,14,14,14,14]Other cases checked:
N=100,size=3→[33,33,33];N=100,size=10(already even) → unchanged at[10]*10.Notes / follow-ups
ConfigArguments.training_steps/eval_stepsstill useceil()(config.py:705-706). They aren't read bymain.py, so they don't drive the deadlock — but they're now a stale upper bound and should be aligned with the floored per-rank count in a follow-up if anything starts consuming them.N=100, size=7, batch=3, drop_last=Trueproduces equal batch counts) if you'd prefer it land here.Storage repo follow-up
Storage currently pins this fork at
1d11f9820(pyproject.toml). Once this PR merges, bump the storage pin to consume the fix.