Skip to content

fix(sampler): equalize per-rank sample counts to avoid inter-epoch deadlock#25

Merged
idevasena merged 3 commits into
mainfrom
fix/sampler-deadlock-uneven-ranks
Jun 18, 2026
Merged

fix(sampler): equalize per-rank sample counts to avoid inter-epoch deadlock#25
idevasena merged 3 commits into
mainfrom
fix/sampler-deadlock-uneven-ranks

Conversation

@FileSystemGuy

Copy link
Copy Markdown

Summary

Fixes the inter-epoch deadlock reported in mlcommons/storage#455.

When num_samples % comm_size != 0, dlio_sampler used math.ceil(N/size) and clamped the last rank to fewer samples. With drop_last=True the last rank produced fewer batches per epoch, and the per-step / end-of-epoch barriers in main._train() matched across iterations on the last rank — both are plain MPI_Barrier and MPI cannot distinguish call sites. The next step then matched an iteration-barrier on the older ranks against a reduce() from the next-epoch reconfigure() on the last rank, producing a permanent CPU-spinning deadlock with no diagnostic.

Change

Replaced ceil+clamp with floor division at the three matching call sites:

  • dlio_benchmark/data_loader/torch_data_loader.pydlio_sampler
  • dlio_benchmark/utils/config.pybuild_sample_map_iter
  • dlio_benchmark/utils/config.pyget_global_map_index

Every rank now gets exactly total_samples // comm_size samples. Up to comm_size - 1 trailing samples per epoch are dropped on purpose; rank 0 emits a warning telling the user this happened and that they should pick total_samples as a multiple of comm_size to use every sample.

Verification (issue example, N=100, size=7)

Ranks Old per-rank lens New per-rank lens
0..6 [15,15,15,15,15,15,10] [14,14,14,14,14,14,14]

Other cases checked: N=100,size=3[33,33,33]; N=100,size=10 (already even) → unchanged at [10]*10.

Notes / follow-ups

  • ConfigArguments.training_steps/eval_steps still use ceil() (config.py:705-706). They aren't read by main.py, so they don't drive the deadlock — but they're now a stale upper bound and should be aligned with the floored per-rank count in a follow-up if anything starts consuming them.
  • No new unit tests in this PR — the sampler logic is exercised by every multi-rank run and is now demonstrably equal across ranks. Happy to add a focused regression test (the original report includes an acceptance test: N=100, size=7, batch=3, drop_last=True produces equal batch counts) if you'd prefer it land here.

Storage repo follow-up

Storage currently pins this fork at 1d11f9820 (pyproject.toml). Once this PR merges, bump the storage pin to consume the fix.

…adlock

When num_samples % comm_size != 0, dlio_sampler used math.ceil(N/size) and
clamped the last rank to fewer samples than its peers. With drop_last=True
the last rank produced fewer batches per epoch, so the per-step and
end-of-epoch barriers in main._train() matched across iterations on the
last rank — both are plain MPI_Barrier and MPI cannot distinguish call
sites. Subsequent steps would then match an iteration-barrier on the
older ranks against a reduce() from the next-epoch reconfigure() on the
last rank, producing a permanent CPU-spinning deadlock.

Replace ceil+clamp with floor division at the three matching call sites
(torch_data_loader.dlio_sampler, config.build_sample_map_iter,
config.get_global_map_index). Every rank now gets the same
floor(N/comm_size) samples. Up to comm_size-1 trailing samples per epoch
are dropped on purpose; rank 0 emits a warning telling the user this
happened and recommending they pick total_samples as a multiple of
comm_size to use every sample.

Reported in mlcommons/storage#455 (uneven multi-rank checkpointing runs
hang silently at the epoch boundary).
@FileSystemGuy

Copy link
Copy Markdown
Author

@russfellows @idevasena
Please review and approve.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One of the targeted regression test fails as below:

(.venv) smrc@dskbd029:~/DLIO_local_changes$ cat > tests/test_pr25_sampler_equalization.py <<'PY'
from dlio_benchmark.data_loader.torch_data_loader import dlio_sampler

def test_dlio_sampler_equalizes_uneven_rank_counts():
    total = 100
    size = 7
    batch_size = 3

    per_rank_samples = [
        len(list(dlio_sampler(rank, size, total, epochs=1)))
        for rank in range(size)
    ]
    per_rank_batches = [n // batch_size for n in per_rank_samples]

    assert per_rank_samples == [14] * 7
    assert per_rank_batches == [4] * 7
    assert total - sum(per_rank_samples) == 2

def test_dlio_sampler_len_matches_iterator_length():
    total = 100
    size = 7

    for rank in range(size):
        sampler = dlio_sampler(rank, size, total, epochs=1)
        assert len(sampler) == len(list(iter(sampler)))
PY

python -m pytest tests/test_pr25_sampler_equalization.py -q
.F                                                                                                                                                             [100%]
============================================================================== FAILURES ==============================================================================
___________________________________________________________ test_dlio_sampler_len_matches_iterator_length ____________________________________________________________

    def test_dlio_sampler_len_matches_iterator_length():
        total = 100
        size = 7
    
        for rank in range(size):
            sampler = dlio_sampler(rank, size, total, epochs=1)
>           assert len(sampler) == len(list(iter(sampler)))
E           assert 100 == 14
E            +  where 100 = len(<dlio_benchmark.data_loader.torch_data_loader.dlio_sampler object at 0x7c476a541280>)
E            +  and   14 = len([0, 1, 2, 3, 4, 5, ...])
E            +    where [0, 1, 2, 3, 4, 5, ...] = list(<generator object dlio_sampler.__iter__ at 0x7c476a0ed6c0>)
E            +      where <generator object dlio_sampler.__iter__ at 0x7c476a0ed6c0> = iter(<dlio_benchmark.data_loader.torch_data_loader.dlio_sampler object at 0x7c4

tests/test_pr25_sampler_equalization.py:24: AssertionError
------------------------------------------------------------------------- Captured log call --------------------------------------------------------------------------
WARNING  DLIO:torch_data_loader.py:431 2026-06-17T23:26:12.651188 dlio_sampler: dropping 2 sample(s) — num_samples (100) is not a multiple of comm_size (7). Each rank
====================================================================== short test summary info =======================================================================
FAILED tests/test_pr25_sampler_equalization.py::test_dlio_sampler_len_matches_iterator_length - assert 100 == 14
1 failed, 1 passed in 1.43s

Reason: Because dlio_sampler.__len__() still returns self.num_samples, while __iter__() yields only the per-rank floored shard. The changed sampler code now builds self.indices from num_samples // size, but __len__ still returns the global sample count.

Fix:
Change:

return self.num_samples

to

return len(self.indices)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thank you — applied in c3a3508.

You're right that the right fix is return len(self.indices). Pre-existing latent bug: the old ceil+clamp also disagreed, just less sharply. Floor division makes it provable, which is what your test trips on.

I also confirmed it didn't matter for DLIO's main training loop (it bounds iteration with its own max_steps = floor(num_samples * num_files_train / batch_size / comm_size) at main.py:361, not len(loader)), but len(DataLoader) is derived from len(sampler), so any third-party consumer of that was over-reporting by a factor of comm_size.

Included your two test cases plus one for even division (N=100, size=10 → unchanged) as tests/test_dlio_sampler.py. All three pass locally.

FileSystemGuy and others added 2 commits June 17, 2026 17:04
Reviewer (@idevasena) caught that dlio_sampler.__len__ returns self.num_samples
(global) while __iter__ yields only the per-rank shard (len(self.indices)).
Pre-existing latent bug — the previous ceil+clamp also disagreed, just less
obviously. With floor division it's sharp: a Sampler test of the form
`len(sampler) == len(list(iter(sampler)))` trips it every time.

DLIO's training loop bounds iteration with max_steps computed independently,
so the loop didn't depend on len(sampler). But len(DataLoader) is derived
from len(sampler), so any consumer of that (progress bars, third-party tools)
sees an over-report by a factor of comm_size.

Switch __len__ to return len(self.indices) and add three regression tests:
- uneven counts equalized (N=100, size=7 → [14]*7, drops 2)
- __len__ matches iterated length (the reviewer's test)
- even division unchanged (N=100, size=10 → [10]*10)
@idevasena

idevasena commented Jun 18, 2026

Copy link
Copy Markdown

@FileSystemGuy There was one more issue i.e. related to ConfigArguments.reconfigure() still expects the full sample sum
The current changes intentionally drop up to comm_size - 1 trailing samples, but derive_configurations() still computes train_sample_index_sum / eval_sample_index_sum from the full total_samples_*, and reconfigure() still raises if the reduced local sums do not match those full totals. For N=100, comm_size=7, the new code processes samples 0..97, whose sum is 4753, but the existing expected sum is 0..99 = 4950. That means the exact uneven case this PR targets can fail with “missing samples” before training.

Thus, I updated the expected sums to use effective_samples = (total_samples // comm_size) * comm_size, or we can update the validation to explicitly allow and account for the intentional tail drop.

@idevasena idevasena merged commit cbe2001 into main Jun 18, 2026
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants