Skip to content

vLLM based Eval framework#3531

Open
dipannita08 wants to merge 1 commit intomainfrom
eval-framework-01
Open

vLLM based Eval framework#3531
dipannita08 wants to merge 1 commit intomainfrom
eval-framework-01

Conversation

@dipannita08
Copy link
Copy Markdown
Collaborator

@dipannita08 dipannita08 commented Mar 31, 2026

Description

Implement a evaluation framework with vllm backend. Requirements, design, further details: go/eval-framework-vllm

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/508639301

Tests

  • Unit tests
  • E2E tests b/508639301

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 31, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@github-actions
Copy link
Copy Markdown

🤖 Hi @Rohan-Bierneni, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

The implementation of the vLLM-based evaluation framework is a strong addition to MaxText, providing native support for custom benchmarks, lm-evaluation-harness, and evalchemy. The code is well-structured, but it needs critical updates to correctly support multi-host TPU environments whereリード (lead) rank coordination is essential.

🔍 General Feedback

  • Rank Coordination: In multi-host TPU setups, client-side operations (warmup, generation, reporting) must be restricted to jax.process_index() == 0 to avoid redundant work and failures on non-lead ranks.
  • Configurability: Key parameters like request timeouts should be made configurable via the CLI/config files rather than being hardcoded.
  • Efficiency: Minor optimizations in NLTK data handling and FastAPI request processing would improve the overall robustness and performance of the evaluation tool.

Comment thread src/maxtext/eval/runner/eval_runner.py Outdated
Comment thread src/maxtext/eval/runner/lm_eval_runner.py Outdated
Comment thread src/maxtext/eval/runner/evalchemy_runner.py Outdated
Comment thread src/maxtext/eval/scoring/rouge_scorer.py Outdated
Comment thread src/maxtext/eval/runner/async_client.py Outdated
Comment thread src/maxtext/eval/runner/server_manager.py
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! I’ve done a high-level pass, though I haven’t done a deep dive into the code just yet. Have you had a chance to run any benchmarks? I'm curious if you're seeing decent scores. Also, is multi-host benchmarking for large models within the scope of this PR?

Comment thread src/maxtext/eval/README.md Outdated
Requires: `pip install evalchemy`

```bash
python -m maxtext.eval.runner.evalchemy_runner \
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you tested if large scale works? i.e. a workload on v5p-64.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ran some tests on v5p-64, example results in b/508639301.

Comment thread src/maxtext/eval/README.md Outdated
| `--hf_token` | HuggingFace token for gated models |
| `--num_samples` | Limit number of eval samples |
| `--hf_mode` | Force HF safetensors mode (disables MaxTextForCausalLM mode) |
| `--tensor_parallel_size` | vLLM tensor parallelism |
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this the only sharding supported?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added and tested EP and DP as well. There were some issues with testing with EP on vLLM with scanned RL checkpoints, but this PR unblocked the remaining of these.

Comment thread src/maxtext/eval/README.md Outdated
python -m maxtext.eval.runner.eval_runner ...
```

### Configuration (eval_runner)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think it will be useful if we put common configs together? Then only put new ones for eval_runner, lm_eval_runner, and evalchemy_runner?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a first effort to move some common configs out but there are still more that are common between the runners, will clean this up in a follow up PR.

logger = logging.getLogger(__name__)

# Maps MaxText benchmark names to lm-eval task names.
_TASK_MAP: dict[str, str] = {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we could directly run lm-eval tasks instead of adding extra mapping layer here? So for any future benchmarks in lm-eval-harness, we could directly use. Similar comments for other framework if applies.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, updated.

@dipannita08 dipannita08 force-pushed the eval-framework-01 branch 2 times, most recently from cb0295b to 348c539 Compare April 23, 2026 01:12
@dipannita08 dipannita08 force-pushed the eval-framework-01 branch from c2d7d04 to 03ce83a Compare May 1, 2026 17:39
@dipannita08 dipannita08 requested a review from RissyRan May 1, 2026 18:13
def sample_requests(
self,
num_samples: int | None,
tokenizer,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also annotate the tokenizer type here? Is the base class PretrainedTokenizerBase from transformers?

@entrpn
Copy link
Copy Markdown
Collaborator

entrpn commented May 1, 2026

@dipannita08 has this code been tested with an nnx checkpoint? @hengtaoguo recently added support for this and it is needed in order to support distilled checkpoints. Hengtao's PR: #3188

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the change!

Potentially we could have a run and see how it performance? This is Kurt's final presentation shows some period of time for common benchmarks using JetStream. It will be great to see if we could speed up using this tool.

@@ -0,0 +1,5 @@
# MLPerf OpenOrca evaluation config.

benchmark: "mlperf_openorca"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide some context? i.e. example of perf configs for inference?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This config is passed via --config to the eval runner. It sets benchmark name, generation length, and sample count, all server, model, and parallelism params come from CLI args.

Example:

python -m maxtext.eval.runner.run \
    --runner eval \
    --config src/maxtext/eval/configs/mlperf.yml \
    --checkpoint_path $CHECKPOINT \
    --model_name llama3.1-8b \
    --hf_path meta-llama/Llama-3.1-8B-Instruct \
    --hf_token $HF_TOKEN \
    --base_output_directory gs://<bucket>/ \
    --run_name $RUN \
    --gcs_results_path gs://<bucket>/results/mlperf.json \
    --max_model_len 8192 \
    --tensor_parallel_size 4 \
    --expert_parallel_size 1 \
    --data_parallel_size 1 \
    --hbm_memory_utilization 0.7 \
    --max_num_batched_tokens 4096 \
    --max_num_seqs 256 \
    --num_samples 5000 \
    --max_tokens 1024 \
    --temperature 0.0 \
    --concurrency 64 \
    --log_level INFO

Only --config, --hf_path, --base_output_directory, --run_name, and --max_model_len are required, others have defaults or can be inherited from the config file.

@@ -0,0 +1,8 @@
# Base evaluation configuration.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we also include benchmark name or eval dataset in this base yml?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Base config covers server and generation parameters that are shared across all eval runs (temperature, concurrency, tensor_parallel_size, etc.). Benchmark name and dataset-specific settings (e.g. num_samples, max_tokens) live in task-specific configs like mlperf.yml so we can re-use the base config for different benchmarks without modification.

For the harness-based runners (lm_eval, evalchemy), benchmark/task selection is handled by the --tasks CLI arg (--tasks gsm8k mmlu gpqa) rather than config files (examples in the bug I shared above).

upload_results(output["local_path"], gcs_results_path)


def add_server_args(parser: argparse.ArgumentParser) -> None:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will those configs be included in base yaml config?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, targeting these type of refactoring changes in a follow up PR to merge the implementation ASAP.

description="MaxText model evaluation runner.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--config", required=True, help="Path to eval config file.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will those configs be included in base yaml config?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there are several common configs that can be put in base. Targeting these type of refactoring changes in a follow up PR to merge the implementation ASAP.

@dipannita08
Copy link
Copy Markdown
Collaborator Author

@dipannita08 has this code been tested with an nnx checkpoint? @hengtaoguo recently added support for this and it is needed in order to support distilled checkpoints. Hengtao's PR: #3188

Yes, the checkpoint loading in model_creation_utils.py defaults to nnx and falls back to Linen if it detects the params.params double-nesting. MaxTextForCausalLM in the adapter is an nnx.Module. Some example runs with Qwen3-30B-A3B NNX checkpoints are in b/508639301.

@dipannita08
Copy link
Copy Markdown
Collaborator Author

Potentially we could have a run and see how it performance? This is Kurt's final presentation shows some period of time for common benchmarks using JetStream. It will be great to see if we could speed up using this tool.

Please see E2E results in the bug: b/508639301

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants