Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/celeste/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
class Constraint(BaseModel, ABC):
"""Base constraint for parameter validation."""

description: str | None = None

@computed_field # type: ignore[prop-decorator]
@property
def type(self) -> str:
Expand Down
19 changes: 14 additions & 5 deletions src/celeste/modalities/audio/parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Parameters for audio modality."""

from enum import StrEnum
from typing import Annotated

from pydantic import Field

from celeste.parameters import Parameters

Expand All @@ -18,11 +21,17 @@ class AudioParameter(StrEnum):
class AudioParameters(Parameters, total=False):
"""Parameters for audio operations."""

voice: str
speed: float
output_format: str
prompt: str
language: str
voice: Annotated[
str, Field(description="Voice identifier for text-to-speech output.")
]
speed: Annotated[
float, Field(description="Playback speed multiplier (1.0 = normal).")
]
output_format: Annotated[str, Field(description="Audio file format.")]
prompt: Annotated[
str, Field(description="Style or delivery instruction for the voice.")
]
language: Annotated[str, Field(description="BCP-47 language tag, e.g. 'en-US'.")]


__all__ = [
Expand Down
5 changes: 4 additions & 1 deletion src/celeste/modalities/embeddings/parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Parameters for embeddings modality."""

from enum import StrEnum
from typing import Annotated

from pydantic import Field

from celeste.parameters import Parameters

Expand All @@ -17,7 +20,7 @@ class EmbeddingsParameter(StrEnum):
class EmbeddingsParameters(Parameters, total=False):
"""Parameters for embeddings operations."""

dimensions: int | None
dimensions: Annotated[int | None, Field(description="Embedding vector length.")]


__all__ = [
Expand Down
46 changes: 31 additions & 15 deletions src/celeste/modalities/images/parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Parameters for images modality."""

from enum import StrEnum
from typing import Annotated

from pydantic import Field

from celeste.artifacts import ImageArtifact
from celeste.parameters import Parameters
Expand Down Expand Up @@ -29,21 +32,34 @@ class ImageParameter(StrEnum):
class ImageParameters(Parameters, total=False):
"""Parameters for images operations."""

aspect_ratio: str
num_images: int
partial_images: int
quality: str
watermark: bool
reference_images: list[ImageArtifact]
prompt_upsampling: bool
negative_prompt: str
seed: int
safety_tolerance: int
output_format: str
steps: int
guidance: float
mask: ImageArtifact
thinking_level: str
aspect_ratio: Annotated[
str, Field(description="Output image dimensions or aspect ratio.")
]
num_images: Annotated[int, Field(description="How many images to return.")]
partial_images: Annotated[
int, Field(description="Number of progressive partial outputs to stream.")
]
quality: Annotated[str, Field(description="Output quality tier.")]
watermark: Annotated[bool, Field(description="Embed a watermark in the output.")]
reference_images: Annotated[
list[ImageArtifact],
Field(description="Additional images for composition or style reference."),
]
prompt_upsampling: Annotated[
bool, Field(description="Let the model rewrite the prompt for better results.")
]
negative_prompt: Annotated[
str, Field(description="Concepts to avoid in the output.")
]
seed: Annotated[int, Field(description="Seed for deterministic output.")]
safety_tolerance: Annotated[int, Field(description="Safety filter threshold.")]
output_format: Annotated[str, Field(description="Output file format.")]
steps: Annotated[int, Field(description="Number of denoising steps.")]
guidance: Annotated[float, Field(description="Prompt-adherence strength.")]
mask: Annotated[
ImageArtifact, Field(description="Mask image for inpainting a region.")
]
thinking_level: Annotated[str, Field(description="Model reasoning depth.")]


__all__ = [
Expand Down
47 changes: 34 additions & 13 deletions src/celeste/modalities/text/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
"""

from enum import StrEnum
from typing import Annotated

from pydantic import BaseModel
from pydantic import BaseModel, Field

from celeste.parameters import Parameters
from celeste.tools import ToolChoiceOption, ToolDefinition
Expand Down Expand Up @@ -45,23 +46,43 @@ class TextParameters(Parameters, total=False):
"""Parameters for text operations."""

# Common parameters
temperature: float
max_tokens: int
seed: int
temperature: Annotated[
float, Field(description="Sampling randomness; 0.0 is deterministic.")
]
max_tokens: Annotated[int, Field(description="Maximum tokens to generate.")]
seed: Annotated[int, Field(description="Seed for deterministic output.")]

# Text-specific parameters
thinking_budget: int | str
thinking_level: str
output_schema: type[BaseModel]
tools: list[ToolDefinition]
tool_choice: ToolChoiceOption
verbosity: str
thinking_budget: Annotated[
int | str,
Field(
description="Reasoning budget — integer token count, or a preset tier for models that accept one."
),
]
thinking_level: Annotated[str, Field(description="Model reasoning depth.")]
output_schema: Annotated[
type[BaseModel],
Field(description="Pydantic model constraining the output shape."),
]
tools: Annotated[
list[ToolDefinition],
Field(description="Tools the model may call during generation."),
]
tool_choice: Annotated[
ToolChoiceOption,
Field(description="Controls whether and which tool the model must call."),
]
verbosity: Annotated[str, Field(description="Output verbosity level.")]

# Deprecated: use tools=[WebSearch()], tools=[XSearch()], tools=[CodeExecution()] instead.
# TODO(deprecation): Remove on 2026-06-07.
web_search: bool
x_search: bool
code_execution: bool
web_search: Annotated[
bool, Field(description="Deprecated. Use tools=[WebSearch()].")
]
x_search: Annotated[bool, Field(description="Deprecated. Use tools=[XSearch()].")]
code_execution: Annotated[
bool, Field(description="Deprecated. Use tools=[CodeExecution()].")
]


__all__ = [
Expand Down
24 changes: 18 additions & 6 deletions src/celeste/modalities/videos/parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Parameters for videos modality."""

from enum import StrEnum
from typing import Annotated

from pydantic import Field

from celeste.artifacts import ImageArtifact
from celeste.parameters import Parameters
Expand All @@ -20,12 +23,21 @@ class VideoParameter(StrEnum):
class VideoParameters(Parameters, total=False):
"""Parameters for video generation operations."""

aspect_ratio: str
resolution: str
duration: int
reference_images: list[ImageArtifact]
first_frame: ImageArtifact
last_frame: ImageArtifact
aspect_ratio: Annotated[
str, Field(description="Output video dimensions or aspect ratio.")
]
resolution: Annotated[str, Field(description="Vertical resolution tier.")]
duration: Annotated[int, Field(description="Clip length in seconds.")]
reference_images: Annotated[
list[ImageArtifact],
Field(description="Additional images conditioning the video."),
]
first_frame: Annotated[
ImageArtifact, Field(description="Image to use as the video's first frame.")
]
last_frame: Annotated[
ImageArtifact, Field(description="Image to use as the video's last frame.")
]


__all__ = [
Expand Down
73 changes: 72 additions & 1 deletion tests/unit_tests/test_parameters.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,88 @@
"""High-value tests for celeste.parameters module."""

from enum import StrEnum
from typing import Any
from typing import Any, get_type_hints

import pytest
from pydantic.fields import FieldInfo

from celeste.constraints import Range, Str
from celeste.core import Parameter, Provider
from celeste.exceptions import ConstraintViolationError
from celeste.modalities.audio.parameters import AudioParameters
from celeste.modalities.images.parameters import ImageParameters
from celeste.modalities.text.parameters import TextParameters
from celeste.modalities.videos.parameters import VideoParameters
from celeste.models import Model
from celeste.types import TextContent


def _field_description(params_type: type, field_name: str) -> str | None:
hints = get_type_hints(params_type, include_extras=True)
hint = hints[field_name]
for meta in getattr(hint, "__metadata__", ()):
if isinstance(meta, FieldInfo):
return meta.description
return None


class TestParameterTypedDictAnnotations:
"""Each modality's Parameters TypedDict must carry per-field Field descriptions."""

def test_image_parameters_describe_every_field(self) -> None:
hints = get_type_hints(ImageParameters, include_extras=True)
for name in hints:
assert _field_description(ImageParameters, name), (
f"ImageParameters.{name} missing Field(description=...)"
)

def test_audio_parameters_describe_every_field(self) -> None:
hints = get_type_hints(AudioParameters, include_extras=True)
for name in hints:
assert _field_description(AudioParameters, name), (
f"AudioParameters.{name} missing Field(description=...)"
)

def test_video_parameters_describe_every_field(self) -> None:
hints = get_type_hints(VideoParameters, include_extras=True)
for name in hints:
assert _field_description(VideoParameters, name), (
f"VideoParameters.{name} missing Field(description=...)"
)

def test_text_parameters_describe_every_field(self) -> None:
hints = get_type_hints(TextParameters, include_extras=True)
for name in hints:
assert _field_description(TextParameters, name), (
f"TextParameters.{name} missing Field(description=...)"
)

def test_base_type_is_preserved_under_annotated(self) -> None:
hints = get_type_hints(ImageParameters, include_extras=True)
assert hints["aspect_ratio"].__origin__ is str
assert hints["num_images"].__origin__ is int
assert hints["guidance"].__origin__ is float


class TestConstraintDescription:
"""Constraint carries optional per-model description metadata."""

def test_description_defaults_to_none(self) -> None:
assert Range(min=0, max=1).description is None

def test_description_is_settable(self) -> None:
constraint = Range(
min=0, max=1, description="Temperature bounds for this model."
)
assert constraint.description == "Temperature bounds for this model."

def test_str_constraint_inherits_description(self) -> None:
constraint = Str(
max_length=500, description="Prompt length capped by provider."
)
assert constraint.description == "Prompt length capped by provider."


class DefaultParseOutputMapper:
"""Mapper that uses default parse_output behavior (returns content unchanged)."""

Expand Down
Loading