Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions docs/guides/data_input_pipeline/olmo_grain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# OLMo numpy pipeline (`dataset_type=olmo_grain`)

Grain-based input pipeline for AI2's pre-tokenized OLMo data mixes (e.g.
`OLMo-mix-0925-official.txt`). Reads headerless flat `.npy` token streams
from a gcsfuse mount, shards across hosts, optionally masks repeated-n-gram
instances, and yields the shapes the MaxText pretrain trainer expects.

## Quick start

1. **Download the data** to a GCS bucket:

```bash
python tools/data_generation/download_olmo_data_to_gcs.py \
--mix-file /path/to/OLMo-mix-0925-official.txt \
--gcs-dest gs://my-bucket/dataset/ \
--staging-dir /mnt/local-ssd/olmo-staging \
--workers 16
```

2. **Mount it read-only** with gcsfuse (`np.memmap` needs a local path):

```bash
gcsfuse --implicit-dirs --o ro my-bucket /mnt/olmo-readonly
```

3. **Build the index**:

```bash
python tools/data_generation/build_olmo_npy_index.py \
--mix-file /path/to/OLMo-mix-0925-official.txt \
--gcs-base gs://my-bucket/dataset/ \
--tokenizer allenai/dolma3-tokenizer \
--sequence-length 8192 \
--output /path/to/olmo_index_seq8192.json
```

4. **Configure + run** the trainer:

```yaml
dataset_type: olmo_grain
olmo_index_path: /path/to/olmo_index_seq8192.json
olmo_path_remap_from: "gs://my-bucket/"
olmo_path_remap_to: "/mnt/olmo-readonly/"
max_target_length: 8192 # must equal index sequence_length
tokenizer_type: huggingface
tokenizer_path: allenai/Olmo-3-7B-Instruct
```

See `scripts/run_olmo3_7b_grain_smoke.sh` for a runnable smoke launcher.

## Resume

Stateless sampler: record at step *k* is a pure function of `(seed, shard, k)`. On startup, the trainer adapter reads the latest step from
`config.checkpoint_dir` and shifts the sampler so the data stream picks
up where it left off — no Grain-iterator-state in the checkpoint.

`scripts/run_olmo3_7b_grain_resume_test.sh` validates this end-to-end.

## Notes

- Files are headerless raw uint32 by default (matches AI2's published
format). The numpy `.npy` extension is misleading.
- Documents may span instance boundaries; this matches OLMo-core.
- `olmo_apply_ngram_filter: True` (default) zeroes loss on instances with
≥ 32 repetitions of any 1–13-gram, per OLMo-core.
- For mixing pretraining + midtraining, build a combined index by
concatenating the two .txt mix files.

## Troubleshooting

| Symptom | Fix |
| ------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------- |
| `OLMo index sequence_length=N but config.max_target_length=M` | Rebuild the index with `--sequence-length M`. |
| `q_block_size=512 should divide q_seq_len=…` | Set `max_target_length` to a multiple of 512. |
| OOM during compile on a small TPU | Shrink with `override_model_config=True base_num_decoder_layers=N`, use `weight_dtype=bfloat16`. |
| Resume restarts at step 0 | Iterator log should print `resumed_step=N initial_step=…`; if both 0, `checkpoint_dir` is empty or wrong. |
135 changes: 135 additions & 0 deletions scripts/run_olmo3_7b_grain_resume_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/bin/bash
# End-to-end resume test for the OLMo grain pipeline (stateless sampler +
# step-derived initial_step). See scripts/run_olmo3_7b_grain_smoke.sh for
# the env-var contract; this script accepts the same vars.
#
# Plan:
# Run A: train 50 steps from scratch, save checkpoint at step 50, exit.
# Run B: relaunch with the SAME run_name (so the checkpoint dir is reused).
# The trainer restores model state at step 50; our iterator factory
# detects the latest checkpoint step and sets ``initial_step`` so
# the data stream picks up at absolute position 50 * per_host_batch.
# Train 25 more steps (to step 75).
#
# What success looks like:
# * Run B's first step (step 51) reports a loss similar to Run A's step 50
# loss. A spike or jump → model state didn't restore.
# * No repeats: Run B's batches are NOT the same as Run A's batches at the
# same absolute step. (Hard to assert without batch-content hashing in
# the trainer; for the smoke we rely on the unit tests + loss continuity.)
# * No regression: Run B's loss continues to decrease.
#
# Outputs:
# ${LOG_A} — first 50 steps
# ${LOG_B} — resumed 25 steps
# $OUTPUT_DIR/<run_name>/checkpoints/ — Orbax checkpoint(s)

set -euo pipefail

MAXTEXT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
VENV_PATH="${VENV_PATH:-${MAXTEXT_ROOT}/maxtext_venv}"
HF_SECRETS="${HF_SECRETS:-}"
INDEX_PATH="${INDEX_PATH:?INDEX_PATH is required (path to olmo index JSON)}"
GCS_BASE="${GCS_BASE:?GCS_BASE is required (e.g. gs://my-bucket/)}"
LOCAL_MOUNT="${LOCAL_MOUNT:?LOCAL_MOUNT is required (gcsfuse mount path of GCS_BASE)}"
OUTPUT_DIR="${OUTPUT_DIR:-/tmp/olmo_resume_test_out}"
RUN_NAME="${RUN_NAME:-olmo_resume_$(date +%Y%m%d-%H%M%S)}"

# Where each run's stdout is teed. Keep them under OUTPUT_DIR so the
# script doesn't depend on a hard-coded absolute path.
LOG_A="${LOG_A:-${OUTPUT_DIR}/${RUN_NAME}.runA.log}"
LOG_B="${LOG_B:-${OUTPUT_DIR}/${RUN_NAME}.runB.log}"

PER_DEVICE_BATCH="${PER_DEVICE_BATCH:-1}"
SEQ_LEN="${SEQ_LEN:-8192}"
WEIGHT_DTYPE="${WEIGHT_DTYPE:-bfloat16}"
NUM_LAYERS="${NUM_LAYERS:-4}"
DATA_SEED="${DATA_SEED:-42}"

# Run A trains 50 steps + saves a checkpoint at step 50; Run B continues to 75.
STEPS_A="${STEPS_A:-50}"
STEPS_B="${STEPS_B:-75}"
CHECKPOINT_PERIOD="${CHECKPOINT_PERIOD:-50}"

# shellcheck disable=SC1090,SC1091
source "${VENV_PATH}/bin/activate"
if [[ -n "${HF_SECRETS:-}" && -f "${HF_SECRETS}" ]]; then
# shellcheck disable=SC1090
source "${HF_SECRETS}"
fi
: "${HF_TOKEN:?HF_TOKEN must be set (or HF_SECRETS pointing at a file that exports it)}"
export PYTHONPATH="${MAXTEXT_ROOT}/src:${PYTHONPATH:-}"
export PYTHONUNBUFFERED=1

mkdir -p "${OUTPUT_DIR}"

TOKENIZER_PATH="${TOKENIZER_PATH:-allenai/Olmo-3-7B-Instruct}"

run_train() {
local steps="$1"
local logfile="$2"
echo "----- launching: steps=${steps} log=${logfile} -----"
python -m maxtext.trainers.pre_train.train \
"${MAXTEXT_ROOT}/src/maxtext/configs/base.yml" \
model_name=olmo3-7b-pt \
run_name="${RUN_NAME}" \
base_output_directory="${OUTPUT_DIR}" \
dataset_type=olmo_grain \
olmo_index_path="${INDEX_PATH}" \
olmo_path_remap_from="${GCS_BASE}" \
olmo_path_remap_to="${LOCAL_MOUNT}" \
data_shuffle_seed="${DATA_SEED}" \
olmo_apply_ngram_filter=True \
grain_worker_count=0 \
per_device_batch_size="${PER_DEVICE_BATCH}" \
max_target_length="${SEQ_LEN}" \
steps="${steps}" \
enable_checkpointing=True \
async_checkpointing=False \
checkpoint_period="${CHECKPOINT_PERIOD}" \
save_checkpoint_on_completion=True \
tokenizer_type=huggingface \
tokenizer_path="${TOKENIZER_PATH}" \
weight_dtype="${WEIGHT_DTYPE}" \
override_model_config=True \
base_num_decoder_layers="${NUM_LAYERS}" \
sharding_tolerance=0.05 \
2>&1 | tee "${logfile}"
}

echo "=== OLMo 3 grain resume test ==="
echo " run_name : ${RUN_NAME}"
echo " output_dir : ${OUTPUT_DIR}/${RUN_NAME}"
echo " per_device_bs : ${PER_DEVICE_BATCH}"
echo " seq_len : ${SEQ_LEN}"
echo " num_layers : ${NUM_LAYERS}"
echo " Run A steps : ${STEPS_A} (will checkpoint at step ${CHECKPOINT_PERIOD})"
echo " Run B steps : ${STEPS_B} (resumed via initial_step)"
echo

# Run A
run_train "${STEPS_A}" "${LOG_A}"

echo
echo "=== Run A done. Last 3 step events: ==="
grep -E "completed step:" "${LOG_A}" | tail -3
echo

# Run B (resume)
run_train "${STEPS_B}" "${LOG_B}"

echo
echo "=== Run B done ==="
echo "First 3 step events from Run B (expect step >= ${STEPS_A}):"
grep -E "completed step:" "${LOG_B}" | head -3
echo
echo "Last 3 step events from Run B:"
grep -E "completed step:" "${LOG_B}" | tail -3
echo

echo "=== Pass criteria (manual check): ==="
echo " 1. Run B's first step number >= ${STEPS_A} (model state restored)"
echo " 2. Run B's first step loss within ~5% of Run A's last step loss"
echo " (model continued, no re-init)"
echo " 3. Loss continues to decrease across Run B"
echo " 4. iterator log line shows 'resumed_step=${STEPS_A} initial_step=...' on Run B"
96 changes: 96 additions & 0 deletions scripts/run_olmo3_7b_grain_smoke.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#!/bin/bash
# Smoke training run for OLMo 3 7B on the OLMo numpy grain pipeline.
#
# Validates that dataset_type=olmo_grain wires through the trainer, that
# OlmoNpyDataSource reads .npy data via a gcsfuse mount, and that 50 steps
# execute without crashes / shape mismatches with monotonically decreasing
# loss.
#
# Required env vars:
# INDEX_PATH JSON index from tools/data_generation/build_olmo_npy_index.py
# GCS_BASE gs:// prefix recorded in the index (e.g. gs://my-bucket/)
# LOCAL_MOUNT gcsfuse mount of GCS_BASE on this host
# HF_TOKEN HuggingFace token for the tokenizer (or HF_SECRETS=<file>)
# Optional: VENV_PATH, OUTPUT_DIR, PER_DEVICE_BATCH, SEQ_LEN, STEPS,
# WEIGHT_DTYPE, NUM_LAYERS.
#
# Usage:
# INDEX_PATH=/path/to/olmo_index_seq8192.json \
# LOCAL_MOUNT=/mnt/your-mount/ \
# GCS_BASE=gs://your-bucket/ \
# HF_TOKEN=hf_... \
# bash scripts/run_olmo3_7b_grain_smoke.sh

set -euo pipefail

MAXTEXT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"

VENV_PATH="${VENV_PATH:-${MAXTEXT_ROOT}/maxtext_venv}"
HF_SECRETS="${HF_SECRETS:-}"
INDEX_PATH="${INDEX_PATH:?INDEX_PATH is required (path to olmo index JSON)}"
GCS_BASE="${GCS_BASE:?GCS_BASE is required (e.g. gs://my-bucket/)}"
LOCAL_MOUNT="${LOCAL_MOUNT:?LOCAL_MOUNT is required (gcsfuse mount path of GCS_BASE)}"
OUTPUT_DIR="${OUTPUT_DIR:-/tmp/olmo_smoke_out}"

PER_DEVICE_BATCH="${PER_DEVICE_BATCH:-1}"
SEQ_LEN="${SEQ_LEN:-8192}"
STEPS="${STEPS:-50}"
DATA_SEED="${DATA_SEED:-42}"
# Smoke test uses a reduced model (bf16, 4 layers) so it fits small TPU
# slices; we're validating the data path, not full-size convergence.
WEIGHT_DTYPE="${WEIGHT_DTYPE:-bfloat16}"
NUM_LAYERS="${NUM_LAYERS:-4}"

RUN_NAME="${RUN_NAME:-olmo_grain_smoke_$(date +%Y%m%d-%H%M%S)}"

# Activate venv + load HF secrets.
# shellcheck disable=SC1090,SC1091
source "${VENV_PATH}/bin/activate"
if [[ -n "${HF_SECRETS:-}" && -f "${HF_SECRETS}" ]]; then
# shellcheck disable=SC1090
source "${HF_SECRETS}"
fi
: "${HF_TOKEN:?HF_TOKEN must be set (or HF_SECRETS pointing at a file that exports it)}"
export PYTHONPATH="${MAXTEXT_ROOT}/src:${PYTHONPATH:-}"
export PYTHONUNBUFFERED=1

mkdir -p "${OUTPUT_DIR}"

echo "=== OLMo 3 7B + olmo_grain smoke run ==="
echo " run_name : ${RUN_NAME}"
echo " index : ${INDEX_PATH}"
echo " path remap : ${GCS_BASE} → ${LOCAL_MOUNT}"
echo " per_device_bs : ${PER_DEVICE_BATCH}"
echo " seq_len : ${SEQ_LEN}"
echo " steps : ${STEPS}"
echo " weight_dtype : ${WEIGHT_DTYPE}"
echo " num_layers : ${NUM_LAYERS} (full 7B has 32)"
echo " output_dir : ${OUTPUT_DIR}"
echo

# Data is already tokenized; the tokenizer is loaded only for pad/eos IDs +
# vocab_size checks. Olmo-3-7B-Instruct uses the same dolma3 tokenizer.
TOKENIZER_PATH="${TOKENIZER_PATH:-allenai/Olmo-3-7B-Instruct}"

python -m maxtext.trainers.pre_train.train \
"${MAXTEXT_ROOT}/src/maxtext/configs/base.yml" \
model_name=olmo3-7b-pt \
run_name="${RUN_NAME}" \
base_output_directory="${OUTPUT_DIR}" \
dataset_type=olmo_grain \
olmo_index_path="${INDEX_PATH}" \
olmo_path_remap_from="${GCS_BASE}" \
olmo_path_remap_to="${LOCAL_MOUNT}" \
data_shuffle_seed="${DATA_SEED}" \
olmo_apply_ngram_filter=True \
grain_worker_count=0 \
per_device_batch_size="${PER_DEVICE_BATCH}" \
max_target_length="${SEQ_LEN}" \
steps="${STEPS}" \
enable_checkpointing=False \
tokenizer_type=huggingface \
tokenizer_path="${TOKENIZER_PATH}" \
weight_dtype="${WEIGHT_DTYPE}" \
override_model_config=True \
base_num_decoder_layers="${NUM_LAYERS}" \
sharding_tolerance=0.05
8 changes: 8 additions & 0 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,14 @@ grain_shuffle_buffer_size: 100 # shuffle buffer when using sequential access for
# for using pathways
colocated_python_data_input: False # experimental feature, under testing

# OLMo numpy pipeline (dataset_type=olmo_grain). Worker count, buffer size,
# and shuffle seed reuse grain_worker_count / grain_per_worker_buffer_size /
# data_shuffle_seed.
olmo_index_path: '' # JSON from tools/data_generation/build_olmo_npy_index.py
olmo_path_remap_from: '' # rewrite index paths starting with this prefix...
olmo_path_remap_to: '' # ...to this one (e.g. gs://bucket/ -> /mnt/.../ for gcsfuse).
olmo_apply_ngram_filter: True # mask instances with repetitive n-grams (OLMo-core filter)

# Training loop
steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps
log_period: 100 # The frequency of Tensorboard flush, gcs metrics writing, and managed profiler metrics updating.
Expand Down
22 changes: 22 additions & 0 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ class DatasetType(str, Enum):
GRAIN = "grain"
TFDS = "tfds"
C4MLPERF = "c4_mlperf"
OLMO_GRAIN = "olmo_grain"


class SamplingStrategy(str, Enum):
Expand Down Expand Up @@ -1128,6 +1129,26 @@ class GrainDataset(BaseModel):
grain_shuffle_buffer_size: int = Field(100, description="Shuffle buffer size when using Parquet or TFRecord.")


class OlmoGrainDataset(BaseModel):
"""Configuration for the OLMo numpy fixed-seq-length input pipeline (dataset_type=olmo_grain).

Worker count, per-worker buffer size, and shuffle seed reuse the standard
grain flags (``grain_worker_count``, ``grain_per_worker_buffer_size``,
``data_shuffle_seed``); only OLMo-specific fields are listed here.
"""

olmo_index_path: PathStr = Field("", description="Path or gs:// URI to the JSON index from build_olmo_npy_index.py.")
olmo_path_remap_from: PathStr = Field(
"",
description="If set, rewrite index file paths starting with this prefix to olmo_path_remap_to.",
)
olmo_path_remap_to: PathStr = Field(
"",
description="Replacement prefix used together with olmo_path_remap_from (e.g. /mnt/disks/.../).",
)
olmo_apply_ngram_filter: bool = Field(True, description="Mask repetitive instances per OLMo-core's repetition filter.")


class FineTuning(BaseModel):
"""Configuration for fine-tuning methods like DPO, SFT, and GRPO."""

Expand Down Expand Up @@ -2092,6 +2113,7 @@ class MaxTextConfig(
TfdsDataset,
HfDataset,
GrainDataset,
OlmoGrainDataset,
Tokenizer,
# Inference
InferenceGeneral,
Expand Down
5 changes: 4 additions & 1 deletion src/maxtext/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from maxtext.input_pipeline.grain_data_processing import make_grain_eval_iterator
from maxtext.input_pipeline.hf_data_processing import make_hf_train_iterator
from maxtext.input_pipeline.hf_data_processing import make_hf_eval_iterator
from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_train_iterator
from maxtext.input_pipeline.olmo_grain_data_processing import make_olmo_grain_eval_iterator
from maxtext.input_pipeline.tfds_data_processing import make_tfds_train_iterator
from maxtext.input_pipeline.tfds_data_processing import make_tfds_eval_iterator
from maxtext.input_pipeline.tfds_data_processing_c4_mlperf import make_c4_mlperf_train_iterator
Expand Down Expand Up @@ -71,10 +73,11 @@ def create_data_iterator(config: pyconfig.HyperParameters, mesh):
"grain": (make_grain_train_iterator, make_grain_eval_iterator),
"hf": (make_hf_train_iterator, make_hf_eval_iterator),
"c4_mlperf": (make_c4_mlperf_train_iterator, make_c4_mlperf_eval_iterator),
"olmo_grain": (make_olmo_grain_train_iterator, make_olmo_grain_eval_iterator),
}

# Collect train and eval iterators
if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf"]:
if config.dataset_type in ["tfds", "grain", "hf", "c4_mlperf", "olmo_grain"]:
if config.dataset_type == "c4_mlperf":
assert config.packing, "c4_mlperf dataloader only works with packing. For padded version, use tfds dataloader"
train_iterator, eval_iterator = dataset_type_to_train_eval_iterator[config.dataset_type]
Expand Down
Loading
Loading