feat: standardize logits in mot#1261
Conversation
Signed-off-by: AngeloDanducci <angelo.danducci.ii@ibm.com>
|
cc @ajbozarth as it relates to your MOT redesign work |
ajbozarth
left a comment
There was a problem hiding this comment.
a couple follow up nits to my above comment
| """Per-token logit scores from the backend, or ``None`` if not requested or unavailable. | ||
|
|
||
| Populated when ``ModelOption.LOGITS=True`` and the backend supports it. | ||
| For the HuggingFace backend this is a tuple of 1-D tensors of shape | ||
| ``(vocab_size,)``, one per generated token. | ||
| """ |
There was a problem hiding this comment.
I'm not sure putting a long docstring here is good practice, though if this moves into generation it would belong there
| # Additional fields that should be standardized across apis. | ||
| self.tool_calls = tool_calls | ||
| self._thinking: str | None = None | ||
| self.logits: Any | None = None |
There was a problem hiding this comment.
as noted in my previous comment (and in #909) we should consider moving this into generation instead of staying on the MOT directly.
ajbozarth
left a comment
There was a problem hiding this comment.
Some feedback from Claude:
Main ask is to land the field on GenerationMetadata rather than as a top-level attr on ModelOutputThunk — per #909 and the #793 precedent, that's the standardized home for backend-execution metadata. The implementation logic (squeeze, clone-per-batch-item, cached vs. non-cached branches) looks correct; the rest of my comments are smaller items.
| # Additional fields that should be standardized across apis. | ||
| self.tool_calls = tool_calls | ||
| self._thinking: str | None = None | ||
| self.logits: Any | None = None |
There was a problem hiding this comment.
Move this onto GenerationMetadata as generation.logits: tuple[torch.Tensor, ...] | None = None, and drop self.logits, the _copy_from, __copy__, and __deepcopy__ lines for it. GenerationMetadata is already deep-copied as a unit.
Also: the type should be the concrete tuple[torch.Tensor, ...] | None, not Any | None. Use TYPE_CHECKING to keep torch out of the runtime import path.
| # squeeze(0): hf_output.scores is (1, vocab_size) per token; normalise to (vocab_size,) | ||
| mot.logits = tuple(s.squeeze(0) for s in hf_output.scores) | ||
|
|
||
| # Clear KV cache and scores from HF output - they're now owned by the LRU cache |
There was a problem hiding this comment.
Stale comment now: when LOGITS=True, the views in mot.logits also pin these tensors, so they aren't solely owned by the LRU cache. Worth tweaking the wording.
| Only supported by the HuggingFace local backend. Ignored silently by | ||
| backends that cannot return logits (OpenAI, Ollama, LiteLLM, WatsonX). | ||
|
|
||
| **Streaming not supported**: when ``ModelOption.STREAM=True``, logit |
There was a problem hiding this comment.
| **Streaming not supported**: when ``ModelOption.STREAM=True``, logit | |
| **Streaming not supported**: when ``ModelOption.STREAM=True``, logit | |
| scores are not available and ``ModelOutputThunk.generation.logits`` will be ``None``. | |
| Backends that cannot return logits (OpenAI, Ollama, LiteLLM, WatsonX) log | |
| a warning when this option is set and leave ``generation.logits`` as ``None``. |
Pair this with adding the warning in each non-HF backend's options handler — happy to take that as a follow-up if you'd rather not touch four backends in this PR. A silent None is hard to debug.
| assert captured["generate_input"]["do_sample"] is False | ||
| assert "temperature" not in captured["generate_input"] | ||
|
|
||
|
|
There was a problem hiding this comment.
This only covers the elif (caching-disabled) branch. Add a sibling test with _use_caches=True and a fake past_key_values so the first branch — which sets mot.logits before hf_output.scores = None — also has coverage. That's the production hot path.
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__]) |
There was a problem hiding this comment.
Drop the if __name__ == "__main__" block — uncommon in this repo, pytest discovers the file via test_*.py naming.
Also missing: a test that generation.logits stays None when both LOGITS=True and STREAM=True are set, since that's a documented contract.
Pull Request
Issue
Fixes #123
Description
standardize logits in MOT
Testing
Attribution
Adding a new component, requirement, sampling strategy, or tool?
If your PR adds or modifies one of the types below, check the matching box. A checklist of type-specific review items will be posted as a comment.
NOTE: Please ensure you have an issue that has been acknowledged by a core contributor and routed you to open a pull request against this repository. Otherwise, please open an issue before continuing with this pull request.