diff --git a/src/celeste/constraints.py b/src/celeste/constraints.py index 2d9f65b..d79e5ce 100644 --- a/src/celeste/constraints.py +++ b/src/celeste/constraints.py @@ -27,6 +27,8 @@ class Constraint(BaseModel, ABC): """Base constraint for parameter validation.""" + description: str | None = None + @computed_field # type: ignore[prop-decorator] @property def type(self) -> str: diff --git a/src/celeste/modalities/audio/parameters.py b/src/celeste/modalities/audio/parameters.py index 82a5718..51e14c6 100644 --- a/src/celeste/modalities/audio/parameters.py +++ b/src/celeste/modalities/audio/parameters.py @@ -1,6 +1,9 @@ """Parameters for audio modality.""" from enum import StrEnum +from typing import Annotated + +from pydantic import Field from celeste.parameters import Parameters @@ -18,11 +21,17 @@ class AudioParameter(StrEnum): class AudioParameters(Parameters, total=False): """Parameters for audio operations.""" - voice: str - speed: float - output_format: str - prompt: str - language: str + voice: Annotated[ + str, Field(description="Voice identifier for text-to-speech output.") + ] + speed: Annotated[ + float, Field(description="Playback speed multiplier (1.0 = normal).") + ] + output_format: Annotated[str, Field(description="Audio file format.")] + prompt: Annotated[ + str, Field(description="Style or delivery instruction for the voice.") + ] + language: Annotated[str, Field(description="BCP-47 language tag, e.g. 'en-US'.")] __all__ = [ diff --git a/src/celeste/modalities/embeddings/parameters.py b/src/celeste/modalities/embeddings/parameters.py index fcbad6f..f7fec92 100644 --- a/src/celeste/modalities/embeddings/parameters.py +++ b/src/celeste/modalities/embeddings/parameters.py @@ -1,6 +1,9 @@ """Parameters for embeddings modality.""" from enum import StrEnum +from typing import Annotated + +from pydantic import Field from celeste.parameters import Parameters @@ -17,7 +20,7 @@ class EmbeddingsParameter(StrEnum): class EmbeddingsParameters(Parameters, total=False): """Parameters for embeddings operations.""" - dimensions: int | None + dimensions: Annotated[int | None, Field(description="Embedding vector length.")] __all__ = [ diff --git a/src/celeste/modalities/images/parameters.py b/src/celeste/modalities/images/parameters.py index 655d06f..858677f 100644 --- a/src/celeste/modalities/images/parameters.py +++ b/src/celeste/modalities/images/parameters.py @@ -1,6 +1,9 @@ """Parameters for images modality.""" from enum import StrEnum +from typing import Annotated + +from pydantic import Field from celeste.artifacts import ImageArtifact from celeste.parameters import Parameters @@ -29,21 +32,34 @@ class ImageParameter(StrEnum): class ImageParameters(Parameters, total=False): """Parameters for images operations.""" - aspect_ratio: str - num_images: int - partial_images: int - quality: str - watermark: bool - reference_images: list[ImageArtifact] - prompt_upsampling: bool - negative_prompt: str - seed: int - safety_tolerance: int - output_format: str - steps: int - guidance: float - mask: ImageArtifact - thinking_level: str + aspect_ratio: Annotated[ + str, Field(description="Output image dimensions or aspect ratio.") + ] + num_images: Annotated[int, Field(description="How many images to return.")] + partial_images: Annotated[ + int, Field(description="Number of progressive partial outputs to stream.") + ] + quality: Annotated[str, Field(description="Output quality tier.")] + watermark: Annotated[bool, Field(description="Embed a watermark in the output.")] + reference_images: Annotated[ + list[ImageArtifact], + Field(description="Additional images for composition or style reference."), + ] + prompt_upsampling: Annotated[ + bool, Field(description="Let the model rewrite the prompt for better results.") + ] + negative_prompt: Annotated[ + str, Field(description="Concepts to avoid in the output.") + ] + seed: Annotated[int, Field(description="Seed for deterministic output.")] + safety_tolerance: Annotated[int, Field(description="Safety filter threshold.")] + output_format: Annotated[str, Field(description="Output file format.")] + steps: Annotated[int, Field(description="Number of denoising steps.")] + guidance: Annotated[float, Field(description="Prompt-adherence strength.")] + mask: Annotated[ + ImageArtifact, Field(description="Mask image for inpainting a region.") + ] + thinking_level: Annotated[str, Field(description="Model reasoning depth.")] __all__ = [ diff --git a/src/celeste/modalities/text/parameters.py b/src/celeste/modalities/text/parameters.py index 93352b0..7875404 100644 --- a/src/celeste/modalities/text/parameters.py +++ b/src/celeste/modalities/text/parameters.py @@ -5,8 +5,9 @@ """ from enum import StrEnum +from typing import Annotated -from pydantic import BaseModel +from pydantic import BaseModel, Field from celeste.parameters import Parameters from celeste.tools import ToolChoiceOption, ToolDefinition @@ -45,23 +46,43 @@ class TextParameters(Parameters, total=False): """Parameters for text operations.""" # Common parameters - temperature: float - max_tokens: int - seed: int + temperature: Annotated[ + float, Field(description="Sampling randomness; 0.0 is deterministic.") + ] + max_tokens: Annotated[int, Field(description="Maximum tokens to generate.")] + seed: Annotated[int, Field(description="Seed for deterministic output.")] # Text-specific parameters - thinking_budget: int | str - thinking_level: str - output_schema: type[BaseModel] - tools: list[ToolDefinition] - tool_choice: ToolChoiceOption - verbosity: str + thinking_budget: Annotated[ + int | str, + Field( + description="Reasoning budget — integer token count, or a preset tier for models that accept one." + ), + ] + thinking_level: Annotated[str, Field(description="Model reasoning depth.")] + output_schema: Annotated[ + type[BaseModel], + Field(description="Pydantic model constraining the output shape."), + ] + tools: Annotated[ + list[ToolDefinition], + Field(description="Tools the model may call during generation."), + ] + tool_choice: Annotated[ + ToolChoiceOption, + Field(description="Controls whether and which tool the model must call."), + ] + verbosity: Annotated[str, Field(description="Output verbosity level.")] # Deprecated: use tools=[WebSearch()], tools=[XSearch()], tools=[CodeExecution()] instead. # TODO(deprecation): Remove on 2026-06-07. - web_search: bool - x_search: bool - code_execution: bool + web_search: Annotated[ + bool, Field(description="Deprecated. Use tools=[WebSearch()].") + ] + x_search: Annotated[bool, Field(description="Deprecated. Use tools=[XSearch()].")] + code_execution: Annotated[ + bool, Field(description="Deprecated. Use tools=[CodeExecution()].") + ] __all__ = [ diff --git a/src/celeste/modalities/videos/parameters.py b/src/celeste/modalities/videos/parameters.py index e69c45f..f724245 100644 --- a/src/celeste/modalities/videos/parameters.py +++ b/src/celeste/modalities/videos/parameters.py @@ -1,6 +1,9 @@ """Parameters for videos modality.""" from enum import StrEnum +from typing import Annotated + +from pydantic import Field from celeste.artifacts import ImageArtifact from celeste.parameters import Parameters @@ -20,12 +23,21 @@ class VideoParameter(StrEnum): class VideoParameters(Parameters, total=False): """Parameters for video generation operations.""" - aspect_ratio: str - resolution: str - duration: int - reference_images: list[ImageArtifact] - first_frame: ImageArtifact - last_frame: ImageArtifact + aspect_ratio: Annotated[ + str, Field(description="Output video dimensions or aspect ratio.") + ] + resolution: Annotated[str, Field(description="Vertical resolution tier.")] + duration: Annotated[int, Field(description="Clip length in seconds.")] + reference_images: Annotated[ + list[ImageArtifact], + Field(description="Additional images conditioning the video."), + ] + first_frame: Annotated[ + ImageArtifact, Field(description="Image to use as the video's first frame.") + ] + last_frame: Annotated[ + ImageArtifact, Field(description="Image to use as the video's last frame.") + ] __all__ = [ diff --git a/tests/unit_tests/test_parameters.py b/tests/unit_tests/test_parameters.py index 23b7b7c..3c090ab 100644 --- a/tests/unit_tests/test_parameters.py +++ b/tests/unit_tests/test_parameters.py @@ -1,17 +1,88 @@ """High-value tests for celeste.parameters module.""" from enum import StrEnum -from typing import Any +from typing import Any, get_type_hints import pytest +from pydantic.fields import FieldInfo from celeste.constraints import Range, Str from celeste.core import Parameter, Provider from celeste.exceptions import ConstraintViolationError +from celeste.modalities.audio.parameters import AudioParameters +from celeste.modalities.images.parameters import ImageParameters +from celeste.modalities.text.parameters import TextParameters +from celeste.modalities.videos.parameters import VideoParameters from celeste.models import Model from celeste.types import TextContent +def _field_description(params_type: type, field_name: str) -> str | None: + hints = get_type_hints(params_type, include_extras=True) + hint = hints[field_name] + for meta in getattr(hint, "__metadata__", ()): + if isinstance(meta, FieldInfo): + return meta.description + return None + + +class TestParameterTypedDictAnnotations: + """Each modality's Parameters TypedDict must carry per-field Field descriptions.""" + + def test_image_parameters_describe_every_field(self) -> None: + hints = get_type_hints(ImageParameters, include_extras=True) + for name in hints: + assert _field_description(ImageParameters, name), ( + f"ImageParameters.{name} missing Field(description=...)" + ) + + def test_audio_parameters_describe_every_field(self) -> None: + hints = get_type_hints(AudioParameters, include_extras=True) + for name in hints: + assert _field_description(AudioParameters, name), ( + f"AudioParameters.{name} missing Field(description=...)" + ) + + def test_video_parameters_describe_every_field(self) -> None: + hints = get_type_hints(VideoParameters, include_extras=True) + for name in hints: + assert _field_description(VideoParameters, name), ( + f"VideoParameters.{name} missing Field(description=...)" + ) + + def test_text_parameters_describe_every_field(self) -> None: + hints = get_type_hints(TextParameters, include_extras=True) + for name in hints: + assert _field_description(TextParameters, name), ( + f"TextParameters.{name} missing Field(description=...)" + ) + + def test_base_type_is_preserved_under_annotated(self) -> None: + hints = get_type_hints(ImageParameters, include_extras=True) + assert hints["aspect_ratio"].__origin__ is str + assert hints["num_images"].__origin__ is int + assert hints["guidance"].__origin__ is float + + +class TestConstraintDescription: + """Constraint carries optional per-model description metadata.""" + + def test_description_defaults_to_none(self) -> None: + assert Range(min=0, max=1).description is None + + def test_description_is_settable(self) -> None: + constraint = Range( + min=0, max=1, description="Temperature bounds for this model." + ) + assert constraint.description == "Temperature bounds for this model." + + def test_str_constraint_inherits_description(self) -> None: + constraint = Str( + max_length=500, description="Prompt length capped by provider." + ) + assert constraint.description == "Prompt length capped by provider." + + class DefaultParseOutputMapper: """Mapper that uses default parse_output behavior (returns content unchanged)."""