Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 169 additions & 0 deletions src/maxtext/eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# MaxText vLLM Eval Framework

A vLLM-native evaluation framework for MaxText models supporting harness-based eval (lm-eval, evalchemy) and custom datasets.

## Quick Start

All runners share a single entry point:

```bash
python -m maxtext.eval.runner.run --runner <eval|lm_eval|evalchemy> [flags]
```

### Custom dataset (MLPerf OpenOrca, ROUGE scoring, Other)

```bash
python -m maxtext.eval.runner.run \
--runner eval \
--config src/maxtext/eval/configs/mlperf.yml \
--checkpoint_path gs://<bucket>/checkpoints/0/items \
--model_name llama3.1-8b \
--hf_path meta-llama/Llama-3.1-8B-Instruct \
--base_output_directory gs://<bucket>/ \
--run_name eval_run \
--max_model_len 8192 \
--hf_token $HF_TOKEN
```

HF safetensors mode (no MaxText checkpoint):

```bash
python -m maxtext.eval.runner.run \
--runner eval \
--config src/maxtext/eval/configs/mlperf.yml \
--hf_path TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--model_name tinyllama \
--base_output_directory gs://<bucket>/ \
--run_name eval_test \
--hf_mode \
--num_samples 20 \
--max_model_len 2048 \
--tensor_parallel_size 1
```

### LM Eval

Requires: `pip install "lm_eval[api]"`

```bash
python -m maxtext.eval.runner.run \
--runner lm_eval \
--checkpoint_path gs://<bucket>/checkpoints/0/items \
--model_name qwen3-30b-a3b \
--hf_path Qwen/Qwen3-30B-A3B \
--tasks gsm8k \
--base_output_directory gs://<bucket>/ \
--run_name my_run \
--max_model_len 8192 \
--tensor_parallel_size 8 \
--expert_parallel_size 8 \
--hf_token $HF_TOKEN
```

### Evalchemy

Requires: `pip install git+https://github.com/mlfoundations/evalchemy.git`

```bash
python -m maxtext.eval.runner.run \
--runner evalchemy \
--checkpoint_path gs://<bucket>/checkpoints/0/items \
--model_name llama3.1-8b \
--hf_path meta-llama/Llama-3.1-8B-Instruct \
--tasks ifeval math500 gpqa_diamond \
--base_output_directory gs://<bucket>/ \
--run_name eval_run \
--max_model_len 8192 \
--tensor_parallel_size 4 \
--hf_token $HF_TOKEN
```

## Common Flags

| Flag | Description |
|---|---|
| `--checkpoint_path` | MaxText Orbax checkpoint path. Enables `MaxTextForCausalLM` mode. |
| `--model_name` | MaxText model name (e.g. `llama3.1-8b`) |
| `--hf_path` | HF model ID or local path |
| `--max_model_len` | vLLM max context length. |
| `--tensor_parallel_size` | Chips per model replica |
| `--expert_parallel_size` | Chips for the expert mesh axis |
| `--data_parallel_size` | Number of model replicas |
| `--hbm_memory_utilization` | Fraction of HBM reserved for KV cache |
| `--hf_token` | HF token (or set `HF_TOKEN` env var) |
| `--hf_mode` | HF safetensors mode, no MaxText checkpoint loading |
| `--server_host` / `--server_port` | vLLM server address (default: localhost:8000) |
| `--max_num_batched_tokens` | vLLM tokens per scheduler step |
| `--max_num_seqs` | vLLM max concurrent sequences |
| `--gcs_results_path` | GCS path to upload results JSON |
| `--log_level` | Logging verbosity (default: INFO) |

Custom `eval` specific:

| Flag | Description |
|---|---|
| `--config` | Benchmark YAML config (required) |
| `--num_samples` | Limit eval samples |
| `--max_tokens` | Max tokens per generation |
| `--temperature` | Sampling temperature (default: 0.0) |
| `--concurrency` | HTTP request concurrency (default: 64) |

Harness `lm_eval` / `evalchemy` specific:

| Flag | Description |
|---|---|
| `--tasks` | Space-separated task names |
| `--num_fewshot` | Few-shot examples per task (default: 0) |
| `--num_samples` | Limit samples per task (default: full dataset) |

## Eval on RL Checkpoints



Example (Qwen3-30B-A3B, v6e-8):

```bash
STEP=244
MODEL=qwen3-30b-a3b
HF_PATH=Qwen/Qwen3-30B-A3B
CHECKPOINT=gs://<bucket>/run/checkpoints/actor/${STEP}/model_params
OUTPUT=gs://<bucket>/eval/

python -m maxtext.eval.runner.run \
--runner lm_eval \
--checkpoint_path ${CHECKPOINT} \
--model_name ${MODEL} \
--hf_path ${HF_PATH} \
--tasks gsm8k \
--base_output_directory ${OUTPUT} \
--run_name rl_${MODEL}_step${STEP} \
--max_model_len 4096 \
--tensor_parallel_size 8 \
--expert_parallel_size 8 \
--num_samples 20 \
--hf_token $HF_TOKEN
```


## Adding a Custom Benchmark

1. Implement `BenchmarkDataset` in `src/maxtext/eval/datasets/`:

```python
from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest

class MyDataset(BenchmarkDataset):
name = "my_benchmark"

def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]:
# load dataset, build prompts, return SampleRequest list
```

2. Register in `src/maxtext/eval/datasets/registry.py`:

```python
from maxtext.eval.datasets.my_dataset import MyDataset
DATASET_REGISTRY["my_benchmark"] = MyDataset
```

3. Add a scorer in `src/maxtext/eval/scoring/` and register it in `src/maxtext/eval/scoring/registry.py`.
8 changes: 8 additions & 0 deletions src/maxtext/eval/configs/base_eval.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Base evaluation configuration.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also include benchmark name or eval dataset in this base yml?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base config covers server and generation parameters that are shared across all eval runs (temperature, concurrency, tensor_parallel_size, etc.). Benchmark name and dataset-specific settings (e.g. num_samples, max_tokens) live in task-specific configs like mlperf.yml so we can re-use the base config for different benchmarks without modification.

For the harness-based runners (lm_eval, evalchemy), benchmark/task selection is handled by the --tasks CLI arg (--tasks gsm8k mmlu gpqa) rather than config files (examples in the bug I shared above).


temperature: 0.0
concurrency: 64
server_host: "localhost"
server_port: 8000
tensor_parallel_size: 4
num_samples: null
5 changes: 5 additions & 0 deletions src/maxtext/eval/configs/mlperf.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# MLPerf OpenOrca evaluation config.

benchmark: "mlperf_openorca"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide some context? i.e. example of perf configs for inference?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config is passed via --config to the eval runner. It sets benchmark name, generation length, and sample count, all server, model, and parallelism params come from CLI args.

Example:

python -m maxtext.eval.runner.run \
    --runner eval \
    --config src/maxtext/eval/configs/mlperf.yml \
    --checkpoint_path $CHECKPOINT \
    --model_name llama3.1-8b \
    --hf_path meta-llama/Llama-3.1-8B-Instruct \
    --hf_token $HF_TOKEN \
    --base_output_directory gs://<bucket>/ \
    --run_name $RUN \
    --gcs_results_path gs://<bucket>/results/mlperf.json \
    --max_model_len 8192 \
    --tensor_parallel_size 4 \
    --expert_parallel_size 1 \
    --data_parallel_size 1 \
    --hbm_memory_utilization 0.7 \
    --max_num_batched_tokens 4096 \
    --max_num_seqs 256 \
    --num_samples 5000 \
    --max_tokens 1024 \
    --temperature 0.0 \
    --concurrency 64 \
    --log_level INFO

Only --config, --hf_path, --base_output_directory, --run_name, and --max_model_len are required, others have defaults or can be inherited from the config file.

max_tokens: 1024
num_samples: 5000
56 changes: 56 additions & 0 deletions src/maxtext/eval/datasets/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Abstract base classes for benchmark datasets."""

from __future__ import annotations

import abc
from typing import NamedTuple


class SampleRequest(NamedTuple):
"""A single inference request with its ground-truth reference.

Attributes:
prompt: The full text prompt to send to the model (after chat templating).
reference: Ground-truth answer/label used by the scorer.
metadata: Optional dict of extra fields forwarded to the scorer
(e.g. {"subject": "college_math"} for per-subject MMLU stats).
"""

prompt: str
reference: str
metadata: dict | None = None


class BenchmarkDataset(abc.ABC):
"""Abstract base class for benchmark datasets."""
name: str

@abc.abstractmethod
def sample_requests(
self,
num_samples: int | None,
tokenizer,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also annotate the tokenizer type here? Is the base class PretrainedTokenizerBase from transformers?

) -> list[SampleRequest]:
"""Load the dataset and return a list of SampleRequests.

Args:
num_samples: If not None, truncate to this number of samples.
tokenizer: A HuggingFace tokenizer used for chat templating.

Returns:
List of SampleRequest objects ready for inference.
"""
63 changes: 63 additions & 0 deletions src/maxtext/eval/datasets/mlperf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MLPerf OpenOrca summarization dataset."""

from __future__ import annotations

from maxtext.eval.datasets.base import BenchmarkDataset, SampleRequest

_SYSTEM_PROMPT = (
"You are a helpful assistant. Summarize the following conversation."
)


class MlperfOpenOrcaDataset(BenchmarkDataset):
"""MLPerf OpenOrca — summarization benchmark used in MLPerf Inference.

Uses Open-Orca/OpenOrca HuggingFace dataset.
"""

name = "mlperf_openorca"

def sample_requests(self, num_samples, tokenizer) -> list[SampleRequest]:
# pylint: disable=import-outside-toplevel
import datasets as hf_datasets

ds = hf_datasets.load_dataset("Open-Orca/OpenOrca", split="train", streaming=True)

requests = []
for row in ds:
if not row.get("response", "").strip():
continue

system_prompt = row.get("system_prompt", _SYSTEM_PROMPT) or _SYSTEM_PROMPT
question = row["question"]
reference = row["response"]

if tokenizer is not None:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": question},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, is it usually preferred to use tokenizer's chat_template if available?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is preferred for any custom benchmarks. We defer all of the tokenizer chat templating to the harness if it already supports the benchmark in question.

else:
prompt = f"{system_prompt}\n\nUser: {question}\nAssistant:"

requests.append(SampleRequest(prompt=prompt, reference=reference))

if num_samples is not None and len(requests) >= num_samples:
break

return requests
60 changes: 60 additions & 0 deletions src/maxtext/eval/datasets/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Registry mapping benchmark names to BenchmarkDataset classes.

This can be used to define custom dataset loaders for benchmarks not covered by lm_eval and evalchemy.
"""

from __future__ import annotations

from maxtext.eval.datasets.base import BenchmarkDataset
from maxtext.eval.datasets.mlperf import MlperfOpenOrcaDataset

DATASET_REGISTRY: dict[str, type[BenchmarkDataset]] = {
"mlperf_openorca": MlperfOpenOrcaDataset,
"openorca": MlperfOpenOrcaDataset,
}


def get_dataset(benchmark_name: str) -> BenchmarkDataset:
"""Instantiate and return the mapping for benchmark_name.

Args:
benchmark_name: Benchmark identifier (e.g. "mlperf_openorca").

Returns:
An instance of the corresponding BenchmarkDataset subclass.

Raises:
KeyError: If no dataset is registered for the given name.
"""
key = benchmark_name.lower()
if key not in DATASET_REGISTRY:
raise KeyError(
f"No dataset registered for benchmark '{benchmark_name}'. "
f"Available: {sorted(DATASET_REGISTRY)}. "
f"For MMLU, GPQA, MATH etc. use lm_eval_runner or evalchemy_runner instead."
)
return DATASET_REGISTRY[key]()


def register_dataset(benchmark_name: str, dataset_cls: type[BenchmarkDataset]) -> None:
"""Register a custom dataset class for benchmark_name.

Args:
benchmark_name: Lowercase benchmark identifier.
dataset_cls: A BenchmarkDataset subclass.
"""
DATASET_REGISTRY[benchmark_name.lower()] = dataset_cls
Loading
Loading