diff --git a/.gitignore b/.gitignore index e7f8718b7..e4ac9d9bf 100644 --- a/.gitignore +++ b/.gitignore @@ -85,6 +85,14 @@ venv.bak/ # Local scratch space .scratch/ +# Generated benchmark/report output +/artifacts/ +/reports/ +/scripts/benchmarks/benchmark_async_scheduling.py +/scripts/benchmarks/export_async_scheduling_perfetto.py +/scripts/benchmarks/generate_async_scheduling_idle_report.py +/scripts/benchmarks/run_async_scheduling_idle_regression.py + docs/notebooks/ docs/notebook_source/*.ipynb docs/notebook_source/*.csv diff --git a/architecture/dataset-builders.md b/architecture/dataset-builders.md index 825a2a392..fc3981543 100644 --- a/architecture/dataset-builders.md +++ b/architecture/dataset-builders.md @@ -35,7 +35,7 @@ Preparation (`_prepare_async_run`): 4. Constructs `CompletionTracker`, `RowGroupBufferManager`, `AsyncTaskScheduler` 5. Hooks `ProcessorRunner` for pre-batch and post-batch stages -`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, semaphore-based capacity limits, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks are admitted through a virtual-time fair queue so one hot column or model-backed generator cannot consume the whole submission window before peer work gets a turn. +`AsyncTaskScheduler` runs on a dedicated async loop with frontier-driven dispatch, task-admission leases, salvage rounds for failed tasks, and order-dependent locks for columns that must execute sequentially. Ready frontier tasks enter `FairTaskQueue`, are selected through virtual-time ordering, and are committed only after `TaskAdmissionController` acquires the required scheduler resources. ### Execution Graph @@ -121,19 +121,24 @@ DatasetBuilder.build() → _prepare_async_run() → ExecutionGraph.create() → CompletionTracker.with_graph() - → AsyncTaskScheduler(semaphores, salvage_rounds) + → AsyncTaskScheduler(task admission, fair queue, salvage_rounds) → scheduler.run() - → for each row group, fairly admit ready tasks from frontier + → admit row groups under the configured row-group cap + → fairly admit ready tasks from the frontier through task admission → tasks execute generators, update CompletionTracker → checkpoints via RowGroupBufferManager → collect TaskTraces, emit telemetry ``` +Row-group admission is fixed by default in the dataset-builder path: the configured row-group concurrency is the hard in-flight cap. The scheduler also has an internal adaptive row-group mode for direct use that only raises a soft target up to that cap; it is additive ramp-up, not AIMD shrink/recovery behavior. + +When request admission is available, async scheduling may use request-pressure snapshots as a read-only advisory during fair-queue selection. A request-pressured task can be skipped for an eligible peer without mutating request-admission state; provider/model/domain request limits remain owned by request admission. + ## Design Decisions - **Dual execution engines behind one API.** The sequential engine is simpler and easier to debug; the async engine adds row-group parallelism for throughput. Users switch via an environment variable without changing their code. - **DAG-driven ordering** ensures columns with dependencies (e.g., a judge column that depends on a text column) are generated in the correct order, regardless of the order they appear in the config. -- **Fair async admission** keeps the scheduler flowing across ready columns and model groups. Global semaphores still bound memory/coroutine growth, while per-group virtual-time queues prevent a large ready frontier from degenerating into a column-by-column wave. LLM admission caps are peer-sensitive: a solo model group can fill available global capacity, but once another scheduling group has queued work the saturated group yields until peers get admission slots or admitted tasks complete. +- **Fair async admission** keeps the scheduler flowing across ready columns and model groups. `FairTaskQueue.select_next(...)` chooses eligible ready work, `TaskAdmissionController` leases scheduler resources before spawn, and `FairTaskQueue.commit(...)` removes the selected task only after admission succeeds. Per-group virtual-time ordering prevents a large ready frontier from degenerating into a column-by-column wave, and scheduler-resource accounting remains separate from provider/model request admission. - **Salvage rounds in async mode** retry failed tasks after all other tasks in a round complete, improving resilience against transient LLM failures without blocking the entire generation. - **Unified DAG construction.** `topologically_sort_column_configs` (in `execution_graph.py`) determines column ordering using Kahn's algorithm; the runtime `ExecutionGraph` adds strategy-aware dependency tracking for the async scheduler. diff --git a/architecture/models.md b/architecture/models.md index 870f5f84c..de449c61b 100644 --- a/architecture/models.md +++ b/architecture/models.md @@ -1,6 +1,6 @@ # Models -The model subsystem provides a unified interface for LLM access: chat completions, embeddings, and image generation. It handles client creation, retry, rate-limit throttling, usage tracking, and MCP tool integration. +The model subsystem provides a unified interface for LLM access: chat completions, embeddings, and image generation. It handles client creation, retry, request admission, usage tracking, and MCP tool integration. Source: `packages/data-designer-engine/src/data_designer/engine/models/` @@ -11,12 +11,12 @@ The model subsystem is layered: ``` ModelRegistry (lazy facade-per-alias) └── ModelFacade (completion, embeddings, image gen, MCP tool loops) - └── ThrottledModelClient (AIMD rate limiting) + └── ModelRequestExecutor (request admission + provider execution) └── ModelClient (OpenAI-compatible or Anthropic adapter) └── RetryTransport (httpx-level retries) ``` -Generators never interact with HTTP clients directly. They request a `ModelFacade` by alias from the `ModelRegistry`, which handles lazy construction and shared throttle state. +Generators never interact with HTTP clients directly. They request a `ModelFacade` by alias from the `ModelRegistry`, which handles lazy construction, request-resource canonicalization, and shared adaptive request admission state. ## Key Components @@ -31,13 +31,13 @@ Defines the contract: sync/async chat, embeddings, image generation, `supports_* `create_model_client` routes by provider type to the appropriate adapter. Optionally wraps with: - **`RetryTransport`** — httpx-level retries via `httpx_retries.RetryTransport`. `HttpModelClient` sets `strip_rate_limit_codes=True` for the async client and `False` for the sync client (`http_model_client.py`), which controls whether 429 responses are eligible for transport-layer retries. -- **`ThrottledModelClient`** — AIMD (Additive Increase, Multiplicative Decrease) concurrency control per throttle domain. +- **`ModelRequestExecutor`** — maps model-call attempts to request-admission items, acquires request leases, invokes the provider client, and releases the exact lease on every terminal path. -### ThrottleManager +### Request Admission -Manages concurrency limits per `ThrottleDomain` (CHAT, EMBEDDING, IMAGE, HEALTHCHECK), keyed by `(provider_name, model_id)`. Thread-safe with a shared lock for sync/async access. +`RequestAdmissionController` manages provider/model/domain request resources. `AdaptiveRequestAdmissionController` adds AIMD (Additive Increase, Multiplicative Decrease) adaptation per `RequestDomain` (`chat`, `embedding`, `image`, `healthcheck`) under the provider/model static cap. -`ThrottledModelClient` wraps each API call in a context manager that acquires/releases throttle capacity and adjusts limits on success (additive increase) or rate-limit errors (multiplicative decrease). +`ModelRequestExecutor` wraps each provider call with a request-admission lease and feeds success or rate-limit outcomes back to the controller. `RequestResourceResolver` owns canonical provider/model/domain identity so aliases that target the same endpoint share request capacity. When `rampup_seconds` is configured, `ThrottleManager` starts new domains at one concurrent request, climbs linearly toward the peak, and aborts to normal AIMD behavior on the first 429. @@ -52,7 +52,7 @@ The primary interface for generators. Holds a `ModelConfig`, `ModelClient`, opti ### ModelRegistry -Lazy `ModelFacade` construction per alias. Registers a shared `ThrottleManager` across all facades for coordinated rate limiting. Provides `get_model_usage_stats` and `log_model_usage` for post-build reporting. +Lazy `ModelFacade` construction per alias. Registers shared request-admission state across all facades for coordinated provider/model/domain capacity. Provides `get_model_usage_stats` and `log_model_usage` for post-build reporting. ### Usage Tracking @@ -61,18 +61,18 @@ Lazy `ModelFacade` construction per alias. Registers a shared `ThrottleManager` ## Data Flow 1. Generator requests a model by alias from `ModelRegistry` -2. Registry lazily creates `ModelFacade` with the appropriate client and throttle config +2. Registry lazily creates `ModelFacade` with the appropriate client and request-admission executor 3. Generator calls `completion()` with prompt/messages -4. `ModelFacade` builds kwargs, calls `ThrottledModelClient` -5. Throttle layer acquires capacity, delegates to `ModelClient` +4. `ModelFacade` builds kwargs, calls `ModelRequestExecutor` +5. Request admission acquires a provider/model/domain lease, delegates to `ModelClient` 6. `ModelClient` makes the HTTP request through `RetryTransport` 7. Response flows back; usage is tracked; if MCP tools are configured, tool calls are executed and results fed back for another completion round ## Design Decisions -- **Facade pattern** hides HTTP, retry, throttle, and MCP complexity from generators. Generators see `completion()` and get back parsed results. -- **AIMD throttling at the application layer** rather than relying solely on HTTP retries. This provides smoother throughput under rate limits — the transport layer still handles many transient failures, while the throttle manager adjusts concurrency to avoid sustained 429 storms. -- **429 handling depends on sync vs async `HttpModelClient`** — The async client uses `strip_rate_limit_codes=True`, so 429s are not retried at the transport layer and rate-limit signals reach `ThrottledModelClient` / AIMD quickly. The sync client uses `strip_rate_limit_codes=False`, so 429s may still be retried transparently at the transport layer before surfacing to callers. +- **Facade pattern** hides HTTP, retry, request admission, and MCP complexity from generators. Generators see `completion()` and get back parsed results. +- **AIMD request admission at the application layer** rather than relying solely on HTTP retries. This provides smoother throughput under rate limits: the transport layer still handles many transient failures, while adaptive request admission adjusts concurrency to avoid sustained 429 storms. +- **429 handling depends on sync vs async `HttpModelClient`** — The async client uses `strip_rate_limit_codes=True`, so 429s are not retried at the transport layer and rate-limit signals reach `ModelRequestExecutor` / request admission quickly. The sync client uses `strip_rate_limit_codes=False`, so 429s may still be retried transparently at the transport layer before surfacing to callers. - **Distribution-valued inference parameters** (`temperature`, `top_p` as `UniformDistribution` or `ManualDistribution`) enable controlled randomness across a dataset without per-row config changes. - **Lazy facade construction** avoids health-checking or connecting to models that are configured but never used in a particular generation run. diff --git a/architecture/overview.md b/architecture/overview.md index 30c91bdfb..10bde6c90 100644 --- a/architecture/overview.md +++ b/architecture/overview.md @@ -30,7 +30,7 @@ Users declare what their data should look like through config objects (columns, | `DataDesigner` | `data-designer` | Public API — `create()`, `preview()`, `validate()` | | `DataDesignerConfigBuilder` | `data-designer-config` | Fluent builder for dataset configs | | `DatasetBuilder` | `data-designer-engine` | Orchestrates generation (sync or async) | -| `ModelFacade` / `ModelRegistry` | `data-designer-engine` | LLM client abstraction with retry, throttle, usage tracking | +| `ModelFacade` / `ModelRegistry` | `data-designer-engine` | LLM client abstraction with retry, request admission, usage tracking | | `MCPFacade` / `MCPRegistry` | `data-designer-engine` | Tool execution via Model Context Protocol | | `ColumnGeneratorRegistry` | `data-designer-engine` | Maps column types to generator implementations | | `PluginRegistry` | `data-designer-config` | Discovers and registers entry-point plugins | @@ -44,7 +44,7 @@ Users declare what their data should look like through config objects (columns, 3. **Generation** — `DatasetBuilder` instantiates column generators from the registry, then executes one of two paths: - **Sequential** (default): batch loop over columns in topological order. Each generator produces its column via `CELL_BY_CELL` (threaded fan-out) or `FULL_COLUMN` strategy. - - **Async** (`DATA_DESIGNER_ASYNC_ENGINE=1`): builds an `ExecutionGraph`, partitions rows into groups, and dispatches tasks via `AsyncTaskScheduler` with semaphore-based concurrency, salvage rounds, and per-row-group checkpointing. + - **Async** (`DATA_DESIGNER_ASYNC_ENGINE=1`): builds an `ExecutionGraph`, partitions rows into groups, and dispatches tasks via `AsyncTaskScheduler` with `FairTaskQueue` selection, `TaskAdmissionController` scheduler-resource leases, salvage rounds, and per-row-group checkpointing. 4. **Post-processing** — `ProcessorRunner` applies transformations (pre-batch, post-batch, after-generation). Profilers analyze the generated dataset. @@ -61,7 +61,7 @@ Users declare what their data should look like through config objects (columns, - [Config Layer](config.md) — builder API, column types, model configs, plugin system - [Engine Layer](engine.md) — compilation, generators, registries -- [Models](models.md) — model facade, adapters, retry/throttle +- [Models](models.md) — model facade, adapters, retry, request admission - [Dataset Builders](dataset-builders.md) — sync/async orchestration, DAG, batching - [MCP](mcp.md) — tool execution, session pooling - [Sampling](sampling.md) — statistical generators, person/entity data diff --git a/docs/concepts/architecture-and-performance.md b/docs/concepts/architecture-and-performance.md index 31ab502aa..f03b549d8 100644 --- a/docs/concepts/architecture-and-performance.md +++ b/docs/concepts/architecture-and-performance.md @@ -48,7 +48,7 @@ This guide explains the architecture, execution model, and how to tune performan ## Execution Model !!! note "Two execution engines" - The default execution path is the **async engine**, which dispatches work at the cell level and overlaps independent columns — see [Async Engine](#async-engine) below for its semantics. The legacy **sync engine** is still available for one transitional release via `DATA_DESIGNER_ASYNC_ENGINE=0` and is what this section describes. The configuration knobs documented below (`buffer_size`, `max_parallel_requests`, AIMD throttle config, error handling) apply to both engines; the differences are flagged inline. + The default execution path is the **async engine**, which dispatches work at the cell level and overlaps independent columns — see [Async Engine](#async-engine) below for its semantics. The legacy **sync engine** is still available for one transitional release via `DATA_DESIGNER_ASYNC_ENGINE=0` and is what this section describes. The public configuration knobs documented below (`buffer_size`, `max_parallel_requests`, error handling) apply to both engines; the differences are flagged inline. The sync engine processes datasets in **batches**, with **parallel** operations within each batch. @@ -104,23 +104,23 @@ At any moment, the number of concurrent LLM requests is: ```python concurrent_requests = min( buffer_size, # Records in current batch - current_throttle_limit, # AIMD-managed limit (≤ max_parallel_requests) + current_request_limit, # AIMD-managed limit (≤ max_parallel_requests) remaining_cells_in_column # Cells left to generate ) ``` -`max_parallel_requests` sets the **ceiling**. The actual limit (`current_throttle_limit`) is managed at runtime by an AIMD (Additive Increase / Multiplicative Decrease) controller that reacts to rate-limit signals from the inference server: +`max_parallel_requests` sets the **ceiling**. The actual limit (`current_request_limit`) is managed at runtime by adaptive request admission that reacts to rate-limit signals from the inference server: -- **During optional startup ramp**: when `rampup_seconds` is greater than 0, a new throttle domain starts at one concurrent request and increases linearly toward `max_parallel_requests` over that duration. +- **During optional startup ramp**: when `startup_ramp_seconds` is greater than 0, a new request resource starts at one concurrent request and increases linearly toward `max_parallel_requests` over that duration. - **On the first 429 in a burst**: the limit is reduced by a configurable factor (default: 25% reduction) and a cooldown is applied. Further 429s from already in-flight requests in the same burst do not reduce the limit again — they release their permits and hold the limit steady. - **After consecutive successes**: the limit increases by 1 (by default) until it reaches the ceiling or a stabilized rate-limit threshold. This means Data Designer automatically finds the right concurrency level for your server without manual tuning. !!! note "Engine paths" - AIMD adaptive concurrency is fully active on the default **async engine**. The legacy **sync engine** is available for one transitional release via `DATA_DESIGNER_ASYNC_ENGINE=0`; on that path 429s are first retried at the HTTP transport layer and AIMD only engages as a fallback. See [Async engine](#async-engine) below. + Request admission wraps model requests on both sync and async paths. When request admission is active, provider 429 responses propagate to the AIMD controller instead of being hidden by HTTP transport retries. See [Async engine](#async-engine) below. -**Example**: With `buffer_size=100` and `max_parallel_requests=32`, Data Designer can send up to 32 requests in parallel. If `rampup_seconds=30`, it starts at one request and climbs linearly toward 32 over 30 seconds. If the server returns 429s, startup ramp stops, concurrency drops automatically (e.g., to 24, then 18), and normal AIMD recovery takes over once the server catches up. +**Example**: With `buffer_size=100` and `max_parallel_requests=32`, Data Designer can send up to 32 requests in parallel. If `startup_ramp_seconds=30`, it starts at one request and climbs linearly toward 32 over 30 seconds. If the server returns 429s, startup ramp stops, concurrency drops automatically (e.g., to 24, then 18), and normal AIMD recovery takes over once the server catches up. --- @@ -154,7 +154,7 @@ designer.set_run_config(run_config) ### `max_parallel_requests` (InferenceParams) -Sets the **maximum** concurrent LLM API calls **per model**. This is the ceiling that the AIMD throttle controller can ramp up to — the actual concurrency at runtime may be lower if the server signals rate limits. +Sets the **maximum** concurrent LLM API calls **per model**. This is the ceiling that adaptive request admission can ramp up to — the actual concurrency at runtime may be lower if the server signals rate limits. ```python import data_designer.config as dd @@ -171,14 +171,14 @@ model = dd.ModelConfig( **Default**: 4 -**When to increase**: Your inference backend has high throughput capacity, you're using a cloud API with generous rate limits, or you're running vLLM/TensorRT-LLM with multiple GPUs. With AIMD, setting an aggressively high value is safer than before — the system will self-correct downward if the server can't keep up. The salvage queue on the async engine (default) reclaims failed rows; on the sync engine the initial burst of 429s before AIMD stabilizes can drop rows, so start with a more conservative ceiling if you've opted into sync. +**When to increase**: Your inference backend has high throughput capacity, you're using a cloud API with generous rate limits, or you're running vLLM/TensorRT-LLM with multiple GPUs. With adaptive request admission, setting an aggressively high value is safer than before — the system will self-correct downward if the server can't keep up. The salvage queue on the async engine (default) reclaims failed rows; on the sync engine the initial burst of 429s before AIMD stabilizes can drop rows, so start with a more conservative ceiling if you've opted into sync. **When to decrease**: You want to cap resource usage to a known safe level, or you want more predictable/debuggable execution. !!! tip "Finding the optimal value" The right value depends on your inference stack and model. Self-hosted vLLM servers can often handle values as high as 256, 512, or even 1024 depending on your hardware. - With AIMD, a practical approach is to set `max_parallel_requests` to the **upper bound** you're comfortable with and let the throttle controller find the sustainable level automatically. If you see frequent 429 → recovery cycles in the logs, your ceiling is above the server's true capacity but the system is handling it. If you never see any throttle activity, you may have room to increase the ceiling further. + With adaptive request admission, a practical approach is to set `max_parallel_requests` to the **upper bound** you're comfortable with and let the request controller find the sustainable level automatically. If you see frequent 429 → recovery cycles in the logs, your ceiling is above the server's true capacity but the system is handling it. If you never see any request-admission activity, you may have room to increase the ceiling further. **Benchmark approach**: Run a small dataset (e.g., 100 records) with increasing `max_parallel_requests` values (4 → 8 → 16 → 32 → ...) and measure generation time. Stop increasing when the runtime stops decreasing—that's when your inference server is saturated. @@ -199,25 +199,24 @@ designer.set_run_config(run_config) --- -### Adaptive Throttling (RunConfig) +### Adaptive Request Admission -Data Designer uses an AIMD (Additive Increase / Multiplicative Decrease) controller to automatically adjust concurrency per model based on rate-limit feedback from the inference server. The defaults work well for most workloads. Override them via `ThrottleConfig` only when you understand the trade-offs. +Data Designer uses AIMD (Additive Increase / Multiplicative Decrease) request admission to automatically adjust concurrency per provider/model/domain based on rate-limit feedback from the inference server. For most workloads, set `max_parallel_requests` as the user-facing ceiling and inspect `AsyncCapacityPlan`/logs to understand the effective runtime limits. Advanced AIMD tuning is available through `RequestAdmissionTuningConfig`. !!! note "Engine paths" - Adaptive throttling is fully active on the default **async engine**, where 429 responses propagate directly to the AIMD controller. On the legacy **sync engine** (`DATA_DESIGNER_ASYNC_ENGINE=0`), 429s are first retried at the HTTP transport layer; `ThrottleConfig` settings only take effect as a fallback if transport retries are exhausted. + Request admission wraps model requests on both sync and async paths. When request admission is active, provider 429 responses propagate to the AIMD controller instead of being hidden by HTTP transport retries. ```python import data_designer.config as dd from data_designer.interface import DataDesigner run_config = dd.RunConfig( - throttle=dd.ThrottleConfig( - reduce_factor=0.75, # Multiply limit by this on a 429 (default: 0.75) - additive_increase=1, # Add this many slots after success_window successes (default: 1) - success_window=25, # Consecutive successes before increasing (default: 25) - cooldown_seconds=2.0, # Pause after a 429 when no Retry-After header (default: 2.0) - ceiling_overshoot=0.10, # Probe 10% above observed server limit (default: 0.10) - rampup_seconds=0.0, # Optional startup ramp duration; 0 disables it (default: 0.0) + request_admission=dd.RequestAdmissionTuningConfig( + multiplicative_decrease_factor=0.75, # Multiply limit by this on a 429 + additive_increase_step=1, # Slots added after each success window + successes_until_increase=25, # Successful releases before increasing + cooldown_seconds=2.0, # Fallback pause when no Retry-After header is present + startup_ramp_seconds=0.0, # Optional startup ramp duration; 0 disables it ), ) @@ -227,15 +226,16 @@ designer.set_run_config(run_config) | Parameter | Default | Effect | |-----------|---------|--------| -| `reduce_factor` | 0.75 | How aggressively to cut concurrency on a 429. Lower = more aggressive. | -| `additive_increase` | 1 | Slots added per recovery step. Higher = faster ramp-up, but riskier. | -| `success_window` | 25 | Consecutive successes required before each increase step. | -| `cooldown_seconds` | 2.0 | Pause duration after a 429 (used when the server doesn't send `Retry-After`). | -| `ceiling_overshoot` | 0.10 | Fraction above the observed rate-limit ceiling the controller is allowed to probe. | -| `rampup_seconds` | 0.0 | Optional startup ramp duration. When greater than 0, domains start at one concurrent request and linearly climb to the configured ceiling unless a 429 aborts the ramp. | +| `multiplicative_decrease_factor` | 0.75 | How aggressively to cut concurrency on a 429. Lower = more aggressive. | +| `additive_increase_step` | 1 | Slots added per recovery step. Higher = faster recovery, but riskier. | +| `successes_until_increase` | 25 | Successful releases required before each increase step. | +| `cooldown_seconds` | 2.0 | Pause duration after a 429 when the server does not send `Retry-After`. | +| `startup_ramp_seconds` | 0.0 | Optional startup ramp duration. When greater than 0, resources start at one concurrent request and linearly climb to the configured ceiling unless a 429 aborts the ramp. | + +`RunConfig.throttle` and `ThrottleConfig` remain as deprecated compatibility shims. Existing `reduce_factor`, `additive_increase`, `success_window`, `cooldown_seconds`, and `rampup_seconds` values are translated to `RequestAdmissionTuningConfig`; `ceiling_overshoot` is accepted for compatibility but is no longer forwarded because request admission does not expose that knob. !!! tip "How it works in practice" - When a model endpoint returns HTTP 429, the controller reduces the concurrency limit for that model and pauses briefly. After enough consecutive successes, it begins ramping back up. If the server rate-limits again, the controller records that level as a ceiling and stabilizes just below it, with a small overshoot band to detect when the server can handle more load. + When a model endpoint returns HTTP 429, the controller reduces the concurrency limit for that request resource and pauses briefly. After enough consecutive successes, it begins ramping back up. If the server rate-limits again, the controller records that level as a ceiling and stabilizes at a lower sustainable limit. You can observe this in the logs — look for messages like `concurrency reduced from X → Y` and `concurrency increased from X → Y`. @@ -266,11 +266,11 @@ designer.set_run_config(run_config) ## Async Engine -The async engine is the default execution path. It dispatches work at the cell level rather than the column level, so independent columns overlap in time and per-(provider, model) AIMD pools tune themselves independently. See the [Async All the Way Down](../devnotes/posts/async-all-the-way-down.md) dev note for the full architecture. +The async engine is the default execution path. It dispatches work at the cell level rather than the column level, so independent columns overlap in time and provider/model/domain request resources tune themselves independently. See the [Async All the Way Down](../devnotes/posts/async-all-the-way-down.md) dev note for the full architecture. ### Per-model timeouts drive every deadline -The `inference_parameters.timeout` field on a `ModelConfig` sets the per-request HTTP timeout. The same value also drives the sync→async bridge that custom columns use when they call `model.generate()`. There is no separate queue-wait deadline — waits scale with provider speed and AIMD's adaptive concurrency. Slow self-hosted endpoints (e.g. large models on a single GPU) only need this one knob raised: +The `inference_parameters.timeout` field on a `ModelConfig` sets the per-request HTTP timeout. The same value also drives the sync→async bridge that custom columns use when they call `model.generate()`. There is no separate queue-wait deadline — waits scale with provider speed and adaptive request admission. Slow self-hosted endpoints (e.g. large models on a single GPU) only need this one knob raised: ```python import data_designer.config as dd @@ -318,8 +318,8 @@ DATA_DESIGNER_ASYNC_ENGINE=0 python my_pipeline.py | Problem | Symptom | Solution | |---------|---------|----------| -| **Low throughput** | Low GPU utilization | Increase `max_parallel_requests` and/or `buffer_size`. If the throttle has self-reduced due to earlier 429s (check logs for "concurrency reduced" messages), the server may need more capacity or you can wait for AIMD recovery. | -| **Frequent 429 → recovery cycles** | Logs show repeated concurrency drops and ramp-ups | The `max_parallel_requests` ceiling is above the server's sustained capacity. This is handled automatically, but you can lower the ceiling to reduce the sawtooth or tune `reduce_factor` / `success_window`. | +| **Low throughput** | Low GPU utilization | Increase `max_parallel_requests` and/or `buffer_size`. If request admission has self-reduced due to earlier 429s (check logs for "concurrency reduced" messages), the server may need more capacity or you can wait for AIMD recovery. | +| **Frequent 429 → recovery cycles** | Logs show repeated concurrency drops and ramp-ups | The `max_parallel_requests` ceiling is above the server's sustained capacity. This is handled automatically, but you can lower the ceiling to reduce the sawtooth. | | **Long tail of slow generations** | Most records fast, few very slow | Reduce `max_conversation_restarts`, simplify schemas, improve prompts | | **Multi-model idle periods** | One model busy, others idle | Reduce `buffer_size` for faster cycling, or consolidate models | | **Memory errors** | OOM crashes | Reduce `buffer_size` and `max_parallel_requests` | @@ -329,10 +329,10 @@ DATA_DESIGNER_ASYNC_ENGINE=0 python my_pipeline.py ## Tuning Workflow -1. **Start with defaults** for initial development — AIMD handles rate-limit adaptation automatically +1. **Start with defaults** for initial development — adaptive request admission handles rate-limit adaptation automatically 2. **Profile your workload**: How many LLM columns? How many records? What models? -3. **Identify bottleneck**: Low GPU util → increase `max_parallel_requests` (AIMD will self-correct if you overshoot). Memory issues → decrease `buffer_size`. Long tails → tune retry settings. -4. **Check throttle logs**: Look for "concurrency reduced" / "concurrency increased" messages to understand whether rate limits are the bottleneck +3. **Identify bottleneck**: Low GPU util → increase `max_parallel_requests` (request admission will self-correct if you overshoot). Memory issues → decrease `buffer_size`. Long tails → tune retry settings. +4. **Check request-admission logs**: Look for "concurrency reduced" / "concurrency increased" messages to understand whether rate limits are the bottleneck 5. **Iterate**: Make one change at a time, measure impact before next change --- diff --git a/packages/data-designer-config/src/data_designer/config/__init__.py b/packages/data-designer-config/src/data_designer/config/__init__.py index eb385e15a..a8f683aa3 100644 --- a/packages/data-designer-config/src/data_designer/config/__init__.py +++ b/packages/data-designer-config/src/data_designer/config/__init__.py @@ -58,7 +58,12 @@ ProcessorType, SchemaTransformProcessorConfig, ) - from data_designer.config.run_config import JinjaRenderingEngine, RunConfig, ThrottleConfig # noqa: F401 + from data_designer.config.run_config import ( # noqa: F401 + JinjaRenderingEngine, + RequestAdmissionTuningConfig, + RunConfig, + ThrottleConfig, + ) from data_designer.config.sampler_constraints import ( # noqa: F401 ColumnInequalityConstraint, ConstraintType, @@ -82,6 +87,7 @@ UniformSamplerParams, UUIDSamplerParams, ) + from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError # noqa: F401 from data_designer.config.seed import ( # noqa: F401 IndexRange, PartitionBlock, @@ -176,8 +182,12 @@ "SchemaTransformProcessorConfig": (_MOD_PROCESSORS, "SchemaTransformProcessorConfig"), # run_config "JinjaRenderingEngine": (f"{_MOD_BASE}.run_config", "JinjaRenderingEngine"), + "RequestAdmissionTuningConfig": (f"{_MOD_BASE}.run_config", "RequestAdmissionTuningConfig"), "RunConfig": (f"{_MOD_BASE}.run_config", "RunConfig"), "ThrottleConfig": (f"{_MOD_BASE}.run_config", "ThrottleConfig"), + # scheduling metadata + "SchedulingMetadata": (f"{_MOD_BASE}.scheduling", "SchedulingMetadata"), + "SchedulingMetadataError": (f"{_MOD_BASE}.scheduling", "SchedulingMetadataError"), # sampler_constraints "ColumnInequalityConstraint": (_MOD_SAMPLER_CONSTRAINTS, "ColumnInequalityConstraint"), "ConstraintType": (_MOD_SAMPLER_CONSTRAINTS, "ConstraintType"), diff --git a/packages/data-designer-config/src/data_designer/config/run_config.py b/packages/data-designer-config/src/data_designer/config/run_config.py index dc29016b8..ea3393b26 100644 --- a/packages/data-designer-config/src/data_designer/config/run_config.py +++ b/packages/data-designer-config/src/data_designer/config/run_config.py @@ -3,7 +3,8 @@ from __future__ import annotations -from typing import ClassVar +import warnings +from typing import Any from pydantic import Field, model_validator from typing_extensions import Self @@ -19,76 +20,101 @@ class JinjaRenderingEngine(StrEnum): SECURE = "secure" -class ThrottleConfig(ConfigBase): - """AIMD throttle tuning parameters for adaptive concurrency control. +_THROTTLE_DEPRECATION_MESSAGE = ( + "RunConfig.throttle and ThrottleConfig are deprecated. Use RunConfig.request_admission with " + "RequestAdmissionTuningConfig for supported advanced request-admission tuning." +) - These knobs configure the ``ThrottleManager`` that wraps every outbound - model HTTP request. The defaults are conservative and suitable for most - workloads; override only when you understand the trade-offs. - Attributes: - reduce_factor: Multiplicative decrease factor applied to the per-domain - concurrency limit on a 429 / rate-limit signal. Must be in (0, 1). - Default is 0.75 (reduce by 25% on rate-limit). - additive_increase: Additive increase step applied after every - ``success_window`` consecutive successes. Default is 1. - success_window: Number of consecutive successful releases before - the additive increase is applied. Default is 25. - cooldown_seconds: Default cooldown duration (seconds) applied after a - rate-limit when the provider does not include a ``Retry-After`` - header. Default is 2.0. - ceiling_overshoot: Fraction above the observed rate-limit ceiling - that additive increase is allowed to probe before capping. - Default is 0.10 (10% overshoot). - rampup_seconds: Optional startup ramp duration. When greater than - zero, each throttle domain starts at one concurrent request and - linearly ramps to its configured peak over this many seconds. - A 429 aborts the startup ramp and switches to normal AIMD recovery. - Default is 0.0 (disabled). +class RequestAdmissionTuningConfig(ConfigBase): + """Advanced request-admission AIMD tuning for model API calls. + + Most workloads should tune model capacity with ``max_parallel_requests`` on + inference parameters. These fields adjust the adaptive recovery behavior + below that cap and are intended for provider/runtime support cases. """ - DEFAULT_REDUCE_FACTOR: ClassVar[float] = 0.75 - DEFAULT_ADDITIVE_INCREASE: ClassVar[int] = 1 - DEFAULT_SUCCESS_WINDOW: ClassVar[int] = 25 - DEFAULT_COOLDOWN_SECONDS: ClassVar[float] = 2.0 - DEFAULT_CEILING_OVERSHOOT: ClassVar[float] = 0.10 - DEFAULT_RAMPUP_SECONDS: ClassVar[float] = 0.0 + multiplicative_decrease_factor: float = Field( + default=0.75, + gt=0.0, + lt=1.0, + description="Factor applied to the adaptive concurrency limit after a provider rate-limit signal.", + ) + additive_increase_step: int = Field( + default=1, + ge=1, + description="Slots added to the adaptive concurrency limit after each successful recovery window.", + ) + successes_until_increase: int = Field( + default=25, + ge=1, + description="Successful releases required before additive recovery increases the adaptive limit.", + ) + cooldown_seconds: float = Field( + default=2.0, + gt=0.0, + description="Fallback cooldown after a rate-limit signal when the provider omits Retry-After.", + ) + startup_ramp_seconds: float = Field( + default=0.0, + ge=0.0, + description=( + "Startup ramp duration. When greater than zero, each request resource starts at one " + "concurrent request and linearly ramps to its configured cap unless a rate-limit aborts the ramp." + ), + ) + + +class ThrottleConfig(ConfigBase): + """Deprecated compatibility DTO for request-admission tuning. + + Use ``RequestAdmissionTuningConfig`` via ``RunConfig.request_admission`` + instead. ``ceiling_overshoot`` is accepted for compatibility but is not + forwarded because request admission no longer exposes an overshoot knob. + """ reduce_factor: float = Field( - default=DEFAULT_REDUCE_FACTOR, + default=0.75, gt=0.0, lt=1.0, - description="Multiplicative decrease factor applied to the per-domain concurrency limit on a 429 signal.", + description="Deprecated alias for RequestAdmissionTuningConfig.multiplicative_decrease_factor.", ) additive_increase: int = Field( - default=DEFAULT_ADDITIVE_INCREASE, + default=1, ge=1, - description="Additive increase step applied after every `success_window` consecutive successes.", + description="Deprecated alias for RequestAdmissionTuningConfig.additive_increase_step.", ) success_window: int = Field( - default=DEFAULT_SUCCESS_WINDOW, + default=25, ge=1, - description="Number of consecutive successful releases before the additive increase is applied.", + description="Deprecated alias for RequestAdmissionTuningConfig.successes_until_increase.", ) cooldown_seconds: float = Field( - default=DEFAULT_COOLDOWN_SECONDS, + default=2.0, gt=0.0, - description="Default cooldown duration (seconds) after a rate-limit when no Retry-After header is present.", + description="Deprecated alias for RequestAdmissionTuningConfig.cooldown_seconds.", ) ceiling_overshoot: float = Field( - default=DEFAULT_CEILING_OVERSHOOT, + default=0.10, ge=0.0, - description="Fraction above the rate-limit ceiling that additive increase is allowed to probe.", + description="Deprecated compatibility field; not forwarded to request admission.", ) rampup_seconds: float = Field( - default=DEFAULT_RAMPUP_SECONDS, + default=0.0, ge=0.0, - description=( - "Startup ramp duration in seconds. When greater than zero, each throttle domain starts at one " - "concurrent request and linearly ramps to the configured peak. A 429 aborts the startup ramp." - ), + description=("Deprecated alias for RequestAdmissionTuningConfig.startup_ramp_seconds."), ) + def to_request_admission_tuning(self) -> RequestAdmissionTuningConfig: + """Translate legacy throttle tuning into the request-admission DTO.""" + return RequestAdmissionTuningConfig( + multiplicative_decrease_factor=self.reduce_factor, + additive_increase_step=self.additive_increase, + successes_until_increase=self.success_window, + cooldown_seconds=self.cooldown_seconds, + startup_ramp_seconds=self.rampup_seconds, + ) + class RunConfig(ConfigBase): """Runtime configuration for dataset generation. @@ -126,7 +152,13 @@ class RunConfig(ConfigBase): fewer Data Designer-specific restrictions. ``secure`` uses Data Designer's hardened sandbox with additional AST, filter, and output guards. Default is ``secure``. - throttle: AIMD throttle tuning parameters. See ``ThrottleConfig`` for details. + request_admission: Advanced AIMD request-admission tuning for provider/model calls. + Most users should leave this unset and tune ``max_parallel_requests`` instead. + + Notes: + Request-admission controller internals remain engine-owned. This field + exposes only the supported tuning DTO and does not expose controller + mutation APIs, leases, queues, or pressure snapshots. """ disable_early_shutdown: bool = False @@ -146,7 +178,31 @@ class RunConfig(ConfigBase): "`native` uses Jinja2's built-in sandbox; `secure` uses Data Designer's hardened sandbox." ), ) - throttle: ThrottleConfig = Field(default_factory=ThrottleConfig) + request_admission: RequestAdmissionTuningConfig | None = None + + @model_validator(mode="before") + @classmethod + def translate_deprecated_throttle_config(cls, data: Any) -> Any: + if isinstance(data, dict) and "throttle" in data: + normalized = dict(data) + throttle = normalized.pop("throttle") + if normalized.get("request_admission") is not None: + raise ValueError( + "Specify either RunConfig.throttle or RunConfig.request_admission, not both. " + "RunConfig.throttle is deprecated." + ) + if throttle is not None: + throttle_config = ( + throttle if isinstance(throttle, ThrottleConfig) else ThrottleConfig.model_validate(throttle) + ) + normalized["request_admission"] = throttle_config.to_request_admission_tuning() + warnings.warn( + _THROTTLE_DEPRECATION_MESSAGE, + DeprecationWarning, + stacklevel=2, + ) + return normalized + return data @model_validator(mode="after") def normalize_shutdown_settings(self) -> Self: diff --git a/packages/data-designer-config/src/data_designer/config/scheduling.py b/packages/data-designer-config/src/data_designer/config/scheduling.py new file mode 100644 index 000000000..84e36b3a0 --- /dev/null +++ b/packages/data-designer-config/src/data_designer/config/scheduling.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Literal + +SchedulingMetadataKind = Literal["local", "model", "custom_model"] + + +@dataclass(frozen=True) +class SchedulingMetadata: + """Static generator-facing scheduling metadata. + + The metadata describes broad resource shape only. It intentionally does + not expose ready-queue state, task-admission state, request-admission + pressure, provider cooldowns, or adaptive request limits. + """ + + kind: SchedulingMetadataKind = "local" + identity: tuple[str, ...] = ("local", "default") + weight: int = 1 + diagnostics: dict[str, object] = field(default_factory=dict) + + @classmethod + def local(cls, resource_name: str = "default", *, weight: int = 1) -> SchedulingMetadata: + return cls(kind="local", identity=("local", resource_name), weight=weight) + + @classmethod + def model( + cls, + provider_name: str, + model_id: str, + generation_kind: str, + *, + weight: int, + diagnostics: dict[str, object] | None = None, + ) -> SchedulingMetadata: + return cls( + kind="model", + identity=("model", provider_name, model_id, generation_kind), + weight=weight, + diagnostics=diagnostics or {}, + ) + + @classmethod + def custom_model( + cls, + plugin_namespace: str, + resource_name: str, + version: str, + *, + weight: int = 1, + diagnostics: dict[str, object] | None = None, + ) -> SchedulingMetadata: + return cls( + kind="custom_model", + identity=("custom_model", plugin_namespace, resource_name, version), + weight=weight, + diagnostics=diagnostics or {}, + ) + + def __post_init__(self) -> None: + if self.kind not in {"local", "model", "custom_model"}: + raise SchedulingMetadataError( + code="invalid_kind", + message=f"Unknown scheduling metadata kind: {self.kind!r}", + diagnostics={"kind": self.kind}, + ) + if not isinstance(self.identity, tuple) or not self.identity: + raise SchedulingMetadataError( + code="invalid_identity", + message="Scheduling metadata identity must be a non-empty tuple of non-empty strings.", + diagnostics={"identity": self.identity}, + ) + if any(not isinstance(part, str) or not part for part in self.identity): + raise SchedulingMetadataError( + code="invalid_identity", + message="Scheduling metadata identity must contain only non-empty strings.", + diagnostics={"identity": self.identity}, + ) + expected_identity_lengths = {"local": 2, "model": 4, "custom_model": 4} + if self.identity[0] != self.kind or len(self.identity) != expected_identity_lengths[self.kind]: + raise SchedulingMetadataError( + code="invalid_identity", + message=f"Scheduling metadata identity for kind {self.kind!r} has an invalid shape.", + diagnostics={ + "kind": self.kind, + "identity": self.identity, + "expected_prefix": self.kind, + "expected_length": expected_identity_lengths[self.kind], + }, + ) + if isinstance(self.weight, bool) or not isinstance(self.weight, int) or self.weight <= 0: + raise SchedulingMetadataError( + code="invalid_weight", + message="Scheduling metadata weight must be a positive integer.", + diagnostics={"weight": self.weight}, + ) + + +class SchedulingMetadataError(ValueError): + """Typed scheduling metadata resolution error.""" + + def __init__( + self, + *, + code: str, + message: str, + fallback: SchedulingMetadata | None = None, + diagnostics: dict[str, object] | None = None, + ) -> None: + super().__init__(message) + self.code = code + self.message = message + self.fallback = fallback + self.diagnostics = diagnostics or {} diff --git a/packages/data-designer-config/tests/config/test_run_config.py b/packages/data-designer-config/tests/config/test_run_config.py index 870e370e1..9d216025c 100644 --- a/packages/data-designer-config/tests/config/test_run_config.py +++ b/packages/data-designer-config/tests/config/test_run_config.py @@ -4,8 +4,15 @@ from __future__ import annotations import pytest +from pydantic import ValidationError -from data_designer.config.run_config import JinjaRenderingEngine, RunConfig, ThrottleConfig +import data_designer.config as dd +from data_designer.config.run_config import ( + JinjaRenderingEngine, + RequestAdmissionTuningConfig, + RunConfig, + ThrottleConfig, +) def test_run_config_defaults_to_secure_jinja_renderer() -> None: @@ -17,6 +24,111 @@ def test_run_config_accepts_native_renderer() -> None: assert JinjaRenderingEngine(run_config.jinja_rendering_engine) == JinjaRenderingEngine.NATIVE +def test_run_config_throttle_shim_rejects_unknown_legacy_fields() -> None: + with pytest.raises(ValidationError, match="max_concurrent_requests"): + RunConfig(throttle={"max_concurrent_requests": 1}) + + +def test_run_config_throttle_shim_translates_to_request_admission() -> None: + with pytest.warns(DeprecationWarning, match="RunConfig.throttle.*RequestAdmissionTuningConfig"): + run_config = RunConfig( + throttle=ThrottleConfig( + reduce_factor=0.5, + additive_increase=2, + success_window=7, + cooldown_seconds=1.5, + ceiling_overshoot=0.2, + rampup_seconds=30.0, + ) + ) + + assert run_config.request_admission is not None + assert run_config.request_admission.multiplicative_decrease_factor == 0.5 + assert run_config.request_admission.additive_increase_step == 2 + assert run_config.request_admission.successes_until_increase == 7 + assert run_config.request_admission.cooldown_seconds == 1.5 + assert run_config.request_admission.startup_ramp_seconds == 30.0 + + +def test_run_config_throttle_shim_accepts_legacy_dict() -> None: + with pytest.warns(DeprecationWarning, match="RunConfig.throttle.*RequestAdmissionTuningConfig"): + run_config = RunConfig( + throttle={ + "reduce_factor": 0.5, + "additive_increase": 2, + "success_window": 7, + "cooldown_seconds": 1.5, + "rampup_seconds": 30.0, + } + ) + + assert run_config.request_admission is not None + assert run_config.request_admission.multiplicative_decrease_factor == 0.5 + assert run_config.request_admission.additive_increase_step == 2 + assert run_config.request_admission.successes_until_increase == 7 + assert run_config.request_admission.cooldown_seconds == 1.5 + assert run_config.request_admission.startup_ramp_seconds == 30.0 + + +def test_run_config_rejects_throttle_and_request_admission_together() -> None: + with pytest.raises(ValidationError, match="Specify either RunConfig.throttle or RunConfig.request_admission"): + RunConfig(throttle=ThrottleConfig(), request_admission=RequestAdmissionTuningConfig()) + + +def test_request_admission_tuning_config_accepts_canonical_fields() -> None: + config = RequestAdmissionTuningConfig( + multiplicative_decrease_factor=0.5, + additive_increase_step=2, + successes_until_increase=7, + cooldown_seconds=1.5, + startup_ramp_seconds=30.0, + ) + + assert config.multiplicative_decrease_factor == 0.5 + assert config.additive_increase_step == 2 + assert config.successes_until_increase == 7 + assert config.cooldown_seconds == 1.5 + assert config.startup_ramp_seconds == 30.0 + + +def test_request_admission_tuning_config_rejects_throttle_era_field_names() -> None: + with pytest.raises(ValidationError, match="success_window"): + RequestAdmissionTuningConfig(success_window=7) + + +def test_run_config_accepts_request_admission_tuning() -> None: + run_config = RunConfig(request_admission=RequestAdmissionTuningConfig(startup_ramp_seconds=10.0)) + + assert run_config.request_admission is not None + assert run_config.request_admission.startup_ramp_seconds == 10.0 + + +def test_run_config_accepts_request_admission_tuning_dict() -> None: + run_config = RunConfig( + request_admission={ + "multiplicative_decrease_factor": 0.5, + "successes_until_increase": 7, + "startup_ramp_seconds": 10.0, + } + ) + + assert run_config.request_admission is not None + assert run_config.request_admission.multiplicative_decrease_factor == 0.5 + assert run_config.request_admission.successes_until_increase == 7 + assert run_config.request_admission.startup_ramp_seconds == 10.0 + + +def test_request_admission_tuning_config_is_exported_from_config_package() -> None: + assert dd.RequestAdmissionTuningConfig is RequestAdmissionTuningConfig + + +def test_deprecated_throttle_config_is_exported_from_config_package() -> None: + assert dd.ThrottleConfig is ThrottleConfig + namespace: dict[str, object] = {} + exec("from data_designer.config import ThrottleConfig", namespace) + assert namespace["ThrottleConfig"] is ThrottleConfig + + def test_throttle_config_accepts_rampup_seconds() -> None: config = ThrottleConfig(rampup_seconds=30.0) assert config.rampup_seconds == 30.0 diff --git a/packages/data-designer-config/tests/config/test_scheduling.py b/packages/data-designer-config/tests/config/test_scheduling.py new file mode 100644 index 000000000..e219daddd --- /dev/null +++ b/packages/data-designer-config/tests/config/test_scheduling.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError + + +@pytest.mark.parametrize( + "metadata", + [ + SchedulingMetadata.local(), + SchedulingMetadata.model("nvidia", "nemotron", "chat", weight=2), + SchedulingMetadata.custom_model("plugin", "resource", "v1"), + ], +) +def test_scheduling_metadata_accepts_normative_shapes(metadata: SchedulingMetadata) -> None: + assert metadata.weight >= 1 + + +@pytest.mark.parametrize( + "kwargs", + [ + {"identity": ["local", "default"]}, + {"weight": True}, + {"kind": "model", "identity": ("local", "default")}, + {"kind": "local", "identity": ("local", "default", "extra")}, + {"kind": "custom_model", "identity": ("custom_model", "plugin")}, + ], +) +def test_scheduling_metadata_rejects_non_normative_direct_construction(kwargs: dict[str, object]) -> None: + with pytest.raises(SchedulingMetadataError): + SchedulingMetadata(**kwargs) # type: ignore[arg-type] diff --git a/packages/data-designer-engine/src/data_designer/engine/capacity.py b/packages/data-designer-engine/src/data_designer/engine/capacity.py new file mode 100644 index 000000000..e10a729e7 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/capacity.py @@ -0,0 +1,123 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass, field +from typing import Generic, Literal, TypeVar + +from data_designer.engine.dataset_builders.scheduling.resources import SchedulerResourceKey, TaskGroupKey +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap + +_T = TypeVar("_T") + +CapacityValueSource = Literal[ + "default", + "run_config", + "dataset_builder", + "model_metadata", + "engine_internal_config", + "adapter_config", + "environment", + "runtime_snapshot", + "benchmark_override", +] + + +@dataclass(frozen=True) +class CapacityValue(Generic[_T]): + value: _T | None + source: CapacityValueSource + fallback_from: str | None = None + missing_reason: str | None = None + + +@dataclass(frozen=True) +class RowGroupAdmission: + row_group_concurrency: CapacityValue[int] + observed_in_flight: int | None = None + mode: Literal["fixed", "adaptive"] = "fixed" + target_in_flight: int | None = None + observed_max_target: int | None = None + max_admitted_rows: int | None = None + blocked_reasons: Mapping[str, int] = field(default_factory=dict) + + +@dataclass(frozen=True) +class RequestAdmissionConfigSnapshot: + resources: Sequence[RequestResourceKey] + initial_limits: Mapping[RequestResourceKey, int] + max_limit_clamps: Mapping[RequestResourceKey, int | None] + cooldown_seconds: float + multiplicative_decrease_factor: float + additive_increase_step: int + successes_until_increase: int + startup_ramp_seconds: float + default_queue_wait_timeout_seconds: float | None + + @classmethod + def from_config(cls, config: RequestAdmissionConfig) -> RequestAdmissionConfigSnapshot: + resources = tuple(sorted({*config.initial_limits, *config.max_limit_clamps})) + return cls( + resources=resources, + initial_limits=dict(config.initial_limits), + max_limit_clamps=dict(config.max_limit_clamps), + cooldown_seconds=config.cooldown_seconds, + multiplicative_decrease_factor=config.multiplicative_decrease_factor, + additive_increase_step=config.additive_increase_step, + successes_until_increase=config.successes_until_increase, + startup_ramp_seconds=config.startup_ramp_seconds, + default_queue_wait_timeout_seconds=config.default_queue_wait_timeout_seconds, + ) + + +@dataclass(frozen=True) +class AsyncCapacityConfigured: + buffer_size: CapacityValue[int] + row_group_admission: RowGroupAdmission + submission_capacity: CapacityValue[int] + task_resource_limits: CapacityValue[Mapping[SchedulerResourceKey, int]] + request_resources: CapacityValue[Sequence[RequestResourceKey]] + provider_model_static_caps: CapacityValue[Mapping[ProviderModelKey, ProviderModelStaticCap]] + request_domain_initial_limits: CapacityValue[Mapping[RequestResourceKey, int]] + request_admission_config: CapacityValue[RequestAdmissionConfigSnapshot] + transport_pool_limits: CapacityValue[Mapping[ProviderModelKey, int]] + + +@dataclass(frozen=True) +class AsyncCapacityRuntimeSnapshot: + request_domain_current_limits: Mapping[RequestResourceKey, int] | None = None + request_domain_effective_max: Mapping[RequestResourceKey, int] | None = None + request_domain_blocked_until: Mapping[RequestResourceKey, float | None] | None = None + provider_model_aggregate_in_flight: Mapping[ProviderModelKey, int] | None = None + + +@dataclass(frozen=True) +class AsyncCapacityObservedMaxima: + row_groups_in_flight: int = 0 + queued_tasks_by_group: Mapping[TaskGroupKey | str, int] = field(default_factory=dict) + task_leases_by_resource: Mapping[SchedulerResourceKey, int] = field(default_factory=dict) + request_waiters_by_resource: Mapping[RequestResourceKey, int] = field(default_factory=dict) + request_in_flight_by_resource: Mapping[RequestResourceKey, int] = field(default_factory=dict) + provider_model_aggregate_in_flight: Mapping[ProviderModelKey, int] = field(default_factory=dict) + request_domain_current_limits: Mapping[RequestResourceKey, int] = field(default_factory=dict) + transport_pool_utilization: Mapping[ProviderModelKey, int] | None = None + + +@dataclass(frozen=True) +class AsyncCapacityPlan: + configured: AsyncCapacityConfigured + runtime_snapshot: AsyncCapacityRuntimeSnapshot + observed_maxima: AsyncCapacityObservedMaxima + + +def missing_capacity_value( + *, + source: CapacityValueSource, + missing_reason: str, + fallback_from: str | None = None, +) -> CapacityValue[object]: + return CapacityValue(value=None, source=source, fallback_from=fallback_from, missing_reason=missing_reason) diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py index 2431c0eb6..ba432ce2c 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/base.py @@ -6,18 +6,21 @@ import asyncio import concurrent.futures import functools +import hashlib import logging from abc import ABC, abstractmethod +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Coroutine, TypeVar, overload from data_designer.config.column_configs import GenerationStrategy +from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError from data_designer.engine.configurable_task import ConfigurableTask, DataT, TaskConfigT from data_designer.logging import LOG_DOUBLE_INDENT, LOG_INDENT _T = TypeVar("_T") # Preserved deliberately. Two other 300s deadlines were retired in the -# async-default flip (PR #592): the throttle queue-wait and the +# async-default flip (PR #592): the request-admission queue wait and the # ``_AsyncBridgedModelFacade`` bridge in ``custom.py`` — both have # ``ModelFacade`` context and could derive a per-call deadline from # ``inference_parameters.timeout``. This generic ``ColumnGenerator.generate()`` @@ -26,6 +29,20 @@ # tracked as a structural follow-up. SYNC_BRIDGE_TIMEOUT = 300 + +@dataclass +class _EndpointBucket: + aliases: list[str] = field(default_factory=list) + caps: list[int] = field(default_factory=list) + + +def _scheduling_generation_kind(generation_type: object) -> str: + value = getattr(generation_type, "value", generation_type) + if value == "chat-completion": + return "chat" + return str(value) + + if TYPE_CHECKING: import pandas as pd @@ -65,10 +82,14 @@ class ColumnGenerator(ConfigurableTask[TaskConfigT], ABC): def can_generate_from_scratch(self) -> bool: return False - @property - def is_llm_bound(self) -> bool: - """Whether this generator makes model/API calls during generation.""" - return False + def get_scheduling_metadata(self) -> SchedulingMetadata: + """Return static scheduler metadata for this generator. + + Generators that do not declare model-backed behavior use the documented + local default. Model-aware base classes override this with provider/model + resource identity derived from registered model aliases. + """ + return SchedulingMetadata.local() @property def is_order_dependent(self) -> bool: @@ -143,10 +164,6 @@ async def agenerate_from_scratch(self, num_records: int) -> pd.DataFrame: class ColumnGeneratorWithModelRegistry(ColumnGenerator[TaskConfigT], ABC): - @property - def is_llm_bound(self) -> bool: - return True - @property def model_registry(self) -> ModelRegistry: return self.resource_provider.model_registry @@ -161,6 +178,90 @@ def get_model_provider_name(self, model_alias: str) -> str: provider = self.model_registry.get_model_provider(model_alias=model_alias) return provider.name + def get_scheduling_metadata(self) -> SchedulingMetadata: + aliases = self._get_scheduling_model_aliases() + if not aliases: + raise SchedulingMetadataError( + code="missing_model_alias", + message=f"{type(self).__name__} has no model aliases for scheduling metadata.", + fallback=SchedulingMetadata.local(), + diagnostics={"generator_type": type(self).__name__}, + ) + + endpoints: dict[tuple[str, str, str], _EndpointBucket] = {} + for alias in aliases: + try: + model_config = self.get_model_config(model_alias=alias) + provider_name = self.get_model_provider_name(model_alias=alias) + except Exception as exc: + raise SchedulingMetadataError( + code="alias_resolution_failed", + message=f"Could not resolve model alias {alias!r} for scheduling metadata.", + diagnostics={"alias": alias, "generator_type": type(self).__name__}, + ) from exc + + endpoint = ( + provider_name, + str(model_config.model), + _scheduling_generation_kind(model_config.generation_type), + ) + max_parallel = getattr(model_config.inference_parameters, "max_parallel_requests", 1) + cap = max_parallel if isinstance(max_parallel, int) and max_parallel > 0 else 1 + bucket = endpoints.setdefault(endpoint, _EndpointBucket()) + bucket.aliases.append(alias) + bucket.caps.append(cap) + + if len(endpoints) != 1: + raw_caps = tuple(cap for bucket in endpoints.values() for cap in bucket.caps) + return SchedulingMetadata.custom_model( + _scheduling_plugin_namespace(type(self)), + _scheduling_alias_set_resource_name(aliases), + "v1", + weight=max(1, sum(raw_caps)), + diagnostics={ + "aliases": tuple(sorted(aliases)), + "endpoints": tuple(sorted(str(endpoint) for endpoint in endpoints)), + "fallback_reason": "multi_endpoint_alias_set", + "raw_caps": raw_caps, + }, + ) + + endpoint, bucket = next(iter(endpoints.items())) + provider_name, model_id, generation_kind = endpoint + effective_cap = max(1, min(bucket.caps)) + return SchedulingMetadata.model( + provider_name, + model_id, + generation_kind, + weight=effective_cap, + diagnostics={ + "aliases": tuple(bucket.aliases), + "raw_caps": tuple(bucket.caps), + "merge_rule": "min_same_endpoint", + }, + ) + + def _get_scheduling_model_aliases(self) -> list[str]: + get_aliases = getattr(self.config, "get_model_aliases", None) + if callable(get_aliases): + aliases = get_aliases() + else: + aliases = [] + if (alias := getattr(self.config, "model_alias", None)) is not None: + aliases.append(alias) + aliases.extend(getattr(self.config, "model_aliases", []) or []) + return list(dict.fromkeys(str(alias) for alias in aliases if alias)) + + +def _scheduling_plugin_namespace(generator_type: type[object]) -> str: + return f"{generator_type.__module__}.{generator_type.__qualname__}" + + +def _scheduling_alias_set_resource_name(aliases: list[str]) -> str: + alias_key = "\0".join(sorted(aliases)).encode() + digest = hashlib.sha1(alias_key).hexdigest()[:16] + return f"alias-set-{digest}" + class ColumnGeneratorWithModel(ColumnGeneratorWithModelRegistry[TaskConfigT], ABC): @functools.cached_property diff --git a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py index b4c863542..08c78120b 100644 --- a/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py +++ b/packages/data-designer-engine/src/data_designer/engine/column_generators/generators/custom.py @@ -13,6 +13,7 @@ import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import CustomColumnConfig, GenerationStrategy +from data_designer.config.scheduling import SchedulingMetadata from data_designer.engine.column_generators.generators.base import ColumnGenerator from data_designer.engine.column_generators.utils.errors import CustomColumnGenerationError from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS, ModelTimeoutError @@ -105,7 +106,7 @@ def generate(self, *args: Any, **kwargs: Any) -> tuple[Any, list]: except concurrent.futures.TimeoutError as exc: future.cancel() # Demoted to debug: the raised ModelTimeoutError already surfaces - # the timeout at the scheduler with full context, and the throttled + # the timeout at the scheduler with full context, and the request-admission # degraded-provider WARN is the user-facing signal under sustained # bridge timeouts. Per-event WARN was noise on top of those. logger.debug("Async model bridge timed out after %.0fs; coroutine cancelled", bridge_timeout) @@ -137,10 +138,18 @@ class CustomColumnGenerator(ColumnGenerator[CustomColumnConfig]): The models dict provides direct access to ModelFacade instances keyed by alias. """ - @property - def is_llm_bound(self) -> bool: - """Custom generators with model_aliases make LLM calls and need the handoff.""" - return bool(self.config.model_aliases) + def get_scheduling_metadata(self) -> SchedulingMetadata: + """Return custom-model metadata when the custom column declares model aliases.""" + if not self.config.model_aliases: + return SchedulingMetadata.local() + identity = "-".join(sorted(str(alias) for alias in self.config.model_aliases)) + return SchedulingMetadata.custom_model( + "custom_column", + identity or self.config.name, + "v1", + weight=max(1, len(self.config.model_aliases)), + diagnostics={"aliases": tuple(sorted(str(alias) for alias in self.config.model_aliases))}, + ) def get_generation_strategy(self) -> GenerationStrategy: """Return strategy based on config.""" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py index 778501da1..9109eafcc 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/async_scheduler.py @@ -5,51 +5,80 @@ import asyncio import contextlib +import hashlib import logging import time -from collections import defaultdict, deque -from collections.abc import Coroutine +import uuid +from collections import Counter, defaultdict, deque +from collections.abc import Coroutine, Mapping from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import GenerationStrategy +from data_designer.engine.capacity import ( + AsyncCapacityConfigured, + AsyncCapacityObservedMaxima, + AsyncCapacityPlan, + AsyncCapacityRuntimeSnapshot, + CapacityValue, + RequestAdmissionConfigSnapshot, + RowGroupAdmission, +) from data_designer.engine.context import current_row_group from data_designer.engine.dataset_builders.errors import DatasetGenerationError from data_designer.engine.dataset_builders.multi_column_configs import MultiColumnConfig +from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta +from data_designer.engine.dataset_builders.scheduling.queue import ( + FairTaskQueue, +) +from data_designer.engine.dataset_builders.scheduling.resolver import TaskSchedulingResolver +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_admission import ( + TaskAdmissionConfig, + TaskAdmissionController, + TaskAdmissionDenied, + TaskAdmissionLease, +) +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef, Task, TaskTrace from data_designer.engine.dataset_builders.utils.async_progress_reporter import ( DEFAULT_REPORT_INTERVAL, AsyncProgressReporter, ) -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta -from data_designer.engine.dataset_builders.utils.fair_task_queue import ( - FairTaskQueue, - TaskGroupKey, - TaskGroupSpec, -) from data_designer.engine.dataset_builders.utils.progress_tracker import ProgressTracker -from data_designer.engine.dataset_builders.utils.scheduling_hints import SchedulingHint, SchedulingHintResolver from data_designer.engine.dataset_builders.utils.skip_evaluator import should_skip_column_for_record from data_designer.engine.dataset_builders.utils.skip_tracker import ( apply_skip_to_record, strip_skip_metadata_from_records, ) from data_designer.engine.dataset_builders.utils.sticky_progress_bar import StickyProgressBar -from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task, TaskTrace -from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS +from data_designer.engine.errors import DataDesignerError +from data_designer.engine.models.clients.errors import ProviderError +from data_designer.engine.models.errors import RETRYABLE_MODEL_ERRORS, GenerationValidationFailureError +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap +from data_designer.engine.observability import ( + RuntimeCorrelation, + SchedulerAdmissionEvent, + SchedulerAdmissionEventSink, + runtime_correlation_provider, +) if TYPE_CHECKING: from data_designer.engine.column_generators.generators.base import ColumnGenerator from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager + from data_designer.engine.models.request_admission.pressure import RequestPressureSnapshotProvider logger = logging.getLogger(__name__) DEFAULT_TASK_POOL_SIZE: int = 256 -# Global LLM wait-pool headroom sizes the memory-safety semaphore above provider capacity. -GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER: int = 2 -# Per-group admission backlog caps how many ready LLM tasks one fair-queue group can hold. -LLM_GROUP_ADMISSION_BACKLOG_MULTIPLIER: int = 2 +MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER: int = 2 +MODEL_GROUP_ADMISSION_BACKLOG_MULTIPLIER: int = 2 # Degraded-provider WARN: emit at most one warning per interval when the # rolling fraction of retryable errors exceeds the threshold. Distinct from @@ -58,21 +87,27 @@ DEGRADED_WARN_RATE: float = 0.5 DEGRADED_WARN_WINDOW: int = 20 DEGRADED_WARN_INTERVAL_S: float = 60.0 +INTERNAL_BUG_EXCEPTIONS = (KeyError, TypeError, AttributeError, AssertionError) -class TrackingSemaphore(asyncio.Semaphore): - """``asyncio.Semaphore`` subclass that exposes available permits publicly.""" +def _identity_hash(identity: tuple[str, ...]) -> str: + return hashlib.sha1("\0".join(identity).encode()).hexdigest()[:16] - @property - def available_permits(self) -> int: - return self._value # type: ignore[attr-defined] - def try_acquire(self) -> bool: - """Non-blocking acquire. Returns ``True`` if a permit was taken.""" - if self._value > 0: # type: ignore[attr-defined] - self._value -= 1 # type: ignore[attr-defined] - return True - return False +def _request_resource_label(resource: object | None) -> str | None: + if resource is None: + return None + provider = getattr(resource, "provider_name", None) + model = getattr(resource, "model_id", None) + domain = getattr(resource, "domain", None) + domain_value = getattr(domain, "value", domain) + if provider is None or model is None or domain_value is None: + return str(resource) + return f"{provider}/{model}/{domain_value}" + + +def _string_keyed_counts(values: Mapping[object, int]) -> dict[str, int]: + return {str(key): int(value) for key, value in values.items()} @dataclass @@ -90,8 +125,7 @@ class _DispatchOutcome: """Result of one fair-dispatch pass over the persistent ready queue.""" dispatched: bool = False - submission_full: bool = False - group_blocked: bool = False + admission_blocked: bool = False class AsyncTaskScheduler: @@ -111,7 +145,8 @@ def __init__( *, max_concurrent_row_groups: int = 3, max_submitted_tasks: int = DEFAULT_TASK_POOL_SIZE, - max_llm_wait_tasks: int = DEFAULT_TASK_POOL_SIZE, + max_model_task_admission: int = DEFAULT_TASK_POOL_SIZE, + task_admission_config: TaskAdmissionConfig | None = None, salvage_max_rounds: int = 2, on_finalize_row_group: Callable[[int], None] | None = None, on_seeds_complete: Callable[[int, int], FrontierDelta | None] | None = None, @@ -127,6 +162,12 @@ def __init__( buffer_size: int = 0, progress_interval: float | None = None, progress_bar: bool = False, + scheduler_event_sink: SchedulerAdmissionEventSink | None = None, + run_id: str | None = None, + adaptive_row_group_admission: bool = False, + adaptive_row_group_initial_target: int = 1, + request_pressure_provider: RequestPressureSnapshotProvider | None = None, + request_pressure_advisory: bool = False, ) -> None: self._generators = generators self._graph = graph @@ -135,22 +176,29 @@ def __init__( self._buffer_manager = buffer_manager self._rg_semaphore = asyncio.Semaphore(max_concurrent_row_groups) - self._submission_semaphore = TrackingSemaphore(max_submitted_tasks) - self._llm_wait_semaphore = TrackingSemaphore(max_llm_wait_tasks) - self._max_llm_wait_tasks = max_llm_wait_tasks - self._llm_bound_lookup = build_llm_bound_lookup(generators) - self._scheduling_hints = SchedulingHintResolver(generators) + self._task_scheduling = TaskSchedulingResolver( + generators, + model_group_limit_multiplier=MODEL_GROUP_ADMISSION_BACKLOG_MULTIPLIER, + model_group_limit_cap=max_model_task_admission, + ) + admission_config = task_admission_config or TaskAdmissionConfig( + submission_capacity=max_submitted_tasks, + resource_limits={"llm_wait": max_model_task_admission, "local": max_submitted_tasks}, + ) + self._task_admission = TaskAdmissionController(admission_config) + self._task_admission_config = admission_config self._fair_queue = FairTaskQueue() self._pending_pre_batch_ready: defaultdict[int, list[Task]] = defaultdict(list) self._pending_pre_batch_ready_tasks: set[Task] = set() - # Task group specs are derived from per-generator scheduling hints and flow identity. - self._task_group_spec_cache: dict[int, TaskGroupSpec] = {} self._dispatched: set[Task] = set() self._in_flight: set[Task] = set() self._worker_tasks: set[asyncio.Task] = set() self._wake_event = asyncio.Event() + self._run_id = run_id or f"run-{uuid.uuid4().hex}" + self._scheduler_event_sink = scheduler_event_sink + self._scheduler_event_sequence = 0 self._salvage_max_rounds = salvage_max_rounds self._on_finalize_row_group = on_finalize_row_group self._on_seeds_complete = on_seeds_complete @@ -202,7 +250,7 @@ def __init__( self._all_rgs_admitted = False # Degraded-provider WARN: separate window tracking retryable-vs-not for - # every outcome (success or failure), throttled to one log per interval. + # every outcome (success or failure), rate-limited to one log per interval. self._degraded_warn_rate = degraded_warn_rate self._degraded_warn_window = degraded_warn_window self._degraded_warn_interval_s = degraded_warn_interval_s @@ -224,9 +272,38 @@ def __init__( # context naturally because the from_scratch task raised; the async # engine drops rows and continues, losing the cause unless we capture it. self._first_non_retryable_error: Exception | None = None + self._fatal_worker_error: BaseException | None = None # Pre-compute row-group sizes for O(1) lookup self._rg_size_map: dict[int, int] = dict(row_groups) + self._max_concurrent_row_groups = max_concurrent_row_groups + self._max_submitted_tasks = max_submitted_tasks + self._max_model_task_admission = max_model_task_admission + self._num_records = num_records + self._buffer_size = buffer_size + self._observed_max_row_groups_in_flight = 0 + self._observed_max_task_leases_by_resource: dict[str, int] = {} + self._observed_max_queued_by_group: dict[str, int] = {} + self._observed_max_request_waiters_by_resource: dict[RequestResourceKey, int] = {} + self._observed_max_request_in_flight_by_resource: dict[RequestResourceKey, int] = {} + self._observed_max_provider_model_aggregate_in_flight: dict[ProviderModelKey, int] = {} + self._observed_max_request_domain_current_limits: dict[RequestResourceKey, int] = {} + self._adaptive_row_group_admission = adaptive_row_group_admission + self._row_group_admission_hard_cap = max(1, max_concurrent_row_groups) + self._row_group_admission_target = ( + max(1, min(self._row_group_admission_hard_cap, adaptive_row_group_initial_target)) + if adaptive_row_group_admission + else self._row_group_admission_hard_cap + ) + self._observed_max_row_group_admission_target = self._row_group_admission_target + self._row_group_admission_event = asyncio.Event() + self._row_group_admission_event.set() + self._row_group_admission_pressure_ticks = 0 + self._row_group_admission_blocked_reasons: Counter[str] = Counter() + self._adaptive_max_admitted_rows = self._max_admitted_rows_guardrail() + self._request_pressure_provider = request_pressure_provider + self._request_pressure_advisory = request_pressure_advisory and request_pressure_provider is not None + self._request_pressure_advisory_skips = 0 # Pre-compute seed columns (graph is static) self._seed_cols: tuple[str, ...] = tuple(c for c in graph.columns if not graph.get_upstream_columns(c)) @@ -293,6 +370,13 @@ def first_non_retryable_error(self) -> Exception | None: """ return self._first_non_retryable_error + def _raise_if_fatal_worker_error(self) -> None: + if self._fatal_worker_error is None: + return + raise DatasetGenerationError( + "Unexpected internal task failure in async scheduler." + ) from self._fatal_worker_error + def _spawn_worker(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task: """Create a tracked worker task that auto-removes itself on completion.""" task = asyncio.create_task(coro) @@ -300,6 +384,230 @@ def _spawn_worker(self, coro: Coroutine[Any, Any, None]) -> asyncio.Task: task.add_done_callback(self._worker_tasks.discard) return task + def _emit_scheduler_event( + self, + event_kind: str, + *, + task: Task | None = None, + lease: TaskAdmissionLease | None = None, + task_execution_id: str | None = None, + scheduler_resource_key: str | None = None, + reason_or_result: str | None = None, + diagnostics: dict[str, object] | None = None, + ) -> None: + if self._scheduler_event_sink is None: + return + self._scheduler_event_sequence += 1 + correlation = None + event_diagnostics = dict(diagnostics or {}) + if task is not None: + schedulable = lease.item if lease is not None else self._schedulable_task(task) + group = schedulable.group + identity_hash = _identity_hash(group.key.identity) + event_diagnostics.setdefault("task_group_key", group.key) + event_diagnostics.setdefault("resource_request", dict(schedulable.resource_request.amounts)) + correlation = RuntimeCorrelation( + run_id=self._run_id, + row_group=task.row_group, + task_column=task.column, + task_type=task.task_type, + scheduling_group_kind=group.key.kind, + scheduling_group_identity_hash=identity_hash, + task_execution_id=task_execution_id, + ) + try: + self._scheduler_event_sink.emit_scheduler_event( + SchedulerAdmissionEvent.capture( + event_kind, # type: ignore[arg-type] + sequence=self._scheduler_event_sequence, + correlation=correlation, + task_id=stable_task_id(task) if task is not None else None, + task_execution_id=task_execution_id, + task_lease_id=lease.lease_id if lease is not None else None, + scheduler_resource_key=scheduler_resource_key, + reason_or_result=reason_or_result, + snapshot=self.task_admission_snapshot(), + diagnostics=event_diagnostics, + ) + ) + except Exception: + logger.warning("Scheduler admission event sink raised; dropping event.", exc_info=True) + return + + def _record_observed_task_state(self) -> None: + self._observed_max_row_groups_in_flight = max(self._observed_max_row_groups_in_flight, len(self._rg_states)) + view = self._task_admission.view() + for resource, count in view.leased_resources.items(): + self._observed_max_task_leases_by_resource[resource] = max( + self._observed_max_task_leases_by_resource.get(resource, 0), + count, + ) + queue_view = self._fair_queue.view() + for group, count in queue_view.queued_by_group.items(): + label = f"{group.kind}:{'/'.join(group.identity)}" + self._observed_max_queued_by_group[label] = max(self._observed_max_queued_by_group.get(label, 0), count) + if self._request_pressure_provider is None: + return + for resource, snapshot in self._request_pressure_provider.snapshots().items(): + self._observed_max_request_waiters_by_resource[resource] = max( + self._observed_max_request_waiters_by_resource.get(resource, 0), + snapshot.waiters, + ) + self._observed_max_request_in_flight_by_resource[resource] = max( + self._observed_max_request_in_flight_by_resource.get(resource, 0), + snapshot.in_flight_count, + ) + self._observed_max_request_domain_current_limits[resource] = max( + self._observed_max_request_domain_current_limits.get(resource, 0), + snapshot.current_limit, + ) + for provider_model, snapshot in self._request_pressure_provider.global_snapshots().items(): + self._observed_max_provider_model_aggregate_in_flight[provider_model] = max( + self._observed_max_provider_model_aggregate_in_flight.get(provider_model, 0), + snapshot.aggregate_in_flight, + ) + + def _emit_scheduler_health_snapshot(self, reason: str) -> None: + self._emit_scheduler_event( + "scheduler_health_snapshot", + diagnostics=self._scheduler_health_diagnostics(reason=reason), + ) + + def _scheduler_health_diagnostics(self, *, reason: str) -> dict[str, object]: + queue_view = self._fair_queue.view() + task_view = self._task_admission.view() + return { + "reason": reason, + "active_row_groups": len(self._rg_states), + "target_row_groups": self._row_group_admission_target, + "hard_cap_row_groups": self._row_group_admission_hard_cap, + "active_admitted_rows": self._active_admitted_row_count(), + "max_admitted_rows": self._adaptive_max_admitted_rows, + "all_row_groups_admitted": self._all_rgs_admitted, + "queued_total": queue_view.queued_total, + "queued_by_group": _string_keyed_counts(queue_view.queued_by_group), + "queued_demand_by_resource": dict(queue_view.queued_peer_demand_by_resource), + "leased_resources": dict(task_view.leased_resources), + "resource_limits": dict(task_view.resource_limits), + "resources_available": dict(task_view.resources_available), + "in_flight_tasks": len(self._in_flight), + "active_workers": self.active_worker_count, + "deferred_tasks": len(self._deferred), + "pending_pre_batch_tasks": len(self._pending_pre_batch_ready_tasks), + "dispatched_tasks": len(self._dispatched), + "request_pressure_advisory_enabled": self._request_pressure_advisory, + "request_pressure_advisory_skips": self._request_pressure_advisory_skips, + "row_group_admission_blocked_reasons": dict(self._row_group_admission_blocked_reasons), + "request_pressure": self._request_pressure_diagnostics(), + } + + def _scheduler_job_diagnostics(self) -> dict[str, object]: + row_group_sizes = [size for _rg_id, size in self._row_groups] + strategies = {column: self._graph.get_strategy(column).value for column in self._graph.columns} + task_count_by_strategy = Counter(strategies.values()) + return { + "run_id": self._run_id, + "num_records": self._num_records, + "buffer_size": self._buffer_size, + "row_group_count": len(self._row_groups), + "row_group_total_rows": sum(row_group_sizes), + "row_group_min_size": min(row_group_sizes, default=0), + "row_group_max_size": max(row_group_sizes, default=0), + "graph_column_count": len(self._graph.columns), + "graph_root_columns": tuple(self._graph.get_root_columns()), + "graph_depth": len(self._graph.get_longest_dependency_chain()), + "task_count_by_strategy": dict(task_count_by_strategy), + "column_scheduling": self._column_scheduling_diagnostics(strategies), + "resource_limits": dict(self._task_admission_config.resource_limits), + "submission_capacity": self._task_admission_config.submission_capacity, + "adaptive_row_group_admission": self._adaptive_row_group_admission, + "row_group_initial_target": self._row_group_admission_target, + "row_group_hard_cap": self._row_group_admission_hard_cap, + "max_admitted_rows": self._adaptive_max_admitted_rows, + "request_pressure_advisory_enabled": self._request_pressure_advisory, + } + + def _column_scheduling_diagnostics(self, strategies: dict[str, str]) -> tuple[dict[str, object], ...]: + diagnostics = [] + for column in self._graph.columns: + task_type = "batch" if self._graph.get_strategy(column) != GenerationStrategy.CELL_BY_CELL else "cell" + row_index = None if task_type == "batch" else 0 + task = Task(column=column, row_group=0, row_index=row_index, task_type=task_type) + resolved = self._task_scheduling.scheduling_for_task(task, self._task_flow_identity(task)) + diagnostics.append( + { + "column": column, + "strategy": strategies[column], + "group_kind": resolved.group.key.kind, + "group_identity_hash": _identity_hash(resolved.group.key.identity), + "group_weight": resolved.group.weight, + "group_admitted_limit": resolved.group.admitted_limit, + "resource_request": dict(resolved.resource_request.amounts), + "request_resource": _request_resource_label(resolved.request_resource_key), + } + ) + return tuple(diagnostics) + + def _request_pressure_diagnostics(self) -> dict[str, object]: + if self._request_pressure_provider is None: + return {"enabled": False} + return { + "enabled": True, + "resources": { + _request_resource_label(resource): { + "effective_max": snapshot.effective_max, + "current_limit": snapshot.current_limit, + "in_flight_count": snapshot.in_flight_count, + "active_lease_count": snapshot.active_lease_count, + "waiters": snapshot.waiters, + "blocked": snapshot.blocked_until_monotonic is not None, + "cooldown_remaining_seconds": snapshot.cooldown_remaining_seconds, + "last_outcome": snapshot.last_outcome, + } + for resource, snapshot in self._request_pressure_provider.snapshots().items() + }, + "provider_models": { + f"{provider_model.provider_name}/{provider_model.model_id}": { + "static_cap": snapshot.static_cap, + "aggregate_in_flight": snapshot.aggregate_in_flight, + "aggregate_active_lease_count": snapshot.aggregate_active_lease_count, + "domains": {domain.value: count for domain, count in snapshot.domains.items()}, + } + for provider_model, snapshot in self._request_pressure_provider.global_snapshots().items() + }, + } + + def _request_pressure_item_diagnostics(self, item: SchedulableTask) -> dict[str, object]: + if item.request_resource_key is None or self._request_pressure_provider is None: + return {"request_resource": None} + snapshot = self._request_pressure_provider.snapshot(item.request_resource_key) + global_snapshot = self._request_pressure_provider.global_snapshot( + item.request_resource_key.provider_name, + item.request_resource_key.model_id, + ) + diagnostics: dict[str, object] = { + "request_resource": _request_resource_label(item.request_resource_key), + "pressure_reason": self._request_pressure_reason(item), + "resource_snapshot": None, + "provider_model_snapshot": None, + } + if snapshot is not None: + diagnostics["resource_snapshot"] = { + "effective_max": snapshot.effective_max, + "current_limit": snapshot.current_limit, + "in_flight_count": snapshot.in_flight_count, + "waiters": snapshot.waiters, + "blocked": snapshot.blocked_until_monotonic is not None, + "cooldown_remaining_seconds": snapshot.cooldown_remaining_seconds, + } + if global_snapshot is not None: + diagnostics["provider_model_snapshot"] = { + "static_cap": global_snapshot.static_cap, + "aggregate_in_flight": global_snapshot.aggregate_in_flight, + "aggregate_active_lease_count": global_snapshot.aggregate_active_lease_count, + } + return diagnostics + async def _cancel_workers(self) -> None: """Cancel all tracked worker tasks and wait for them to finish.""" for t in self._worker_tasks: @@ -313,114 +621,373 @@ def _apply_frontier_delta(self, delta: FrontierDelta) -> None: return for task in delta.removed: self._discard_ready_task(task) - for task in delta.added: - self._enqueue_ready_task(task) + self._enqueue_ready_tasks(delta.added) def _enqueue_ready_task(self, task: Task) -> None: - if task in self._dispatched or task.row_group not in self._rg_states: - return - if not self._tracker.is_frontier_task(task): - return - state = self._rg_states[task.row_group] - if self._on_seeds_complete is not None and not state.pre_batch_done: - if task not in self._pending_pre_batch_ready_tasks: - self._pending_pre_batch_ready[task.row_group].append(task) - self._pending_pre_batch_ready_tasks.add(task) + self._enqueue_ready_tasks((task,)) + + def _enqueue_ready_tasks(self, tasks: tuple[Task, ...]) -> None: + schedulables: list[SchedulableTask] = [] + accepted_tasks_by_id: dict[str, Task] = {} + for task in tasks: + if task in self._dispatched or task.row_group not in self._rg_states: + continue + if not self._tracker.is_frontier_task(task): + continue + self._emit_scheduler_event("dependency_ready", task=task) + state = self._rg_states[task.row_group] + if self._on_seeds_complete is not None and not state.pre_batch_done and task.column not in self._seed_cols: + if task not in self._pending_pre_batch_ready_tasks: + self._pending_pre_batch_ready[task.row_group].append(task) + self._pending_pre_batch_ready_tasks.add(task) + continue + schedulable = self._schedulable_task(task) + schedulables.append(schedulable) + accepted_tasks_by_id[schedulable.task_id] = task + + if not schedulables: return - self._fair_queue.enqueue(task, self._task_group_spec(task)) + accepted = self._fair_queue.enqueue(schedulables) + if accepted: + self._tracker.mark_enqueued(accepted) + for task_id in accepted: + self._emit_scheduler_event("ready_enqueued", task=accepted_tasks_by_id[task_id]) + self._record_observed_task_state() + self._wake_event.set() def _discard_ready_task(self, task: Task) -> None: - self._fair_queue.discard(task) + self._fair_queue.discard(stable_task_id(task)) self._pending_pre_batch_ready_tasks.discard(task) def _flush_pre_batch_ready(self, row_group: int) -> None: pending = self._pending_pre_batch_ready.pop(row_group, []) + ready = [] for task in pending: if task not in self._pending_pre_batch_ready_tasks: continue self._pending_pre_batch_ready_tasks.discard(task) - self._enqueue_ready_task(task) + ready.append(task) + self._enqueue_ready_tasks(tuple(ready)) def _drop_pending_ready_for_row_group(self, row_group: int) -> None: pending = self._pending_pre_batch_ready.pop(row_group, []) for task in pending: self._pending_pre_batch_ready_tasks.discard(task) - self._fair_queue.discard_where(lambda task: task.row_group == row_group) + self._fair_queue.discard_where(lambda item: item.payload.row_group == row_group) def _dispatch_queued_tasks(self) -> _DispatchOutcome: dispatched = False while self._fair_queue.has_queued_tasks: - if not self._submission_semaphore.try_acquire(): - return _DispatchOutcome(dispatched=dispatched, submission_full=True) - - selection = self._fair_queue.admit_next() + selection = self._fair_queue.select_next(self._is_dispatch_eligible) if selection is None: - self._submission_semaphore.release() - return _DispatchOutcome(dispatched=dispatched, group_blocked=True) + summary = self._task_admission.explain_blocked(self._fair_queue.view()) + if "group_cap" in summary.dominant_denial_reasons: + event_kind = "group_capped" + elif summary.dominant_denial_reasons: + event_kind = "admission_blocked" + else: + event_kind = "queue_empty" + self._emit_scheduler_event( + event_kind, + diagnostics={ + "queued_count": summary.queued_count, + "reasons": dict(summary.dominant_denial_reasons), + }, + ) + self._emit_scheduler_health_snapshot(event_kind) + return _DispatchOutcome(dispatched=dispatched, admission_blocked=True) + + self._emit_scheduler_event("selected", task=selection.item.payload) + decision = self._task_admission.try_acquire(selection.item, selection.queue_view) + if isinstance(decision, TaskAdmissionDenied): + self._emit_scheduler_event( + "admission_denied", + task=selection.item.payload, + reason_or_result=decision.reason, + diagnostics=dict(decision.diagnostics), + ) + return _DispatchOutcome(dispatched=dispatched, admission_blocked=True) + self._emit_scheduler_event("task_lease_acquired", task=selection.item.payload, lease=decision) + + committed = self._fair_queue.commit(selection) + if committed is None: + result = self._task_admission.release(decision) + self._emit_scheduler_event( + "stale_selection", + task=selection.item.payload, + lease=decision, + reason_or_result=result.reason, + ) + return _DispatchOutcome(dispatched=dispatched, admission_blocked=True) - self._dispatch_selected_task(selection.task) + self._dispatch_selected_task(committed, decision) dispatched = True + self._record_observed_task_state() + if dispatched: + self._emit_scheduler_event("queue_drained") + self._emit_scheduler_health_snapshot("queue_drained") return _DispatchOutcome(dispatched=dispatched) - def _dispatch_selected_task(self, task: Task) -> None: + def _is_dispatch_eligible(self, item: SchedulableTask, view: Any) -> bool: + if not self._task_admission.is_eligible(item, view): + return False + if not self._request_pressure_advisory: + return True + if not self._is_request_pressure_limited(item): + return True + open_peer = self._request_pressure_open_peer(item, view) + if open_peer is not None: + self._request_pressure_advisory_skips += 1 + self._emit_scheduler_event( + "request_pressure_advisory_skipped", + task=item.payload, + diagnostics=self._request_pressure_item_diagnostics(item) + | { + "open_peer_task_id": open_peer.task_id, + "open_peer_column": open_peer.payload.column, + "open_peer_request_resource": _request_resource_label(open_peer.request_resource_key), + "skip_count": self._request_pressure_advisory_skips, + }, + ) + return False + return True + + def _is_request_pressure_limited(self, item: SchedulableTask) -> bool: + return self._request_pressure_reason(item) is not None + + def _request_pressure_reason(self, item: SchedulableTask) -> str | None: + if item.request_resource_key is None or self._request_pressure_provider is None: + return None + snapshot = self._request_pressure_provider.snapshot(item.request_resource_key) + global_snapshot = self._request_pressure_provider.global_snapshot( + item.request_resource_key.provider_name, + item.request_resource_key.model_id, + ) + if ( + global_snapshot is not None + and global_snapshot.static_cap > 0 + and global_snapshot.aggregate_in_flight >= global_snapshot.static_cap + ): + return "provider_model_aggregate_cap" + if snapshot is None: + return None + if snapshot.cooldown_remaining_seconds > 0.0 or snapshot.blocked_until_monotonic is not None: + return "cooldown" + if snapshot.waiters > 0: + return "waiters" + if snapshot.current_limit > 0 and snapshot.in_flight_count >= snapshot.current_limit: + return "resource_limit" + return None + + def _has_request_pressure_open_peer(self, item: SchedulableTask, view: Any) -> bool: + return self._request_pressure_open_peer(item, view) is not None + + def _request_pressure_open_peer(self, item: SchedulableTask, view: Any) -> SchedulableTask | None: + for peer in view.first_candidate_tasks_by_group.values(): + if peer.task_id == item.task_id: + continue + if not self._task_admission.is_eligible(peer, view): + continue + if not self._is_request_pressure_limited(peer): + return peer + return None + + def _dispatch_selected_task(self, item: SchedulableTask, lease: TaskAdmissionLease) -> None: + task = item.payload + task_execution_id = f"task-exec-{uuid.uuid4().hex}" self._dispatched.add(task) self._in_flight.add(task) if (s := self._rg_states.get(task.row_group)) is not None: s.in_flight_count += 1 - self._spawn_worker(self._execute_task(task)) - - def _task_group_spec(self, task: Task) -> TaskGroupSpec: - generator = self._generators[task.column] - generator_id = id(generator) - cached = self._task_group_spec_cache.get(generator_id) - if cached is not None: - return cached - - spec = self._task_group_spec_from_hint( - self._scheduling_hints.hint_for(generator), - self._task_flow_identity(task), - ) - self._task_group_spec_cache[generator_id] = spec - return spec - - def _task_group_spec_from_hint(self, hint: SchedulingHint, flow_identity: tuple[str, ...]) -> TaskGroupSpec: - if hint.group_kind == "local": - return TaskGroupSpec(key=TaskGroupKey(kind="local", identity=flow_identity)) - - if hint.group_kind == "custom_model": - identity = (*flow_identity, *hint.identity_suffix) - else: - identity = (*hint.identity_prefix, *flow_identity, *hint.identity_suffix) + try: + self._spawn_worker(self._execute_task(task, lease, task_execution_id)) + self._emit_scheduler_event("worker_spawned", task=task, lease=lease, task_execution_id=task_execution_id) + except Exception: + result = self._task_admission.release(lease) + self._in_flight.discard(task) + self._dispatched.discard(task) + if (s := self._rg_states.get(task.row_group)) is not None: + s.in_flight_count = max(0, s.in_flight_count - 1) + self._emit_scheduler_event( + "worker_spawn_failed", + task=task, + lease=lease, + task_execution_id=task_execution_id, + reason_or_result=result.reason, + ) + raise - weight = max(1, hint.weight) - return TaskGroupSpec( - key=TaskGroupKey(kind=hint.group_kind, identity=identity), - weight=float(weight), - admitted_limit=self._llm_group_admitted_limit(weight), - ) + def _schedulable_task(self, task: Task) -> SchedulableTask: + return self._task_scheduling.schedulable_task(task, self._task_flow_identity(task)) def _task_flow_identity(self, task: Task) -> tuple[str, ...]: generator = self._generators[task.column] output_columns = self._gen_instance_to_columns.get(id(generator), [task.column]) return tuple(output_columns) - def _llm_group_admitted_limit(self, weight: int) -> int: - return max(1, min(self._max_llm_wait_tasks, LLM_GROUP_ADMISSION_BACKLOG_MULTIPLIER * weight)) + def _max_admitted_rows_guardrail(self) -> int: + if self._num_records > 0 and self._buffer_size > 0: + return min(self._num_records, max(3 * self._buffer_size, 8192)) + total_rows = sum(size for _rg_id, size in self._row_groups) + return max(1, total_rows) + + async def _wait_for_row_group_admission_capacity(self, row_group_size: int) -> None: + while True: + target_blocked = len(self._rg_states) >= self._row_group_admission_target + row_guard_blocked = not self._row_group_row_guard_allows(row_group_size) + if not target_blocked and not row_guard_blocked: + return + self._row_group_admission_event.clear() + target_blocked = len(self._rg_states) >= self._row_group_admission_target + row_guard_blocked = not self._row_group_row_guard_allows(row_group_size) + if not target_blocked and not row_guard_blocked: + return + if row_guard_blocked: + self._row_group_admission_blocked_reasons["max_admitted_rows"] += 1 + self._emit_scheduler_event( + "row_group_admission_blocked", + diagnostics=self._row_group_admission_diagnostics(reason="max_admitted_rows"), + ) + self._emit_scheduler_health_snapshot("row_group_admission_blocked") + await self._row_group_admission_event.wait() + self._raise_if_fatal_worker_error() + + def _row_group_row_guard_allows(self, row_group_size: int) -> bool: + if not self._adaptive_row_group_admission: + return True + admitted_rows = self._active_admitted_row_count() + return admitted_rows == 0 or admitted_rows + row_group_size <= self._adaptive_max_admitted_rows + + def _active_admitted_row_count(self) -> int: + return sum(state.size for state in self._rg_states.values()) + + def _maybe_update_adaptive_row_group_target(self) -> None: + if not self._adaptive_row_group_admission: + return + if self._all_rgs_admitted or self._early_shutdown or self._fatal_worker_error is not None: + return + if len(self._rg_states) >= self._row_group_admission_hard_cap: + self._row_group_admission_pressure_ticks = 0 + return + reason = self._adaptive_row_group_block_reason() + if reason is not None: + self._row_group_admission_blocked_reasons[reason] += 1 + self._row_group_admission_pressure_ticks = 0 + self._emit_scheduler_event( + "row_group_admission_blocked", + diagnostics=self._row_group_admission_diagnostics(reason=reason), + ) + self._emit_scheduler_health_snapshot("row_group_admission_blocked") + return + + self._row_group_admission_pressure_ticks += 1 + if self._fair_queue.view().queued_total > 0 and self._row_group_admission_pressure_ticks < 2: + return + old_target = self._row_group_admission_target + self._row_group_admission_target = min(self._row_group_admission_hard_cap, old_target + 1) + self._observed_max_row_group_admission_target = max( + self._observed_max_row_group_admission_target, + self._row_group_admission_target, + ) + self._row_group_admission_pressure_ticks = 0 + if self._row_group_admission_target != old_target: + self._emit_scheduler_event( + "row_group_admission_target_changed", + diagnostics=self._row_group_admission_diagnostics(reason="horizon_limited") + | {"old_target": old_target, "new_target": self._row_group_admission_target}, + ) + self._emit_scheduler_health_snapshot("row_group_admission_target_changed") + self._row_group_admission_event.set() + + def _adaptive_row_group_block_reason(self) -> str | None: + if self._deferred: + return "deferred_tasks" + next_size = self._next_unadmitted_row_group_size() + if next_size is None: + return "no_pending_row_groups" + if not self._row_group_row_guard_allows(next_size): + return "max_admitted_rows" + queue_view = self._fair_queue.view() + queue_guard = max(self._max_submitted_tasks * 4, self._max_model_task_admission * 2) + if queue_view.queued_total >= queue_guard: + return "queued_task_guardrail" + task_view = self._task_admission.view() + llm_limit = task_view.resource_limits.get("llm_wait", 0) + if llm_limit <= 0: + return "no_llm_wait_resource" + llm_available = task_view.resources_available.get("llm_wait", 0) + queued_llm = queue_view.queued_peer_demand_by_resource.get("llm_wait", 0) + if llm_available <= 0: + return "llm_wait_saturated" + if llm_available <= queued_llm and queue_view.queued_total > 0: + return "queued_llm_demand" + return None + + def _next_unadmitted_row_group_size(self) -> int | None: + for rg_id, rg_size in self._row_groups: + if rg_id not in self._rg_states and not self._tracker.is_row_group_complete( + rg_id, rg_size, self._graph.columns + ): + return rg_size + return None + + def _row_group_admission_diagnostics(self, *, reason: str) -> dict[str, object]: + queue_view = self._fair_queue.view() + task_view = self._task_admission.view() + admitted_rows = self._active_admitted_row_count() + return { + "mode": "adaptive" if self._adaptive_row_group_admission else "fixed", + "reason": reason, + "active_row_groups": len(self._rg_states), + "target_row_groups": self._row_group_admission_target, + "hard_cap": self._row_group_admission_hard_cap, + "admitted_rows": admitted_rows, + "max_admitted_rows": self._adaptive_max_admitted_rows, + "queued_total": queue_view.queued_total, + "queued_llm_wait_demand": queue_view.queued_peer_demand_by_resource.get("llm_wait", 0), + "llm_wait_limit": task_view.resource_limits.get("llm_wait", 0), + "llm_wait_leased": task_view.leased_resources.get("llm_wait", 0), + "llm_wait_available": task_view.resources_available.get("llm_wait", 0), + "blocked_reasons": dict(self._row_group_admission_blocked_reasons), + } async def _admit_row_groups(self) -> None: """Admit row groups as semaphore slots become available.""" + all_admitted = True for rg_id, rg_size in self._row_groups: + await self._wait_for_row_group_admission_capacity(rg_size) + if self._early_shutdown or self._fatal_worker_error is not None: + all_admitted = False + break await self._rg_semaphore.acquire() + if self._early_shutdown or self._fatal_worker_error is not None: + self._rg_semaphore.release() + all_admitted = False + break + if not self._row_group_row_guard_allows(rg_size): + self._rg_semaphore.release() + await self._wait_for_row_group_admission_capacity(rg_size) + await self._rg_semaphore.acquire() + if self._early_shutdown or self._fatal_worker_error is not None: + self._rg_semaphore.release() + all_admitted = False + break self._rg_states[rg_id] = _RowGroupState(size=rg_size) if self._buffer_manager is not None: self._buffer_manager.init_row_group(rg_id, rg_size) await self._dispatch_seeds(rg_id, rg_size) + self._emit_scheduler_event( + "row_group_admitted", + diagnostics=self._row_group_admission_diagnostics(reason="admitted") + | {"row_group": rg_id, "row_group_size": rg_size}, + ) + self._emit_scheduler_health_snapshot("row_group_admitted") self._wake_event.set() - self._all_rgs_admitted = True + self._all_rgs_admitted = all_admitted self._wake_event.set() async def run(self) -> None: @@ -440,6 +1007,9 @@ async def run(self) -> None: if self._reporter: self._reporter.log_start(num_row_groups=num_rgs) + self._emit_scheduler_event("scheduler_job_started", diagnostics=self._scheduler_job_diagnostics()) + self._emit_scheduler_health_snapshot("start") + # Launch admission as a background task so it interleaves with dispatch. admission_task = asyncio.create_task(self._admit_row_groups()) @@ -466,6 +1036,11 @@ async def run(self) -> None: if self._reporter: self._reporter.log_final() + self._emit_scheduler_health_snapshot("completed") + self._emit_scheduler_event( + "scheduler_job_completed", diagnostics=self._scheduler_health_diagnostics(reason="completed") + ) + if self._rg_states: incomplete = list(self._rg_states) logger.error( @@ -481,6 +1056,7 @@ async def _main_dispatch_loop( ) -> None: """Core dispatch loop extracted from ``run()``.""" while True: + self._raise_if_fatal_worker_error() if self._early_shutdown: logger.warning("Early shutdown triggered - non-retryable error rate exceeded threshold") if self._deferred: @@ -496,28 +1072,37 @@ async def _main_dispatch_loop( dispatch_outcome = self._dispatch_queued_tasks() self._checkpoint_completed_row_groups(all_columns) + self._maybe_update_adaptive_row_group_target() # Eagerly salvage any row groups that have only deferred tasks, # even if other row groups are still in-flight. This frees # semaphore slots so admission doesn't lose capacity. if self._deferred: await self._salvage_stalled_row_groups(seed_cols, has_pre_batch, all_columns) + self._maybe_update_adaptive_row_group_target() # Are we done? all_done = self._all_rgs_admitted and not self._rg_states and not self._in_flight if all_done: break + pending_pre_batch = has_pre_batch and any( + state.seeds_dispatched and not state.pre_batch_done for state in self._rg_states.values() + ) if not self._fair_queue.has_queued_tasks and not self._in_flight: - if self._all_rgs_admitted: + if self._all_rgs_admitted and not pending_pre_batch: break + if pending_pre_batch: + await asyncio.sleep(0) + continue - if ( - not self._fair_queue.has_queued_tasks - or dispatch_outcome.submission_full - or dispatch_outcome.group_blocked - ): + if not self._fair_queue.has_queued_tasks or dispatch_outcome.admission_blocked: + if self._fair_queue.has_queued_tasks and not dispatch_outcome.dispatched and not self._in_flight: + raise RuntimeError( + "Ready frontier is admission-blocked with no in-flight task to release scheduler capacity." + ) await self._wake_event.wait() + self._raise_if_fatal_worker_error() async def _salvage_rounds( self, @@ -549,34 +1134,10 @@ async def _salvage_rounds( self._dispatched.discard( Task(column=sibling, row_group=task.row_group, row_index=None, task_type="batch") ) - # Acquire stateful lock (mirrors _dispatch_seeds) so - # _execute_seed_task can safely release it in finally. - if gid in self._stateful_locks: - await self._stateful_locks[gid].acquire() - await self._submission_semaphore.acquire() - self._dispatched.add(task) - # Re-register batch alias to mirror _dispatch_seeds and prevent - # duplicate dispatch if the frontier contains a stale batch task. - self._dispatched.add( - Task(column=task.column, row_group=task.row_group, row_index=None, task_type="batch") - ) - # Re-mark sibling columns as dispatched to mirror _dispatch_seeds - # and prevent _drain_frontier from re-dispatching them. - for sibling in self._gen_instance_to_columns.get(gid, []): - if sibling != task.column: - self._dispatched.add( - Task(column=sibling, row_group=task.row_group, row_index=None, task_type="from_scratch") - ) - self._dispatched.add( - Task(column=sibling, row_group=task.row_group, row_index=None, task_type="batch") - ) - self._in_flight.add(task) - if (s := self._rg_states.get(task.row_group)) is not None: - s.in_flight_count += 1 - self._spawn_worker(self._execute_seed_task(task, gid)) + self._apply_frontier_delta(self._tracker.add_ready_tasks((task,))) else: self._dispatched.discard(task) - self._enqueue_ready_task(task) + self._apply_frontier_delta(self._tracker.add_ready_tasks((task,))) # Drain: dispatch frontier tasks and any newly-ready downstream tasks # until nothing remains in-flight or in the frontier. await self._drain_frontier(seed_cols, has_pre_batch) @@ -585,6 +1146,7 @@ async def _salvage_rounds( async def _drain_frontier(self, seed_cols: tuple[str, ...], has_pre_batch: bool) -> None: """Dispatch all frontier tasks and their downstream until quiescent.""" while True: + self._raise_if_fatal_worker_error() if has_pre_batch: self._run_seeds_complete_check(seed_cols) dispatch_outcome = self._dispatch_queued_tasks() @@ -599,6 +1161,7 @@ async def _drain_frontier(self, seed_cols: tuple[str, ...], has_pre_batch: bool) continue self._wake_event.clear() await self._wake_event.wait() + self._raise_if_fatal_worker_error() async def _salvage_stalled_row_groups( self, @@ -657,6 +1220,9 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: if self._tracker.is_row_group_complete(rg_id, state.size, all_columns) ] for rg_id, rg_size in completed: + dropped_rows = sum(1 for ri in range(rg_size) if self._tracker.is_dropped(rg_id, ri)) + checkpointed = False + checkpoint_result = "unknown" try: if self._on_before_checkpoint: try: @@ -670,17 +1236,36 @@ def _checkpoint_completed_row_groups(self, all_columns: list[str]) -> None: # Remove from tracking only after the callback succeeds. del self._rg_states[rg_id] # If all rows were dropped (e.g. seed failure), free instead of finalizing - if all(self._tracker.is_dropped(rg_id, ri) for ri in range(rg_size)): + if dropped_rows == rg_size: if self._buffer_manager: self._buffer_manager.free_row_group(rg_id) + checkpoint_result = "all_rows_dropped" elif self._on_finalize_row_group is not None: self._on_finalize_row_group(rg_id) + checkpoint_result = "finalized" + else: + checkpoint_result = "completed" + checkpointed = True except DatasetGenerationError: raise except Exception: logger.error(f"Failed to checkpoint row group {rg_id}.", exc_info=True) finally: self._rg_semaphore.release() + self._row_group_admission_event.set() + if checkpointed: + self._emit_scheduler_event( + "row_group_checkpointed", + diagnostics={ + "row_group": rg_id, + "row_group_size": rg_size, + "dropped_rows": dropped_rows, + "surviving_rows": rg_size - dropped_rows, + "result": checkpoint_result, + "active_row_groups": len(self._rg_states), + }, + ) + self._emit_scheduler_health_snapshot("row_group_checkpointed") # Clean up deferred tasks for checkpointed row groups if completed: @@ -803,7 +1388,7 @@ def _check_error_rate(self, *, success: bool) -> None: self._early_shutdown = True def _record_retryable_outcome(self, *, retryable: bool) -> None: - """Track retryable-error rate and emit a throttled WARN under provider degradation. + """Track retryable-error rate and emit a rate-limited WARN under provider degradation. Distinct from ``_check_error_rate``: every LLM-bound task outcome (success or failure) feeds this window so the rate reflects the provider's overall @@ -832,7 +1417,7 @@ def _record_retryable_outcome(self, *, retryable: bool) -> None: ) async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: - """Dispatch from_scratch tasks for a row group.""" + """Make from-scratch/root tasks ready for a row group.""" self._rg_states[rg_id].seeds_dispatched = True seed_cols = self._seed_cols if not seed_cols: @@ -841,6 +1426,7 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: width = len(str(num_rgs)) logger.info(f"🚀 ({rg_id + 1:0{width}d}/{num_rgs}) Dispatching with {rg_size} records") seen_instances: set[int] = set() + root_columns: list[str] = [] for col in seed_cols: gen = self._generators[col] @@ -848,64 +1434,38 @@ async def _dispatch_seeds(self, rg_id: int, rg_size: int) -> None: if gid in seen_instances: continue seen_instances.add(gid) + root_columns.append(col) - task = Task(column=col, row_group=rg_id, row_index=None, task_type="from_scratch") - # Also mark the "batch" variant as dispatched to prevent duplicate - # scheduling for this column. - batch_alias = Task(column=col, row_group=rg_id, row_index=None, task_type="batch") - if task in self._dispatched or batch_alias in self._dispatched: - continue - - # Seeds bypass fair-queue admission while row groups are being admitted; - # direct dispatch preserves stateful lock ordering across row groups. - # Acquire stateful lock *before* submission semaphore to preserve - # row-group ordering. Held until generation completes (_execute_seed_task). - if gid in self._stateful_locks: - await self._stateful_locks[gid].acquire() - - await self._submission_semaphore.acquire() - self._dispatched.add(task) - self._dispatched.add(batch_alias) - # Also mark all sibling output columns as dispatched (multi-column dedup) - for sibling_col in self._gen_instance_to_columns.get(gid, []): - if sibling_col != col: - self._dispatched.add( - Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="from_scratch") - ) - self._dispatched.add(Task(column=sibling_col, row_group=rg_id, row_index=None, task_type="batch")) - self._in_flight.add(task) - if (s := self._rg_states.get(task.row_group)) is not None: - s.in_flight_count += 1 - self._spawn_worker(self._execute_seed_task(task, gid)) - - async def _execute_seed_task(self, task: Task, generator_id: int) -> None: - """Execute a from_scratch task and release stateful lock if held.""" - try: - await self._execute_task_inner(task) - finally: - if generator_id in self._stateful_locks: - self._stateful_locks[generator_id].release() + self._apply_frontier_delta(self._tracker.add_root_tasks(rg_id, rg_size, columns=tuple(root_columns))) - async def _execute_task(self, task: Task) -> None: + async def _execute_task(self, task: Task, lease: TaskAdmissionLease, task_execution_id: str) -> None: """Execute a single task (cell or batch).""" - await self._execute_task_inner(task) + await self._execute_task_inner(task, lease, task_execution_id) - async def _execute_task_inner(self, task: Task) -> None: - """Core task execution logic. - - For LLM-bound tasks, uses a one-way semaphore handoff: acquires the - LLM-wait slot while still holding the submission slot, then releases - the submission slot (never reacquired). This prevents cross-key - starvation while bounding live coroutines. - """ + async def _execute_task_inner(self, task: Task, lease: TaskAdmissionLease, task_execution_id: str) -> None: + """Core task execution logic.""" num_rgs = len(self._row_groups) token = current_row_group.set((task.row_group, num_rgs)) + group = lease.item.group + identity_hash = hashlib.sha1("\0".join(group.key.identity).encode()).hexdigest()[:16] + correlation_token = runtime_correlation_provider.set( + RuntimeCorrelation( + run_id=self._run_id, + row_group=task.row_group, + task_column=task.column, + task_type=task.task_type, + scheduling_group_kind=group.key.kind, + scheduling_group_identity_hash=identity_hash, + task_execution_id=task_execution_id, + ) + ) try: - await self._execute_task_inner_impl(task) + await self._execute_task_inner_impl(task, lease, task_execution_id) finally: + runtime_correlation_provider.reset(correlation_token) current_row_group.reset(token) - async def _execute_task_inner_impl(self, task: Task) -> None: + async def _execute_task_inner_impl(self, task: Task, lease: TaskAdmissionLease, task_execution_id: str) -> None: trace: TaskTrace | None = None if self._trace: trace = TaskTrace.from_task(task) @@ -914,12 +1474,12 @@ async def _execute_task_inner_impl(self, task: Task) -> None: generator = self._generators[task.column] output_cols = self._gen_instance_to_columns.get(id(generator), [task.column]) retryable = False + cancelled = False # When True, skip removing from _dispatched so the task isn't re-dispatched # from the frontier (it was never completed, so it stays in the frontier). skipped = False - is_llm = self._llm_bound_lookup.get(task.column, False) - holds_submission = True - holds_llm_wait = False + uses_model_stage_resource = "llm_wait" in lease.resources + stateful_lock_acquired = False try: # Skip tasks whose row group was already checkpointed (can happen @@ -929,11 +1489,9 @@ async def _execute_task_inner_impl(self, task: Task) -> None: skipped = True return - if is_llm: - await self._llm_wait_semaphore.acquire() - holds_llm_wait = True - self._submission_semaphore.release() - holds_submission = False + if task.task_type == "from_scratch" and id(generator) in self._stateful_locks: + await self._stateful_locks[id(generator)].acquire() + stateful_lock_acquired = True if self._trace and trace: trace.slot_acquired_at = time.perf_counter() @@ -962,7 +1520,7 @@ async def _execute_task_inner_impl(self, task: Task) -> None: # window from LLM-bound tasks so a healthy non-model task mix # (samplers, expressions, non-LLM customs) doesn't dilute the # rate and silence the WARN under genuine provider stress. - if is_llm: + if uses_model_stage_resource: self._record_retryable_outcome(retryable=False) if self._reporter: if cell_skipped: @@ -972,6 +1530,13 @@ async def _execute_task_inner_impl(self, task: Task) -> None: if self._trace and trace: trace.status = "ok" + except asyncio.CancelledError: + cancelled = True + if self._trace and trace: + trace.status = "cancelled" + self._emit_scheduler_event("cancelled", task=task, lease=lease, task_execution_id=task_execution_id) + raise + except Exception as exc: retryable = self._is_retryable(exc) # Only non-retryable errors (auth, schema, code bugs) count toward @@ -980,7 +1545,7 @@ async def _execute_task_inner_impl(self, task: Task) -> None: # and would otherwise trip the gate even when salvage could recover. if not retryable: self._check_error_rate(success=False) - if is_llm: + if uses_model_stage_resource: self._record_retryable_outcome(retryable=retryable) if not retryable and self._reporter: self._reporter.record_failure(task.column) @@ -990,21 +1555,41 @@ async def _execute_task_inner_impl(self, task: Task) -> None: if retryable: self._deferred.append(task) + self._emit_scheduler_event( + "retry_deferred", task=task, lease=lease, task_execution_id=task_execution_id + ) else: # Capture the first non-retryable error for the interface to surface # as the root cause when the run produces 0 records (e.g. deterministic # seed failures). Subsequent failures are still logged below. if self._first_non_retryable_error is None: self._first_non_retryable_error = exc - # Non-retryable: drop the affected row(s) + log_message = ( + f"Non-retryable failure on {task.column}[rg={task.row_group}, row={task.row_index}]: {exc}" + ) + if self._is_expected_non_retryable(exc): + logger.warning(log_message) + elif self._is_internal_bug(exc): + logger.error("Unexpected fatal %s", log_message, exc_info=True) + self._fatal_worker_error = exc + self._wake_event.set() + raise + else: + logger.error("Unexpected %s", log_message, exc_info=True) + # Non-retryable data/user/provider failures drop the affected row(s); + # internal bug-shaped failures above abort the run instead. if task.row_index is not None: self._drop_row(task.row_group, task.row_index, exclude_columns={task.column}) else: # Batch/from_scratch failure: drop all rows in the row group rg_size = self._get_rg_size(task.row_group) self._drop_row_group(task.row_group, rg_size, exclude_columns={task.column}) - logger.warning( - f"Non-retryable failure on {task.column}[rg={task.row_group}, row={task.row_index}]: {exc}" + self._emit_scheduler_event( + "non_retryable_dropped", + task=task, + lease=lease, + task_execution_id=task_execution_id, + diagnostics={"error_type": type(exc).__name__}, ) finally: @@ -1012,18 +1597,70 @@ async def _execute_task_inner_impl(self, task: Task) -> None: trace.completed_at = time.perf_counter() self.traces.append(trace) - self._fair_queue.release(task) + self._tracker.mark_complete(task) + if not cancelled: + self._emit_scheduler_event( + "task_completed", + task=task, + lease=lease, + task_execution_id=task_execution_id, + ) self._in_flight.discard(task) if (s := self._rg_states.get(task.row_group)) is not None: s.in_flight_count = max(0, s.in_flight_count - 1) if not retryable and not skipped: self._dispatched.discard(task) - if holds_llm_wait: - self._llm_wait_semaphore.release() - if holds_submission: - self._submission_semaphore.release() + if stateful_lock_acquired: + self._stateful_locks[id(generator)].release() + release_result = self._task_admission.release(lease) + self._emit_scheduler_event( + "task_lease_released", + task=task, + lease=lease, + task_execution_id=task_execution_id, + reason_or_result=release_result.reason, + ) + if not release_result.released: + self._emit_scheduler_event( + "release_diagnostic", + task=task, + lease=lease, + task_execution_id=task_execution_id, + reason_or_result=release_result.reason, + ) + self._record_observed_task_state() self._wake_event.set() + async def _run_generator_call(self, task: Task, operation: str, call: Coroutine[Any, Any, Any]) -> Any: + """Run user/plugin generator code while preserving scheduler-owned failures.""" + try: + return await call + except Exception as exc: + if self._is_retryable(exc) or self._is_expected_non_retryable(exc): + raise + raise DatasetGenerationError( + f"Generator failed for column '{task.column}' during {operation}: {exc}" + ) from exc + + def _require_dataframe_result( + self, + task: Task, + operation: str, + result: Any, + *, + expected_rows: int | None = None, + ) -> Any: + if not isinstance(result, lazy.pd.DataFrame): + raise DatasetGenerationError( + f"{operation} for column '{task.column}' must return a DataFrame, got {type(result).__name__}." + ) + if expected_rows is not None and len(result) != expected_rows: + raise DatasetGenerationError( + f"{operation} for column '{task.column}' returned {len(result)} rows " + f"but {expected_rows} were expected (rg={task.row_group})." + ) + return result + async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any: """Execute a from_scratch task.""" rg_size = self._get_rg_size(task.row_group) @@ -1031,7 +1668,12 @@ async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any from data_designer.engine.column_generators.generators.base import FromScratchColumnGenerator if isinstance(generator, FromScratchColumnGenerator): - result_df = await generator.agenerate_from_scratch(rg_size) + result_df = await self._run_generator_call( + task, + "from-scratch generation", + generator.agenerate_from_scratch(rg_size), + ) + result_operation = "From-scratch generator" else: # Non-FromScratch generators dispatched as seeds (no upstream columns) # operate on existing buffer rows — same contract as the sync engine's @@ -1043,7 +1685,18 @@ async def _run_from_scratch(self, task: Task, generator: ColumnGenerator) -> Any input_df = lazy.pd.DataFrame(records) else: input_df = lazy.pd.DataFrame(index=range(rg_size)) - result_df = await generator.agenerate(input_df) + result_df = await self._run_generator_call( + task, + "full-column generation", + generator.agenerate(input_df), + ) + result_operation = "Full-column generator" + result_df = self._require_dataframe_result( + task, + result_operation, + result_df, + expected_rows=rg_size, + ) # Write results to buffer (include side-effect columns) if self._buffer_manager is not None: @@ -1077,7 +1730,11 @@ async def _run_cell(self, task: Task, generator: ColumnGenerator) -> tuple[Any, # Copy for generation: agenerate crosses an await boundary, so the # generator must not hold a mutable reference to the live record. - result = await generator.agenerate(dict(record)) + result = await self._run_generator_call( + task, + "cell generation", + generator.agenerate(dict(record)), + ) # Write back to buffer (include side-effect columns) if self._buffer_manager is not None and not self._tracker.is_dropped(task.row_group, task.row_index): @@ -1147,17 +1804,22 @@ async def _run_batch(self, task: Task, generator: ColumnGenerator) -> Any: if len(batch_df) == 0: return batch_df - result_df = await generator.agenerate(batch_df) + active_rows = rg_size - len(pre_dropped) - len(pre_skipped) if self._buffer_manager is not None else None + result_df = await self._run_generator_call( + task, + "batch generation", + generator.agenerate(batch_df), + ) + result_df = self._require_dataframe_result( + task, + "Batch generator", + result_df, + expected_rows=active_rows, + ) # Merge result columns back to buffer (include side-effect columns) if self._buffer_manager is not None: write_cols = self._gen_instance_to_columns_including_side_effects.get(id(generator), [task.column]) - active_rows = rg_size - len(pre_dropped) - len(pre_skipped) - if len(result_df) != active_rows: - raise ValueError( - f"Batch generator for '{task.column}' returned {len(result_df)} rows " - f"but {active_rows} were expected (rg={task.row_group})." - ) result_idx = 0 for ri in range(rg_size): if ri in pre_dropped or ri in pre_skipped: @@ -1176,18 +1838,154 @@ def _get_rg_size(self, row_group: int) -> int: except KeyError: raise ValueError(f"Unknown row group: {row_group}") from None - def get_semaphore_permits(self) -> tuple[int, int]: - """Return ``(submission_available, llm_wait_available)`` for diagnostics.""" - return ( - self._submission_semaphore.available_permits, - self._llm_wait_semaphore.available_permits, + def task_admission_snapshot(self) -> object: + """Return the current scheduler task-admission snapshot for diagnostics.""" + return self._task_admission.view() + + def capacity_plan(self) -> AsyncCapacityPlan: + """Return the scheduler-side async capacity explanation for this run.""" + task_view = self._task_admission.view() + request_snapshots = ( + dict(self._request_pressure_provider.snapshots()) if self._request_pressure_provider is not None else {} + ) + provider_snapshots = ( + dict(self._request_pressure_provider.global_snapshots()) + if self._request_pressure_provider is not None + else {} + ) + request_resources = tuple(sorted(request_snapshots)) + provider_model_static_caps = { + provider_model: ProviderModelStaticCap( + cap=snapshot.static_cap, + aliases=snapshot.aliases, + raw_caps=snapshot.raw_caps, + ) + for provider_model, snapshot in provider_snapshots.items() + } + request_config = self._request_pressure_provider.config if self._request_pressure_provider is not None else None + request_config_snapshot = ( + RequestAdmissionConfigSnapshot.from_config(request_config) + if isinstance(request_config, RequestAdmissionConfig) + else None + ) + request_domain_initial_limits: dict[RequestResourceKey, int] = {} + if request_config_snapshot is not None: + request_domain_initial_limits.update(request_config_snapshot.initial_limits) + for resource, snapshot in request_snapshots.items(): + configured_initial = ( + request_config_snapshot.initial_limits.get(resource) if request_config_snapshot is not None else None + ) + request_domain_initial_limits[resource] = ( + max(1, min(configured_initial, snapshot.effective_max)) + if configured_initial is not None + else snapshot.effective_max + ) + request_domain_current_limits = { + resource: snapshot.current_limit for resource, snapshot in request_snapshots.items() + } + request_domain_effective_max = { + resource: snapshot.effective_max for resource, snapshot in request_snapshots.items() + } + request_domain_blocked_until = { + resource: snapshot.blocked_until_monotonic for resource, snapshot in request_snapshots.items() + } + provider_model_aggregate_in_flight = { + provider_model: snapshot.aggregate_in_flight for provider_model, snapshot in provider_snapshots.items() + } + return AsyncCapacityPlan( + configured=AsyncCapacityConfigured( + buffer_size=CapacityValue(value=self._buffer_size, source="run_config"), + row_group_admission=RowGroupAdmission( + row_group_concurrency=CapacityValue( + value=self._max_concurrent_row_groups, + source="dataset_builder", + ), + observed_in_flight=len(self._rg_states), + mode="adaptive" if self._adaptive_row_group_admission else "fixed", + target_in_flight=self._row_group_admission_target, + observed_max_target=self._observed_max_row_group_admission_target, + max_admitted_rows=self._adaptive_max_admitted_rows, + blocked_reasons=dict(self._row_group_admission_blocked_reasons), + ), + submission_capacity=CapacityValue(value=self._max_submitted_tasks, source="dataset_builder"), + task_resource_limits=CapacityValue( + value=dict(self._task_admission_config.resource_limits), + source="engine_internal_config", + ), + request_resources=CapacityValue( + value=request_resources, + source="runtime_snapshot", + missing_reason=None if request_resources else "request admission has not observed any resources", + ), + provider_model_static_caps=CapacityValue( + value=provider_model_static_caps, + source="model_metadata", + missing_reason=None if provider_model_static_caps else "request admission has no registered models", + ), + request_domain_initial_limits=CapacityValue( + value=request_domain_initial_limits, + source="engine_internal_config" if request_config_snapshot is not None else "runtime_snapshot", + missing_reason=None + if request_domain_initial_limits + else "request admission has not observed any domain limits", + ), + request_admission_config=CapacityValue( + value=request_config_snapshot, + source="engine_internal_config", + missing_reason=None + if request_config_snapshot is not None + else "request admission config is not exposed by the pressure provider", + ), + transport_pool_limits=CapacityValue( + value={}, + source="adapter_config", + missing_reason="transport pool utilization is adapter-specific", + ), + ), + runtime_snapshot=AsyncCapacityRuntimeSnapshot( + request_domain_current_limits=request_domain_current_limits, + request_domain_effective_max=request_domain_effective_max, + request_domain_blocked_until=request_domain_blocked_until, + provider_model_aggregate_in_flight=provider_model_aggregate_in_flight, + ), + observed_maxima=AsyncCapacityObservedMaxima( + row_groups_in_flight=self._observed_max_row_groups_in_flight, + queued_tasks_by_group=dict(self._observed_max_queued_by_group), + task_leases_by_resource=dict(self._observed_max_task_leases_by_resource or task_view.leased_resources), + request_waiters_by_resource=dict( + self._observed_max_request_waiters_by_resource + or {resource: snapshot.waiters for resource, snapshot in request_snapshots.items()} + ), + request_in_flight_by_resource=dict( + self._observed_max_request_in_flight_by_resource + or {resource: snapshot.in_flight_count for resource, snapshot in request_snapshots.items()} + ), + provider_model_aggregate_in_flight=dict( + self._observed_max_provider_model_aggregate_in_flight or provider_model_aggregate_in_flight + ), + request_domain_current_limits=dict( + self._observed_max_request_domain_current_limits or request_domain_current_limits + ), + transport_pool_utilization=None, + ), ) @staticmethod - def _is_retryable(exc: Exception) -> bool: + def _is_retryable(exc: BaseException) -> bool: """Classify whether an exception is retryable.""" return isinstance(exc, RETRYABLE_MODEL_ERRORS) + @staticmethod + def _is_expected_non_retryable(exc: BaseException) -> bool: + return isinstance( + exc, + ( + DataDesignerError, + DatasetGenerationError, + GenerationValidationFailureError, + ProviderError, + ), + ) -def build_llm_bound_lookup(generators: dict[str, ColumnGenerator]) -> dict[str, bool]: - return {col: gen.is_llm_bound for col, gen in generators.items()} + def _is_internal_bug(self, exc: BaseException) -> bool: + return isinstance(exc, INTERNAL_BUG_EXCEPTIONS) and not self._is_expected_non_retryable(exc) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py index 72b939dee..8ce6c0cde 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/dataset_builder.py @@ -70,7 +70,7 @@ from data_designer.config.run_config import RunConfig from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry - from data_designer.engine.dataset_builders.utils.task_model import TaskTrace + from data_designer.engine.dataset_builders.scheduling.task_model import TaskTrace from data_designer.engine.models.usage import ModelUsageStats logger = logging.getLogger(__name__) @@ -85,14 +85,14 @@ from data_designer.engine.dataset_builders.async_scheduler import ( DEFAULT_TASK_POOL_SIZE, - GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER, + MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER, AsyncTaskScheduler, ) + from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.async_concurrency import ( AsyncConcurrentExecutor, ensure_async_engine_loop, ) - from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager @@ -1019,9 +1019,9 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: df = self._processor_runner.run_post_batch(df, current_batch_number=rg_id, strict_row_count=True) buffer_manager.replace_dataframe(rg_id, df) - # Coarse upper bound: sums all registered aliases, not just those used - # in this build. Oversizing is harmless - ThrottleManager enforces - # the real per-key limit; the semaphore is a memory-safety cap. + # Coarse upper bound used only for scheduler task-stage model admission. + # Concrete provider/model request capacity is enforced by request admission + # at the model-call boundary. aggregate = self._resource_provider.model_registry.get_aggregate_max_parallel_requests() scheduler = AsyncTaskScheduler( @@ -1031,7 +1031,7 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: row_groups=row_groups, buffer_manager=buffer_manager, max_submitted_tasks=DEFAULT_TASK_POOL_SIZE, - max_llm_wait_tasks=max(DEFAULT_TASK_POOL_SIZE, GLOBAL_LLM_WAIT_POOL_HEADROOM_MULTIPLIER * aggregate), + max_model_task_admission=max(DEFAULT_TASK_POOL_SIZE, MODEL_TASK_ADMISSION_HEADROOM_MULTIPLIER * aggregate), on_finalize_row_group=on_finalize_row_group, on_seeds_complete=( on_seeds_complete if self._processor_runner.has_processors_for(ProcessorStage.PRE_BATCH) else None @@ -1049,6 +1049,8 @@ def on_before_checkpoint(rg_id: int, rg_size: int) -> None: buffer_size=buffer_size, progress_interval=self._resource_provider.run_config.progress_interval, progress_bar=self._resource_provider.run_config.progress_bar, + request_pressure_provider=self._resource_provider.model_registry.request_admission, + request_pressure_advisory=True, ) return scheduler, buffer_manager diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py similarity index 83% rename from packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py rename to packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py index 2d35ec0be..b34ffe69a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/completion_tracker.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/completion.py @@ -8,7 +8,8 @@ from typing import TYPE_CHECKING from data_designer.config.column_configs import GenerationStrategy -from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task +from data_designer.engine.dataset_builders.scheduling.resources import stable_task_id +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef, Task if TYPE_CHECKING: from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph @@ -147,14 +148,32 @@ def is_row_group_complete( return False return True - def get_ready_tasks(self, dispatched: set[Task], admitted_rgs: set[int] | None = None) -> list[Task]: - """Return all currently dispatchable tasks from the frontier. + def ready_frontier(self) -> tuple[Task, ...]: + """Return dependency-ready tasks not yet acknowledged as enqueued.""" + return tuple(self._frontier) - Excludes already-dispatched/in-flight tasks and tasks for row groups - not yet admitted (if ``admitted_rgs`` is provided). - """ + def mark_enqueued(self, task_ids: set[str] | list[str] | tuple[str, ...]) -> None: + """Acknowledge tasks accepted by the ready queue.""" + wanted = set(task_ids) + self._frontier = {task for task in self._frontier if stable_task_id(task) not in wanted} + + def mark_complete(self, task: Task) -> None: + """Compatibility hook for scheduler terminal accounting.""" + + def add_ready_tasks(self, tasks: list[Task] | tuple[Task, ...]) -> FrontierDelta: + """Add ready tasks to the frontier idempotently.""" + added: list[Task] = [] + for task in tasks: + if self._add_frontier_task(task): + added.append(task) + return self._record_delta(added=added, removed=[]) + + def get_ready_tasks(self, dispatched: set[Task], admitted_rgs: set[int] | None = None) -> list[Task]: + """Return all currently dispatchable tasks from the frontier.""" return [ - t for t in self._frontier if t not in dispatched and (admitted_rgs is None or t.row_group in admitted_rgs) + t + for t in self.ready_frontier() + if t not in dispatched and (admitted_rgs is None or t.row_group in admitted_rgs) ] def is_frontier_task(self, task: Task) -> bool: @@ -171,13 +190,36 @@ def seed_frontier(self) -> None: if self._graph is None: raise RuntimeError("This method requires a graph to be set.") for col in self._graph.get_root_columns(): - strategy = self._graph.get_strategy(col) for rg_id, rg_size in self._row_group_sizes.items(): - if strategy == GenerationStrategy.CELL_BY_CELL: - for ri in range(rg_size): - self._frontier.add(Task(column=col, row_group=rg_id, row_index=ri, task_type="cell")) - else: - self._frontier.add(Task(column=col, row_group=rg_id, row_index=None, task_type="batch")) + self.add_root_tasks(rg_id, rg_size, columns=(col,)) + + def add_root_tasks( + self, + row_group: int, + row_group_size: int, + *, + columns: tuple[str, ...] | None = None, + ) -> FrontierDelta: + """Add root/from-scratch tasks for one admitted row group.""" + if self._graph is None: + raise RuntimeError("This method requires a graph to be set.") + expected = self._validate_row_group(row_group) + if expected is not None and expected != row_group_size: + raise ValueError(f"Row-group size mismatch for rg={row_group}: got {row_group_size}, expected {expected}") + root_columns = columns or tuple(self._graph.get_root_columns()) + added: list[Task] = [] + for col in root_columns: + strategy = self._graph.get_strategy(col) + if strategy == GenerationStrategy.CELL_BY_CELL: + for ri in range(row_group_size): + task = Task(column=col, row_group=row_group, row_index=ri, task_type="cell") + if self._add_frontier_task(task): + added.append(task) + else: + task = Task(column=col, row_group=row_group, row_index=None, task_type="from_scratch") + if self._add_frontier_task(task): + added.append(task) + return self._record_delta(added=added, removed=[]) def _record_delta(self, *, added: list[Task], removed: list[Task]) -> FrontierDelta: return FrontierDelta(added=tuple(added), removed=tuple(removed)) @@ -204,7 +246,7 @@ def _enqueue_downstream(self, column: str, row_group: int, row_index: int | None rg_batch_complete = self._batch_complete.get(row_group, set()) rg_size = self._row_group_sizes[row_group] - for down in self._graph.get_downstream_columns(column): + for down in sorted(self._graph.get_downstream_columns(column)): batch_ups, cell_ups = self._graph.split_upstream_by_strategy(down) if any(up not in rg_batch_complete for up in batch_ups): diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py new file mode 100644 index 000000000..2cdd99b36 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/queue.py @@ -0,0 +1,202 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import heapq +from collections import Counter, defaultdict, deque +from collections.abc import Callable, Iterable, Mapping +from dataclasses import dataclass + +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceKey, + TaskGroupKey, + TaskGroupSpec, +) + + +@dataclass(frozen=True) +class QueueView: + """Read-only queue facts supplied to task admission policies.""" + + queued_total: int + queued_by_group: Mapping[TaskGroupKey, int] + queued_resource_demand_by_group: Mapping[TaskGroupKey, Mapping[SchedulerResourceKey, int]] + first_candidate_resources_by_group: Mapping[TaskGroupKey, Mapping[SchedulerResourceKey, int]] + first_candidate_tasks_by_group: Mapping[TaskGroupKey, SchedulableTask] + first_candidate_group_specs_by_group: Mapping[TaskGroupKey, TaskGroupSpec] + queued_peer_demand_by_resource: Mapping[SchedulerResourceKey, int] + + +@dataclass(frozen=True) +class QueueSelection: + """Non-mutating fair-queue selection returned to the scheduler.""" + + item: SchedulableTask + queue_view: QueueView + sequence_version: int + + +class FairTaskQueue: + """Virtual-time fair queue that owns ready membership and ordering only.""" + + def __init__(self) -> None: + self._queues: dict[TaskGroupKey, deque[SchedulableTask]] = {} + self._queued: dict[str, SchedulableTask] = {} + self._task_groups: dict[str, TaskGroupKey] = {} + self._group_specs: dict[TaskGroupKey, TaskGroupSpec] = {} + self._group_finish: dict[TaskGroupKey, float] = {} + self._heap: list[tuple[float, int, TaskGroupKey]] = [] + self._active_heap_keys: set[TaskGroupKey] = set() + self._active_heap_entries: dict[TaskGroupKey, tuple[float, int]] = {} + self._sequence = 0 + self._sequence_version = 0 + self._virtual_time = 0.0 + + @property + def has_queued_tasks(self) -> bool: + return bool(self._queued) + + def enqueue(self, items: Iterable[SchedulableTask]) -> tuple[str, ...]: + """Add ready tasks idempotently and return newly accepted task ids.""" + accepted: list[str] = [] + for item in items: + if item.task_id in self._queued: + continue + self._group_specs[item.group.key] = item.group + queue = self._queues.setdefault(item.group.key, deque()) + queue.append(item) + self._queued[item.task_id] = item + self._task_groups[item.task_id] = item.group.key + self._activate_group(item.group.key) + accepted.append(item.task_id) + if accepted: + self._sequence_version += 1 + return tuple(accepted) + + def discard(self, task_id: str) -> None: + """Remove a queued task lazily if it is no longer dispatchable.""" + if task_id in self._queued: + self._sequence_version += 1 + self._queued.pop(task_id, None) + self._task_groups.pop(task_id, None) + + def discard_where(self, predicate: Callable[[SchedulableTask], bool]) -> None: + """Remove queued tasks matching a predicate.""" + for task_id, item in tuple(self._queued.items()): + if predicate(item): + self.discard(task_id) + + def select_next(self, is_eligible: Callable[[SchedulableTask, QueueView], bool]) -> QueueSelection | None: + """Return the next eligible task without mutating queue state.""" + view = self.view() + heap_copy = list(self._heap) + heapq.heapify(heap_copy) + active_seen: set[TaskGroupKey] = set() + while heap_copy: + finish, sequence, key = heapq.heappop(heap_copy) + if key in active_seen: + continue + if self._active_heap_entries.get(key) != (finish, sequence): + continue + active_seen.add(key) + item = self._first_valid_item(key) + if item is None: + continue + if not is_eligible(item, view): + continue + return QueueSelection(item=item, queue_view=view, sequence_version=self._sequence_version) + return None + + def commit(self, selection: QueueSelection) -> SchedulableTask | None: + """Remove a previously selected task and advance fair-queue state.""" + if selection.sequence_version != self._sequence_version: + return None + item = selection.item + key = self._task_groups.get(item.task_id) + if key is None or key != item.group.key: + return None + queue = self._queues.get(key) + if queue is None: + return None + self._purge_queue_head(key) + if not queue or queue[0].task_id != item.task_id: + return None + + queue.popleft() + self._queued.pop(item.task_id, None) + self._task_groups.pop(item.task_id, None) + self._active_heap_keys.discard(key) + self._active_heap_entries.pop(key, None) + group = self._group_specs[key] + finish = self._group_finish.get(key, self._virtual_time) + self._virtual_time = max(self._virtual_time, finish) + self._group_finish[key] = self._virtual_time + (1.0 / max(group.weight, 1.0)) + self._sequence_version += 1 + self._purge_queue_head(key) + if queue: + self._activate_group(key) + return item + + def view(self) -> QueueView: + queued_by_group: Counter[TaskGroupKey] = Counter() + demand_by_group: dict[TaskGroupKey, dict[SchedulerResourceKey, int]] = defaultdict(lambda: defaultdict(int)) + first_by_group: dict[TaskGroupKey, Mapping[SchedulerResourceKey, int]] = {} + first_tasks_by_group: dict[TaskGroupKey, SchedulableTask] = {} + first_group_specs: dict[TaskGroupKey, TaskGroupSpec] = {} + demand_by_resource: Counter[SchedulerResourceKey] = Counter() + + for item in self._queued.values(): + key = item.group.key + queued_by_group[key] += 1 + for resource, amount in item.resource_request.amounts.items(): + demand_by_group[key][resource] += amount + demand_by_resource[resource] += amount + + for key, queue in self._queues.items(): + first = self._first_valid_item(key) + if first is not None: + first_by_group[key] = dict(first.resource_request.amounts) + first_tasks_by_group[key] = first + first_group_specs[key] = first.group + + return QueueView( + queued_total=len(self._queued), + queued_by_group=dict(queued_by_group), + queued_resource_demand_by_group={key: dict(value) for key, value in demand_by_group.items()}, + first_candidate_resources_by_group=first_by_group, + first_candidate_tasks_by_group=first_tasks_by_group, + first_candidate_group_specs_by_group=first_group_specs, + queued_peer_demand_by_resource=dict(demand_by_resource), + ) + + def _activate_group(self, key: TaskGroupKey) -> None: + self._purge_queue_head(key) + queue = self._queues.get(key) + if not queue or key in self._active_heap_keys: + return + self._sequence += 1 + finish = self._group_finish.get(key, self._virtual_time) + heapq.heappush(self._heap, (finish, self._sequence, key)) + self._active_heap_keys.add(key) + self._active_heap_entries[key] = (finish, self._sequence) + + def _first_valid_item(self, key: TaskGroupKey) -> SchedulableTask | None: + queue = self._queues.get(key) + if queue is None: + return None + for item in queue: + if item.task_id in self._queued and self._task_groups.get(item.task_id) == key: + return item + return None + + def _purge_queue_head(self, key: TaskGroupKey) -> None: + queue = self._queues.get(key) + if queue is None: + return + while queue: + item = queue[0] + if item.task_id in self._queued and self._task_groups.get(item.task_id) == key: + break + queue.popleft() diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resolver.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resolver.py new file mode 100644 index 000000000..c2f61e1e1 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resolver.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceRequest, + TaskGroupKey, + TaskGroupSpec, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey + +if TYPE_CHECKING: + from data_designer.engine.column_generators.generators.base import ColumnGenerator + + +@dataclass(frozen=True) +class ResolvedTaskScheduling: + """Scheduler inputs resolved from generator-facing metadata.""" + + group: TaskGroupSpec + resource_request: SchedulerResourceRequest + request_resource_key: RequestResourceKey | None = None + + +class TaskSchedulingResolver: + """Resolve generator metadata into scheduler-internal task inputs.""" + + def __init__( + self, + generators: Mapping[str, ColumnGenerator], + *, + model_group_limit_multiplier: int = 2, + model_group_limit_cap: int = 256, + ) -> None: + self._generators = generators + self._model_group_limit_multiplier = model_group_limit_multiplier + self._model_group_limit_cap = model_group_limit_cap + self._metadata_by_generator_id: dict[int, SchedulingMetadata] = {} + self._diagnostics: list[dict[str, object]] = [] + for generator in dict.fromkeys(generators.values()): + self._metadata_by_generator_id[id(generator)] = self._resolve_metadata(generator) + + @property + def diagnostics(self) -> tuple[dict[str, object], ...]: + return tuple(self._diagnostics) + + def scheduling_for_task(self, task: Task, flow_identity: tuple[str, ...]) -> ResolvedTaskScheduling: + generator = self._generators[task.column] + metadata = self._metadata_by_generator_id[id(generator)] + return self._resolved_from_metadata(metadata, flow_identity) + + def schedulable_task(self, task: Task, flow_identity: tuple[str, ...]) -> SchedulableTask: + resolved = self.scheduling_for_task(task, flow_identity) + return SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=resolved.group, + resource_request=resolved.resource_request, + request_resource_key=resolved.request_resource_key, + ) + + def _resolve_metadata(self, generator: ColumnGenerator) -> SchedulingMetadata: + try: + return generator.get_scheduling_metadata() + except SchedulingMetadataError as exc: + if exc.fallback is None: + raise + self._diagnostics.append( + { + "code": exc.code, + "message": exc.message, + "fallback": exc.fallback.identity, + "diagnostics": exc.diagnostics, + } + ) + return exc.fallback + + def _resolved_from_metadata( + self, + metadata: SchedulingMetadata, + flow_identity: tuple[str, ...], + ) -> ResolvedTaskScheduling: + weight = max(1, metadata.weight) + if metadata.kind == "local": + key = TaskGroupKey(kind="local", identity=(*metadata.identity, *flow_identity)) + return ResolvedTaskScheduling( + group=TaskGroupSpec(key=key, weight=float(weight)), + resource_request=SchedulerResourceRequest({"submission": 1}), + ) + + identity = (*metadata.identity, *flow_identity) + admitted_limit = max(1, min(self._model_group_limit_cap, self._model_group_limit_multiplier * weight)) + request_resource_key = _request_resource_key(metadata) + return ResolvedTaskScheduling( + group=TaskGroupSpec( + key=TaskGroupKey(kind=metadata.kind, identity=identity), + weight=float(weight), + admitted_limit=admitted_limit, + ), + resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 1}), + request_resource_key=request_resource_key, + ) + + +def _request_resource_key(metadata: SchedulingMetadata) -> RequestResourceKey | None: + if metadata.kind != "model": + return None + _kind, provider_name, model_id, generation_kind = metadata.identity + try: + domain = RequestDomain(generation_kind) + except ValueError: + return None + return RequestResourceKey(provider_name=provider_name, model_id=model_id, domain=domain) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resources.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resources.py new file mode 100644 index 000000000..35a0ec18f --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/resources.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import hashlib +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Literal + +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.models.request_admission.resources import RequestResourceKey + +SchedulerResourceKey = Literal["submission", "llm_wait", "local"] + + +@dataclass(frozen=True, order=True) +class TaskGroupKey: + """Stable identity for a stream of related scheduler tasks.""" + + kind: Literal["model", "custom_model", "local"] + identity: tuple[str, ...] + + +@dataclass(frozen=True) +class TaskGroupSpec: + """Scheduler-internal task group metadata.""" + + key: TaskGroupKey + weight: float = 1.0 + admitted_limit: int | None = None + + +@dataclass(frozen=True) +class SchedulerResourceRequest: + """Scheduler task-stage resource request.""" + + amounts: Mapping[SchedulerResourceKey, int] = field(default_factory=lambda: {"submission": 1}) + + def __post_init__(self) -> None: + for resource, amount in self.amounts.items(): + if resource not in {"submission", "llm_wait", "local"}: + raise ValueError(f"Unknown scheduler resource key: {resource!r}") + if not isinstance(amount, int) or amount <= 0: + raise ValueError(f"Scheduler resource amount for {resource!r} must be a positive integer.") + + +@dataclass(frozen=True) +class SchedulableTask: + """Ready task plus scheduler-owned grouping and resource request.""" + + task_id: str + payload: Task + group: TaskGroupSpec + resource_request: SchedulerResourceRequest + request_resource_key: RequestResourceKey | None = None + + +def stable_task_id(task: Task) -> str: + """Return a stable scheduler task id for queue/admission membership.""" + raw = f"{task.column}\0{task.row_group}\0{task.row_index}\0{task.task_type}".encode() + digest = hashlib.sha1(raw).hexdigest()[:16] + return f"task-{digest}" diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_admission.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_admission.py new file mode 100644 index 000000000..89fb3e280 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_admission.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import time +import uuid +from collections import Counter, defaultdict, deque +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Literal + +from data_designer.engine.dataset_builders.scheduling.queue import QueueView +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceKey, + SchedulerResourceRequest, + TaskGroupKey, +) +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.dataset_builders.scheduling.task_policies import ( + BoundedBorrowTaskAdmissionPolicy, + BoundedBorrowTaskAdmissionPolicyConfig, + PolicyStateDelta, + StrictFairTaskAdmissionPolicy, + TaskAdmissionDenyReason, + TaskAdmissionPolicy, + TaskAdmissionPolicyDecision, +) + +ReleaseReason = Literal[ + "released", + "duplicate", + "stale_lease", + "wrong_controller_generation", + "unknown_lease", +] +RELEASED_TASK_LEASE_HISTORY_LIMIT = 8192 + + +@dataclass(frozen=True) +class TaskAdmissionConfig: + """Engine-internal scheduler task-stage admission configuration.""" + + submission_capacity: int = 256 + resource_limits: Mapping[SchedulerResourceKey, int] = field(default_factory=dict) + bounded_borrow: BoundedBorrowTaskAdmissionPolicyConfig | None = None + + def __post_init__(self) -> None: + if self.submission_capacity <= 0: + raise ValueError("submission_capacity must be positive.") + merged = {"submission": self.submission_capacity, **self.resource_limits} + for resource, limit in merged.items(): + if limit <= 0: + raise ValueError(f"Task admission limit for {resource!r} must be positive.") + object.__setattr__(self, "resource_limits", merged) + + +@dataclass(frozen=True) +class TaskAdmissionView: + resource_limits: Mapping[SchedulerResourceKey, int] + resources_available: Mapping[SchedulerResourceKey, int] + leased_resources: Mapping[SchedulerResourceKey, int] + leased_resources_by_group: Mapping[TaskGroupKey, Mapping[SchedulerResourceKey, int]] + running_counts_by_group: Mapping[TaskGroupKey, int] + policy_debt_by_group_resource: Mapping[tuple[TaskGroupKey, SchedulerResourceKey], int] + + +@dataclass(frozen=True) +class TaskAdmissionLease: + lease_id: str + item: SchedulableTask + resources: Mapping[SchedulerResourceKey, int] + acquired_at: float + controller_generation: str + + +@dataclass(frozen=True) +class TaskAdmissionDenied: + item: SchedulableTask + reason: TaskAdmissionDenyReason + available_after: float | None = None + snapshot: TaskAdmissionView | None = None + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +TaskAdmissionDecision = TaskAdmissionLease | TaskAdmissionDenied + + +@dataclass(frozen=True) +class ReleaseResult: + released: bool + reason: ReleaseReason + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class TaskAdmissionBlockSummary: + queued_count: int + dominant_denial_reasons: Mapping[TaskAdmissionDenyReason, int] + available_after: float | None = None + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +class TaskAdmissionController: + """Owns scheduler-level task leases and resource accounting.""" + + def __init__( + self, + config: TaskAdmissionConfig | None = None, + policy: TaskAdmissionPolicy | None = None, + ) -> None: + self._config = config or TaskAdmissionConfig() + self._generation = uuid.uuid4().hex + self._leases: dict[str, TaskAdmissionLease] = {} + self._released: set[str] = set() + self._released_order: deque[str] = deque(maxlen=RELEASED_TASK_LEASE_HISTORY_LIMIT) + self._leased_by_resource: Counter[SchedulerResourceKey] = Counter() + self._leased_by_group: dict[TaskGroupKey, Counter[SchedulerResourceKey]] = defaultdict(Counter) + self._running_by_group: Counter[TaskGroupKey] = Counter() + self._policy_debt: Counter[tuple[TaskGroupKey, SchedulerResourceKey]] = Counter() + self._release_diagnostics: Counter[str] = Counter() + if policy is not None: + self._policy = policy + elif self._config.bounded_borrow is not None: + self._policy = BoundedBorrowTaskAdmissionPolicy(self._config.bounded_borrow) + else: + self._policy = StrictFairTaskAdmissionPolicy() + + def is_eligible(self, item: SchedulableTask, queue_view: QueueView) -> bool: + return not isinstance(self.try_evaluate(item, queue_view), TaskAdmissionDenied) + + def try_evaluate( + self, item: SchedulableTask, queue_view: QueueView + ) -> TaskAdmissionPolicyDecision | TaskAdmissionDenied: + view = self.view() + missing = self._missing_resources(item, view) + if missing: + return TaskAdmissionDenied( + item=item, + reason="no_capacity", + snapshot=view, + diagnostics={"missing_resources": missing}, + ) + decision = self._policy.evaluate(item, queue_view, view) + if not decision.allowed: + return TaskAdmissionDenied( + item=item, + reason=decision.reason or "policy_denial", + available_after=decision.available_after, + snapshot=view, + diagnostics=decision.diagnostics, + ) + return decision + + def try_acquire(self, item: SchedulableTask, queue_view: QueueView) -> TaskAdmissionDecision: + evaluated = self.try_evaluate(item, queue_view) + if isinstance(evaluated, TaskAdmissionDenied): + return evaluated + lease = TaskAdmissionLease( + lease_id=uuid.uuid4().hex, + item=item, + resources=dict(item.resource_request.amounts), + acquired_at=time.monotonic(), + controller_generation=self._generation, + ) + for resource, amount in lease.resources.items(): + self._leased_by_resource[resource] += amount + self._leased_by_group[item.group.key][resource] += amount + self._running_by_group[item.group.key] += 1 + self._apply_delta(self._policy.on_acquire(lease, evaluated)) + self._leases[lease.lease_id] = lease + return lease + + def release(self, lease: TaskAdmissionLease) -> ReleaseResult: + if lease.controller_generation != self._generation: + self._release_diagnostics["wrong_controller_generation"] += 1 + return ReleaseResult(released=False, reason="wrong_controller_generation") + active = self._leases.pop(lease.lease_id, None) + if active is None: + reason: ReleaseReason = "duplicate" if lease.lease_id in self._released else "unknown_lease" + self._release_diagnostics[reason] += 1 + return ReleaseResult(released=False, reason=reason) + if active.item.task_id != lease.item.task_id: + self._leases[lease.lease_id] = active + self._release_diagnostics["stale_lease"] += 1 + return ReleaseResult(released=False, reason="stale_lease") + + self._remember_released(lease.lease_id) + for resource, amount in active.resources.items(): + self._leased_by_resource[resource] = max(0, self._leased_by_resource[resource] - amount) + self._leased_by_group[active.item.group.key][resource] = max( + 0, + self._leased_by_group[active.item.group.key][resource] - amount, + ) + self._running_by_group[active.item.group.key] = max(0, self._running_by_group[active.item.group.key] - 1) + self._apply_delta(self._policy.on_release(active)) + return ReleaseResult(released=True, reason="released") + + def view(self) -> TaskAdmissionView: + limits = dict(self._config.resource_limits) + leased = {resource: count for resource, count in self._leased_by_resource.items() if count > 0} + available = { + resource: max(0, limit - self._leased_by_resource.get(resource, 0)) for resource, limit in limits.items() + } + return TaskAdmissionView( + resource_limits=limits, + resources_available=available, + leased_resources=leased, + leased_resources_by_group={ + group: {resource: count for resource, count in counts.items() if count > 0} + for group, counts in self._leased_by_group.items() + }, + running_counts_by_group={group: count for group, count in self._running_by_group.items() if count > 0}, + policy_debt_by_group_resource={key: count for key, count in self._policy_debt.items() if count > 0}, + ) + + def explain_blocked(self, queue_view: QueueView) -> TaskAdmissionBlockSummary: + reasons: Counter[TaskAdmissionDenyReason] = Counter() + available_after_values: list[float] = [] + view = self.view() + for group_key, resources in queue_view.first_candidate_resources_by_group.items(): + for resource, amount in resources.items(): + if view.resources_available.get(resource, 0) < amount: + reasons["no_capacity"] += 1 + break + else: + group = queue_view.first_candidate_group_specs_by_group.get(group_key) + if group is None: + continue + task = SchedulableTask( + task_id=f"blocked-{group_key.kind}-{'-'.join(group_key.identity)}", + payload=Task(column="", row_group=-1, row_index=None, task_type="batch"), + group=group, + resource_request=SchedulerResourceRequest(dict(resources)), + ) + decision = self._policy.evaluate(task, queue_view, view) + if not decision.allowed: + reasons[decision.reason or "policy_denial"] += 1 + if decision.available_after is not None: + available_after_values.append(decision.available_after) + return TaskAdmissionBlockSummary( + queued_count=queue_view.queued_total, + dominant_denial_reasons=dict(reasons), + available_after=min(available_after_values) if available_after_values else None, + diagnostics={"snapshot": self.view()}, + ) + + def _missing_resources( + self, + item: SchedulableTask, + view: TaskAdmissionView, + ) -> dict[SchedulerResourceKey, dict[str, int]]: + missing: dict[SchedulerResourceKey, dict[str, int]] = {} + for resource, amount in item.resource_request.amounts.items(): + available = view.resources_available.get(resource, 0) + if available < amount: + missing[resource] = {"requested": amount, "available": available} + return missing + + def _apply_delta(self, delta: PolicyStateDelta) -> None: + for key, change in delta.debt_changes.items(): + self._policy_debt[key] = max(0, self._policy_debt[key] + change) + + def _remember_released(self, lease_id: str) -> None: + if lease_id in self._released: + return + maxlen = self._released_order.maxlen + if maxlen is not None and len(self._released_order) >= maxlen: + self._released.discard(self._released_order[0]) + self._released.add(lease_id) + self._released_order.append(lease_id) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py similarity index 100% rename from packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/task_model.py rename to packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_model.py diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_policies.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_policies.py new file mode 100644 index 000000000..011e4e703 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/scheduling/task_policies.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Literal, Protocol + +from data_designer.engine.dataset_builders.scheduling.queue import QueueView +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceKey, + TaskGroupKey, +) + +if TYPE_CHECKING: + from data_designer.engine.dataset_builders.scheduling.task_admission import ( + TaskAdmissionLease, + TaskAdmissionView, + ) + +TaskAdmissionDenyReason = Literal[ + "no_capacity", + "group_cap", + "borrow_debt", + "shutdown", + "policy_denial", +] + + +@dataclass(frozen=True) +class BoundedBorrowTaskAdmissionPolicyConfig: + """Engine-internal bounded-borrow policy configuration. + + Borrow debt is tracked by task group and scheduler resource. Any completed + lease in the same group repays debt for the released resources; repayment is + not tied to the specific lease that originally borrowed. + """ + + borrow_ceiling_by_group_resource: Mapping[tuple[TaskGroupKey, SchedulerResourceKey], int] = field( + default_factory=dict + ) + default_borrow_ceiling: int = 0 + strict_share_rounding: Literal["floor", "ceil"] = "floor" + repay_on_withheld_peer_pressure: bool = True + + +@dataclass(frozen=True) +class TaskAdmissionPolicyDecision: + allowed: bool + reason: TaskAdmissionDenyReason | None = None + available_after: float | None = None + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class PolicyStateDelta: + debt_changes: Mapping[tuple[TaskGroupKey, SchedulerResourceKey], int] = field(default_factory=dict) + diagnostic_counters: Mapping[str, int] = field(default_factory=dict) + + +class TaskAdmissionPolicy(Protocol): + def evaluate( + self, + item: SchedulableTask, + queue_view: QueueView, + admission_view: TaskAdmissionView, + ) -> TaskAdmissionPolicyDecision: ... + + def on_acquire( + self, + lease: TaskAdmissionLease, + decision: TaskAdmissionPolicyDecision, + ) -> PolicyStateDelta: ... + + def on_release(self, lease: TaskAdmissionLease) -> PolicyStateDelta: ... + + +class StrictFairTaskAdmissionPolicy: + """Behavior-preserving policy that enforces per-group admitted caps.""" + + def evaluate( + self, + item: SchedulableTask, + queue_view: QueueView, + admission_view: TaskAdmissionView, + ) -> TaskAdmissionPolicyDecision: + if item.group.admitted_limit is None: + return TaskAdmissionPolicyDecision(allowed=True) + leased_count = admission_view.running_counts_by_group.get(item.group.key, 0) + if leased_count < item.group.admitted_limit: + return TaskAdmissionPolicyDecision(allowed=True) + pressure_resources = _queued_peer_pressure_resources(item, queue_view, admission_view) + if not pressure_resources: + return TaskAdmissionPolicyDecision(allowed=True) + return TaskAdmissionPolicyDecision( + allowed=False, + reason="group_cap", + diagnostics={ + "admitted_limit": item.group.admitted_limit, + "leased_count": leased_count, + "pressure_resources": pressure_resources, + }, + ) + + def on_acquire( + self, + lease: TaskAdmissionLease, + decision: TaskAdmissionPolicyDecision, + ) -> PolicyStateDelta: + return PolicyStateDelta() + + def on_release(self, lease: TaskAdmissionLease) -> PolicyStateDelta: + return PolicyStateDelta() + + +class BoundedBorrowTaskAdmissionPolicy(StrictFairTaskAdmissionPolicy): + """Strict policy with optional bounded borrow debt over peer pressure.""" + + def __init__(self, config: BoundedBorrowTaskAdmissionPolicyConfig) -> None: + self._config = config + + def evaluate( + self, + item: SchedulableTask, + queue_view: QueueView, + admission_view: TaskAdmissionView, + ) -> TaskAdmissionPolicyDecision: + limit = item.group.admitted_limit + if limit is None: + return TaskAdmissionPolicyDecision(allowed=True) + + leased_count = admission_view.running_counts_by_group.get(item.group.key, 0) + if leased_count < limit: + return TaskAdmissionPolicyDecision(allowed=True) + + pressure_resources = _queued_peer_pressure_resources(item, queue_view, admission_view) + if pressure_resources: + for resource in pressure_resources: + debt_key = (item.group.key, resource) + debt = admission_view.policy_debt_by_group_resource.get(debt_key, 0) + if debt > 0: + return TaskAdmissionPolicyDecision( + allowed=False, + reason="borrow_debt", + diagnostics={"resource": resource, "debt": debt}, + ) + return TaskAdmissionPolicyDecision( + allowed=False, + reason="group_cap", + diagnostics={ + "admitted_limit": limit, + "leased_count": leased_count, + "pressure_resources": pressure_resources, + }, + ) + + borrow_resources: list[tuple[SchedulerResourceKey, int]] = [] + for resource, amount in item.resource_request.amounts.items(): + debt_key = (item.group.key, resource) + debt = admission_view.policy_debt_by_group_resource.get(debt_key, 0) + ceiling = self._config.borrow_ceiling_by_group_resource.get( + debt_key, + self._config.default_borrow_ceiling, + ) + if debt + amount > ceiling: + return TaskAdmissionPolicyDecision( + allowed=False, + reason="borrow_debt", + diagnostics={"resource": resource, "debt": debt, "requested": amount, "ceiling": ceiling}, + ) + borrow_resources.append((resource, amount)) + return TaskAdmissionPolicyDecision(allowed=True, diagnostics={"borrow_resources": tuple(borrow_resources)}) + + def on_acquire( + self, + lease: TaskAdmissionLease, + decision: TaskAdmissionPolicyDecision, + ) -> PolicyStateDelta: + borrow_resources = decision.diagnostics.get("borrow_resources") + if borrow_resources: + changes = { + (lease.item.group.key, resource): amount + for resource, amount in borrow_resources + if isinstance(resource, str) and isinstance(amount, int) + } + return PolicyStateDelta(debt_changes=changes) + return PolicyStateDelta() + + def on_release(self, lease: TaskAdmissionLease) -> PolicyStateDelta: + if not self._config.repay_on_withheld_peer_pressure: + return PolicyStateDelta() + # Borrow debt is group-level: any completed lease in the group repays it, clamped to zero by the controller. + return PolicyStateDelta( + debt_changes={(lease.item.group.key, resource): -amount for resource, amount in lease.resources.items()} + ) + + +def _queued_peer_pressure_resources( + item: SchedulableTask, + queue_view: QueueView, + admission_view: TaskAdmissionView, +) -> tuple[SchedulerResourceKey, ...]: + candidate_resources = _fair_pressure_resources(item.resource_request.amounts) + pressure_resources: list[SchedulerResourceKey] = [] + for group_key, peer_resources in queue_view.first_candidate_resources_by_group.items(): + if group_key == item.group.key: + continue + if not _is_hard_resource_eligible(peer_resources, admission_view): + continue + for resource in candidate_resources: + if peer_resources.get(resource, 0) > 0 and resource not in pressure_resources: + pressure_resources.append(resource) + return tuple(pressure_resources) + + +def _fair_pressure_resources( + resources: Mapping[SchedulerResourceKey, int], +) -> tuple[SchedulerResourceKey, ...]: + typed_resources = tuple(resource for resource in resources if resource != "submission") + if typed_resources: + return typed_resources + return tuple(resources) + + +def _is_hard_resource_eligible( + resources: Mapping[SchedulerResourceKey, int], + admission_view: TaskAdmissionView, +) -> bool: + return all(admission_view.resources_available.get(resource, 0) >= amount for resource, amount in resources.items()) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py index b090cf63d..29b7d99bc 100644 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py +++ b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/execution_graph.py @@ -15,8 +15,8 @@ DatasetBuilderColumnConfigT, MultiColumnConfig, ) +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError -from data_designer.engine.dataset_builders.utils.task_model import SliceRef from data_designer.logging import LOG_INDENT logger = logging.getLogger(__name__) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/fair_task_queue.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/fair_task_queue.py deleted file mode 100644 index 32301b767..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/fair_task_queue.py +++ /dev/null @@ -1,156 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import heapq -from collections import deque -from collections.abc import Callable -from dataclasses import dataclass -from typing import Literal - -from data_designer.engine.dataset_builders.utils.task_model import Task - - -@dataclass(frozen=True, order=True) -class TaskGroupKey: - """Stable identity for a stream of related scheduler tasks.""" - - kind: Literal["model", "custom_model", "local"] - identity: tuple[str, ...] - - -@dataclass(frozen=True) -class TaskGroupSpec: - """Scheduling metadata for a task group.""" - - key: TaskGroupKey - weight: float = 1.0 - admitted_limit: int | None = None - - -@dataclass(frozen=True) -class TaskSelection: - """A task selected for dispatch with the group metadata used to choose it.""" - - task: Task - group: TaskGroupSpec - - -class FairTaskQueue: - """Virtual-time fair queue with peer-sensitive per-group FIFO admission limits.""" - - def __init__(self) -> None: - self._queues: dict[TaskGroupKey, deque[Task]] = {} - self._queued: set[Task] = set() - self._task_groups: dict[Task, TaskGroupKey] = {} - self._group_specs: dict[TaskGroupKey, TaskGroupSpec] = {} - self._group_finish: dict[TaskGroupKey, float] = {} - self._admitted_by_group: dict[TaskGroupKey, int] = {} - self._admitted_task_groups: dict[Task, TaskGroupKey] = {} - self._heap: list[tuple[float, int, TaskGroupKey]] = [] - self._active_heap_keys: set[TaskGroupKey] = set() - self._sequence = 0 - self._virtual_time = 0.0 - - @property - def has_queued_tasks(self) -> bool: - return bool(self._queued) - - def enqueue(self, task: Task, group: TaskGroupSpec) -> None: - """Add one ready task to its fair scheduling group.""" - self._group_specs[group.key] = group - if task in self._queued: - return - queue = self._queues.setdefault(group.key, deque()) - queue.append(task) - self._queued.add(task) - self._task_groups[task] = group.key - self._activate_group(group.key) - - def discard(self, task: Task) -> None: - """Remove a queued task lazily if it is no longer dispatchable.""" - self._queued.discard(task) - self._task_groups.pop(task, None) - - def discard_where(self, predicate: Callable[[Task], bool]) -> None: - """Remove queued tasks matching a predicate.""" - for task in tuple(self._queued): - if predicate(task): - self.discard(task) - - def admit_next(self) -> TaskSelection | None: - """Admit the next eligible task, or ``None`` if no queued group can run.""" - blocked: list[TaskGroupKey] = [] - try: - while self._heap: - finish, _, key = heapq.heappop(self._heap) - self._active_heap_keys.discard(key) - self._purge_queue_head(key) - queue = self._queues.get(key) - if not queue: - continue - if not self._can_admit_group(key): - blocked.append(key) - continue - - task = queue.popleft() - self._queued.discard(task) - self._task_groups.pop(task, None) - self._admitted_task_groups[task] = key - self._admitted_by_group[key] = self._admitted_by_group.get(key, 0) + 1 - - group = self._group_specs[key] - self._virtual_time = max(self._virtual_time, finish) - self._group_finish[key] = self._virtual_time + (1.0 / max(group.weight, 1.0)) - self._purge_queue_head(key) - if queue: - self._activate_group(key) - return TaskSelection(task=task, group=group) - return None - finally: - for key in blocked: - self._activate_group(key) - - def release(self, task: Task) -> None: - """Release one previously admitted task from its group limit.""" - key = self._admitted_task_groups.pop(task, None) - if key is None: - return - admitted = self._admitted_by_group.get(key, 0) - if admitted <= 1: - self._admitted_by_group.pop(key, None) - else: - self._admitted_by_group[key] = admitted - 1 - self._activate_group(key) - - def _activate_group(self, key: TaskGroupKey) -> None: - self._purge_queue_head(key) - queue = self._queues.get(key) - if not queue or key in self._active_heap_keys: - return - self._sequence += 1 - finish = self._group_finish.get(key, self._virtual_time) - heapq.heappush(self._heap, (finish, self._sequence, key)) - self._active_heap_keys.add(key) - - def _purge_queue_head(self, key: TaskGroupKey) -> None: - queue = self._queues.get(key) - if queue is None: - return - while queue: - task = queue[0] - if task in self._queued and self._task_groups.get(task) == key: - break - queue.popleft() - - def _can_admit_group(self, key: TaskGroupKey) -> bool: - group = self._group_specs[key] - if group.admitted_limit is None: - return True - if self._admitted_by_group.get(key, 0) < group.admitted_limit: - return True - return not self._has_queued_peer_group(key) - - def _has_queued_peer_group(self, key: TaskGroupKey) -> bool: - return any(queued_key != key for queued_key in self._task_groups.values()) diff --git a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/scheduling_hints.py b/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/scheduling_hints.py deleted file mode 100644 index dea66eeda..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/dataset_builders/utils/scheduling_hints.py +++ /dev/null @@ -1,123 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import logging -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Literal - -if TYPE_CHECKING: - from data_designer.engine.column_generators.generators.base import ColumnGenerator - -logger = logging.getLogger(__name__) - -SchedulingGroupKind = Literal["local", "model", "custom_model"] - - -@dataclass(frozen=True) -class SchedulingHint: - """Resolved task-scheduling metadata independent of graph flow identity.""" - - group_kind: SchedulingGroupKind - identity_prefix: tuple[str, ...] = () - identity_suffix: tuple[str, ...] = () - weight: int = 1 - - -class SchedulingHintResolver: - """Resolve generator/config/model metadata once for a scheduler run.""" - - def __init__(self, generators: dict[str, ColumnGenerator]) -> None: - self._hints_by_generator_id: dict[int, SchedulingHint] = {} - for column, generator in generators.items(): - generator_id = id(generator) - if generator_id not in self._hints_by_generator_id: - self._hints_by_generator_id[generator_id] = self._resolve_hint(column, generator) - - def hint_for(self, generator: ColumnGenerator) -> SchedulingHint: - return self._hints_by_generator_id[id(generator)] - - def _resolve_hint(self, column: str, generator: ColumnGenerator) -> SchedulingHint: - if not generator.is_llm_bound: - return SchedulingHint(group_kind="local") - - aliases = _model_aliases_for_generator(generator) - if not aliases: - return SchedulingHint(group_kind="model", identity_prefix=("unknown",), weight=1) - - model_parts: list[str] = [] - total_parallel = 0 - primary_alias = getattr(generator.config, "model_alias", None) - for alias in aliases: - try: - model_config = _get_model_config_for_alias(generator, alias) - provider_name = _get_model_provider_name_for_alias(generator, alias) - except Exception: - logger.debug( - "Falling back to custom-model scheduling group for column %r after failing to resolve " - "model alias %r from aliases %r.", - column, - alias, - aliases, - exc_info=True, - ) - return SchedulingHint( - group_kind="custom_model", - identity_suffix=tuple(sorted(aliases)), - weight=max(1, total_parallel), - ) - - max_parallel = getattr(model_config.inference_parameters, "max_parallel_requests", 1) - if not isinstance(max_parallel, int): - max_parallel = 1 - model_parts.extend( - ( - provider_name, - str(model_config.model), - str(model_config.generation_type), - alias, - ) - ) - total_parallel += max_parallel - - weight = max(1, total_parallel) - if len(aliases) == 1 and primary_alias == aliases[0]: - return SchedulingHint( - group_kind="model", - identity_prefix=tuple(model_parts[:3]), - weight=weight, - ) - - return SchedulingHint( - group_kind="custom_model", - identity_suffix=tuple(sorted(aliases)), - weight=weight, - ) - - -def _get_model_config_for_alias(generator: ColumnGenerator, alias: str) -> Any: - get_model_config = getattr(generator, "get_model_config", None) - if callable(get_model_config): - return get_model_config(model_alias=alias) - return generator.resource_provider.model_registry.get_model_config(model_alias=alias) - - -def _get_model_provider_name_for_alias(generator: ColumnGenerator, alias: str) -> str: - get_provider_name = getattr(generator, "get_model_provider_name", None) - if callable(get_provider_name): - return str(get_provider_name(model_alias=alias)) - provider = generator.resource_provider.model_registry.get_model_provider(model_alias=alias) - return str(provider.name) - - -def _model_aliases_for_generator(generator: ColumnGenerator) -> list[str]: - get_aliases = getattr(generator.config, "get_model_aliases", None) - if callable(get_aliases): - aliases = get_aliases() - else: - aliases = [] - if (alias := getattr(generator.config, "model_alias", None)) is not None: - aliases.append(alias) - aliases.extend(getattr(generator.config, "model_aliases", []) or []) - return list(dict.fromkeys(alias for alias in aliases if alias)) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py index 0e7d7907a..059c52b11 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/__init__.py @@ -12,9 +12,8 @@ map_http_status_to_provider_error_kind, ) from data_designer.engine.models.clients.factory import create_model_client +from data_designer.engine.models.clients.model_request_executor import ModelRequestExecutor from data_designer.engine.models.clients.retry import RetryConfig -from data_designer.engine.models.clients.throttle_manager import ThrottleDomain, ThrottleManager -from data_designer.engine.models.clients.throttled import ThrottledModelClient from data_designer.engine.models.clients.types import ( AssistantMessage, ChatCompletionChoice, @@ -42,13 +41,11 @@ "ImageGenerationResponse", "ImagePayload", "ModelClient", + "ModelRequestExecutor", "OpenAICompatibleClient", "ProviderError", "ProviderErrorKind", "RetryConfig", - "ThrottleDomain", - "ThrottleManager", - "ThrottledModelClient", "ToolCall", "Usage", "create_model_client", diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py index 17abd8a88..204b46677 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/anthropic.py @@ -31,8 +31,8 @@ class AnthropicClient(HttpModelClient): """Native HTTP adapter for the Anthropic Messages API. Uses ``httpx`` with ``httpx_retries.RetryTransport`` for resilient HTTP - calls. Concurrency / throttle policy is an orchestration concern and - is not managed here — see ``ThrottleManager`` and ``AsyncTaskScheduler``. + calls. Concurrency and request-admission policy are orchestration concerns + and are not managed here. """ _ROUTE_MESSAGES = "/messages" diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py index 54f01961b..44ab1f1d5 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/adapters/openai_compatible.py @@ -33,8 +33,8 @@ class OpenAICompatibleClient(HttpModelClient): """Native HTTP adapter for OpenAI-compatible provider APIs. Uses ``httpx`` with ``httpx_retries.RetryTransport`` for resilient HTTP - calls. Concurrency / throttle policy is an orchestration concern and - is not managed here — see ``ThrottleManager`` and ``AsyncTaskScheduler``. + calls. Concurrency and request-admission policy are orchestration concerns + and are not managed here. """ _ROUTE_CHAT = "/chat/completions" diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py index 458ebfcad..398d151a4 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/factory.py @@ -10,13 +10,15 @@ from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.clients.model_request_executor import ModelRequestExecutor from data_designer.engine.models.clients.retry import RetryConfig -from data_designer.engine.models.clients.throttle_manager import ThrottleManager -from data_designer.engine.models.clients.throttled import ThrottledModelClient from data_designer.engine.models.errors import FormattedLLMErrorMessage +from data_designer.engine.models.request_admission.controller import RequestAdmissionController +from data_designer.engine.observability import RequestAdmissionEventSink from data_designer.engine.secret_resolver import SecretResolver _SUPPORTED_PROVIDER_TYPES = ("openai", "anthropic") +_NO_TRANSPORT_RETRY_CONFIG = RetryConfig(max_retries=0, retryable_status_codes=frozenset()) def create_model_client( @@ -26,7 +28,8 @@ def create_model_client( *, retry_config: RetryConfig | None = None, client_concurrency_mode: ClientConcurrencyMode = ClientConcurrencyMode.SYNC, - throttle_manager: ThrottleManager | None = None, + request_admission: RequestAdmissionController | None = None, + request_event_sink: RequestAdmissionEventSink | None = None, ) -> ModelClient: """Create a ``ModelClient`` for the given model configuration. @@ -40,12 +43,12 @@ def create_model_client( client_concurrency_mode: ``"sync"`` (default) for the sync engine path, ``"async"`` for the async engine path. Native HTTP adapters are constrained to a single concurrency mode. - throttle_manager: Optional throttle manager for per-request AIMD - concurrency control. When provided, the returned client is wrapped - with ``ThrottledModelClient``. + request_admission: Optional request-admission controller for per-request + provider/model/domain admission. When provided, the returned client + is wrapped with ``ModelRequestExecutor``. **Ordering invariant:** the ``(provider_name, model_id)`` pair must - be registered on the ``ThrottleManager`` via ``register()`` before + be registered on the request-admission controller via ``register()`` before the returned client makes its first request. In the standard flow, ``ModelRegistry._get_model()`` calls ``register()`` during model setup, which happens before any generation task invokes the client. @@ -69,13 +72,14 @@ def create_model_client( max_parallel = model_config.inference_parameters.max_parallel_requests raw_timeout = model_config.inference_parameters.timeout timeout_s = float(raw_timeout if raw_timeout is not None else 60) + adapter_retry_config = _NO_TRANSPORT_RETRY_CONFIG if request_admission is not None else retry_config if provider.provider_type == "openai": client: ModelClient = OpenAICompatibleClient( provider_name=provider.name, endpoint=provider.endpoint, api_key=api_key, - retry_config=retry_config, + retry_config=adapter_retry_config, max_parallel_requests=max_parallel, timeout_s=timeout_s, concurrency_mode=client_concurrency_mode, @@ -85,7 +89,7 @@ def create_model_client( provider_name=provider.name, endpoint=provider.endpoint, api_key=api_key, - retry_config=retry_config, + retry_config=adapter_retry_config, max_parallel_requests=max_parallel, timeout_s=timeout_s, concurrency_mode=client_concurrency_mode, @@ -102,12 +106,14 @@ def create_model_client( ) ) - if throttle_manager is not None: - client = ThrottledModelClient( + if request_admission is not None: + client = ModelRequestExecutor( inner=client, - throttle_manager=throttle_manager, + request_admission=request_admission, provider_name=provider.name, model_id=model_config.model, + event_sink=request_event_sink, + retry_config=retry_config, ) return client diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/model_request_executor.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/model_request_executor.py new file mode 100644 index 000000000..721afa41a --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/model_request_executor.py @@ -0,0 +1,304 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging +import time +import uuid +from typing import TYPE_CHECKING, TypeVar + +from data_designer.engine.models.clients.base import ModelClient +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind +from data_designer.engine.models.clients.retry import RetryConfig +from data_designer.engine.models.clients.types import ( + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + ImageGenerationResponse, +) +from data_designer.engine.models.request_admission.controller import ( + RequestAdmissionController, + RequestAdmissionError, + RequestAdmissionLease, +) +from data_designer.engine.models.request_admission.outcomes import RequestReleaseOutcome +from data_designer.engine.models.request_admission.resolver import RequestResourceResolver +from data_designer.engine.models.request_admission.resources import ( + RequestAdmissionItem, + RequestDomain, + RequestEventContext, + RequestGroupSpec, +) +from data_designer.engine.observability import ( + RequestAdmissionEvent, + RequestAdmissionEventSink, + runtime_correlation_provider, +) + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + +_T = TypeVar("_T") + +logger = logging.getLogger(__name__) + + +class ModelRequestExecutor(ModelClient): + """Model-call boundary that acquires/releases request-admission leases.""" + + def __init__( + self, + inner: ModelClient, + request_admission: RequestAdmissionController, + provider_name: str, + model_id: str, + event_sink: RequestAdmissionEventSink | None = None, + resource_resolver: RequestResourceResolver | None = None, + retry_config: RetryConfig | None = None, + ) -> None: + self._inner = inner + self._request_admission = request_admission + self._provider_name = provider_name + self._model_id = model_id + self._event_sink = event_sink + self._resource_resolver = resource_resolver or RequestResourceResolver() + self._retry_config = retry_config or RetryConfig() + self._event_sequence = 0 + + @property + def provider_name(self) -> str: + return self._inner.provider_name + + def supports_chat_completion(self) -> bool: + return self._inner.supports_chat_completion() + + def supports_embeddings(self) -> bool: + return self._inner.supports_embeddings() + + def supports_image_generation(self) -> bool: + return self._inner.supports_image_generation() + + def close(self) -> None: + self._inner.close() + + async def aclose(self) -> None: + await self._inner.aclose() + + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + return self._execute_sync(RequestDomain.CHAT, lambda: self._inner.completion(request)) + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + return await self._execute_async(RequestDomain.CHAT, lambda: self._inner.acompletion(request)) + + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + return self._execute_sync(RequestDomain.EMBEDDING, lambda: self._inner.embeddings(request)) + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + return await self._execute_async(RequestDomain.EMBEDDING, lambda: self._inner.aembeddings(request)) + + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + return self._execute_sync(self._image_domain(request), lambda: self._inner.generate_image(request)) + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + return await self._execute_async(self._image_domain(request), lambda: self._inner.agenerate_image(request)) + + def _execute_sync(self, domain: RequestDomain, call: Callable[[], _T]) -> _T: + for attempt in range(self._max_attempts()): + try: + return self._execute_sync_attempt(domain, call) + except ProviderError as exc: + if not self._should_retry(exc, attempt): + raise + self._sleep_before_retry(attempt) + raise RuntimeError("unreachable request retry state") + + def _execute_sync_attempt(self, domain: RequestDomain, call: Callable[[], _T]) -> _T: + item = self._item(domain) + try: + lease = self._request_admission.acquire_sync(item) + except RequestAdmissionError as exc: + raise ProviderError( + kind=ProviderErrorKind.TIMEOUT, + message=str(exc), + provider_name=self._provider_name, + model_name=self._model_id, + ) from exc + try: + self._emit_model_event("model_request_started", item=item, lease=lease) + result = call() + except ProviderError as exc: + self._release_provider_error(lease, exc) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": exc.kind.value} + ) + raise + except TimeoutError: + self._request_admission.release(lease, RequestReleaseOutcome(kind="provider_timeout")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "provider_timeout"} + ) + raise + except BaseException as exc: + outcome = "local_cancelled" if isinstance(exc, KeyboardInterrupt) else "unexpected_exception" + self._request_admission.release(lease, RequestReleaseOutcome(kind=outcome)) + self._emit_model_event("model_request_completed", item=item, lease=lease, diagnostics={"outcome": outcome}) + raise + else: + self._request_admission.release(lease, RequestReleaseOutcome(kind="success")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "success"} + ) + return result + + async def _execute_async(self, domain: RequestDomain, call: Callable[[], Awaitable[_T]]) -> _T: + for attempt in range(self._max_attempts()): + try: + return await self._execute_async_attempt(domain, call) + except ProviderError as exc: + if not self._should_retry(exc, attempt): + raise + await self._async_sleep_before_retry(attempt) + raise RuntimeError("unreachable request retry state") + + async def _execute_async_attempt(self, domain: RequestDomain, call: Callable[[], Awaitable[_T]]) -> _T: + item = self._item(domain) + try: + lease = await self._request_admission.acquire_async(item) + except RequestAdmissionError as exc: + raise ProviderError( + kind=ProviderErrorKind.TIMEOUT, + message=str(exc), + provider_name=self._provider_name, + model_name=self._model_id, + ) from exc + except asyncio.CancelledError: + raise + try: + self._emit_model_event("model_request_started", item=item, lease=lease) + result = await call() + except asyncio.CancelledError: + self._request_admission.release(lease, RequestReleaseOutcome(kind="local_cancelled")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "local_cancelled"} + ) + raise + except ProviderError as exc: + self._release_provider_error(lease, exc) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": exc.kind.value} + ) + raise + except TimeoutError: + self._request_admission.release(lease, RequestReleaseOutcome(kind="provider_timeout")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "provider_timeout"} + ) + raise + except BaseException as exc: + outcome = "local_cancelled" if isinstance(exc, KeyboardInterrupt) else "unexpected_exception" + self._request_admission.release(lease, RequestReleaseOutcome(kind=outcome)) + self._emit_model_event("model_request_completed", item=item, lease=lease, diagnostics={"outcome": outcome}) + raise + else: + self._request_admission.release(lease, RequestReleaseOutcome(kind="success")) + self._emit_model_event( + "model_request_completed", item=item, lease=lease, diagnostics={"outcome": "success"} + ) + return result + + def _max_attempts(self) -> int: + return max(1, self._retry_config.max_retries + 1) + + def _should_retry(self, exc: ProviderError, attempt: int) -> bool: + if attempt >= self._max_attempts() - 1: + return False + if isinstance(exc.__cause__, RequestAdmissionError): + return False + if exc.kind == ProviderErrorKind.RATE_LIMIT: + return False + if exc.status_code is not None: + return exc.status_code in self._retry_config.retryable_status_codes + return exc.kind == ProviderErrorKind.API_CONNECTION + + def _sleep_before_retry(self, attempt: int) -> None: + delay = self._retry_delay_seconds(attempt) + if delay > 0.0: + time.sleep(delay) + + async def _async_sleep_before_retry(self, attempt: int) -> None: + delay = self._retry_delay_seconds(attempt) + if delay > 0.0: + await asyncio.sleep(delay) + + def _retry_delay_seconds(self, attempt: int) -> float: + if self._retry_config.backoff_factor <= 0.0: + return 0.0 + delay = self._retry_config.backoff_factor * (2**attempt) + return min(delay, self._retry_config.max_backoff_wait) + + def _release_provider_error(self, lease: RequestAdmissionLease, exc: ProviderError) -> None: + if exc.kind == ProviderErrorKind.RATE_LIMIT: + outcome = RequestReleaseOutcome(kind="rate_limited", retry_after_seconds=exc.retry_after) + elif exc.kind == ProviderErrorKind.TIMEOUT: + outcome = RequestReleaseOutcome(kind="provider_timeout") + else: + outcome = RequestReleaseOutcome(kind="provider_failure") + self._request_admission.release(lease, outcome) + + def _item(self, domain: RequestDomain) -> RequestAdmissionItem: + resolved = self._resource_resolver.resolve( + provider_name=self._provider_name, + model_id=self._model_id, + domain=domain, + ) + resource = resolved.resource + correlation = runtime_correlation_provider.current() + return RequestAdmissionItem( + resource=resource, + group=RequestGroupSpec(key=resource), + event_context=RequestEventContext( + captured_correlation=correlation, + task_execution_id=correlation.task_execution_id if correlation is not None else None, + request_attempt_id=f"request-{uuid.uuid4().hex}", + ), + ) + + @staticmethod + def _image_domain(request: ImageGenerationRequest) -> RequestDomain: + return RequestDomain.CHAT if request.messages is not None else RequestDomain.IMAGE + + def _emit_model_event( + self, + event_kind: str, + *, + item: RequestAdmissionItem, + lease: RequestAdmissionLease, + diagnostics: dict[str, object] | None = None, + ) -> None: + if self._event_sink is None: + return + self._event_sequence += 1 + context = item.event_context + try: + self._event_sink.emit_request_event( + RequestAdmissionEvent.capture( + event_kind, # type: ignore[arg-type] + sequence=self._event_sequence, + correlation=context.captured_correlation + if context is not None + else runtime_correlation_provider.current(), + request_attempt_id=context.request_attempt_id if context is not None else None, + request_lease_id=lease.lease_id, + request_resource_key=item.resource, + request_group_key=item.group.key, + pressure_snapshot=self._request_admission.pressure.snapshot(item.resource), + diagnostics=diagnostics or {}, + ) + ) + except Exception: + logger.warning("Model request event sink raised; dropping event.", exc_info=True) + return diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py index 56aa1eec4..9f51a48b2 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/clients/retry.py @@ -14,8 +14,8 @@ logger = logging.getLogger(__name__) -# 429 must not be retried at the transport layer so that rate-limit signals -# propagate to ThrottledModelClient for AIMD backoff. +# 429 must not be retried at the transport layer so rate-limit signals +# propagate to ModelRequestExecutor and request admission for AIMD backoff. _RESERVED_STATUS_CODES: frozenset[int] = frozenset({429}) @@ -25,7 +25,7 @@ class RetryConfig: Retries non-rate-limit transient failures (``502``, ``503``, ``504``) and connection/transport errors. ``429`` is intentionally excluded so that - rate-limit signals reach the ``ThrottledModelClient`` wrapper for AIMD + rate-limit signals reach the ``ModelRequestExecutor`` boundary for AIMD backoff. If a caller includes ``429`` in ``retryable_status_codes``, ``create_retry_transport`` will strip it and log a warning. """ @@ -52,10 +52,8 @@ def create_retry_transport( config: Retry policy. Uses ``RetryConfig()`` defaults when ``None``. strip_rate_limit_codes: When ``True`` (default, used by the async engine), status codes in ``_RESERVED_STATUS_CODES`` (currently ``{429}``) are - stripped so that rate-limit responses reach the ``ThrottledModelClient`` - AIMD feedback loop. When ``False`` (used by the sync engine, which has - no salvage queue), 429 is kept in the retry list so the transport layer - retries it transparently. + stripped so that rate-limit responses reach the request-admission + AIMD feedback loop. transport: Optional pre-configured transport to pass directly to ``RetryTransport``. Pass ``httpx.HTTPTransport`` for sync clients or ``httpx.AsyncHTTPTransport`` for async clients — typically with a custom @@ -70,7 +68,7 @@ def create_retry_transport( if reserved_overlap: logger.warning( "Stripping reserved status codes %s from retryable_status_codes; " - "these must reach ThrottledModelClient for AIMD backoff.", + "these must reach ModelRequestExecutor/request admission for AIMD backoff.", sorted(reserved_overlap), ) status_codes = status_codes - _RESERVED_STATUS_CODES diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle_manager.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle_manager.py deleted file mode 100644 index 4b8ae28a7..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttle_manager.py +++ /dev/null @@ -1,555 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import asyncio -import logging -import math -import threading -import time -from dataclasses import dataclass, field -from enum import Enum - -from data_designer.config.run_config import ThrottleConfig - -logger = logging.getLogger(__name__) - - -class ThrottleDomain(str, Enum): - CHAT = "chat" - EMBEDDING = "embedding" - IMAGE = "image" - HEALTHCHECK = "healthcheck" - - -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- - -DEFAULT_MIN_LIMIT: int = 1 -CAPACITY_POLL_INTERVAL: float = 0.05 - - -# --------------------------------------------------------------------------- -# Internal state containers -# --------------------------------------------------------------------------- - - -@dataclass -class DomainThrottleState: - """Per-domain AIMD concurrency state. - - All mutations must be performed while holding the owning - ``ThrottleManager._lock``. - """ - - current_limit: int - in_flight: int = 0 - blocked_until: float = 0.0 - success_streak: int = 0 - waiters: int = 0 - rate_limit_ceiling: int = 0 - consecutive_429s: int = 0 - rampup_started_at: float = 0.0 - rampup_active: bool = False - - -@dataclass -class GlobalCapState: - """Tracks the effective hard cap across aliases sharing a provider+model.""" - - limits_by_alias: dict[str, int] = field(default_factory=dict) - effective_max: int = 0 - - def register_alias(self, alias: str, max_parallel: int) -> None: - self.limits_by_alias[alias] = max_parallel - self.effective_max = min(self.limits_by_alias.values()) - - -# --------------------------------------------------------------------------- -# ThrottleManager -# --------------------------------------------------------------------------- - - -class ThrottleManager: - """Adaptive concurrency manager using AIMD (Additive Increase / - Multiplicative Decrease). - - Keyed at two levels: - - - **Global cap**: ``(provider_name, model_id)`` — shared hard ceiling. - - **Domain**: ``(provider_name, model_id, throttle_domain)`` — per-route - AIMD state that floats between 1 and the global effective max. - - **AIMD behaviour**: - - - *Decrease* — on a 429 / rate-limit signal the domain's concurrency limit - is multiplied by ``reduce_factor`` (default 0.75, i.e. reduced by 25%) - and a cooldown block is applied for ``retry_after`` seconds (or - ``default_cooldown_seconds``). - - *Increase* — after every ``success_window`` consecutive successful - releases the limit grows by ``additive_increase`` (default 1), up to - the *rate-limit ceiling* (or the global effective max if no 429 has - been observed yet). - - *Startup ramp* — when ``rampup_seconds`` is greater than zero, each new - domain starts at one concurrent request and linearly ramps to the global - effective max over that duration. The first 429 aborts the ramp and the - domain continues with regular AIMD decrease/recovery. - - *Stabilization* — each 429 records the pre-halving limit as - ``rate_limit_ceiling``. Subsequent additive increases stop at - ``ceiling * (1 + ceiling_overshoot)`` (default 10%) instead of - climbing all the way to ``effective_max``. The overshoot band lets - the system probe whether the endpoint can now handle more traffic - (e.g. after load drops) while dampening the sawtooth. If the probe - succeeds, the ceiling ratchets up; if it triggers another 429, the - ceiling lowers. - - Thread-safe: all state mutations are guarded by a single lock so that - sync and async callers co-throttle correctly. - """ - - def __init__( - self, - config: ThrottleConfig | None = None, - ) -> None: - tc = config or ThrottleConfig() - self._reduce_factor = tc.reduce_factor - self._additive_increase = tc.additive_increase - self._success_window = tc.success_window - self._default_cooldown_seconds = tc.cooldown_seconds - self._ceiling_overshoot = tc.ceiling_overshoot - self._rampup_seconds = tc.rampup_seconds - self._lock = threading.Lock() - self._global_caps: dict[tuple[str, str], GlobalCapState] = {} - self._domains: dict[tuple[str, str, str], DomainThrottleState] = {} - - # ------------------------------------------------------------------- - # Registration - # ------------------------------------------------------------------- - - def register( - self, - *, - provider_name: str, - model_id: str, - alias: str, - max_parallel_requests: int, - ) -> None: - """Register a model alias and its concurrency limit. - - If multiple aliases share the same ``(provider_name, model_id)`` the - effective max is ``min()`` of all registered limits. Existing domain - states are clamped to the new effective max. - - **Ordering invariant:** ``register()`` must be called for a - ``(provider_name, model_id)`` pair *before* any ``try_acquire()`` for - the same key. If ``try_acquire()`` runs first it creates a domain at - ``DEFAULT_MIN_LIMIT`` and ``_clamp_domains`` only *decreases* limits, - so a later ``register()`` will not raise the domain to the intended - capacity. - """ - with self._lock: - global_key = (provider_name, model_id) - cap = self._global_caps.setdefault(global_key, GlobalCapState()) - cap.register_alias(alias, max_parallel_requests) - self._clamp_domains(global_key, cap.effective_max) - logger.debug( - "Throttle registered alias=%r for %s/%s (max_parallel=%d, effective_max=%d)", - alias, - provider_name, - model_id, - max_parallel_requests, - cap.effective_max, - ) - - # ------------------------------------------------------------------- - # Core non-blocking primitives - # ------------------------------------------------------------------- - - def is_registered(self, provider_name: str, model_id: str) -> bool: - """Return ``True`` if ``register()`` has been called for this key.""" - with self._lock: - return (provider_name, model_id) in self._global_caps - - def try_acquire( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - now: float | None = None, - ) -> float: - """Attempt to acquire a concurrency slot. - - Returns ``0.0`` if the slot was acquired, otherwise the number of - seconds the caller should wait before retrying. - - Raises ``RuntimeError`` if the ``(provider_name, model_id)`` pair - has not been registered via ``register()``. - """ - now = now if now is not None else time.monotonic() - with self._lock: - if (provider_name, model_id) not in self._global_caps: - raise RuntimeError( - f"ThrottleManager.try_acquire() called before register() " - f"for ({provider_name!r}, {model_id!r}). " - f"Call register() first to set the concurrency limit." - ) - state = self._get_or_create_domain(provider_name, model_id, domain, now=now) - self._apply_startup_ramp(state, self._effective_max_for(provider_name, model_id), now) - if now < state.blocked_until: - return state.blocked_until - now - if state.in_flight >= state.current_limit: - return CAPACITY_POLL_INTERVAL - state.in_flight += 1 - return 0.0 - - def release_success( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - now: float | None = None, - ) -> None: - now = now if now is not None else time.monotonic() - with self._lock: - state = self._get_or_create_domain(provider_name, model_id, domain, now=now) - state.in_flight = max(0, state.in_flight - 1) - state.consecutive_429s = 0 - effective_max = self._effective_max_for(provider_name, model_id) - self._apply_startup_ramp(state, effective_max, now) - if state.rampup_active: - state.success_streak = 0 - return - state.success_streak += 1 - if state.success_streak >= self._success_window: - cap = self._compute_soft_ceiling(state, effective_max) - if state.current_limit < cap: - prev = state.current_limit - state.current_limit = min(state.current_limit + self._additive_increase, cap) - if state.current_limit >= cap: - if cap < effective_max: - logger.info( - "🔋✅ '%s' [%s] concurrency recovered to %d parallel requests", - model_id, - domain.value, - state.current_limit, - ) - else: - logger.info( - "🔋✅ '%s' [%s] concurrency fully recovered (%d parallel requests)", - model_id, - domain.value, - state.current_limit, - ) - else: - logger.info( - "🪫📈🔥 '%s' [%s] concurrency increased from %d → %d", - model_id, - domain.value, - prev, - state.current_limit, - ) - state.success_streak = 0 - - def release_rate_limited( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - retry_after: float | None = None, - now: float | None = None, - ) -> None: - now = now if now is not None else time.monotonic() - with self._lock: - state = self._get_or_create_domain(provider_name, model_id, domain, now=now) - state.in_flight = max(0, state.in_flight - 1) - state.rampup_active = False - prev_limit = state.current_limit - is_first_in_cascade = state.consecutive_429s == 0 - state.consecutive_429s += 1 - cooldown_duration = ( - retry_after if retry_after is not None and retry_after > 0 else self._default_cooldown_seconds - ) - state.blocked_until = now + cooldown_duration - state.success_streak = 0 - - if is_first_in_cascade: - state.current_limit = max(DEFAULT_MIN_LIMIT, math.floor(state.current_limit * self._reduce_factor)) - if state.current_limit < prev_limit: - if state.rate_limit_ceiling == 0: - state.rate_limit_ceiling = prev_limit - else: - state.rate_limit_ceiling = min(state.rate_limit_ceiling, prev_limit) - if state.rate_limit_ceiling < prev_limit: - logger.info( - "🪫📉 '%s' [%s] server rate-limited at %d (server limit ~%d) — concurrency reduced to %d (retrying in %.0fs)", - model_id, - domain.value, - prev_limit, - state.rate_limit_ceiling, - state.current_limit, - cooldown_duration, - ) - else: - logger.info( - "🪫📉 '%s' [%s] server rate-limited — concurrency reduced from %d → %d (retrying in %.0fs)", - model_id, - domain.value, - prev_limit, - state.current_limit, - cooldown_duration, - ) - else: - logger.info( - "🪫📉 '%s' [%s] server rate-limited at minimum concurrency %d (retrying in %.0fs)", - model_id, - domain.value, - state.current_limit, - cooldown_duration, - ) - else: - logger.debug( - "Throttle %s [%s] cascade 429 #%d (limit held at %d)", - model_id, - domain.value, - state.consecutive_429s, - state.current_limit, - ) - - def release_failure( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - now: float | None = None, - ) -> None: - now = now if now is not None else time.monotonic() - with self._lock: - state = self._get_or_create_domain(provider_name, model_id, domain, now=now) - state.in_flight = max(0, state.in_flight - 1) - # Non-rate-limit failure breaks the 429 cascade: a sequence like - # 429 → 500 → 429 should treat the second 429 as the start of a - # new cascade. But only after the prior burst has fully drained - # (in_flight == 0) - otherwise mixed responses from a single - # in-flight wave (429 → 500 → 429 with concurrent slots) would - # double-reduce the limit even though the provider hasn't - # recovered between the two 429s. - if state.in_flight == 0: - state.consecutive_429s = 0 - - # ------------------------------------------------------------------- - # Sync / async wrappers - # ------------------------------------------------------------------- - - def acquire_sync( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - timeout: float | None = None, - ) -> None: - """Block until a permit is available. - - ``timeout=None`` (the default) waits indefinitely; the per-request HTTP - timeout (``inference_parameters.timeout``) is the only deadline that bounds - actual work, and queue waits scale naturally with provider speed and - AIMD's adaptive concurrency. Pass an explicit float for tests or for - support cases where a queue-wait deadline is genuinely desired. - """ - now = time.monotonic() - deadline = (now + timeout) if timeout is not None else None - wait = self.try_acquire(provider_name=provider_name, model_id=model_id, domain=domain, now=now) - if wait == 0.0: - return - with self._lock: - # state is captured once and reused in the finally block; safe - # because DomainThrottleState objects are never replaced after creation. - state = self._get_or_create_domain(provider_name, model_id, domain, now=now) - state.waiters += 1 - if state.waiters == 1: - logger.debug( - "Throttle %s/%s [%s] queue forming (in_flight=%d/%d)", - provider_name, - model_id, - domain.value, - state.in_flight, - state.current_limit, - ) - try: - while True: - if deadline is not None: - remaining = deadline - time.monotonic() - if remaining <= 0 or wait > remaining: - raise TimeoutError( - f"Throttle acquire timed out after {timeout:.0f}s " - f"for {provider_name}/{model_id} [{domain.value}]" - ) - sleep_for = min(wait, remaining) - else: - sleep_for = wait - time.sleep(sleep_for) - now = time.monotonic() - wait = self.try_acquire(provider_name=provider_name, model_id=model_id, domain=domain, now=now) - if wait == 0.0: - return - finally: - with self._lock: - state.waiters -= 1 - if state.waiters == 0: - logger.debug( - "Throttle %s/%s [%s] queue drained", - provider_name, - model_id, - domain.value, - ) - - async def acquire_async( - self, - *, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - timeout: float | None = None, - ) -> None: - """Block until a permit is available. - - ``timeout=None`` (the default) waits indefinitely; the per-request HTTP - timeout (``inference_parameters.timeout``) is the only deadline that bounds - actual work, and queue waits scale naturally with provider speed and - AIMD's adaptive concurrency. Pass an explicit float for tests or for - support cases where a queue-wait deadline is genuinely desired. - """ - now = time.monotonic() - deadline = (now + timeout) if timeout is not None else None - wait = self.try_acquire(provider_name=provider_name, model_id=model_id, domain=domain, now=now) - if wait == 0.0: - return - with self._lock: - # state is captured once and reused in the finally block; safe - # because DomainThrottleState objects are never replaced after creation. - state = self._get_or_create_domain(provider_name, model_id, domain, now=now) - state.waiters += 1 - if state.waiters == 1: - logger.debug( - "Throttle %s/%s [%s] queue forming (in_flight=%d/%d)", - provider_name, - model_id, - domain.value, - state.in_flight, - state.current_limit, - ) - try: - while True: - if deadline is not None: - remaining = deadline - time.monotonic() - if remaining <= 0 or wait > remaining: - raise TimeoutError( - f"Throttle acquire timed out after {timeout:.0f}s " - f"for {provider_name}/{model_id} [{domain.value}]" - ) - sleep_for = min(wait, remaining) - else: - sleep_for = wait - await asyncio.sleep(sleep_for) - now = time.monotonic() - wait = self.try_acquire(provider_name=provider_name, model_id=model_id, domain=domain, now=now) - if wait == 0.0: - return - finally: - with self._lock: - state.waiters -= 1 - if state.waiters == 0: - logger.debug( - "Throttle %s/%s [%s] queue drained", - provider_name, - model_id, - domain.value, - ) - - # ------------------------------------------------------------------- - # Introspection (useful for tests and telemetry) - # ------------------------------------------------------------------- - - def get_domain_state( - self, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - ) -> DomainThrottleState | None: - with self._lock: - return self._domains.get((provider_name, model_id, domain.value)) - - def get_effective_max(self, provider_name: str, model_id: str) -> int: - with self._lock: - return self._effective_max_for(provider_name, model_id) - - # ------------------------------------------------------------------- - # Private helpers - # ------------------------------------------------------------------- - - def _compute_soft_ceiling(self, state: DomainThrottleState, effective_max: int) -> int: - """Return the upper bound for additive increase. - - If a rate-limit ceiling has been recorded, allow probing up to - ``ceiling * (1 + overshoot)`` (clamped to ``effective_max``). - Otherwise fall back to ``effective_max``. - """ - if state.rate_limit_ceiling <= 0: - return effective_max - soft = state.rate_limit_ceiling + max(1, math.floor(state.rate_limit_ceiling * self._ceiling_overshoot)) - return min(soft, effective_max) - - def _get_or_create_domain( - self, - provider_name: str, - model_id: str, - domain: ThrottleDomain, - now: float, - ) -> DomainThrottleState: - key = (provider_name, model_id, domain.value) - state = self._domains.get(key) - if state is None: - effective_max = self._effective_max_for(provider_name, model_id) - rampup_active = self._rampup_seconds > 0 and effective_max > DEFAULT_MIN_LIMIT - state = DomainThrottleState( - current_limit=DEFAULT_MIN_LIMIT if rampup_active else effective_max, - rampup_started_at=now, - rampup_active=rampup_active, - ) - self._domains[key] = state - return state - - def _apply_startup_ramp(self, state: DomainThrottleState, effective_max: int, now: float) -> None: - """Apply the configured startup ramp to a domain, if it is still active.""" - if not state.rampup_active: - return - if self._rampup_seconds <= 0 or effective_max <= DEFAULT_MIN_LIMIT: - state.current_limit = min(state.current_limit, effective_max) - state.rampup_active = False - return - elapsed = max(0.0, now - state.rampup_started_at) - if elapsed >= self._rampup_seconds: - state.current_limit = effective_max - state.rampup_active = False - return - fraction = elapsed / self._rampup_seconds - ramp_slots = math.floor((effective_max - DEFAULT_MIN_LIMIT) * fraction) - state.current_limit = min(effective_max, DEFAULT_MIN_LIMIT + ramp_slots) - - def _effective_max_for(self, provider_name: str, model_id: str) -> int: - cap = self._global_caps.get((provider_name, model_id)) - if cap is None or cap.effective_max <= 0: - return DEFAULT_MIN_LIMIT - return cap.effective_max - - def _clamp_domains(self, global_key: tuple[str, str], effective_max: int) -> None: - provider_name, model_id = global_key - for (pn, mid, _dom), state in self._domains.items(): - if pn == provider_name and mid == model_id and state.current_limit > effective_max: - state.current_limit = effective_max diff --git a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttled.py b/packages/data-designer-engine/src/data_designer/engine/models/clients/throttled.py deleted file mode 100644 index 797452c69..000000000 --- a/packages/data-designer-engine/src/data_designer/engine/models/clients/throttled.py +++ /dev/null @@ -1,222 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import contextlib -import logging -from typing import TYPE_CHECKING - -from data_designer.engine.models.clients.base import ModelClient -from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind -from data_designer.engine.models.clients.throttle_manager import ThrottleDomain -from data_designer.engine.models.clients.types import ( - ChatCompletionRequest, - ChatCompletionResponse, - EmbeddingRequest, - EmbeddingResponse, - ImageGenerationRequest, - ImageGenerationResponse, -) - -if TYPE_CHECKING: - from collections.abc import AsyncIterator, Iterator - - from data_designer.engine.models.clients.throttle_manager import ThrottleManager - - -logger = logging.getLogger(__name__) - - -class ThrottledModelClient(ModelClient): - """Wraps a ``ModelClient`` with per-request throttle acquire/release. - - Inherits from ``ModelClient`` (a ``Protocol``) so that static type - checkers verify conformance and flag missing methods if the protocol - evolves. - - Every outbound HTTP call acquires a throttle permit from the - ``ThrottleManager`` and releases it on success, rate-limit, or failure. - The ``ThrottleDomain`` is determined by the method: - - - ``completion`` / ``acompletion`` -> ``CHAT`` - - ``embeddings`` / ``aembeddings`` -> ``EMBEDDING`` - - ``generate_image`` / ``agenerate_image`` -> ``IMAGE`` when - ``request.messages is None`` (diffusion), ``CHAT`` when messages are set - """ - - def __init__( - self, - inner: ModelClient, - throttle_manager: ThrottleManager, - provider_name: str, - model_id: str, - ) -> None: - self._inner = inner - self._tm = throttle_manager - self._provider_name = provider_name - self._model_id = model_id - - # --- ModelClient protocol delegation --- - - @property - def provider_name(self) -> str: - return self._inner.provider_name - - def supports_chat_completion(self) -> bool: - return self._inner.supports_chat_completion() - - def supports_embeddings(self) -> bool: - return self._inner.supports_embeddings() - - def supports_image_generation(self) -> bool: - return self._inner.supports_image_generation() - - def close(self) -> None: - self._inner.close() - - async def aclose(self) -> None: - await self._inner.aclose() - - # --- Throttled methods --- - - def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - with self._throttled_sync(ThrottleDomain.CHAT): - return self._inner.completion(request) - - async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: - async with self._athrottled(ThrottleDomain.CHAT): - return await self._inner.acompletion(request) - - def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - with self._throttled_sync(ThrottleDomain.EMBEDDING): - return self._inner.embeddings(request) - - async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - async with self._athrottled(ThrottleDomain.EMBEDDING): - return await self._inner.aembeddings(request) - - def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - domain = self._image_domain(request) - with self._throttled_sync(domain): - return self._inner.generate_image(request) - - async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: - domain = self._image_domain(request) - async with self._athrottled(domain): - return await self._inner.agenerate_image(request) - - # --- Context managers --- - - @contextlib.contextmanager - def _throttled_sync(self, domain: ThrottleDomain) -> Iterator[None]: - try: - self._tm.acquire_sync( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except TimeoutError as exc: - raise ProviderError( - kind=ProviderErrorKind.TIMEOUT, - message=str(exc), - provider_name=self._provider_name, - model_name=self._model_id, - ) from exc - exc_to_reraise: BaseException | None = None - try: - yield - except ProviderError as exc: - exc_to_reraise = exc - try: - self._release_on_provider_error(domain, exc) - except Exception: - logger.exception("ThrottleManager release failed; permit may leak") - except BaseException as exc: - exc_to_reraise = exc - try: - self._tm.release_failure( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except Exception: - logger.exception("ThrottleManager release failed; permit may leak") - else: - try: - self._tm.release_success( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except Exception: - logger.exception("ThrottleManager release_success failed") - if exc_to_reraise is not None: - raise exc_to_reraise - - @contextlib.asynccontextmanager - async def _athrottled(self, domain: ThrottleDomain) -> AsyncIterator[None]: - try: - await self._tm.acquire_async( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except TimeoutError as exc: - raise ProviderError( - kind=ProviderErrorKind.TIMEOUT, - message=str(exc), - provider_name=self._provider_name, - model_name=self._model_id, - ) from exc - exc_to_reraise: BaseException | None = None - try: - yield - except ProviderError as exc: - exc_to_reraise = exc - try: - self._release_on_provider_error(domain, exc) - except Exception: - logger.exception("ThrottleManager release failed; permit may leak") - except BaseException as exc: - exc_to_reraise = exc - try: - self._tm.release_failure( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except Exception: - logger.exception("ThrottleManager release failed; permit may leak") - else: - try: - self._tm.release_success( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - except Exception: - logger.exception("ThrottleManager release_success failed") - if exc_to_reraise is not None: - raise exc_to_reraise - - # --- Private helpers --- - - def _release_on_provider_error(self, domain: ThrottleDomain, exc: ProviderError) -> None: - if exc.kind == ProviderErrorKind.RATE_LIMIT: - self._tm.release_rate_limited( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - retry_after=exc.retry_after, - ) - else: - self._tm.release_failure( - provider_name=self._provider_name, - model_id=self._model_id, - domain=domain, - ) - - @staticmethod - def _image_domain(request: ImageGenerationRequest) -> ThrottleDomain: - return ThrottleDomain.CHAT if request.messages is not None else ThrottleDomain.IMAGE diff --git a/packages/data-designer-engine/src/data_designer/engine/models/factory.py b/packages/data-designer-engine/src/data_designer/engine/models/factory.py index e3836a912..61e1e70b9 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/factory.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/factory.py @@ -13,8 +13,9 @@ if TYPE_CHECKING: from data_designer.config.run_config import RunConfig from data_designer.engine.mcp.registry import MCPRegistry - from data_designer.engine.models.clients.throttle_manager import ThrottleManager from data_designer.engine.models.registry import ModelRegistry + from data_designer.engine.models.request_admission.config import RequestAdmissionConfig + from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController def create_model_registry( @@ -25,7 +26,7 @@ def create_model_registry( mcp_registry: MCPRegistry | None = None, client_concurrency_mode: ClientConcurrencyMode = ClientConcurrencyMode.SYNC, run_config: RunConfig | None = None, - throttle_manager: ThrottleManager | None = None, + request_admission: AdaptiveRequestAdmissionController | None = None, ) -> ModelRegistry: """Factory function for creating a ModelRegistry instance. @@ -42,24 +43,21 @@ def create_model_registry( client_concurrency_mode: ``"sync"`` (default) or ``"async"``. Forwarded to native HTTP adapters so each client is constrained to a single concurrency mode. - run_config: Optional runtime configuration. The nested - ``run_config.throttle`` (a ``ThrottleConfig``) is forwarded to the - ``ThrottleManager`` constructor. - throttle_manager: Optional shared throttle manager. When omitted, a new - manager is created for this registry. + run_config: Optional runtime configuration. Public request-admission + tuning is translated to the engine-internal request-admission config. + request_admission: Optional shared request-admission controller. When + omitted, a new controller is created from ``run_config``. Returns: A configured ModelRegistry instance. """ - from data_designer.config.run_config import RunConfig from data_designer.engine.models.clients.factory import create_model_client from data_designer.engine.models.clients.retry import RetryConfig - from data_designer.engine.models.clients.throttle_manager import ThrottleManager from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.registry import ModelRegistry - if throttle_manager is None: - throttle_manager = ThrottleManager((run_config or RunConfig()).throttle) + if request_admission is None: + request_admission = create_request_admission_controller(run_config) def model_facade_factory( model_config: ModelConfig, @@ -73,7 +71,7 @@ def model_facade_factory( model_provider_registry, retry_config=retry_config, client_concurrency_mode=client_concurrency_mode, - throttle_manager=throttle_manager, + request_admission=request_admission, ) return ModelFacade( model_config, @@ -87,6 +85,36 @@ def model_facade_factory( secret_resolver=secret_resolver, model_provider_registry=model_provider_registry, model_facade_factory=model_facade_factory, - throttle_manager=throttle_manager, + request_admission=request_admission, retry_config=RetryConfig(), ) + + +def create_request_admission_controller( + run_config: RunConfig | None = None, +) -> AdaptiveRequestAdmissionController: + """Create a request-admission controller from public runtime tuning.""" + from data_designer.config.run_config import RunConfig + from data_designer.engine.models.request_admission.config import RequestAdmissionConfig + from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController + + resolved_run_config = run_config or RunConfig() + return AdaptiveRequestAdmissionController( + _request_admission_config_from_run_config(resolved_run_config, RequestAdmissionConfig) + ) + + +def _request_admission_config_from_run_config( + run_config: RunConfig, + config_cls: type[RequestAdmissionConfig], +) -> RequestAdmissionConfig: + tuning = run_config.request_admission + if tuning is None: + return config_cls() + return config_cls( + cooldown_seconds=tuning.cooldown_seconds, + multiplicative_decrease_factor=tuning.multiplicative_decrease_factor, + additive_increase_step=tuning.additive_increase_step, + successes_until_increase=tuning.successes_until_increase, + startup_ramp_seconds=tuning.startup_ramp_seconds, + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/registry.py b/packages/data-designer-engine/src/data_designer/engine/models/registry.py index de0ecc036..b333ff408 100644 --- a/packages/data-designer-engine/src/data_designer/engine/models/registry.py +++ b/packages/data-designer-engine/src/data_designer/engine/models/registry.py @@ -16,8 +16,8 @@ from collections.abc import Callable from data_designer.engine.models.clients.retry import RetryConfig - from data_designer.engine.models.clients.throttle_manager import ThrottleManager from data_designer.engine.models.facade import ModelFacade + from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController ModelFacadeFactory = Callable[ [ModelConfig, SecretResolver, ModelProviderRegistry, RetryConfig | None], @@ -47,13 +47,13 @@ def __init__( model_provider_registry: ModelProviderRegistry, model_configs: list[ModelConfig] | None = None, model_facade_factory: ModelFacadeFactory | None = None, - throttle_manager: ThrottleManager | None = None, + request_admission: AdaptiveRequestAdmissionController | None = None, retry_config: RetryConfig | None = None, ) -> None: self._secret_resolver = secret_resolver self._model_provider_registry = model_provider_registry self._model_facade_factory = model_facade_factory - self._throttle_manager = throttle_manager + self._request_admission = request_admission self._retry_config = retry_config self._model_configs: dict[str, ModelConfig] = {} self._models: dict[str, ModelFacade] = {} @@ -68,8 +68,8 @@ def models(self) -> dict[str, ModelFacade]: return self._models @property - def throttle_manager(self) -> ThrottleManager | None: - return self._throttle_manager + def request_admission(self) -> AdaptiveRequestAdmissionController | None: + return self._request_admission @property def retry_config(self) -> RetryConfig | None: @@ -215,10 +215,9 @@ def get_aggregate_max_parallel_requests(self) -> int: This is a coarse upper bound: it sums over *all* registered aliases, including those not referenced by the current generator set, and does not deduplicate aliases sharing a ``(provider_name, model_id)`` key. - The result is used to size the scheduler's LLM-wait semaphore, which - is a memory-safety cap — oversizing wastes a few coroutine slots but - does not affect correctness because the ``ThrottleManager`` enforces - the real per-key limit. + The result is used to size scheduler task-stage model admission, which + is a memory-safety cap. Concrete provider/model request capacity is + enforced by request admission at model-call time. """ return sum(mc.inference_parameters.max_parallel_requests for mc in self._model_configs.values()) @@ -351,8 +350,8 @@ def _get_model(self, model_config: ModelConfig) -> ModelFacade: self._model_provider_registry, self._retry_config, ) - if self._throttle_manager is not None: - self._throttle_manager.register( + if self._request_admission is not None: + self._request_admission.register( provider_name=facade.model_provider_name, model_id=model_config.model, alias=model_config.alias, diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/config.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/config.py new file mode 100644 index 000000000..433e7942c --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/config.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field + +from data_designer.engine.models.request_admission.resources import RequestResourceKey + + +@dataclass(frozen=True) +class RequestAdmissionConfig: + initial_limits: Mapping[RequestResourceKey, int] = field(default_factory=dict) + max_limit_clamps: Mapping[RequestResourceKey, int | None] = field(default_factory=dict) + cooldown_seconds: float = 2.0 + multiplicative_decrease_factor: float = 0.75 + additive_increase_step: int = 1 + successes_until_increase: int = 25 + startup_ramp_seconds: float = 0.0 + default_queue_wait_timeout_seconds: float | None = None diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py new file mode 100644 index 000000000..38bbd0598 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/controller.py @@ -0,0 +1,788 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging +import math +import threading +import time +import uuid +from collections import Counter, deque +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Literal, Protocol + +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.limits import AdaptiveRequestLimitState +from data_designer.engine.models.request_admission.outcomes import ReleaseResult, RequestReleaseOutcome +from data_designer.engine.models.request_admission.pressure import ( + ProviderModelPressureSnapshot, + RequestPressureSnapshot, + RequestPressureSnapshotProvider, +) +from data_designer.engine.models.request_admission.queue import RequestFairQueue, RequestWaiter +from data_designer.engine.models.request_admission.resources import ( + RequestAdmissionItem, + RequestDomain, + RequestResourceKey, +) +from data_designer.engine.models.resources import ProviderModelKey +from data_designer.engine.observability import ( + RequestAdmissionEvent, + RequestAdmissionEventSink, + runtime_correlation_provider, +) + +logger = logging.getLogger(__name__) + +DEFAULT_MIN_LIMIT: int = 1 +RequestDenyReason = Literal[ + "no_capacity", + "cooldown", + "queue_timeout", + "queued_waiters_ahead", + "cancellation", + "shutdown", + "hard_policy_denial", +] +RELEASED_LEASE_HISTORY_LIMIT = 8192 +_TERMINAL_DENIAL_REASONS: frozenset[RequestDenyReason] = frozenset({"hard_policy_denial", "shutdown"}) + + +@dataclass(frozen=True) +class RequestAdmissionDenied: + item: RequestAdmissionItem + reason: RequestDenyReason + retry_after_seconds: float | None = None + available_after_monotonic: float | None = None + snapshot: object | None = None + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class RequestAdmissionLease: + lease_id: str + item: RequestAdmissionItem + acquired_at: float + current_adaptive_limit: int + effective_max: int + controller_generation: str + + +RequestAdmissionDecision = RequestAdmissionLease | RequestAdmissionDenied + + +class RequestAdmissionError(RuntimeError): + """Raised by blocking acquire paths when no request lease is acquired.""" + + def __init__(self, decision: RequestAdmissionDenied) -> None: + super().__init__(f"Request admission failed: {decision.reason}") + self.decision = decision + + +class RequestAdmissionController(Protocol): + def try_acquire(self, item: RequestAdmissionItem) -> RequestAdmissionDecision: ... + + def acquire_sync(self, item: RequestAdmissionItem) -> RequestAdmissionLease: ... + + async def acquire_async(self, item: RequestAdmissionItem) -> RequestAdmissionLease: ... + + def release(self, lease: RequestAdmissionLease, outcome: RequestReleaseOutcome) -> ReleaseResult: ... + + @property + def pressure(self) -> RequestPressureSnapshotProvider: ... + + +@dataclass +class _GlobalCapState: + limits_by_alias: dict[str, int] = field(default_factory=dict) + effective_max: int = 0 + + def register_alias(self, alias: str, max_parallel: int) -> None: + self.limits_by_alias[alias] = max(1, max_parallel) + self.effective_max = min(self.limits_by_alias.values()) + + +class AdaptiveRequestAdmissionController(RequestPressureSnapshotProvider): + """AIMD-backed request admission controller with exact request leases.""" + + def __init__( + self, + config: RequestAdmissionConfig | None = None, + *, + event_sink: RequestAdmissionEventSink | None = None, + ) -> None: + self._config = config or RequestAdmissionConfig() + self._lock = threading.Lock() + self._condition = threading.Condition(self._lock) + self._generation = uuid.uuid4().hex + self._global_caps: dict[ProviderModelKey, _GlobalCapState] = {} + self._domains: dict[RequestResourceKey, AdaptiveRequestLimitState] = {} + self._active_leases: dict[str, RequestAdmissionLease] = {} + self._released: set[str] = set() + self._released_order: deque[str] = deque(maxlen=RELEASED_LEASE_HISTORY_LIMIT) + self._aggregate_in_flight: Counter[ProviderModelKey] = Counter() + self._aggregate_active_leases: Counter[ProviderModelKey] = Counter() + self._sequence = 0 + self._release_diagnostics: Counter[str] = Counter() + self._queue = RequestFairQueue() + self._event_sink = event_sink + + @property + def pressure(self) -> RequestPressureSnapshotProvider: + return self + + @property + def config(self) -> RequestAdmissionConfig: + return self._config + + def register( + self, + *, + provider_name: str, + model_id: str, + alias: str, + max_parallel_requests: int, + ) -> None: + events: list[RequestAdmissionEvent] = [] + with self._lock: + key = ProviderModelKey(provider_name, model_id) + cap = self._global_caps.setdefault(key, _GlobalCapState()) + before = cap.effective_max + cap.register_alias(alias, max_parallel_requests) + self._sequence += 1 + for resource, state in self._domains.items(): + if resource.provider_model_key == key: + effective_max = self._effective_max_for_resource(resource) + state.current_limit = min(state.current_limit, effective_max) + events.append( + self._request_event_locked( + "request_resource_registered", + request_resource_key=RequestResourceKey(provider_name, model_id, RequestDomain.CHAT), + diagnostics={"alias": alias, "provider_model": key, "max_parallel_requests": max_parallel_requests}, + ) + ) + if before != cap.effective_max: + events.append( + self._request_event_locked( + "request_effective_cap_changed", + request_resource_key=RequestResourceKey(provider_name, model_id, RequestDomain.CHAT), + diagnostics={"provider_model": key, "previous": before, "current": cap.effective_max}, + ) + ) + self._admit_waiters_locked(events) + self._condition.notify_all() + self._emit_events(events) + + def try_acquire(self, item: RequestAdmissionItem) -> RequestAdmissionDecision: + now = time.monotonic() + events: list[RequestAdmissionEvent] = [] + decision: RequestAdmissionDecision + with self._lock: + events.append(self._request_event_locked("request_wait_started", item=item)) + if self._queued_waiter_ahead_locked(item, now): + decision = RequestAdmissionDenied( + item=item, + reason="queued_waiters_ahead", + snapshot=self._snapshot_locked(item.resource, now), + ) + events.append(self._request_event_locked("request_acquire_denied", item=item, decision=decision)) + else: + denied = self._denial_for(item, now) + if denied is not None: + decision = denied + events.append(self._request_event_locked("request_acquire_denied", item=item, decision=decision)) + else: + decision = self._acquire_locked(item, now) + events.append(self._request_event_locked("request_wait_completed", item=item, lease=decision)) + events.append(self._request_event_locked("request_lease_acquired", item=item, lease=decision)) + self._emit_events(events) + return decision + + def acquire_sync(self, item: RequestAdmissionItem) -> RequestAdmissionLease: + try: + asyncio.get_running_loop() + except RuntimeError: + pass + else: + raise RuntimeError("acquire_sync would block the running event loop; use acquire_async instead.") + + timeout = ( + item.queue_wait_timeout_seconds + if item.queue_wait_timeout_seconds is not None + else self._config.default_queue_wait_timeout_seconds + ) + now = time.monotonic() + deadline = now + timeout if timeout is not None else None + waiter = RequestWaiter(waiter_id=uuid.uuid4().hex, item=item, enqueued_at=now, deadline_monotonic=deadline) + events: list[RequestAdmissionEvent] = [] + try: + while True: + with self._lock: + if waiter.assigned_lease is not None: + return waiter.assigned_lease + now = time.monotonic() + if deadline is not None and now >= deadline: + self._remove_waiter_locked(waiter) + denied = RequestAdmissionDenied( + item=item, + reason="queue_timeout", + snapshot=self._snapshot_locked(item.resource, now), + ) + events.append(self._request_event_locked("request_wait_timeout", item=item, decision=denied)) + raise RequestAdmissionError(denied) + if not self._queue.contains(waiter.waiter_id) and waiter.assigned_lease is None: + self._enqueue_waiter_locked(waiter, events) + self._admit_waiters_locked(events) + if waiter.assigned_lease is not None: + return waiter.assigned_lease + now = time.monotonic() + if (denied := self._terminal_denial_for(item, now)) is not None: + self._remove_waiter_locked(waiter) + events.append(self._request_event_locked("request_acquire_denied", item=item, decision=denied)) + self._condition.notify_all() + raise RequestAdmissionError(denied) + if deadline is not None and now >= deadline: + self._remove_waiter_locked(waiter) + denied = RequestAdmissionDenied( + item=item, + reason="queue_timeout", + snapshot=self._snapshot_locked(item.resource, now), + ) + events.append(self._request_event_locked("request_wait_timeout", item=item, decision=denied)) + raise RequestAdmissionError(denied) + wait = self._wait_seconds_locked(item, now, deadline) + self._condition.wait(timeout=wait) + finally: + self._emit_events(events) + + async def acquire_async(self, item: RequestAdmissionItem) -> RequestAdmissionLease: + loop = asyncio.get_running_loop() + wakeup = asyncio.Event() + timeout = ( + item.queue_wait_timeout_seconds + if item.queue_wait_timeout_seconds is not None + else self._config.default_queue_wait_timeout_seconds + ) + now = time.monotonic() + deadline = now + timeout if timeout is not None else None + waiter = RequestWaiter( + waiter_id=uuid.uuid4().hex, + item=item, + enqueued_at=now, + deadline_monotonic=deadline, + wakeup=lambda: loop.call_soon_threadsafe(wakeup.set), + ) + events: list[RequestAdmissionEvent] = [] + try: + while True: + with self._lock: + if waiter.assigned_lease is not None: + return waiter.assigned_lease + now = time.monotonic() + if deadline is not None and now >= deadline: + self._remove_waiter_locked(waiter) + denied = RequestAdmissionDenied( + item=item, + reason="queue_timeout", + snapshot=self._snapshot_locked(item.resource, now), + ) + events.append(self._request_event_locked("request_wait_timeout", item=item, decision=denied)) + raise RequestAdmissionError(denied) + if not self._queue.contains(waiter.waiter_id) and waiter.assigned_lease is None: + self._enqueue_waiter_locked(waiter, events) + self._admit_waiters_locked(events) + if waiter.assigned_lease is not None: + return waiter.assigned_lease + now = time.monotonic() + if (denied := self._terminal_denial_for(item, now)) is not None: + self._remove_waiter_locked(waiter) + events.append(self._request_event_locked("request_acquire_denied", item=item, decision=denied)) + self._condition.notify_all() + raise RequestAdmissionError(denied) + if deadline is not None and now >= deadline: + self._remove_waiter_locked(waiter) + denied = RequestAdmissionDenied( + item=item, + reason="queue_timeout", + snapshot=self._snapshot_locked(item.resource, now), + ) + events.append(self._request_event_locked("request_wait_timeout", item=item, decision=denied)) + raise RequestAdmissionError(denied) + wait = self._wait_seconds_locked(item, now, deadline) + try: + await asyncio.wait_for(wakeup.wait(), timeout=wait) + except asyncio.TimeoutError: + pass + wakeup.clear() + except asyncio.CancelledError: + lease_to_release: RequestAdmissionLease | None = None + with self._lock: + lease_to_release = waiter.assigned_lease + if lease_to_release is None: + self._remove_waiter_locked(waiter) + denied = RequestAdmissionDenied(item=item, reason="cancellation") + events.append( + self._request_event_locked( + "request_wait_cancelled", + item=item, + lease=lease_to_release, + decision=denied, + ) + ) + self._condition.notify_all() + if lease_to_release is not None: + self._emit_events(events) + events.clear() + self.release(lease_to_release, RequestReleaseOutcome(kind="local_cancelled")) + raise + finally: + self._emit_events(events) + + def release(self, lease: RequestAdmissionLease, outcome: RequestReleaseOutcome) -> ReleaseResult: + now = time.monotonic() + events: list[RequestAdmissionEvent] = [] + result: ReleaseResult + with self._lock: + if lease.controller_generation != self._generation: + self._release_diagnostics["wrong_controller_generation"] += 1 + result = ReleaseResult(released=False, reason="wrong_controller_generation") + events.append( + self._request_event_locked( + "request_release_diagnostic", item=lease.item, lease=lease, result=result + ) + ) + elif (active := self._active_leases.pop(lease.lease_id, None)) is None: + reason = "duplicate" if lease.lease_id in self._released else "unknown_lease" + self._release_diagnostics[reason] += 1 + result = ReleaseResult(released=False, reason=reason) + events.append( + self._request_event_locked( + "request_release_diagnostic", item=lease.item, lease=lease, result=result + ) + ) + elif active != lease: + self._active_leases[lease.lease_id] = active + self._release_diagnostics["stale_lease"] += 1 + result = ReleaseResult(released=False, reason="stale_lease") + events.append( + self._request_event_locked( + "request_release_diagnostic", item=lease.item, lease=lease, result=result + ) + ) + else: + self._remember_released_locked(lease.lease_id) + resource = active.item.resource + provider_model = resource.provider_model_key + state = self._get_or_create_state(resource) + state.in_flight = max(0, state.in_flight - 1) + state.active_lease_count = max(0, state.active_lease_count - 1) + state.last_outcome = outcome.kind + self._aggregate_in_flight[provider_model] = max(0, self._aggregate_in_flight[provider_model] - 1) + self._aggregate_active_leases[provider_model] = max( + 0, + self._aggregate_active_leases[provider_model] - 1, + ) + self._apply_outcome(state, resource, active.current_adaptive_limit, outcome, now, events) + self._sequence += 1 + result = ReleaseResult(released=True, reason="released") + if outcome.kind == "rate_limited": + events.append(self._request_event_locked("request_rate_limited", item=active.item, lease=active)) + events.append( + self._request_event_locked( + "request_lease_released", + item=active.item, + lease=active, + result=result, + outcome=outcome, + ) + ) + self._admit_waiters_locked(events) + self._condition.notify_all() + self._emit_events(events) + return result + + def snapshot(self, resource: RequestResourceKey) -> RequestPressureSnapshot | None: + with self._lock: + if resource not in self._domains: + return None + return self._snapshot_locked(resource, time.monotonic()) + + def snapshots(self) -> Mapping[RequestResourceKey, RequestPressureSnapshot]: + with self._lock: + now = time.monotonic() + return {resource: self._snapshot_locked(resource, now) for resource in self._domains} + + def global_snapshot(self, provider: str, model: str) -> ProviderModelPressureSnapshot | None: + with self._lock: + key = ProviderModelKey(provider, model) + if key not in self._global_caps: + return None + return self._global_snapshot_locked(key, time.monotonic()) + + def global_snapshots(self) -> Mapping[ProviderModelKey, ProviderModelPressureSnapshot]: + with self._lock: + now = time.monotonic() + return {key: self._global_snapshot_locked(key, now) for key in self._global_caps} + + def _queued_waiter_ahead_locked(self, item: RequestAdmissionItem, now: float) -> bool: + if not self._queue.has_waiters: + return False + self._expire_waiters_locked(now) + selection = self._queue.select_next(lambda waiter, _view: self._denial_for(waiter.item, now) is None) + if selection is None: + return False + selected_key = selection.item.resource.provider_model_key + return selected_key == item.resource.provider_model_key or selection.item.resource == item.resource + + def _enqueue_waiter_locked(self, waiter: RequestWaiter, events: list[RequestAdmissionEvent]) -> None: + if self._queue.enqueue(waiter): + self._get_or_create_state(waiter.item.resource).waiters += 1 + self._sequence += 1 + if self._queue.view().queued_total == 1: + events.append(self._request_event_locked("request_queue_formed", item=waiter.item)) + events.append(self._request_event_locked("request_wait_started", item=waiter.item)) + + def _remove_waiter_locked(self, waiter: RequestWaiter) -> None: + removed = self._queue.remove(waiter.waiter_id) + if removed is None: + return + state = self._get_or_create_state(waiter.item.resource) + state.waiters = max(0, state.waiters - 1) + self._sequence += 1 + + def _expire_waiters_locked(self, now: float) -> None: + for waiter in self._queue.waiters(): + if waiter.deadline_monotonic is not None and now >= waiter.deadline_monotonic: + self._remove_waiter_locked(waiter) + self._wake_waiter_locked(waiter) + + def _admit_waiters_locked(self, events: list[RequestAdmissionEvent]) -> None: + while self._queue.has_waiters: + now = time.monotonic() + self._expire_waiters_locked(now) + if not self._queue.has_waiters: + return + selection = self._queue.select_next(lambda waiter, _view: self._denial_for(waiter.item, now) is None) + if selection is None: + return + waiter = self._queue.commit(selection) + if waiter is None: + return + state = self._get_or_create_state(waiter.item.resource) + state.waiters = max(0, state.waiters - 1) + lease = self._acquire_locked(waiter.item, now) + waiter.assigned_lease = lease + self._wake_waiter_locked(waiter) + events.append(self._request_event_locked("request_wait_completed", item=waiter.item, lease=lease)) + events.append(self._request_event_locked("request_lease_acquired", item=waiter.item, lease=lease)) + if not self._queue.has_waiters: + events.append(self._request_event_locked("request_queue_drained", item=waiter.item)) + + def _wake_waiter_locked(self, waiter: RequestWaiter) -> None: + if waiter.wakeup is None: + return + waiter.wakeup() + + def _wait_seconds_locked( + self, + item: RequestAdmissionItem, + now: float, + deadline: float | None, + ) -> float: + candidates = [0.05] + if deadline is not None: + candidates.append(max(0.0, deadline - now)) + state = self._domains.get(item.resource) + if state is not None and state.blocked_until > now: + candidates.append(max(0.0, state.blocked_until - now)) + return max(0.0, min(candidates)) + + def _denial_for(self, item: RequestAdmissionItem, now: float) -> RequestAdmissionDenied | None: + resource = item.resource + provider_model = resource.provider_model_key + if provider_model not in self._global_caps: + return RequestAdmissionDenied(item=item, reason="hard_policy_denial", diagnostics={"unregistered": True}) + state = self._get_or_create_state(resource) + self._apply_startup_ramp_locked(state, resource, now) + if now < state.blocked_until: + return RequestAdmissionDenied( + item=item, + reason="cooldown", + retry_after_seconds=state.blocked_until - now, + available_after_monotonic=state.blocked_until, + snapshot=self._snapshot_locked(resource, now), + ) + effective_max = self._effective_max_for_resource(resource) + aggregate_cap = self._global_caps[provider_model].effective_max + if state.in_flight >= min(state.current_limit, effective_max): + return RequestAdmissionDenied( + item=item, reason="no_capacity", snapshot=self._snapshot_locked(resource, now) + ) + if self._aggregate_in_flight[provider_model] >= aggregate_cap: + return RequestAdmissionDenied( + item=item, reason="no_capacity", snapshot=self._snapshot_locked(resource, now) + ) + return None + + def _terminal_denial_for(self, item: RequestAdmissionItem, now: float) -> RequestAdmissionDenied | None: + denied = self._denial_for(item, now) + if denied is None or denied.reason not in _TERMINAL_DENIAL_REASONS: + return None + return denied + + def _remember_released_locked(self, lease_id: str) -> None: + if lease_id in self._released: + return + maxlen = self._released_order.maxlen + if maxlen is not None and len(self._released_order) >= maxlen: + self._released.discard(self._released_order[0]) + self._released.add(lease_id) + self._released_order.append(lease_id) + + def _acquire_locked(self, item: RequestAdmissionItem, now: float) -> RequestAdmissionLease: + resource = item.resource + provider_model = resource.provider_model_key + state = self._get_or_create_state(resource) + state.in_flight += 1 + state.active_lease_count += 1 + self._aggregate_in_flight[provider_model] += 1 + self._aggregate_active_leases[provider_model] += 1 + lease = RequestAdmissionLease( + lease_id=uuid.uuid4().hex, + item=item, + acquired_at=now, + current_adaptive_limit=state.current_limit, + effective_max=self._effective_max_for_resource(resource), + controller_generation=self._generation, + ) + self._active_leases[lease.lease_id] = lease + self._sequence += 1 + return lease + + def _apply_outcome( + self, + state: AdaptiveRequestLimitState, + resource: RequestResourceKey, + admitted_adaptive_limit: int, + outcome: RequestReleaseOutcome, + now: float, + events: list[RequestAdmissionEvent], + ) -> None: + effective_max = self._effective_max_for_resource(resource) + if outcome.kind == "rate_limited": + state.startup_ramp_active = False + prev_limit = state.current_limit + should_decrease = admitted_adaptive_limit <= prev_limit + state.consecutive_rate_limits += 1 + cooldown = ( + outcome.retry_after_seconds + if outcome.retry_after_seconds is not None and outcome.retry_after_seconds > 0 + else self._config.cooldown_seconds + ) + state.blocked_until = now + cooldown + state.success_streak = 0 + if should_decrease: + state.current_limit = max( + 1, math.floor(state.current_limit * self._config.multiplicative_decrease_factor) + ) + if state.rate_limit_ceiling == 0: + state.rate_limit_ceiling = max(1, admitted_adaptive_limit) + if state.current_limit != prev_limit: + events.append( + self._request_event_locked( + "request_limit_decreased", + request_resource_key=resource, + diagnostics={"previous": prev_limit, "current": state.current_limit}, + ) + ) + return + if state.startup_ramp_active: + self._apply_startup_ramp_locked(state, resource, now) + if outcome.kind == "success": + state.success_streak = 0 + return + if outcome.kind == "success" and now >= state.blocked_until: + prev_limit = state.current_limit + state.consecutive_rate_limits = 0 + state.success_streak += 1 + if state.success_streak >= self._config.successes_until_increase: + state.current_limit = min(effective_max, state.current_limit + self._config.additive_increase_step) + state.success_streak = 0 + if state.current_limit != prev_limit: + events.append( + self._request_event_locked( + "request_limit_increased", + request_resource_key=resource, + diagnostics={"previous": prev_limit, "current": state.current_limit}, + ) + ) + if state.rate_limit_ceiling and state.current_limit > state.rate_limit_ceiling: + events.append( + self._request_event_locked( + "request_soft_ceiling_recovered", + request_resource_key=resource, + diagnostics={"rate_limit_ceiling": state.rate_limit_ceiling}, + ) + ) + if state.current_limit == effective_max and state.blocked_until <= now: + events.append( + self._request_event_locked("request_fully_recovered", request_resource_key=resource) + ) + return + if state.in_flight == 0 and outcome.kind not in {"local_cancelled", "local_timeout"}: + state.consecutive_rate_limits = 0 + + def _increment_waiter(self, item: RequestAdmissionItem) -> None: + with self._lock: + self._get_or_create_state(item.resource).waiters += 1 + self._sequence += 1 + + def _decrement_waiter(self, item: RequestAdmissionItem) -> None: + with self._lock: + state = self._get_or_create_state(item.resource) + state.waiters = max(0, state.waiters - 1) + self._sequence += 1 + + def _get_or_create_state(self, resource: RequestResourceKey) -> AdaptiveRequestLimitState: + state = self._domains.get(resource) + if state is None: + initial = self._initial_limit_for_resource(resource) + ramp_active = self._config.startup_ramp_seconds > 0.0 and initial > DEFAULT_MIN_LIMIT + state = AdaptiveRequestLimitState( + current_limit=DEFAULT_MIN_LIMIT if ramp_active else initial, + startup_ramp_started_at=time.monotonic(), + startup_ramp_active=ramp_active, + ) + self._domains[resource] = state + return state + + def _initial_limit_for_resource(self, resource: RequestResourceKey) -> int: + effective_max = self._effective_max_for_resource(resource) + initial = self._config.initial_limits.get(resource, effective_max) + return max(DEFAULT_MIN_LIMIT, min(initial, effective_max)) + + def _effective_max_for_resource(self, resource: RequestResourceKey) -> int: + provider_model_cap = self._global_caps.get(resource.provider_model_key) + static_cap = provider_model_cap.effective_max if provider_model_cap is not None else DEFAULT_MIN_LIMIT + clamp = self._config.max_limit_clamps.get(resource) + return max(DEFAULT_MIN_LIMIT, min(static_cap, clamp if clamp is not None else static_cap)) + + def _apply_startup_ramp_locked( + self, + state: AdaptiveRequestLimitState, + resource: RequestResourceKey, + now: float, + ) -> None: + if not state.startup_ramp_active: + return + target_limit = self._initial_limit_for_resource(resource) + if self._config.startup_ramp_seconds <= 0.0 or target_limit <= DEFAULT_MIN_LIMIT: + changed = state.current_limit != target_limit or state.startup_ramp_active + state.current_limit = min(state.current_limit, target_limit) + state.startup_ramp_active = False + if changed: + self._sequence += 1 + return + + elapsed = max(0.0, now - state.startup_ramp_started_at) + previous_limit = state.current_limit + if elapsed >= self._config.startup_ramp_seconds: + state.current_limit = target_limit + state.startup_ramp_active = False + else: + fraction = elapsed / self._config.startup_ramp_seconds + ramp_slots = math.floor((target_limit - DEFAULT_MIN_LIMIT) * fraction) + state.current_limit = min(target_limit, DEFAULT_MIN_LIMIT + ramp_slots) + if state.current_limit != previous_limit or not state.startup_ramp_active: + self._sequence += 1 + + def _snapshot_locked(self, resource: RequestResourceKey, now: float) -> RequestPressureSnapshot: + state = self._get_or_create_state(resource) + self._apply_startup_ramp_locked(state, resource, now) + blocked_until = state.blocked_until if state.blocked_until > now else None + return RequestPressureSnapshot( + captured_at=now, + sequence=self._sequence, + resource=resource, + effective_max=self._effective_max_for_resource(resource), + current_limit=state.current_limit, + in_flight_count=state.in_flight, + active_lease_count=state.active_lease_count, + waiters=state.waiters, + blocked_until_monotonic=blocked_until, + cooldown_remaining_seconds=max(0.0, state.blocked_until - now), + rate_limit_ceiling=state.rate_limit_ceiling, + consecutive_rate_limits=state.consecutive_rate_limits, + last_outcome=state.last_outcome, + leak_diagnostics=dict(self._release_diagnostics), + ) + + def _global_snapshot_locked(self, key: ProviderModelKey, now: float) -> ProviderModelPressureSnapshot: + cap = self._global_caps[key] + domains = { + resource.domain: state.current_limit + for resource, state in self._domains.items() + if resource.provider_model_key == key + } + return ProviderModelPressureSnapshot( + captured_at=now, + sequence=self._sequence, + provider_model=key, + static_cap=cap.effective_max, + aggregate_in_flight=self._aggregate_in_flight[key], + aggregate_active_lease_count=self._aggregate_active_leases[key], + aliases=tuple(sorted(cap.limits_by_alias)), + raw_caps=dict(cap.limits_by_alias), + domains=domains, + ) + + def _request_event_locked( + self, + event_kind: str, + *, + item: RequestAdmissionItem | None = None, + lease: RequestAdmissionLease | None = None, + decision: RequestAdmissionDenied | None = None, + result: ReleaseResult | None = None, + outcome: RequestReleaseOutcome | None = None, + request_resource_key: RequestResourceKey | None = None, + diagnostics: Mapping[str, object] | None = None, + ) -> RequestAdmissionEvent: + self._sequence += 1 + event_context = item.event_context if item is not None else None + resource = request_resource_key or (item.resource if item is not None else None) + group_key = item.group.key if item is not None else None + reason_or_outcome = None + if decision is not None: + reason_or_outcome = decision.reason + elif outcome is not None: + reason_or_outcome = outcome.kind + elif result is not None: + reason_or_outcome = result.reason + return RequestAdmissionEvent.capture( + event_kind, # type: ignore[arg-type] + sequence=self._sequence, + correlation=event_context.captured_correlation + if event_context is not None + else runtime_correlation_provider.current(), + request_attempt_id=event_context.request_attempt_id if event_context is not None else None, + request_lease_id=lease.lease_id if lease is not None else None, + request_resource_key=resource, + request_group_key=group_key, + reason_or_outcome=reason_or_outcome, + pressure_snapshot=self._snapshot_locked(resource, time.monotonic()) if resource is not None else None, + diagnostics=dict(diagnostics or {}), + ) + + def _emit_events(self, events: list[RequestAdmissionEvent]) -> None: + if self._event_sink is None: + return + for event in events: + try: + self._event_sink.emit_request_event(event) + except Exception: + logger.warning("Request admission event sink raised; dropping event.", exc_info=True) + continue diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/limits.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/limits.py new file mode 100644 index 000000000..7df1bc278 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/limits.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass +class AdaptiveRequestLimitState: + current_limit: int + in_flight: int = 0 + blocked_until: float = 0.0 + success_streak: int = 0 + waiters: int = 0 + rate_limit_ceiling: int = 0 + consecutive_rate_limits: int = 0 + active_lease_count: int = 0 + last_outcome: str | None = None + startup_ramp_started_at: float = 0.0 + startup_ramp_active: bool = False diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/outcomes.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/outcomes.py new file mode 100644 index 000000000..3399b07f4 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/outcomes.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Literal + + +@dataclass(frozen=True) +class RequestReleaseOutcome: + kind: Literal[ + "success", + "rate_limited", + "provider_failure", + "provider_timeout", + "local_cancelled", + "local_timeout", + "unexpected_exception", + ] + retry_after_seconds: float | None = None + provider_status: int | None = None + diagnostics: Mapping[str, object] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ReleaseResult: + released: bool + reason: Literal["released", "duplicate", "stale_lease", "wrong_controller_generation", "unknown_lease"] + diagnostics: Mapping[str, object] = field(default_factory=dict) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/pressure.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/pressure.py new file mode 100644 index 000000000..a268f8898 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/pressure.py @@ -0,0 +1,56 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Protocol + +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey + + +@dataclass(frozen=True) +class RequestPressureSnapshot: + captured_at: float + sequence: int + resource: RequestResourceKey + effective_max: int + current_limit: int + in_flight_count: int + active_lease_count: int + waiters: int + blocked_until_monotonic: float | None + cooldown_remaining_seconds: float + rate_limit_ceiling: int + consecutive_rate_limits: int + last_outcome: str | None + leak_diagnostics: Mapping[str, int] + + +@dataclass(frozen=True) +class ProviderModelPressureSnapshot: + captured_at: float + sequence: int + provider_model: ProviderModelKey + static_cap: int + aggregate_in_flight: int + aggregate_active_lease_count: int + aliases: tuple[str, ...] + raw_caps: Mapping[str, int | None] + domains: Mapping[RequestDomain, int] + + +class RequestPressureSnapshotProvider(Protocol): + @property + def config(self) -> RequestAdmissionConfig | None: ... + + def snapshot(self, resource: RequestResourceKey) -> RequestPressureSnapshot | None: ... + + def snapshots(self) -> Mapping[RequestResourceKey, RequestPressureSnapshot]: ... + + def global_snapshot(self, provider: str, model: str) -> ProviderModelPressureSnapshot | None: ... + + def global_snapshots(self) -> Mapping[ProviderModelKey, ProviderModelPressureSnapshot]: ... diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/queue.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/queue.py new file mode 100644 index 000000000..cdca7027b --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/queue.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import heapq +from collections import Counter, deque +from collections.abc import Callable, Mapping +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from data_designer.engine.models.request_admission.resources import RequestAdmissionItem, RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey + +if TYPE_CHECKING: + from data_designer.engine.models.request_admission.controller import RequestAdmissionLease + + +@dataclass +class RequestWaiter: + waiter_id: str + item: RequestAdmissionItem + enqueued_at: float + deadline_monotonic: float | None = None + assigned_lease: RequestAdmissionLease | None = None + wakeup: Callable[[], None] | None = None + + +@dataclass(frozen=True) +class RequestQueueView: + queued_total: int + queued_by_group: Mapping[RequestResourceKey, int] + queued_demand_by_resource: Mapping[RequestResourceKey, int] + aggregate_provider_model_waiters: Mapping[ProviderModelKey, int] + + +@dataclass(frozen=True) +class RequestQueueSelection: + waiter: RequestWaiter + item: RequestAdmissionItem + waiter_id: str + queue_view: RequestQueueView + sequence_version: int + + +class RequestFairQueue: + """Weighted fair waiter queue used by request admission.""" + + def __init__(self) -> None: + self._queues: dict[RequestResourceKey, deque[RequestWaiter]] = {} + self._queued: dict[str, RequestWaiter] = {} + self._waiter_groups: dict[str, RequestResourceKey] = {} + self._group_finish: dict[RequestResourceKey, float] = {} + self._heap: list[tuple[float, int, RequestResourceKey]] = [] + self._active_heap_entries: dict[RequestResourceKey, tuple[float, int]] = {} + self._sequence = 0 + self._sequence_version = 0 + self._virtual_time = 0.0 + + @property + def has_waiters(self) -> bool: + return bool(self._queued) + + def contains(self, waiter_id: str) -> bool: + return waiter_id in self._queued + + def waiters(self) -> tuple[RequestWaiter, ...]: + return tuple(self._queued.values()) + + def enqueue(self, waiter: RequestWaiter) -> bool: + if waiter.waiter_id in self._queued: + return False + key = waiter.item.group.key + queue = self._queues.setdefault(key, deque()) + queue.append(waiter) + self._queued[waiter.waiter_id] = waiter + self._waiter_groups[waiter.waiter_id] = key + self._activate_group(key) + self._sequence_version += 1 + return True + + def remove(self, waiter_id: str) -> RequestWaiter | None: + waiter = self._queued.pop(waiter_id, None) + if waiter is None: + return None + self._waiter_groups.pop(waiter_id, None) + self._sequence_version += 1 + return waiter + + def select_next( + self, is_eligible: Callable[[RequestWaiter, RequestQueueView], bool] + ) -> RequestQueueSelection | None: + view = self.view() + heap_copy = list(self._heap) + heapq.heapify(heap_copy) + active_seen: set[RequestResourceKey] = set() + while heap_copy: + finish, sequence, key = heapq.heappop(heap_copy) + if key in active_seen: + continue + if self._active_heap_entries.get(key) != (finish, sequence): + continue + active_seen.add(key) + waiter = self._first_valid_waiter(key) + if waiter is None: + continue + if not is_eligible(waiter, view): + continue + return RequestQueueSelection( + waiter=waiter, + item=waiter.item, + waiter_id=waiter.waiter_id, + queue_view=view, + sequence_version=self._sequence_version, + ) + return None + + def commit(self, selection: RequestQueueSelection) -> RequestWaiter | None: + if selection.sequence_version != self._sequence_version: + return None + key = self._waiter_groups.get(selection.waiter_id) + if key is None or key != selection.item.group.key: + return None + queue = self._queues.get(key) + if queue is None: + return None + self._purge_queue_head(key) + if not queue or queue[0].waiter_id != selection.waiter_id: + return None + + waiter = queue.popleft() + self._queued.pop(waiter.waiter_id, None) + self._waiter_groups.pop(waiter.waiter_id, None) + self._active_heap_entries.pop(key, None) + weight = max(selection.item.group.weight, 1.0) + finish = self._group_finish.get(key, self._virtual_time) + self._virtual_time = max(self._virtual_time, finish) + self._group_finish[key] = self._virtual_time + (1.0 / weight) + self._sequence_version += 1 + self._purge_queue_head(key) + if queue: + self._activate_group(key) + return waiter + + def view(self) -> RequestQueueView: + queued_by_group: Counter[RequestResourceKey] = Counter() + demand_by_resource: Counter[RequestResourceKey] = Counter() + aggregate_waiters: Counter[ProviderModelKey] = Counter() + for waiter in self._queued.values(): + resource = waiter.item.resource + queued_by_group[waiter.item.group.key] += 1 + demand_by_resource[resource] += 1 + aggregate_waiters[resource.provider_model_key] += 1 + return RequestQueueView( + queued_total=len(self._queued), + queued_by_group=dict(queued_by_group), + queued_demand_by_resource=dict(demand_by_resource), + aggregate_provider_model_waiters=dict(aggregate_waiters), + ) + + def _activate_group(self, key: RequestResourceKey) -> None: + self._purge_queue_head(key) + queue = self._queues.get(key) + if not queue or key in self._active_heap_entries: + return + self._sequence += 1 + finish = self._group_finish.get(key, self._virtual_time) + heapq.heappush(self._heap, (finish, self._sequence, key)) + self._active_heap_entries[key] = (finish, self._sequence) + + def _first_valid_waiter(self, key: RequestResourceKey) -> RequestWaiter | None: + queue = self._queues.get(key) + if queue is None: + return None + for waiter in queue: + if waiter.waiter_id in self._queued and self._waiter_groups.get(waiter.waiter_id) == key: + return waiter + return None + + def _purge_queue_head(self, key: RequestResourceKey) -> None: + queue = self._queues.get(key) + if queue is None: + return + while queue: + waiter = queue[0] + if waiter.waiter_id in self._queued and self._waiter_groups.get(waiter.waiter_id) == key: + break + queue.popleft() diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resolver.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resolver.py new file mode 100644 index 000000000..462e77427 --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resolver.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass + +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey + + +@dataclass(frozen=True) +class ResolvedRequestResource: + provider_model: ProviderModelKey + resource: RequestResourceKey + aliases: tuple[str, ...] = () + generation_kind: str | None = None + + +class RequestResourceResolver: + """Canonical provider/model/domain request-resource identity factory.""" + + def resolve( + self, + *, + provider_name: str, + model_id: str, + domain: RequestDomain, + model_alias: str | None = None, + provider_alias: str | None = None, + generation_kind: str | None = None, + ) -> ResolvedRequestResource: + resource = RequestResourceKey(provider_name=provider_name, model_id=model_id, domain=domain) + aliases = tuple(alias for alias in (provider_alias, model_alias) if alias) + return ResolvedRequestResource( + provider_model=resource.provider_model_key, + resource=resource, + aliases=aliases, + generation_kind=generation_kind, + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resources.py b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resources.py new file mode 100644 index 000000000..b7b4bd2cd --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/request_admission/resources.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + +from data_designer.engine.models.resources import ProviderModelKey + + +class RequestDomain(str, Enum): + CHAT = "chat" + EMBEDDING = "embedding" + IMAGE = "image" + HEALTHCHECK = "healthcheck" + + +@dataclass(frozen=True, order=True) +class RequestResourceKey: + provider_name: str + model_id: str + domain: RequestDomain + + @property + def provider_model_key(self) -> ProviderModelKey: + return ProviderModelKey(self.provider_name, self.model_id) + + +@dataclass(frozen=True) +class RequestGroupSpec: + key: RequestResourceKey + weight: float = 1.0 + + +@dataclass(frozen=True) +class RequestEventContext: + captured_correlation: object | None = None + task_execution_id: str | None = None + request_attempt_id: str | None = None + + +@dataclass(frozen=True) +class RequestAdmissionItem: + resource: RequestResourceKey + group: RequestGroupSpec + queue_wait_timeout_seconds: float | None = None + event_context: RequestEventContext | None = None diff --git a/packages/data-designer-engine/src/data_designer/engine/models/resources.py b/packages/data-designer-engine/src/data_designer/engine/models/resources.py new file mode 100644 index 000000000..091e2936b --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/models/resources.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass + + +@dataclass(frozen=True, order=True) +class ProviderModelKey: + provider_name: str + model_id: str + + +@dataclass +class ProviderModelStaticCap: + cap: int + aliases: tuple[str, ...] + raw_caps: Mapping[str, int | None] + merge_rule: str = "min_same_endpoint" diff --git a/packages/data-designer-engine/src/data_designer/engine/observability.py b/packages/data-designer-engine/src/data_designer/engine/observability.py new file mode 100644 index 000000000..a7a28c41b --- /dev/null +++ b/packages/data-designer-engine/src/data_designer/engine/observability.py @@ -0,0 +1,244 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import contextvars +import math +import time +from collections.abc import Mapping +from dataclasses import dataclass, field, fields, is_dataclass +from enum import Enum +from typing import Literal, Protocol + + +@dataclass(frozen=True) +class RuntimeCorrelation: + run_id: str + row_group: int | None + task_column: str | None + task_type: str | None + scheduling_group_kind: str | None + scheduling_group_identity_hash: str | None + task_execution_id: str | None + + +JsonValue = str | int | float | bool | None | list["JsonValue"] | dict[str, "JsonValue"] + + +def _json_safe(value: object) -> JsonValue: + if value is None or isinstance(value, str | int | bool): + return value + if isinstance(value, float): + return value if math.isfinite(value) else str(value) + if isinstance(value, Enum): + return _json_safe(value.value) + if is_dataclass(value) and not isinstance(value, type): + return {field.name: _json_safe(getattr(value, field.name)) for field in fields(value)} + if isinstance(value, Mapping): + return {_json_safe_key(key): _json_safe(item) for key, item in value.items()} + if isinstance(value, list | tuple): + return [_json_safe(item) for item in value] + if isinstance(value, set | frozenset): + return [_json_safe(item) for item in sorted(value, key=repr)] + return str(value) + + +def _json_safe_key(value: object) -> str: + safe = _json_safe(value) + if isinstance(safe, str): + return safe + return str(safe) + + +def _json_safe_dict(value: Mapping[str, object] | None) -> dict[str, JsonValue]: + if value is None: + return {} + return {_json_safe_key(key): _json_safe(item) for key, item in value.items()} + + +class RuntimeCorrelationProvider: + """Context-variable backed runtime correlation provider.""" + + def __init__(self) -> None: + self._current: contextvars.ContextVar[RuntimeCorrelation | None] = contextvars.ContextVar( + "data_designer_runtime_correlation", + default=None, + ) + + def current(self) -> RuntimeCorrelation | None: + return self._current.get() + + def set(self, correlation: RuntimeCorrelation | None) -> contextvars.Token: + return self._current.set(correlation) + + def reset(self, token: contextvars.Token) -> None: + self._current.reset(token) + + +runtime_correlation_provider = RuntimeCorrelationProvider() + +SchedulerAdmissionEventKind = Literal[ + "scheduler_job_started", + "scheduler_job_completed", + "scheduler_health_snapshot", + "dependency_ready", + "ready_enqueued", + "row_group_admitted", + "row_group_admission_blocked", + "row_group_admission_target_changed", + "row_group_checkpointed", + "selected", + "queue_empty", + "admission_blocked", + "group_capped", + "request_pressure_advisory_skipped", + "task_lease_acquired", + "admission_denied", + "worker_spawned", + "worker_spawn_failed", + "stale_selection", + "retry_deferred", + "non_retryable_dropped", + "cancelled", + "salvage_redispatched", + "queue_drained", + "task_completed", + "task_lease_released", + "release_diagnostic", +] + +RequestAdmissionEventKind = Literal[ + "request_resource_registered", + "request_effective_cap_changed", + "request_queue_formed", + "request_wait_started", + "request_wait_completed", + "request_wait_timeout", + "request_wait_cancelled", + "request_acquire_denied", + "request_lease_acquired", + "model_request_started", + "model_request_completed", + "request_queue_drained", + "request_rate_limited", + "request_limit_decreased", + "request_limit_increased", + "request_soft_ceiling_recovered", + "request_fully_recovered", + "request_lease_released", + "request_release_diagnostic", +] + + +@dataclass(frozen=True) +class SchedulerAdmissionEvent: + event_kind: SchedulerAdmissionEventKind + captured_at_monotonic: float + sequence: int + captured_correlation: JsonValue = None + task_id: str | None = None + task_execution_id: str | None = None + task_lease_id: str | None = None + scheduler_resource_key: str | None = None + reason_or_result: str | None = None + snapshot: JsonValue = None + diagnostics: dict[str, JsonValue] = field(default_factory=dict) + + def __post_init__(self) -> None: + object.__setattr__(self, "captured_correlation", _json_safe(self.captured_correlation)) + object.__setattr__(self, "snapshot", _json_safe(self.snapshot)) + object.__setattr__(self, "diagnostics", _json_safe_dict(self.diagnostics)) + + @classmethod + def capture( + cls, + event_kind: SchedulerAdmissionEventKind, + *, + sequence: int, + correlation: RuntimeCorrelation | None = None, + **kwargs: object, + ) -> SchedulerAdmissionEvent: + return cls( + event_kind=event_kind, + captured_at_monotonic=time.monotonic(), + sequence=sequence, + captured_correlation=correlation, + **kwargs, + ) + + +@dataclass(frozen=True) +class RequestAdmissionEvent: + event_kind: RequestAdmissionEventKind + captured_at_monotonic: float + sequence: int + captured_correlation: JsonValue = None + request_attempt_id: str | None = None + request_lease_id: str | None = None + request_resource_key: JsonValue = None + request_group_key: JsonValue = None + reason_or_outcome: str | None = None + pressure_snapshot: JsonValue = None + diagnostics: dict[str, JsonValue] = field(default_factory=dict) + + def __post_init__(self) -> None: + object.__setattr__(self, "captured_correlation", _json_safe(self.captured_correlation)) + object.__setattr__(self, "request_resource_key", _json_safe(self.request_resource_key)) + object.__setattr__(self, "request_group_key", _json_safe(self.request_group_key)) + object.__setattr__(self, "pressure_snapshot", _json_safe(self.pressure_snapshot)) + object.__setattr__(self, "diagnostics", _json_safe_dict(self.diagnostics)) + + @classmethod + def capture( + cls, + event_kind: RequestAdmissionEventKind, + *, + sequence: int, + correlation: RuntimeCorrelation | None = None, + **kwargs: object, + ) -> RequestAdmissionEvent: + return cls( + event_kind=event_kind, + captured_at_monotonic=time.monotonic(), + sequence=sequence, + captured_correlation=correlation, + **kwargs, + ) + + +class SchedulerAdmissionEventSink(Protocol): + def emit_scheduler_event(self, event: SchedulerAdmissionEvent) -> None: ... + + +class RequestAdmissionEventSink(Protocol): + def emit_request_event(self, event: RequestAdmissionEvent) -> None: ... + + +class InMemoryAdmissionEventSink: + """Small sink used by tests, diagnostics, and benchmark smoke runs.""" + + def __init__(self) -> None: + self.scheduler_events: list[SchedulerAdmissionEvent] = [] + self.request_events: list[RequestAdmissionEvent] = [] + + def emit_scheduler_event(self, event: SchedulerAdmissionEvent) -> None: + self.scheduler_events.append(event) + + def emit_request_event(self, event: RequestAdmissionEvent) -> None: + self.request_events.append(event) + + +@dataclass(frozen=True) +class CorrelatedRuntimeView: + scheduler_events: tuple[SchedulerAdmissionEvent, ...] + request_events: tuple[RequestAdmissionEvent, ...] + + @property + def timeline(self) -> tuple[SchedulerAdmissionEvent | RequestAdmissionEvent, ...]: + return tuple( + sorted( + (*self.scheduler_events, *self.request_events), + key=lambda event: (event.captured_at_monotonic, event.sequence), + ) + ) diff --git a/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py b/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py index bb012310b..98802726a 100644 --- a/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py +++ b/packages/data-designer-engine/src/data_designer/engine/resources/resource_provider.py @@ -28,7 +28,7 @@ from data_designer.engine.storage.artifact_storage import ArtifactStorage if TYPE_CHECKING: - from data_designer.engine.models.clients.throttle_manager import ThrottleManager + from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController class ResourceType(StrEnum): @@ -95,7 +95,7 @@ def create_resource_provider( mcp_providers: list[MCPProviderT] | None = None, tool_configs: list[ToolConfig] | None = None, client_concurrency_mode: ClientConcurrencyMode | None = None, - throttle_manager: ThrottleManager | None = None, + request_admission: AdaptiveRequestAdmissionController | None = None, ) -> ResourceProvider: """Factory function for creating a ResourceProvider instance. @@ -116,7 +116,7 @@ def create_resource_provider( run_config: Optional runtime configuration. mcp_providers: Optional list of MCP provider configurations. tool_configs: Optional list of tool configurations. - throttle_manager: Optional shared throttle manager for model clients. + request_admission: Optional shared request-admission controller for model clients. Returns: A configured ResourceProvider instance. @@ -164,7 +164,7 @@ def create_resource_provider( mcp_registry=mcp_registry, client_concurrency_mode=client_concurrency_mode, run_config=effective_run_config, - throttle_manager=throttle_manager, + request_admission=request_admission, ), person_reader=person_reader, mcp_registry=mcp_registry, diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py similarity index 80% rename from packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py rename to packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py index 2ec7b4cd3..e647d4ac6 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_completion_tracker.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_completion.py @@ -14,9 +14,10 @@ SamplerColumnConfig, ) from data_designer.config.sampler_params import SamplerType -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker +from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker +from data_designer.engine.dataset_builders.scheduling.resources import stable_task_id +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef, Task from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph -from data_designer.engine.dataset_builders.utils.task_model import SliceRef, Task MODEL_ALIAS = "stub" @@ -189,7 +190,16 @@ def test_get_ready_tasks_seed_frontier(ready_ctx: ReadyTasksFixture) -> None: assert len(ready) == 1 assert ready[0].column == "topic" - assert ready[0].task_type == "batch" + assert ready[0].task_type == "from_scratch" + + +def test_mark_enqueued_uses_scheduler_stable_task_id(ready_ctx: ReadyTasksFixture) -> None: + ready_ctx.tracker.seed_frontier() + task = ready_ctx.tracker.ready_frontier()[0] + + ready_ctx.tracker.mark_enqueued({stable_task_id(task)}) + + assert ready_ctx.tracker.ready_frontier() == () def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> None: @@ -205,6 +215,53 @@ def test_get_ready_tasks_after_seed_complete(ready_ctx: ReadyTasksFixture) -> No assert delta.removed == () +def test_fan_out_cell_completion_readies_all_children_for_same_row() -> None: + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="heavy", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="child_a", prompt="{{ heavy }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="child_b", prompt="{{ heavy }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="child_c", prompt="{{ heavy }}", model_alias=MODEL_ALIAS), + ] + strategies = {config.name: GenerationStrategy.CELL_BY_CELL for config in configs[1:]} + strategies["topic"] = GenerationStrategy.FULL_COLUMN + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + tracker.mark_row_range_complete("topic", 0, 2) + + delta = tracker.mark_cell_complete("heavy", 0, 0) + + assert {task.column for task in delta.added} == {"child_a", "child_b", "child_c"} + assert {task.row_index for task in delta.added} == {0} + ready = tracker.get_ready_tasks(set()) + assert not any(task.column.startswith("child_") and task.row_index == 1 for task in ready) + + +def test_fan_in_cell_downstream_waits_for_all_same_row_upstreams() -> None: + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="up_a", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="up_b", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="up_c", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="judge", prompt="{{ up_a }} {{ up_b }} {{ up_c }}", model_alias=MODEL_ALIAS), + ] + strategies = {config.name: GenerationStrategy.CELL_BY_CELL for config in configs[1:]} + strategies["topic"] = GenerationStrategy.FULL_COLUMN + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + tracker.mark_row_range_complete("topic", 0, 2) + + first_delta = tracker.mark_cell_complete("up_a", 0, 0) + second_delta = tracker.mark_cell_complete("up_b", 0, 0) + final_delta = tracker.mark_cell_complete("up_c", 0, 0) + + assert not any(task.column == "judge" for task in first_delta.added) + assert not any(task.column == "judge" for task in second_delta.added) + assert final_delta.added == (Task(column="judge", row_group=0, row_index=0, task_type="cell"),) + ready = tracker.get_ready_tasks(set()) + assert not any(task.column == "judge" and task.row_index == 1 for task in ready) + + def test_get_ready_tasks_skips_dispatched(ready_ctx: ReadyTasksFixture) -> None: ready_ctx.tracker.mark_row_range_complete("topic", 0, 3) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py new file mode 100644 index 000000000..e2a9179f0 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from collections import Counter + +from data_designer.engine.dataset_builders.scheduling.queue import FairTaskQueue, QueueView +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceRequest, + TaskGroupKey, + TaskGroupSpec, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_model import Task + + +def _task(column: str, row_index: int) -> Task: + return Task(column=column, row_group=0, row_index=row_index, task_type="cell") + + +def _group(name: str, *, weight: float = 1.0, admitted_limit: int | None = None) -> TaskGroupSpec: + return TaskGroupSpec( + key=TaskGroupKey(kind="local", identity=(name,)), + weight=weight, + admitted_limit=admitted_limit, + ) + + +def _item(column: str, row_index: int, group: TaskGroupSpec | None = None) -> SchedulableTask: + task = _task(column, row_index) + group = group or _group(column) + return SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=SchedulerResourceRequest({"submission": 1}), + ) + + +def _select_and_commit(queue: FairTaskQueue) -> SchedulableTask | None: + selection = queue.select_next(lambda _item, _view: True) + if selection is None: + return None + return queue.commit(selection) + + +def test_fair_task_queue_equal_groups_round_robins() -> None: + queue = FairTaskQueue() + queue.enqueue( + [ + _item("a", 0), + _item("a", 1), + _item("b", 0), + _item("b", 1), + _item("c", 0), + _item("c", 1), + ] + ) + + selected = [_select_and_commit(queue) for _ in range(6)] + + assert [item.payload.column for item in selected if item is not None] == ["a", "b", "c", "a", "b", "c"] + + +def test_fair_task_queue_weighted_groups() -> None: + queue = FairTaskQueue() + queue.enqueue( + [_item("a", i, _group("a", weight=2)) for i in range(6)] + + [_item("b", i, _group("b", weight=1)) for i in range(6)] + ) + + selected = [_select_and_commit(queue) for _ in range(6)] + counts = Counter(item.payload.column for item in selected if item is not None) + + assert counts == {"a": 4, "b": 2} + + +def test_select_next_is_non_mutating_until_commit() -> None: + queue = FairTaskQueue() + first = _item("a", 0) + second = _item("b", 0) + queue.enqueue([first, second]) + + selection = queue.select_next(lambda _item, _view: True) + + assert selection is not None + assert queue.view().queued_total == 2 + committed = queue.commit(selection) + assert committed == first + assert queue.view().queued_total == 1 + + +def test_commit_rejects_stale_selection() -> None: + queue = FairTaskQueue() + first = _item("a", 0) + queue.enqueue([first]) + + selection = queue.select_next(lambda _item, _view: True) + assert selection is not None + queue.enqueue([_item("b", 0)]) + + assert queue.commit(selection) is None + assert queue.view().queued_total == 2 + + +def test_select_next_uses_scheduler_eligibility_callback() -> None: + queue = FairTaskQueue() + queue.enqueue([_item("a", 0), _item("b", 0)]) + + selection = queue.select_next(lambda item, _view: item.payload.column == "b") + + assert selection is not None + assert selection.item.payload.column == "b" + assert queue.commit(selection) == selection.item + + +def test_enqueue_is_idempotent_by_task_id() -> None: + queue = FairTaskQueue() + item = _item("a", 0) + + first = queue.enqueue([item]) + second = queue.enqueue([item]) + + assert first == (item.task_id,) + assert second == () + assert queue.view().queued_total == 1 + + +def test_discard_where_removes_matching_tasks() -> None: + queue = FairTaskQueue() + queue.enqueue([_item(column, i) for column in ["a", "b"] for i in range(2)]) + + queue.discard_where(lambda item: item.payload.column == "a") + selected = [_select_and_commit(queue) for _ in range(2)] + + assert [item.payload.column for item in selected if item is not None] == ["b", "b"] + assert _select_and_commit(queue) is None + + +def test_queue_view_exposes_group_and_resource_demand() -> None: + queue = FairTaskQueue() + group = _group("a") + task = _task("a", 0) + item = SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=SchedulerResourceRequest({"submission": 1, "llm_wait": 1}), + ) + + queue.enqueue([item]) + view: QueueView = queue.view() + + assert view.queued_total == 1 + assert view.queued_by_group[group.key] == 1 + assert view.queued_resource_demand_by_group[group.key]["llm_wait"] == 1 + assert view.first_candidate_resources_by_group[group.key]["submission"] == 1 diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resolver.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resolver.py new file mode 100644 index 000000000..d6dfcdbab --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resolver.py @@ -0,0 +1,192 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Literal +from unittest.mock import MagicMock + +import pytest + +from data_designer.config.base import SingleColumnConfig +from data_designer.config.column_configs import ExpressionColumnConfig +from data_designer.config.models import GenerationType +from data_designer.config.scheduling import SchedulingMetadata, SchedulingMetadataError +from data_designer.engine.column_generators.generators.base import ColumnGeneratorWithModelRegistry +from data_designer.engine.dataset_builders.scheduling.resolver import TaskSchedulingResolver +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey + + +class _LocalGenerator: + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.local() + + +class _ModelGenerator: + def __init__(self, metadata: SchedulingMetadata) -> None: + self._metadata = metadata + + def get_scheduling_metadata(self) -> SchedulingMetadata: + return self._metadata + + +class _FallbackGenerator: + def get_scheduling_metadata(self) -> SchedulingMetadata: + raise SchedulingMetadataError( + code="partial", + message="using fallback", + fallback=SchedulingMetadata.local("fallback"), + diagnostics={"reason": "test"}, + ) + + +class _FatalGenerator: + def get_scheduling_metadata(self) -> SchedulingMetadata: + raise SchedulingMetadataError(code="fatal", message="fatal") + + +def _task(column: str = "answer") -> Task: + return Task(column=column, row_group=0, row_index=0, task_type="cell") + + +def test_task_scheduling_resolver_uses_local_default_metadata() -> None: + resolver = TaskSchedulingResolver({"answer": _LocalGenerator()}) # type: ignore[arg-type] + + schedulable = resolver.schedulable_task(_task(), ("answer",)) + + assert schedulable.group.key.kind == "local" + assert schedulable.resource_request.amounts == {"submission": 1} + + +def test_task_scheduling_resolver_maps_model_metadata_to_model_resource() -> None: + metadata = SchedulingMetadata.model("nvidia", "nemotron", "chat", weight=3) + resolver = TaskSchedulingResolver({"answer": _ModelGenerator(metadata)}) # type: ignore[arg-type] + + schedulable = resolver.schedulable_task(_task(), ("answer",)) + + assert schedulable.group.key.kind == "model" + assert schedulable.group.weight == 3.0 + assert schedulable.group.admitted_limit == 6 + assert schedulable.resource_request.amounts == {"submission": 1, "llm_wait": 1} + assert schedulable.request_resource_key == RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) + + +def test_task_scheduling_resolver_records_safe_fallback_diagnostics() -> None: + resolver = TaskSchedulingResolver({"answer": _FallbackGenerator()}) # type: ignore[arg-type] + + schedulable = resolver.schedulable_task(_task(), ("answer",)) + + assert schedulable.group.key.identity[:2] == ("local", "fallback") + assert resolver.diagnostics[0]["code"] == "partial" + + +def test_task_scheduling_resolver_raises_fatal_metadata_error() -> None: + with pytest.raises(SchedulingMetadataError): + TaskSchedulingResolver({"answer": _FatalGenerator()}) # type: ignore[arg-type] + + +def test_model_registry_generator_metadata_deduplicates_same_endpoint_aliases() -> None: + class _RegistryGenerator(ColumnGeneratorWithModelRegistry[ExpressionColumnConfig]): + @staticmethod + def get_generation_strategy() -> object: + return object() + + def generate(self, data: object) -> object: + return data + + config = ExpressionColumnConfig(name="answer", expr="{{ x }}", dtype="str") + generator = _RegistryGenerator(config=config, resource_provider=MagicMock()) + generator._get_scheduling_model_aliases = lambda: ["primary", "secondary"] # type: ignore[method-assign] + configs = { + "primary": SimpleNamespace( + model="endpoint", + generation_type=GenerationType.CHAT_COMPLETION, + inference_parameters=SimpleNamespace(max_parallel_requests=4), + ), + "secondary": SimpleNamespace( + model="endpoint", + generation_type=GenerationType.CHAT_COMPLETION, + inference_parameters=SimpleNamespace(max_parallel_requests=2), + ), + } + providers = { + "primary": SimpleNamespace(name="nvidia"), + "secondary": SimpleNamespace(name="nvidia"), + } + generator.get_model_config = lambda model_alias: configs[model_alias] # type: ignore[method-assign] + generator.get_model_provider_name = lambda model_alias: providers[model_alias].name # type: ignore[method-assign] + + metadata = generator.get_scheduling_metadata() + + assert metadata.identity == ("model", "nvidia", "endpoint", "chat") + assert metadata.weight == 2 + assert metadata.diagnostics["merge_rule"] == "min_same_endpoint" + + resolver = TaskSchedulingResolver({"answer": generator}) # type: ignore[arg-type] + schedulable = resolver.schedulable_task(_task(), ("answer",)) + assert schedulable.request_resource_key == RequestResourceKey("nvidia", "endpoint", RequestDomain.CHAT) + + +def test_model_registry_generator_metadata_uses_custom_model_for_multi_endpoint_aliases() -> None: + class _PairwiseJudgeColumnConfig(SingleColumnConfig): + column_type: Literal["pairwise-judge-test"] = "pairwise-judge-test" + model_alias: str + judge_model_alias: str + + @property + def required_columns(self) -> list[str]: + return [] + + @property + def side_effect_columns(self) -> list[str]: + return [] + + def get_model_aliases(self) -> list[str]: + return [self.model_alias, self.judge_model_alias] + + class _RegistryGenerator(ColumnGeneratorWithModelRegistry[_PairwiseJudgeColumnConfig]): + @staticmethod + def get_generation_strategy() -> object: + return object() + + def generate(self, data: object) -> object: + return data + + config = _PairwiseJudgeColumnConfig(name="answer", model_alias="draft", judge_model_alias="judge") + assert config.get_model_aliases() == ["draft", "judge"] + generator = _RegistryGenerator(config=config, resource_provider=MagicMock()) + configs = { + "draft": SimpleNamespace( + model="draft-endpoint", + generation_type=GenerationType.CHAT_COMPLETION, + inference_parameters=SimpleNamespace(max_parallel_requests=4), + ), + "judge": SimpleNamespace( + model="judge-endpoint", + generation_type=GenerationType.CHAT_COMPLETION, + inference_parameters=SimpleNamespace(max_parallel_requests=2), + ), + } + providers = { + "draft": SimpleNamespace(name="nvidia"), + "judge": SimpleNamespace(name="openai"), + } + generator.get_model_config = lambda model_alias: configs[model_alias] # type: ignore[method-assign] + generator.get_model_provider_name = lambda model_alias: providers[model_alias].name # type: ignore[method-assign] + + metadata = generator.get_scheduling_metadata() + + assert metadata.kind == "custom_model" + assert metadata.identity[1].endswith("._RegistryGenerator") + assert metadata.identity[2].startswith("alias-set-") + assert metadata.weight == 6 + assert metadata.diagnostics["aliases"] == ("draft", "judge") + assert metadata.diagnostics["fallback_reason"] == "multi_endpoint_alias_set" + assert metadata.diagnostics["raw_caps"] == (4, 2) + + resolver = TaskSchedulingResolver({"answer": generator}) # type: ignore[arg-type] + schedulable = resolver.schedulable_task(_task(), ("answer",)) + assert schedulable.group.key.kind == "custom_model" + assert schedulable.request_resource_key is None diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py new file mode 100644 index 000000000..935f2c074 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py @@ -0,0 +1,62 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest + +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceRequest, + TaskGroupKey, + TaskGroupSpec, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_model import Task + + +def test_scheduler_resource_request_defaults_to_submission() -> None: + request = SchedulerResourceRequest() + + assert request.amounts == {"submission": 1} + + +def test_scheduler_resource_request_rejects_unknown_resource() -> None: + with pytest.raises(ValueError, match="Unknown scheduler resource key"): + SchedulerResourceRequest({"gpu": 1}) # type: ignore[arg-type] + + +def test_scheduler_resource_request_rejects_non_positive_amounts() -> None: + with pytest.raises(ValueError, match="must be a positive integer"): + SchedulerResourceRequest({"submission": 0}) + + +def test_stable_task_id_is_stable_for_task_identity() -> None: + task = Task(column="answer", row_group=3, row_index=8, task_type="cell") + + assert stable_task_id(task) == stable_task_id(task) + assert stable_task_id(task).startswith("task-") + + +def test_stable_task_id_distinguishes_task_identity_fields() -> None: + first = Task(column="answer", row_group=3, row_index=8, task_type="cell") + second = Task(column="answer", row_group=3, row_index=9, task_type="cell") + + assert stable_task_id(first) != stable_task_id(second) + + +def test_schedulable_task_binds_payload_group_and_resource_request() -> None: + task = Task(column="answer", row_group=0, row_index=1, task_type="cell") + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("nvidia", "nemotron")), admitted_limit=2) + request = SchedulerResourceRequest({"submission": 1, "llm_wait": 1}) + + item = SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=request, + ) + + assert item.payload == task + assert item.group == group + assert item.resource_request.amounts["llm_wait"] == 1 diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_admission.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_admission.py new file mode 100644 index 000000000..fbb2fd469 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_admission.py @@ -0,0 +1,275 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.engine.dataset_builders.scheduling.queue import FairTaskQueue +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceRequest, + TaskGroupKey, + TaskGroupSpec, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_admission import ( + RELEASED_TASK_LEASE_HISTORY_LIMIT, + TaskAdmissionConfig, + TaskAdmissionController, + TaskAdmissionDenied, + TaskAdmissionLease, +) +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.dataset_builders.scheduling.task_policies import BoundedBorrowTaskAdmissionPolicyConfig + + +def _item( + column: str, + row: int = 0, + *, + group: TaskGroupSpec | None = None, + resources: dict[str, int] | None = None, +) -> SchedulableTask: + task = Task(column=column, row_group=0, row_index=row, task_type="cell") + group = group or TaskGroupSpec(TaskGroupKey(kind="local", identity=(column,))) + return SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=SchedulerResourceRequest(resources or {"submission": 1}), + ) + + +def _queue_view(*items: SchedulableTask): + queue = FairTaskQueue() + queue.enqueue(items) + return queue.view() + + +def test_task_admission_acquires_and_releases_exact_lease() -> None: + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=1)) + item = _item("a") + + decision = controller.try_acquire(item, _queue_view(item)) + + assert isinstance(decision, TaskAdmissionLease) + assert controller.view().resources_available["submission"] == 0 + result = controller.release(decision) + assert result.released is True + assert controller.view().resources_available["submission"] == 1 + + +def test_task_admission_denies_when_resource_full() -> None: + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=1)) + first = _item("a") + second = _item("b") + lease = controller.try_acquire(first, _queue_view(first, second)) + + assert isinstance(lease, TaskAdmissionLease) + decision = controller.try_acquire(second, _queue_view(second)) + + assert isinstance(decision, TaskAdmissionDenied) + assert decision.reason == "no_capacity" + + +def test_task_admission_duplicate_release_does_not_increase_capacity() -> None: + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=1)) + item = _item("a") + lease = controller.try_acquire(item, _queue_view(item)) + assert isinstance(lease, TaskAdmissionLease) + + first = controller.release(lease) + second = controller.release(lease) + + assert first.released is True + assert second.released is False + assert second.reason == "duplicate" + assert controller.view().resources_available["submission"] == 1 + + +def test_task_admission_released_history_is_bounded() -> None: + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=1)) + first_lease: TaskAdmissionLease | None = None + for index in range(RELEASED_TASK_LEASE_HISTORY_LIMIT + 5): + item = _item(f"task-{index}") + lease = controller.try_acquire(item, _queue_view(item)) + assert isinstance(lease, TaskAdmissionLease) + first_lease = first_lease or lease + controller.release(lease) + + assert len(controller._released) == RELEASED_TASK_LEASE_HISTORY_LIMIT + assert len(controller._released_order) == RELEASED_TASK_LEASE_HISTORY_LIMIT + assert controller._released_order.maxlen == RELEASED_TASK_LEASE_HISTORY_LIMIT + assert first_lease is not None + assert controller.release(first_lease).reason == "unknown_lease" + + +def test_task_admission_group_cap_yields_to_peer_pressure() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=2)) + first = _item("a", 0, group=group) + second = _item("a", 1, group=group) + peer = _item("b") + lease = controller.try_acquire(first, _queue_view(first, second, peer)) + assert isinstance(lease, TaskAdmissionLease) + + decision = controller.try_acquire(second, _queue_view(second, peer)) + + assert isinstance(decision, TaskAdmissionDenied) + assert decision.reason == "group_cap" + + +def test_task_admission_group_cap_ignores_non_overlapping_typed_peer_resource() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController( + TaskAdmissionConfig(submission_capacity=3, resource_limits={"llm_wait": 3, "local": 3}) + ) + first = _item("a", 0, group=group, resources={"submission": 1, "llm_wait": 1}) + second = _item("a", 1, group=group, resources={"submission": 1, "llm_wait": 1}) + local_peer = _item("b", resources={"submission": 1, "local": 1}) + lease = controller.try_acquire(first, _queue_view(first, second, local_peer)) + assert isinstance(lease, TaskAdmissionLease) + + decision = controller.try_acquire(second, _queue_view(second, local_peer)) + + assert isinstance(decision, TaskAdmissionLease) + + +def test_task_admission_group_cap_applies_to_overlapping_typed_peer_resource() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + peer_group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "peer")), admitted_limit=1) + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=3, resource_limits={"llm_wait": 3})) + first = _item("a", 0, group=group, resources={"submission": 1, "llm_wait": 1}) + second = _item("a", 1, group=group, resources={"submission": 1, "llm_wait": 1}) + peer = _item("b", group=peer_group, resources={"submission": 1, "llm_wait": 1}) + lease = controller.try_acquire(first, _queue_view(first, second, peer)) + assert isinstance(lease, TaskAdmissionLease) + + decision = controller.try_acquire(second, _queue_view(second, peer)) + + assert isinstance(decision, TaskAdmissionDenied) + assert decision.reason == "group_cap" + assert decision.diagnostics["pressure_resources"] == ("llm_wait",) + + +def test_task_admission_group_cap_ignores_peer_blocked_by_hard_resource_capacity() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + peer_group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "peer")), admitted_limit=1) + controller = TaskAdmissionController( + TaskAdmissionConfig(submission_capacity=4, resource_limits={"llm_wait": 3, "local": 1}) + ) + first = _item("a", 0, group=group, resources={"submission": 1, "llm_wait": 1}) + second = _item("a", 1, group=group, resources={"submission": 1, "llm_wait": 1}) + local_holder = _item("local-holder", resources={"submission": 1, "local": 1}) + blocked_peer = _item("b", group=peer_group, resources={"submission": 1, "llm_wait": 1, "local": 1}) + first_lease = controller.try_acquire(first, _queue_view(first, second, blocked_peer)) + local_lease = controller.try_acquire(local_holder, _queue_view(local_holder, blocked_peer)) + assert isinstance(first_lease, TaskAdmissionLease) + assert isinstance(local_lease, TaskAdmissionLease) + + decision = controller.try_acquire(second, _queue_view(second, blocked_peer)) + + assert isinstance(decision, TaskAdmissionLease) + + +def test_explain_blocked_reports_group_cap_denials() -> None: + first_group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "first")), admitted_limit=1) + second_group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "second")), admitted_limit=1) + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=4)) + first_active = _item("a", 0, group=first_group) + second_active = _item("b", 0, group=second_group) + first_queued = _item("a", 1, group=first_group) + second_queued = _item("b", 1, group=second_group) + first_lease = controller.try_acquire(first_active, _queue_view(first_active, second_active)) + second_lease = controller.try_acquire(second_active, _queue_view(second_active, first_queued)) + assert isinstance(first_lease, TaskAdmissionLease) + assert isinstance(second_lease, TaskAdmissionLease) + queue = FairTaskQueue() + queue.enqueue((first_queued, second_queued)) + + assert queue.select_next(controller.is_eligible) is None + summary = controller.explain_blocked(queue.view()) + + assert summary.dominant_denial_reasons == {"group_cap": 2} + + +def test_task_admission_group_cap_does_not_block_solo_group() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController(TaskAdmissionConfig(submission_capacity=2)) + first = _item("a", 0, group=group) + second = _item("a", 1, group=group) + lease = controller.try_acquire(first, _queue_view(first, second)) + assert isinstance(lease, TaskAdmissionLease) + + decision = controller.try_acquire(second, _queue_view(second)) + + assert isinstance(decision, TaskAdmissionLease) + + +def test_bounded_borrow_limits_solo_group_borrow_debt() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController( + TaskAdmissionConfig( + submission_capacity=3, + bounded_borrow=BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1), + ) + ) + first = _item("a", 0, group=group) + second = _item("a", 1, group=group) + third = _item("a", 2, group=group) + first_lease = controller.try_acquire(first, _queue_view(first, second, third)) + assert isinstance(first_lease, TaskAdmissionLease) + borrowed = controller.try_acquire(second, _queue_view(second, third)) + assert isinstance(borrowed, TaskAdmissionLease) + + denied = controller.try_acquire(third, _queue_view(third)) + + assert isinstance(denied, TaskAdmissionDenied) + assert denied.reason == "borrow_debt" + assert controller.view().policy_debt_by_group_resource[(group.key, "submission")] == 1 + + +def test_bounded_borrow_debt_blocks_under_peer_pressure_and_releases() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController( + TaskAdmissionConfig( + submission_capacity=3, + bounded_borrow=BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1), + ) + ) + first = _item("a", 0, group=group) + borrowed_item = _item("a", 1, group=group) + blocked_item = _item("a", 2, group=group) + peer = _item("b") + first_lease = controller.try_acquire(first, _queue_view(first, borrowed_item)) + borrowed = controller.try_acquire(borrowed_item, _queue_view(borrowed_item)) + assert isinstance(first_lease, TaskAdmissionLease) + assert isinstance(borrowed, TaskAdmissionLease) + + denied = controller.try_acquire(blocked_item, _queue_view(blocked_item, peer)) + + assert isinstance(denied, TaskAdmissionDenied) + assert denied.reason == "borrow_debt" + controller.release(borrowed) + assert (group.key, "submission") not in controller.view().policy_debt_by_group_resource + + +def test_bounded_borrow_release_repayment_is_group_level() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + controller = TaskAdmissionController( + TaskAdmissionConfig( + submission_capacity=3, + bounded_borrow=BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1), + ) + ) + first = _item("a", 0, group=group) + borrowed_item = _item("a", 1, group=group) + first_lease = controller.try_acquire(first, _queue_view(first, borrowed_item)) + borrowed = controller.try_acquire(borrowed_item, _queue_view(borrowed_item)) + assert isinstance(first_lease, TaskAdmissionLease) + assert isinstance(borrowed, TaskAdmissionLease) + assert controller.view().policy_debt_by_group_resource[(group.key, "submission")] == 1 + + controller.release(first_lease) + + assert (group.key, "submission") not in controller.view().policy_debt_by_group_resource + controller.release(borrowed) diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py similarity index 96% rename from packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py rename to packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py index 5d5716213..cdc5e6c6a 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_task_model.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_model.py @@ -5,7 +5,7 @@ import pytest -from data_designer.engine.dataset_builders.utils.task_model import Task, TaskResult, TaskTrace +from data_designer.engine.dataset_builders.scheduling.task_model import Task, TaskResult, TaskTrace def test_task_is_frozen() -> None: diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_policies.py b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_policies.py new file mode 100644 index 000000000..286fdee96 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_policies.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.engine.dataset_builders.scheduling.queue import FairTaskQueue, QueueView +from data_designer.engine.dataset_builders.scheduling.resources import ( + SchedulableTask, + SchedulerResourceRequest, + TaskGroupKey, + TaskGroupSpec, + stable_task_id, +) +from data_designer.engine.dataset_builders.scheduling.task_admission import TaskAdmissionLease, TaskAdmissionView +from data_designer.engine.dataset_builders.scheduling.task_model import Task +from data_designer.engine.dataset_builders.scheduling.task_policies import ( + BoundedBorrowTaskAdmissionPolicy, + BoundedBorrowTaskAdmissionPolicyConfig, + StrictFairTaskAdmissionPolicy, +) + + +def _item(column: str, group: TaskGroupSpec) -> SchedulableTask: + task = Task(column=column, row_group=0, row_index=0, task_type="cell") + return SchedulableTask( + task_id=stable_task_id(task), + payload=task, + group=group, + resource_request=SchedulerResourceRequest({"submission": 1}), + ) + + +def _queue_view(*items: SchedulableTask) -> QueueView: + queue = FairTaskQueue() + queue.enqueue(items) + return queue.view() + + +def _admission_view( + *, + running_group: TaskGroupKey, + running_count: int = 1, + debt: int = 0, +) -> TaskAdmissionView: + return TaskAdmissionView( + resource_limits={"submission": 4}, + resources_available={"submission": 3}, + leased_resources={"submission": running_count}, + leased_resources_by_group={running_group: {"submission": running_count}}, + running_counts_by_group={running_group: running_count}, + policy_debt_by_group_resource={(running_group, "submission"): debt} if debt else {}, + ) + + +def _lease(item: SchedulableTask) -> TaskAdmissionLease: + return TaskAdmissionLease( + lease_id="lease", + item=item, + resources={"submission": 1}, + acquired_at=0.0, + controller_generation="generation", + ) + + +def test_strict_fair_policy_allows_group_without_peer_pressure() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + item = _item("a", group) + policy = StrictFairTaskAdmissionPolicy() + + decision = policy.evaluate(item, _queue_view(item), _admission_view(running_group=group.key)) + + assert decision.allowed is True + + +def test_strict_fair_policy_denies_capped_group_with_peer_pressure() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + peer_group = TaskGroupSpec(TaskGroupKey(kind="local", identity=("peer",))) + item = _item("a", group) + peer = _item("b", peer_group) + policy = StrictFairTaskAdmissionPolicy() + + decision = policy.evaluate(item, _queue_view(item, peer), _admission_view(running_group=group.key)) + + assert decision.allowed is False + assert decision.reason == "group_cap" + + +def test_bounded_borrow_policy_records_borrow_without_peer_pressure() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + item = _item("a", group) + policy = BoundedBorrowTaskAdmissionPolicy(BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1)) + + decision = policy.evaluate(item, _queue_view(item), _admission_view(running_group=group.key)) + delta = policy.on_acquire(_lease(item), decision) + + assert decision.allowed is True + assert delta.debt_changes == {(group.key, "submission"): 1} + + +def test_bounded_borrow_policy_denies_existing_debt_under_peer_pressure() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + peer_group = TaskGroupSpec(TaskGroupKey(kind="local", identity=("peer",))) + item = _item("a", group) + peer = _item("b", peer_group) + policy = BoundedBorrowTaskAdmissionPolicy(BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1)) + + decision = policy.evaluate(item, _queue_view(item, peer), _admission_view(running_group=group.key, debt=1)) + + assert decision.allowed is False + assert decision.reason == "borrow_debt" + + +def test_bounded_borrow_policy_releases_debt() -> None: + group = TaskGroupSpec(TaskGroupKey(kind="model", identity=("provider", "model")), admitted_limit=1) + item = _item("a", group) + policy = BoundedBorrowTaskAdmissionPolicy(BoundedBorrowTaskAdmissionPolicyConfig(default_borrow_ceiling=1)) + + delta = policy.on_release(_lease(item)) + + assert delta.debt_changes == {(group.key, "submission"): -1} diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py index 684c009ba..f01dc1d91 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_builder_integration.py @@ -5,10 +5,12 @@ import math import warnings +from types import SimpleNamespace from unittest.mock import MagicMock, Mock import pytest +import data_designer.engine.dataset_builders.dataset_builder as builder_mod import data_designer.lazy_heavy_imports as lazy from data_designer.config.column_configs import ( ExpressionColumnConfig, @@ -24,7 +26,7 @@ ) from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler from data_designer.engine.dataset_builders.dataset_builder import DatasetBuilder -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta +from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager from data_designer.engine.resources.resource_provider import ResourceProvider @@ -189,6 +191,39 @@ def finalize_row_group(rg_id: int) -> None: assert tracker.is_row_group_complete(1, 2, all_cols) +def test_prepare_async_run_enables_request_pressure_advisory(monkeypatch: pytest.MonkeyPatch) -> None: + captured_kwargs: dict[str, object] = {} + + class _SpyScheduler: + def __init__(self, **kwargs: object) -> None: + captured_kwargs.update(kwargs) + + monkeypatch.setattr(builder_mod, "AsyncTaskScheduler", _SpyScheduler) + request_admission = object() + model_registry = MagicMock() + model_registry.get_aggregate_max_parallel_requests.return_value = 2 + model_registry.request_admission = request_admission + provider = SimpleNamespace( + model_registry=model_registry, + run_config=SimpleNamespace(progress_interval=5.0, progress_bar=False), + ) + processor_runner = MagicMock() + processor_runner.has_processors_for.return_value = False + config = SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}) + builder = SimpleNamespace( + _column_configs=[config], + _processor_runner=processor_runner, + artifact_storage=MagicMock(), + _resource_provider=provider, + ) + generator = MockSeed(config=_expr_config("seed"), resource_provider=provider) + + DatasetBuilder._prepare_async_run(builder, [generator], num_records=1, buffer_size=1) + + assert captured_kwargs["request_pressure_provider"] is request_admission + assert captured_kwargs["request_pressure_advisory"] is True + + # -- Test that existing sync path is unaffected -------------------------------- diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py index 6097232ef..41191c609 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py @@ -4,6 +4,8 @@ from __future__ import annotations import asyncio +import logging +import time from collections.abc import Callable from types import SimpleNamespace from typing import Any @@ -23,24 +25,42 @@ from data_designer.config.custom_column import custom_column_generator from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig from data_designer.config.sampler_params import SamplerType +from data_designer.config.scheduling import SchedulingMetadata from data_designer.engine.column_generators.generators.base import ( ColumnGenerator, ColumnGeneratorFullColumn, + ColumnGeneratorWithModelRegistry, FromScratchColumnGenerator, ) from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator -from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler, build_llm_bound_lookup +from data_designer.engine.dataset_builders.async_scheduler import AsyncTaskScheduler from data_designer.engine.dataset_builders.errors import DatasetGenerationError -from data_designer.engine.dataset_builders.utils.completion_tracker import CompletionTracker, FrontierDelta +from data_designer.engine.dataset_builders.scheduling.completion import CompletionTracker, FrontierDelta +from data_designer.engine.dataset_builders.scheduling.task_admission import TaskAdmissionConfig, TaskAdmissionLease +from data_designer.engine.dataset_builders.scheduling.task_model import Task from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph from data_designer.engine.dataset_builders.utils.row_group_buffer import RowGroupBufferManager -from data_designer.engine.dataset_builders.utils.task_model import Task from data_designer.engine.models.errors import ( RETRYABLE_MODEL_ERRORS, ModelInternalServerError, ModelRateLimitError, ModelTimeoutError, ) +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.controller import ( + AdaptiveRequestAdmissionController, + RequestAdmissionLease, +) +from data_designer.engine.models.request_admission.outcomes import RequestReleaseOutcome +from data_designer.engine.models.request_admission.pressure import RequestPressureSnapshot +from data_designer.engine.models.request_admission.resources import ( + RequestAdmissionItem, + RequestDomain, + RequestGroupSpec, + RequestResourceKey, +) +from data_designer.engine.models.resources import ProviderModelKey +from data_designer.engine.observability import InMemoryAdmissionEventSink from data_designer.engine.resources.resource_provider import ResourceProvider MODEL_ALIAS = "stub" @@ -83,6 +103,25 @@ def generate(self, data: dict) -> dict: return data +class MockRootCellGenerator(ColumnGenerator[ExpressionColumnConfig]): + """Root cell generator that records the shape it receives.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.call_types: list[str] = [] + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, data: dict) -> dict: + self.call_types.append(type(data).__name__) + if not isinstance(data, dict): + raise TypeError(f"expected dict, got {type(data).__name__}") + data[self.config.name] = f"root_{len(self.call_types)}" + return data + + class MockFullColumnGenerator(ColumnGeneratorFullColumn[ExpressionColumnConfig]): """Mock full-column generator.""" @@ -152,6 +191,59 @@ def generate(self, data: dict) -> dict: return data +class MockBuggyGenerator(ColumnGenerator[ExpressionColumnConfig]): + """Generator that raises a bare built-in exception from generator code.""" + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.CELL_BY_CELL + + def generate(self, _data: dict) -> dict: + raise KeyError("missing internal key") + + +class MockBuggyFromScratchGenerator(FromScratchColumnGenerator[ExpressionColumnConfig]): + """From-scratch generator that raises a bare built-in exception from generator code.""" + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + return data + + def generate_from_scratch(self, _num_records: int) -> lazy.pd.DataFrame: + raise AssertionError("invalid seed source") + + +class MockMalformedFromScratchGenerator(FromScratchColumnGenerator[ExpressionColumnConfig]): + """From-scratch generator that returns a non-DataFrame object.""" + + @staticmethod + def get_generation_strategy() -> GenerationStrategy: + return GenerationStrategy.FULL_COLUMN + + def generate(self, data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + return data + + def generate_from_scratch(self, num_records: int) -> Any: + return [{"seed": index} for index in range(num_records)] + + +class MockBuggyFullColumnGenerator(ColumnGeneratorFullColumn[ExpressionColumnConfig]): + """Full-column generator that raises a bare built-in exception from generator code.""" + + def generate(self, _data: lazy.pd.DataFrame) -> lazy.pd.DataFrame: + raise TypeError("bad batch shape") + + +class MockMalformedFullColumnGenerator(ColumnGeneratorFullColumn[ExpressionColumnConfig]): + """Full-column generator that returns a non-DataFrame object.""" + + def generate(self, data: lazy.pd.DataFrame) -> Any: + return [{"seed": value, self.config.name: "bad"} for value in data.get("seed", [])] + + class MockRateLimitGenerator(ColumnGenerator[ExpressionColumnConfig]): """Generator that fails with rate-limit errors before succeeding. @@ -228,8 +320,8 @@ def generate(self, data: dict) -> dict: class MockRetryableErrorGenerator(ColumnGenerator[ExpressionColumnConfig]): """Generator that raises a parametrizable retryable error then succeeds. - Declares ``is_llm_bound=True`` because it mimics model-call behavior; - the scheduler's degraded-provider WARN window only counts LLM-bound tasks. + Declares model scheduling metadata because it mimics model-call behavior; + the scheduler's degraded-provider WARN window counts model-stage tasks. """ def __init__( @@ -248,9 +340,8 @@ def __init__( def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.CELL_BY_CELL - @property - def is_llm_bound(self) -> bool: - return True + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.custom_model("test", self.config.name, "v1") def generate(self, data: dict) -> dict: self._calls += 1 @@ -260,6 +351,11 @@ def generate(self, data: dict) -> dict: return data +class _BrokenSchedulerSink: + def emit_scheduler_event(self, _event: object) -> None: + raise RuntimeError("sink boom") + + # -- Helper to build graph + scheduler ---------------------------------------- @@ -270,6 +366,7 @@ def _build_simple_pipeline( generators: dict[str, ColumnGenerator] | None = None, configs: list[SamplerColumnConfig | LLMTextColumnConfig | ExpressionColumnConfig] | None = None, strategies: dict[str, GenerationStrategy] | None = None, + scheduler_event_sink: Any | None = None, ) -> tuple[AsyncTaskScheduler, CompletionTracker]: """Build a simple seed → cell pipeline for testing.""" if configs is None: @@ -308,6 +405,7 @@ def _build_simple_pipeline( tracker=tracker, row_groups=row_groups, trace=trace, + scheduler_event_sink=scheduler_event_sink, ) return scheduler, tracker @@ -377,6 +475,31 @@ async def test_scheduler_dispatches_seeds_first() -> None: assert seed_traces[0].dispatched_at < cell_traces[0].dispatched_at +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_dispatches_root_cell_by_cell_columns_per_row() -> None: + provider = _mock_provider() + generator = MockRootCellGenerator(config=_expr_config("root_cell"), resource_provider=provider) + configs = [SamplerColumnConfig(name="root_cell", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]})] + strategies = {"root_cell": GenerationStrategy.CELL_BY_CELL} + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators={"root_cell": generator}, + graph=graph, + tracker=tracker, + row_groups=row_groups, + trace=True, + ) + + await scheduler.run() + + assert generator.call_types == ["dict", "dict", "dict"] + assert [trace.task_type for trace in scheduler.traces] == ["cell", "cell", "cell"] + assert not any(tracker.is_dropped(0, row_index) for row_index in range(3)) + assert tracker.is_row_group_complete(0, 3, ["root_cell"]) + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_with_buffer_manager() -> None: """Scheduler writes results to buffer manager and checkpoints.""" @@ -475,6 +598,226 @@ async def test_scheduler_non_retryable_failure_drops_row() -> None: assert tracker.is_row_group_complete(0, 2, ["seed", "fail_col"]) +def test_scheduler_internal_bug_classifier_preserves_scheduler_builtin_failures() -> None: + scheduler, tracker = _build_simple_pipeline(num_records=1) + assert scheduler._is_internal_bug(KeyError("missing internal key")) + assert not scheduler._is_internal_bug(DatasetGenerationError("generator failure")) + assert not tracker.is_dropped(0, 0) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_generator_builtin_exception_drops_cell_without_fatal_abort( + caplog: pytest.LogCaptureFixture, +) -> None: + provider = _mock_provider() + scheduler, tracker = _build_simple_pipeline( + num_records=1, + configs=[ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="buggy_col", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ], + strategies={ + "seed": GenerationStrategy.FULL_COLUMN, + "buggy_col": GenerationStrategy.CELL_BY_CELL, + }, + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "buggy_col": MockBuggyGenerator(config=_expr_config("buggy_col"), resource_provider=provider), + }, + ) + + with caplog.at_level(logging.WARNING, logger="data_designer.engine.dataset_builders.async_scheduler"): + await scheduler.run() + + assert tracker.is_dropped(0, 0) + assert isinstance(scheduler.first_non_retryable_error, DatasetGenerationError) + assert isinstance(scheduler.first_non_retryable_error.__cause__, KeyError) + assert "Unexpected fatal Non-retryable failure" not in caplog.text + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_generator_builtin_exception_drops_from_scratch_group_without_fatal_abort( + caplog: pytest.LogCaptureFixture, +) -> None: + provider = _mock_provider() + configs = [SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]})] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + generators = {"seed": MockBuggyFromScratchGenerator(config=_expr_config("seed"), resource_provider=provider)} + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 2)]) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=[(0, 2)], + ) + + with caplog.at_level(logging.WARNING, logger="data_designer.engine.dataset_builders.async_scheduler"): + await scheduler.run() + + assert tracker.is_dropped(0, 0) + assert tracker.is_dropped(0, 1) + assert isinstance(scheduler.first_non_retryable_error, DatasetGenerationError) + assert isinstance(scheduler.first_non_retryable_error.__cause__, AssertionError) + assert "Unexpected fatal Non-retryable failure" not in caplog.text + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_generator_builtin_exception_drops_batch_group_without_fatal_abort( + caplog: pytest.LogCaptureFixture, +) -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="buggy_batch", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "buggy_batch": GenerationStrategy.FULL_COLUMN, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "buggy_batch": MockBuggyFullColumnGenerator( + config=_expr_config("buggy_batch"), + resource_provider=provider, + ), + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=RowGroupBufferManager(_make_storage()), + ) + + with caplog.at_level(logging.WARNING, logger="data_designer.engine.dataset_builders.async_scheduler"): + await scheduler.run() + + assert tracker.is_dropped(0, 0) + assert tracker.is_dropped(0, 1) + assert isinstance(scheduler.first_non_retryable_error, DatasetGenerationError) + assert isinstance(scheduler.first_non_retryable_error.__cause__, TypeError) + assert "Unexpected fatal Non-retryable failure" not in caplog.text + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_generator_malformed_from_scratch_return_drops_group_without_fatal_abort( + caplog: pytest.LogCaptureFixture, +) -> None: + provider = _mock_provider() + configs = [SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]})] + strategies = {"seed": GenerationStrategy.FULL_COLUMN} + generators = {"seed": MockMalformedFromScratchGenerator(config=_expr_config("seed"), resource_provider=provider)} + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=RowGroupBufferManager(_make_storage()), + ) + + with caplog.at_level(logging.WARNING, logger="data_designer.engine.dataset_builders.async_scheduler"): + await scheduler.run() + + assert tracker.is_dropped(0, 0) + assert tracker.is_dropped(0, 1) + assert isinstance(scheduler.first_non_retryable_error, DatasetGenerationError) + assert "must return a DataFrame, got list" in str(scheduler.first_non_retryable_error) + assert "Unexpected fatal Non-retryable failure" not in caplog.text + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_generator_malformed_batch_return_drops_group_without_fatal_abort( + caplog: pytest.LogCaptureFixture, +) -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="malformed_batch", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "malformed_batch": GenerationStrategy.FULL_COLUMN, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "malformed_batch": MockMalformedFullColumnGenerator( + config=_expr_config("malformed_batch"), + resource_provider=provider, + ), + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 2)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + buffer_manager=RowGroupBufferManager(_make_storage()), + ) + + with caplog.at_level(logging.WARNING, logger="data_designer.engine.dataset_builders.async_scheduler"): + await scheduler.run() + + assert tracker.is_dropped(0, 0) + assert tracker.is_dropped(0, 1) + assert isinstance(scheduler.first_non_retryable_error, DatasetGenerationError) + assert "must return a DataFrame, got list" in str(scheduler.first_non_retryable_error) + assert "Unexpected fatal Non-retryable failure" not in caplog.text + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_custom_generator_key_error_drops_row_without_fatal_abort( + caplog: pytest.LogCaptureFixture, +) -> None: + @custom_column_generator() + def failing_custom(row: dict) -> dict: + raise KeyError("missing user field") + + provider = _mock_provider() + custom_config = CustomColumnConfig(name="custom_col", generator_function=failing_custom) + scheduler, tracker = _build_simple_pipeline( + num_records=1, + configs=[ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + custom_config, + ], + strategies={ + "seed": GenerationStrategy.FULL_COLUMN, + "custom_col": GenerationStrategy.CELL_BY_CELL, + }, + generators={ + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "custom_col": CustomColumnGenerator(config=custom_config, resource_provider=provider), + }, + ) + + with caplog.at_level(logging.WARNING): + await scheduler.run() + + assert tracker.is_dropped(0, 0) + assert "This record will be skipped" in caplog.text + assert "Unexpected fatal Non-retryable failure" not in caplog.text + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_logs_sink_failures(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.WARNING, logger="data_designer.engine.dataset_builders.async_scheduler") + scheduler, tracker = _build_simple_pipeline(num_records=1, scheduler_event_sink=_BrokenSchedulerSink()) + + await scheduler.run() + + assert tracker.is_row_group_complete(0, 1, ["seed", "cell_out"]) + assert "Scheduler admission event sink raised; dropping event." in caplog.text + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_stateful_generator_serializes() -> None: """Stateful generators serialize across row groups.""" @@ -1084,10 +1427,10 @@ def _count_degraded_msgs(caplog: pytest.LogCaptureFixture) -> int: @pytest.mark.parametrize( "retryable_failures,num_records,window,interval_s,expected_count", [ - # Above-threshold + zero throttle: at least one WARN should fire. + # Above-threshold + no log interval: at least one WARN should fire. pytest.param(6, 10, 8, 0.0, "at_least_one", id="fires_above_threshold"), - # Above-threshold + 1h throttle: only one WARN despite sustained degradation. - pytest.param(8, 12, 4, 3600.0, 1, id="throttled_to_one"), + # Above-threshold + 1h log interval: only one WARN despite sustained degradation. + pytest.param(8, 12, 4, 3600.0, 1, id="rate_limited_to_one"), ], ) @pytest.mark.asyncio(loop_scope="session") @@ -1397,15 +1740,14 @@ async def test_scheduler_out_of_order_row_group_completion() -> None: assert checkpoint_order.index(1) < checkpoint_order.index(0) -# -- Dual-semaphore / LLM-bound tests ----------------------------------------- +# -- Task-admission / model-stage tests --------------------------------------- class MockLLMBoundCellGenerator(ColumnGenerator[ExpressionColumnConfig]): - """Mock cell-by-cell generator that reports is_llm_bound=True.""" + """Mock cell-by-cell generator that reports model-stage scheduling metadata.""" - @property - def is_llm_bound(self) -> bool: - return True + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.custom_model("test", self.config.name, "v1") @staticmethod def get_generation_strategy() -> GenerationStrategy: @@ -1416,13 +1758,9 @@ def generate(self, data: dict) -> dict: return data -class MockConfiguredModelCellGenerator(ColumnGenerator[LLMTextColumnConfig]): +class MockConfiguredModelCellGenerator(ColumnGeneratorWithModelRegistry[LLMTextColumnConfig]): """Mock cell generator with model-registry helpers.""" - @property - def is_llm_bound(self) -> bool: - return True - @staticmethod def get_generation_strategy() -> GenerationStrategy: return GenerationStrategy.CELL_BY_CELL @@ -1447,9 +1785,8 @@ def __init__(self, *args: Any, rate_limit_failures: int = 0, **kwargs: Any) -> N self._rate_limit_failures = rate_limit_failures self._calls = 0 - @property - def is_llm_bound(self) -> bool: - return True + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.custom_model("test", self.config.name, "v1") @staticmethod def get_generation_strategy() -> GenerationStrategy: @@ -1492,19 +1829,15 @@ async def test_scheduler_llm_bound_one_way_handoff() -> None: tracker=tracker, row_groups=row_groups, max_submitted_tasks=max_submitted, - max_llm_wait_tasks=max_llm_wait, + max_model_task_admission=max_llm_wait, ) await scheduler.run() assert tracker.is_row_group_complete(0, 3, ["seed", "llm_col"]) - sub_available, llm_available = scheduler.get_semaphore_permits() - assert sub_available == max_submitted, ( - f"Submission semaphore leaked after LLM handoff: available={sub_available}, expected={max_submitted}" - ) - assert llm_available == max_llm_wait, ( - f"LLM-wait semaphore leaked after LLM handoff: available={llm_available}, expected={max_llm_wait}" - ) + snapshot = scheduler.task_admission_snapshot() + assert snapshot.resources_available["submission"] == max_submitted + assert snapshot.resources_available["llm_wait"] == max_llm_wait @pytest.mark.asyncio(loop_scope="session") @@ -1535,21 +1868,19 @@ async def test_scheduler_non_llm_holds_submission_slot() -> None: tracker=tracker, row_groups=row_groups, max_submitted_tasks=2, - max_llm_wait_tasks=max_llm_wait, + max_model_task_admission=max_llm_wait, ) await scheduler.run() assert tracker.is_row_group_complete(0, 3, ["seed", "cell_out"]) - _, llm_available = scheduler.get_semaphore_permits() - assert llm_available == max_llm_wait, ( - f"LLM-wait semaphore was consumed by non-LLM task: available={llm_available}, expected={max_llm_wait}" - ) + snapshot = scheduler.task_admission_snapshot() + assert snapshot.resources_available["llm_wait"] == max_llm_wait @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_deadlock_regression() -> None: - """max_submitted_tasks=1, max_llm_wait_tasks=1, two ready LLM tasks completes without deadlock.""" + """max_submitted_tasks=1, max_model_task_admission=1, two ready LLM tasks completes without deadlock.""" provider = _mock_provider() configs = [ SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), @@ -1574,7 +1905,7 @@ async def test_scheduler_deadlock_regression() -> None: tracker=tracker, row_groups=row_groups, max_submitted_tasks=1, - max_llm_wait_tasks=1, + max_model_task_admission=1, ) await asyncio.wait_for(scheduler.run(), timeout=10.0) @@ -1613,15 +1944,88 @@ async def test_drain_frontier_raises_when_ready_but_no_capacity_or_inflight() -> graph=graph, tracker=tracker, row_groups=row_groups, - max_submitted_tasks=0, + task_admission_config=TaskAdmissionConfig(submission_capacity=1), ) scheduler._rg_states[0] = MagicMock(size=1) + blocker = scheduler._schedulable_task(Task(column="cell_out", row_group=0, row_index=99, task_type="cell")) + lease = scheduler._task_admission.try_acquire(blocker, scheduler._fair_queue.view()) + assert isinstance(lease, TaskAdmissionLease) scheduler._apply_frontier_delta(seed_delta) with pytest.raises(RuntimeError, match="Ready frontier is admission-blocked"): await scheduler._drain_frontier(("seed",), False) +def test_dispatch_selected_task_rolls_back_scheduler_state_when_worker_spawn_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = _mock_provider() + config = ExpressionColumnConfig(name="cell_out", expr="'x'", dtype="str") + graph = ExecutionGraph.create([config], {"cell_out": GenerationStrategy.CELL_BY_CELL}) + scheduler = AsyncTaskScheduler( + generators={"cell_out": MockCellGenerator(config=config, resource_provider=provider)}, + graph=graph, + tracker=CompletionTracker.with_graph(graph, [(0, 1)]), + row_groups=[(0, 1)], + scheduler_event_sink=(sink := InMemoryAdmissionEventSink()), + ) + task = Task(column="cell_out", row_group=0, row_index=0, task_type="cell") + item = scheduler._schedulable_task(task) + lease = scheduler._task_admission.try_acquire(item, scheduler._fair_queue.view()) + assert isinstance(lease, TaskAdmissionLease) + scheduler._rg_states[0] = SimpleNamespace(size=1, in_flight_count=0) + + def fail_spawn(coro: Any) -> None: + coro.close() + raise RuntimeError("spawn failed") + + monkeypatch.setattr(scheduler, "_spawn_worker", fail_spawn) + + with pytest.raises(RuntimeError, match="spawn failed"): + scheduler._dispatch_selected_task(item, lease) + + assert task not in scheduler._dispatched + assert task not in scheduler._in_flight + assert scheduler._rg_states[0].in_flight_count == 0 + assert scheduler.task_admission_snapshot().leased_resources == {} + assert scheduler.task_admission_snapshot().running_counts_by_group == {} + assert any(event.event_kind == "worker_spawn_failed" for event in sink.scheduler_events) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_main_dispatch_loop_yields_when_pre_batch_is_pending( + monkeypatch: pytest.MonkeyPatch, +) -> None: + provider = _mock_provider() + seed_config = ExpressionColumnConfig(name="seed", expr="'seed'", dtype="str") + graph = ExecutionGraph.create([seed_config], {"seed": GenerationStrategy.FULL_COLUMN}) + scheduler = AsyncTaskScheduler( + generators={"seed": MockSeedGenerator(config=seed_config, resource_provider=provider)}, + graph=graph, + tracker=CompletionTracker.with_graph(graph, [(0, 1)]), + row_groups=[(0, 1)], + ) + scheduler._all_rgs_admitted = True + scheduler._rg_states[0] = SimpleNamespace(size=1, seeds_dispatched=True, pre_batch_done=False) + monkeypatch.setattr(scheduler, "_run_seeds_complete_check", lambda seed_cols: None) + monkeypatch.setattr( + scheduler, "_dispatch_queued_tasks", lambda: SimpleNamespace(dispatched=False, admission_blocked=False) + ) + monkeypatch.setattr(scheduler, "_checkpoint_completed_row_groups", lambda all_columns: None) + monkeypatch.setattr(scheduler, "_maybe_update_adaptive_row_group_target", lambda: None) + yielded_delays: list[float] = [] + + async def record_sleep(delay: float) -> None: + yielded_delays.append(delay) + scheduler._rg_states[0].pre_batch_done = True + + monkeypatch.setattr(asyncio, "sleep", record_sleep) + + await scheduler._main_dispatch_loop(("seed",), True, ["seed"]) + + assert yielded_delays == [0] + + @pytest.mark.asyncio(loop_scope="session") async def test_scheduler_dispatch_does_not_scan_ready_frontier(monkeypatch: pytest.MonkeyPatch) -> None: provider = _mock_provider() @@ -1695,22 +2099,99 @@ def drop_middle_row(row_group: int, row_group_size: int) -> FrontierDelta: assert tracker.is_row_group_complete(0, 3, ["seed", "cell_out"]) -@pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_is_llm_bound_property_drives_lookup() -> None: - """is_llm_bound property on generators drives the lookup, not isinstance.""" +def test_apply_frontier_delta_enqueues_ready_tasks_in_one_queue_operation(monkeypatch: pytest.MonkeyPatch) -> None: provider = _mock_provider() - llm_gen = MockLLMBoundCellGenerator(config=_expr_config("llm_col"), resource_provider=provider) - non_llm_gen = MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider) - - assert llm_gen.is_llm_bound is True - assert non_llm_gen.is_llm_bound is False + configs = [ + LLMTextColumnConfig(name="root", prompt="root", model_alias=MODEL_ALIAS), + ] + strategies = { + "root": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "root": MockCellGenerator(config=_expr_config("root"), resource_provider=provider), + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 5)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + ) + scheduler._rg_states[0] = SimpleNamespace(size=5, pre_batch_done=True) + + enqueue_sizes: list[int] = [] + original_enqueue = scheduler._fair_queue.enqueue + + def spy_enqueue(items: Any) -> tuple[str, ...]: + materialized = tuple(items) + enqueue_sizes.append(len(materialized)) + return original_enqueue(materialized) + + monkeypatch.setattr(scheduler._fair_queue, "enqueue", spy_enqueue) - lookup = build_llm_bound_lookup({"llm_col": llm_gen, "cell_out": non_llm_gen}) - assert lookup == {"llm_col": True, "cell_out": False} + scheduler._apply_frontier_delta(tracker.add_root_tasks(0, 5)) + assert enqueue_sizes == [5] + assert tracker.ready_frontier() == () + assert scheduler._fair_queue.view().queued_total == 5 -def test_custom_generator_with_model_aliases_is_llm_bound() -> None: - """CustomColumnGenerator with model_aliases reports is_llm_bound=True.""" + +def test_pre_batch_flush_batches_pending_ready_and_skips_dropped_rows(monkeypatch: pytest.MonkeyPatch) -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="seed", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="cell_out", prompt="{{ seed }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "seed": GenerationStrategy.FULL_COLUMN, + "cell_out": GenerationStrategy.CELL_BY_CELL, + } + generators = { + "seed": MockSeedGenerator(config=_expr_config("seed"), resource_provider=provider), + "cell_out": MockCellGenerator(config=_expr_config("cell_out"), resource_provider=provider), + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + sink = InMemoryAdmissionEventSink() + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + on_seeds_complete=lambda row_group, row_group_size: None, + scheduler_event_sink=sink, + ) + state = SimpleNamespace(size=3, pre_batch_done=False) + scheduler._rg_states[0] = state + + enqueue_sizes: list[int] = [] + original_enqueue = scheduler._fair_queue.enqueue + + def spy_enqueue(items: Any) -> tuple[str, ...]: + materialized = tuple(items) + enqueue_sizes.append(len(materialized)) + return original_enqueue(materialized) + + monkeypatch.setattr(scheduler._fair_queue, "enqueue", spy_enqueue) + + scheduler._apply_frontier_delta(tracker.mark_row_range_complete("seed", 0, 3)) + scheduler._apply_frontier_delta(tracker.drop_row(0, 1)) + state.pre_batch_done = True + scheduler._flush_pre_batch_ready(0) + + assert enqueue_sizes == [2] + assert scheduler._fair_queue.view().queued_total == 2 + assert {item.payload.row_index for item in scheduler._fair_queue._queued.values()} == {0, 2} + assert tracker.is_dropped(0, 1) + assert sum(event.event_kind == "ready_enqueued" for event in sink.scheduler_events) == 2 + assert sum(event.event_kind == "dependency_ready" for event in sink.scheduler_events) == 5 + + +def test_custom_generator_with_model_aliases_reports_custom_model_metadata() -> None: + """CustomColumnGenerator with model_aliases reports custom-model metadata.""" @custom_column_generator(model_aliases=["my_model"]) def gen_with_models(row: dict, generator_params: None, models: dict) -> dict: @@ -1729,11 +2210,8 @@ def gen_no_models(row: dict) -> dict: llm_gen = CustomColumnGenerator(config=llm_config, resource_provider=provider) plain_gen = CustomColumnGenerator(config=plain_config, resource_provider=provider) - assert llm_gen.is_llm_bound is True - assert plain_gen.is_llm_bound is False - - lookup = build_llm_bound_lookup({"custom_llm": llm_gen, "custom_plain": plain_gen}) - assert lookup == {"custom_llm": True, "custom_plain": False} + assert llm_gen.get_scheduling_metadata().kind == "custom_model" + assert plain_gen.get_scheduling_metadata().kind == "local" def _provider_with_model_configs(configs: dict[str, ModelConfig]) -> MagicMock: @@ -1762,13 +2240,13 @@ def test_scheduler_model_task_group_spec_uses_model_resource_and_flow() -> None: graph=graph, tracker=tracker, row_groups=[(0, 1)], - max_llm_wait_tasks=5, + max_model_task_admission=5, ) - spec = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=0, task_type="cell")) + spec = scheduler._schedulable_task(Task(column="answer", row_group=0, row_index=0, task_type="cell")).group assert spec.key.kind == "model" - assert spec.key.identity[:2] == ("mock-provider", "model-text") + assert spec.key.identity[:3] == ("model", "mock-provider", "model-text") assert spec.key.identity[-1] == "answer" assert spec.weight == 3.0 assert spec.admitted_limit == 5 @@ -1792,21 +2270,19 @@ def test_scheduler_task_group_spec_is_cached_per_generator() -> None: graph=graph, tracker=tracker, row_groups=[(0, 2)], - max_llm_wait_tasks=5, + max_model_task_admission=5, ) - spec_a = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=0, task_type="cell")) - spec_b = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=1, task_type="cell")) + spec_a = scheduler._schedulable_task(Task(column="answer", row_group=0, row_index=0, task_type="cell")).group + spec_b = scheduler._schedulable_task(Task(column="answer", row_group=0, row_index=1, task_type="cell")).group - assert spec_a is spec_b + assert spec_a == spec_b assert provider.model_registry.get_model_config.call_count == 1 assert provider.model_registry.get_model_provider.call_count == 1 -def test_scheduler_task_group_spec_logs_debug_on_model_resolution_fallback( - caplog: pytest.LogCaptureFixture, -) -> None: - """Direct spec resolution isolates fallback logging without timing-based scheduler traces.""" +def test_scheduler_task_group_spec_raises_on_model_resolution_failure() -> None: + """Model metadata resolution failures are fatal without an explicit fallback.""" provider = MagicMock() provider.model_registry = MagicMock() provider.model_registry.get_model_config.side_effect = RuntimeError("registry unavailable") @@ -1816,29 +2292,14 @@ def test_scheduler_task_group_spec_logs_debug_on_model_resolution_fallback( graph = ExecutionGraph.create([column_config], {"answer": GenerationStrategy.CELL_BY_CELL}) tracker = CompletionTracker.with_graph(graph, [(0, 2)]) - with caplog.at_level("DEBUG", logger="data_designer.engine.dataset_builders.utils.scheduling_hints"): - scheduler = AsyncTaskScheduler( + with pytest.raises(Exception): + AsyncTaskScheduler( generators={"answer": generator}, graph=graph, tracker=tracker, row_groups=[(0, 2)], - max_llm_wait_tasks=5, + max_model_task_admission=5, ) - spec_a = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=0, task_type="cell")) - spec_b = scheduler._task_group_spec(Task(column="answer", row_group=0, row_index=1, task_type="cell")) - - assert spec_a is spec_b - assert spec_a.key.kind == "custom_model" - assert spec_a.key.identity == ("answer", MODEL_ALIAS) - assert spec_a.weight == 1.0 - assert provider.model_registry.get_model_config.call_count == 1 - fallback_records = [ - record for record in caplog.records if "Falling back to custom-model scheduling group" in record.getMessage() - ] - assert len(fallback_records) == 1 - assert "answer" in fallback_records[0].getMessage() - assert MODEL_ALIAS in fallback_records[0].getMessage() - assert fallback_records[0].exc_info is not None def test_scheduler_custom_model_task_group_spec_uses_alias_set_weight() -> None: @@ -1874,15 +2335,15 @@ def gen_with_models(row: dict, generator_params: None, models: dict) -> dict: graph=graph, tracker=tracker, row_groups=[(0, 1)], - max_llm_wait_tasks=10, + max_model_task_admission=10, ) - spec = scheduler._task_group_spec(Task(column="custom_llm", row_group=0, row_index=0, task_type="cell")) + spec = scheduler._schedulable_task(Task(column="custom_llm", row_group=0, row_index=0, task_type="cell")).group assert spec.key.kind == "custom_model" - assert spec.key.identity == ("custom_llm", "draft", "judge") - assert spec.weight == 5.0 - assert spec.admitted_limit == 10 + assert spec.key.identity[:3] == ("custom_model", "custom_column", "draft-judge") + assert spec.weight == 2.0 + assert spec.admitted_limit == 4 @pytest.mark.asyncio(loop_scope="session") @@ -1927,33 +2388,28 @@ async def test_scheduler_llm_bound_429_retried_in_salvage() -> None: row_groups=row_groups, buffer_manager=buffer_mgr, max_submitted_tasks=max_submitted, - max_llm_wait_tasks=max_llm_wait, + max_model_task_admission=max_llm_wait, ) await scheduler.run() assert tracker.is_row_group_complete(0, num_records, ["seed", "llm_col"]) - sub_available, llm_available = scheduler.get_semaphore_permits() - assert sub_available == max_submitted, ( - f"Submission semaphore leaked after salvage retry: available={sub_available}, expected={max_submitted}" - ) - assert llm_available == max_llm_wait, ( - f"LLM-wait semaphore leaked after salvage retry: available={llm_available}, expected={max_llm_wait}" - ) + snapshot = scheduler.task_admission_snapshot() + assert snapshot.resources_available["submission"] == max_submitted + assert snapshot.resources_available["llm_wait"] == max_llm_wait @pytest.mark.asyncio(loop_scope="session") -async def test_scheduler_cancellation_releases_semaphores() -> None: - """Cancelling the scheduler while LLM-bound tasks are in-flight releases all semaphore slots.""" +async def test_scheduler_cancellation_releases_task_admission_leases() -> None: + """Cancelling the scheduler while model-stage tasks are in-flight releases task leases.""" provider = _mock_provider() blocked = asyncio.Event() proceed = asyncio.Event() class BlockingLLMGenerator(ColumnGenerator[ExpressionColumnConfig]): - @property - def is_llm_bound(self) -> bool: - return True + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.custom_model("test", self.config.name, "v1") @staticmethod def get_generation_strategy() -> GenerationStrategy: @@ -1987,13 +2443,15 @@ async def agenerate(self, data: dict) -> dict: max_submitted = 4 max_llm_wait = 2 + sink = InMemoryAdmissionEventSink() scheduler = AsyncTaskScheduler( generators=generators, graph=graph, tracker=tracker, row_groups=row_groups, max_submitted_tasks=max_submitted, - max_llm_wait_tasks=max_llm_wait, + max_model_task_admission=max_llm_wait, + scheduler_event_sink=sink, ) run_task = asyncio.create_task(scheduler.run()) @@ -2003,13 +2461,14 @@ async def agenerate(self, data: dict) -> dict: with pytest.raises(asyncio.CancelledError): await run_task - sub_available, llm_available = scheduler.get_semaphore_permits() - assert sub_available == max_submitted, ( - f"Submission semaphore leaked: available={sub_available}, expected={max_submitted}" - ) - assert llm_available == max_llm_wait, ( - f"LLM-wait semaphore leaked: available={llm_available}, expected={max_llm_wait}" - ) + snapshot = scheduler.task_admission_snapshot() + assert snapshot.resources_available["submission"] == max_submitted + assert snapshot.resources_available["llm_wait"] == max_llm_wait + assert "cancelled" in [event.event_kind for event in sink.scheduler_events] + assert all(event.snapshot is not None for event in sink.scheduler_events) + task_events = [event for event in sink.scheduler_events if event.task_id is not None] + assert all("resource_request" in event.diagnostics for event in task_events) + assert any("llm_wait" in event.diagnostics["resource_request"] for event in task_events) @pytest.mark.asyncio(loop_scope="session") @@ -2017,7 +2476,7 @@ async def test_scheduler_rg_semaphore_deadlock_with_transient_failures() -> None """Row groups stalled by transient failures don't block admission of new row groups. Regression test: with max_concurrent_row_groups=1 and 2 row groups, if all - tasks in RG0 fail transiently, the semaphore must still be released so RG1 + tasks in RG0 fail transiently, row-group capacity must still be released so RG1 can be admitted. The scheduler salvages RG0 inline and continues. """ provider = _mock_provider() @@ -2098,31 +2557,6 @@ def test_side_effect_columns_separated_from_completion_tracking() -> None: assert "side_b" in write_cols -# -- TrackingSemaphore tests --------------------------------------------------- - - -def test_tracking_semaphore_try_acquire() -> None: - """try_acquire returns True when permits are available, False when exhausted.""" - from data_designer.engine.dataset_builders.async_scheduler import TrackingSemaphore - - sem = TrackingSemaphore(2) - assert sem.available_permits == 2 - - assert sem.try_acquire() is True - assert sem.available_permits == 1 - - assert sem.try_acquire() is True - assert sem.available_permits == 0 - - assert sem.try_acquire() is False - assert sem.available_permits == 0 - - sem.release() - assert sem.available_permits == 1 - assert sem.try_acquire() is True - assert sem.available_permits == 0 - - # -- Pipeline parallelism (stale dispatch fix, issue #504) --------------------- @@ -2147,11 +2581,80 @@ async def agenerate(self, data: dict) -> dict: class SlowLLMBoundCellGenerator(SlowCellGenerator): - """Slow cell generator that participates in LLM-wait scheduling.""" + """Slow cell generator that participates in model-stage scheduling.""" + + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.custom_model("test", self.config.name, "v1") + + +class SlowModelBoundCellGenerator(SlowCellGenerator): + """Slow cell generator with concrete request-pressure identity.""" + + def __init__( + self, + *args: Any, + provider_name: str = "provider", + model_id: str = "model", + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._provider_name = provider_name + self._model_id = model_id + + def get_scheduling_metadata(self) -> SchedulingMetadata: + return SchedulingMetadata.model( + self._provider_name, + self._model_id, + "chat", + weight=1, + ) + + +class _StaticRequestPressureProvider: + def __init__(self, snapshots: dict[RequestResourceKey, RequestPressureSnapshot]) -> None: + self._snapshots = snapshots @property - def is_llm_bound(self) -> bool: - return True + def config(self) -> RequestAdmissionConfig | None: + return None + + def snapshot(self, resource: RequestResourceKey) -> RequestPressureSnapshot | None: + return self._snapshots.get(resource) + + def snapshots(self) -> dict[RequestResourceKey, RequestPressureSnapshot]: + return dict(self._snapshots) + + def global_snapshot(self, provider: str, model: str) -> None: + return None + + def global_snapshots(self) -> dict[ProviderModelKey, object]: + return {} + + +def _pressure_snapshot( + resource: RequestResourceKey, + *, + current_limit: int = 1, + in_flight: int = 0, + waiters: int = 0, + cooldown: float = 0.0, +) -> RequestPressureSnapshot: + return RequestPressureSnapshot( + captured_at=time.monotonic(), + sequence=1, + resource=resource, + effective_max=max(1, current_limit), + current_limit=current_limit, + in_flight_count=in_flight, + active_lease_count=in_flight, + waiters=waiters, + blocked_until_monotonic=time.monotonic() + cooldown if cooldown > 0.0 else None, + cooldown_remaining_seconds=cooldown, + rate_limit_ceiling=max(1, current_limit), + consecutive_rate_limits=0, + last_outcome=None, + leak_diagnostics={}, + ) @pytest.mark.asyncio(loop_scope="session") @@ -2304,7 +2807,7 @@ async def test_scheduler_fair_llm_group_cap_preserves_peer_admission() -> None: tracker=tracker, row_groups=row_groups, max_submitted_tasks=4, - max_llm_wait_tasks=4, + max_model_task_admission=4, trace=True, ) @@ -2318,7 +2821,9 @@ async def test_scheduler_fair_llm_group_cap_preserves_peer_admission() -> None: assert first_window.count("hot") == 2 assert first_window.count("peer") == 2 assert tracker.is_row_group_complete(0, 8, ["topic", *gen_names]) - assert scheduler.get_semaphore_permits() == (4, 4) + snapshot = scheduler.task_admission_snapshot() + assert snapshot.resources_available["submission"] == 4 + assert snapshot.resources_available["llm_wait"] == 4 @pytest.mark.asyncio(loop_scope="session") @@ -2332,9 +2837,9 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None: ├── gen_b (slow, 50ms) → judge_b (instant) └── gen_c (slow, 50ms) → judge_c (instant) - With a small semaphore (4) and 10 records, the 30 gen tasks (3 cols x 10 rows) - saturate the semaphore. The dispatch loop must re-query the frontier when the - semaphore is full so that judge tasks from completed gen rows are picked up + With small task admission capacity (4) and 10 records, the 30 gen tasks + saturate admission. The dispatch loop must re-query the frontier when capacity + is full so that judge tasks from completed gen rows are picked up before all gen tasks finish. """ provider = _mock_provider() @@ -2394,6 +2899,426 @@ async def test_scheduler_downstream_interleaves_with_upstream() -> None: ) +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_capacity_plan_observes_buffer_backpressure() -> None: + provider = _mock_provider() + gen_names = ["gen_a", "gen_b", "gen_c"] + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + *[LLMTextColumnConfig(name=g, prompt="{{ topic }}", model_alias=MODEL_ALIAS) for g in gen_names], + ] + strategies: dict[str, GenerationStrategy] = {"topic": GenerationStrategy.FULL_COLUMN} + strategies.update({column: GenerationStrategy.CELL_BY_CELL for column in gen_names}) + generators: dict[str, ColumnGenerator] = { + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + **{ + name: SlowCellGenerator(config=_expr_config(name), resource_provider=provider, delay=0.02) + for name in gen_names + }, + } + graph = ExecutionGraph.create(configs, strategies) + row_groups = [(0, 3), (1, 3), (2, 3), (3, 3)] + tracker = CompletionTracker.with_graph(graph, row_groups) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=2, + max_submitted_tasks=2, + trace=True, + num_records=12, + buffer_size=3, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + plan = scheduler.capacity_plan() + for row_group_index, row_count in row_groups: + assert tracker.is_row_group_complete(row_group_index, row_count, ["topic", *gen_names]) + assert plan.configured.row_group_admission.observed_in_flight == 0 + assert plan.observed_maxima.row_groups_in_flight == 2 + assert plan.observed_maxima.queued_tasks_by_group + assert max(plan.observed_maxima.task_leases_by_resource.values()) <= 2 + + +def test_scheduler_capacity_plan_reports_request_admission_state() -> None: + resource = RequestResourceKey("provider", "model", RequestDomain.CHAT) + request_admission = AdaptiveRequestAdmissionController( + RequestAdmissionConfig(initial_limits={resource: 2}, max_limit_clamps={resource: 3}) + ) + request_admission.register( + provider_name="provider", + model_id="model", + alias="primary", + max_parallel_requests=4, + ) + lease = request_admission.try_acquire(RequestAdmissionItem(resource, RequestGroupSpec(resource))) + assert isinstance(lease, RequestAdmissionLease) + + scheduler, _tracker = _build_simple_pipeline() + scheduler._request_pressure_provider = request_admission + scheduler._record_observed_task_state() + plan = scheduler.capacity_plan() + + assert plan.configured.request_resources.value == (resource,) + assert plan.configured.request_domain_initial_limits.value[resource] == 2 + assert plan.configured.request_admission_config.value is not None + assert plan.configured.provider_model_static_caps.value[ProviderModelKey("provider", "model")].cap == 4 + assert plan.runtime_snapshot.request_domain_current_limits[resource] == 2 + assert plan.runtime_snapshot.request_domain_effective_max[resource] == 3 + assert plan.runtime_snapshot.provider_model_aggregate_in_flight[ProviderModelKey("provider", "model")] == 1 + assert plan.observed_maxima.request_in_flight_by_resource[resource] == 1 + assert plan.observed_maxima.provider_model_aggregate_in_flight[ProviderModelKey("provider", "model")] == 1 + request_admission.release(lease, RequestReleaseOutcome(kind="success")) + + +def test_scheduler_capacity_plan_reports_default_request_initial_limit_after_aimd_drop() -> None: + resource = RequestResourceKey("provider", "model", RequestDomain.CHAT) + request_admission = AdaptiveRequestAdmissionController() + request_admission.register( + provider_name="provider", + model_id="model", + alias="primary", + max_parallel_requests=4, + ) + lease = request_admission.try_acquire(RequestAdmissionItem(resource, RequestGroupSpec(resource))) + assert isinstance(lease, RequestAdmissionLease) + request_admission.release(lease, RequestReleaseOutcome(kind="rate_limited")) + + scheduler, _tracker = _build_simple_pipeline() + scheduler._request_pressure_provider = request_admission + plan = scheduler.capacity_plan() + + assert plan.configured.request_domain_initial_limits.value[resource] == 4 + assert plan.runtime_snapshot.request_domain_effective_max[resource] == 4 + assert plan.runtime_snapshot.request_domain_current_limits[resource] == 3 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_emits_job_health_and_row_group_telemetry() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 2)] + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, row_groups) + sink = InMemoryAdmissionEventSink() + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": SlowLLMBoundCellGenerator( + config=_expr_config("model_col"), + resource_provider=provider, + delay=0.0, + ), + }, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=1, + max_submitted_tasks=2, + max_model_task_admission=1, + scheduler_event_sink=sink, + num_records=2, + buffer_size=2, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + kinds = [event.event_kind for event in sink.scheduler_events] + assert "scheduler_job_started" in kinds + assert "scheduler_health_snapshot" in kinds + assert "row_group_checkpointed" in kinds + assert "scheduler_job_completed" in kinds + + started = next(event for event in sink.scheduler_events if event.event_kind == "scheduler_job_started") + assert started.diagnostics["num_records"] == 2 + assert started.diagnostics["buffer_size"] == 2 + assert started.diagnostics["row_group_count"] == 1 + assert started.diagnostics["graph_depth"] == 2 + column_scheduling = started.diagnostics["column_scheduling"] + assert isinstance(column_scheduling, list) + model_column = next(item for item in column_scheduling if item["column"] == "model_col") + assert model_column["group_kind"] == "custom_model" + assert model_column["resource_request"] == {"submission": 1, "llm_wait": 1} + + health = next(event for event in sink.scheduler_events if event.event_kind == "scheduler_health_snapshot") + assert "queued_total" in health.diagnostics + assert "leased_resources" in health.diagnostics + assert "request_pressure" in health.diagnostics + + checkpointed = next(event for event in sink.scheduler_events if event.event_kind == "row_group_checkpointed") + assert checkpointed.diagnostics["row_group"] == 0 + assert checkpointed.diagnostics["row_group_size"] == 2 + assert checkpointed.diagnostics["surviving_rows"] == 2 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_adaptive_row_group_admission_expands_target_for_horizon_idle() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1), (1, 1), (2, 1), (3, 1)] + generators: dict[str, ColumnGenerator] = { + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": SlowLLMBoundCellGenerator( + config=_expr_config("model_col"), + resource_provider=provider, + delay=0.04, + ), + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, row_groups) + sink = InMemoryAdmissionEventSink() + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=row_groups, + max_concurrent_row_groups=4, + max_submitted_tasks=4, + max_model_task_admission=4, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=1, + scheduler_event_sink=sink, + trace=True, + num_records=4, + buffer_size=1, + ) + + await asyncio.wait_for(scheduler.run(), timeout=10.0) + + plan = scheduler.capacity_plan() + assert tracker.is_row_group_complete(0, 1, ["topic", "model_col"]) + assert plan.configured.row_group_admission.mode == "adaptive" + assert plan.configured.row_group_admission.observed_max_target is not None + assert plan.configured.row_group_admission.observed_max_target > 1 + assert plan.observed_maxima.row_groups_in_flight > 1 + assert any(event.event_kind == "row_group_admission_target_changed" for event in sink.scheduler_events) + + +def test_scheduler_adaptive_row_group_row_guard_blocks_extra_large_groups() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 5_000), (1, 5_000)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": SlowLLMBoundCellGenerator( + config=_expr_config("model_col"), + resource_provider=provider, + delay=0.0, + ), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=4, + adaptive_row_group_admission=True, + adaptive_row_group_initial_target=4, + num_records=10_000, + buffer_size=1, + ) + + scheduler._rg_states[0] = SimpleNamespace(size=5_000) + + assert scheduler._adaptive_max_admitted_rows == 8_192 + assert not scheduler._row_group_row_guard_allows(5_000) + assert scheduler._row_group_row_guard_allows(1_000) + scheduler._rg_states.clear() + assert scheduler._row_group_row_guard_allows(9_000) + + +def test_scheduler_adaptive_row_group_block_reason_prefers_llm_saturation() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1), (1, 1)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": SlowLLMBoundCellGenerator( + config=_expr_config("model_col"), + resource_provider=provider, + delay=0.0, + ), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + max_concurrent_row_groups=2, + adaptive_row_group_admission=True, + num_records=2, + buffer_size=1, + ) + scheduler._fair_queue = SimpleNamespace( + view=lambda: SimpleNamespace(queued_total=1, queued_peer_demand_by_resource={}) + ) + scheduler._task_admission = SimpleNamespace( + view=lambda: SimpleNamespace(resource_limits={"llm_wait": 1}, resources_available={"llm_wait": 0}) + ) + + assert scheduler._adaptive_row_group_block_reason() == "llm_wait_saturated" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_scheduler_raises_when_ready_frontier_blocked_without_in_flight() -> None: + provider = _mock_provider() + configs = [ + SamplerColumnConfig(name="topic", sampler_type=SamplerType.CATEGORY, params={"values": ["A"]}), + LLMTextColumnConfig(name="model_col", prompt="{{ topic }}", model_alias=MODEL_ALIAS), + ] + strategies = { + "topic": GenerationStrategy.FULL_COLUMN, + "model_col": GenerationStrategy.CELL_BY_CELL, + } + row_groups = [(0, 1)] + graph = ExecutionGraph.create(configs, strategies) + scheduler = AsyncTaskScheduler( + generators={ + "topic": MockSeedGenerator(config=_expr_config("topic"), resource_provider=provider), + "model_col": SlowLLMBoundCellGenerator( + config=_expr_config("model_col"), + resource_provider=provider, + delay=0.0, + ), + }, + graph=graph, + tracker=CompletionTracker.with_graph(graph, row_groups), + row_groups=row_groups, + task_admission_config=TaskAdmissionConfig( + submission_capacity=1, + resource_limits={"submission": 1, "local": 1}, + ), + ) + + with pytest.raises(RuntimeError, match="Ready frontier is admission-blocked"): + await asyncio.wait_for(scheduler.run(), timeout=2.0) + + +def test_scheduler_request_pressure_advisory_prefers_pressure_open_peer() -> None: + provider = _mock_provider() + configs = [ + LLMTextColumnConfig(name="pressured", prompt="A", model_alias=MODEL_ALIAS), + LLMTextColumnConfig(name="open", prompt="B", model_alias=MODEL_ALIAS), + ] + strategies = { + "pressured": GenerationStrategy.CELL_BY_CELL, + "open": GenerationStrategy.CELL_BY_CELL, + } + generators: dict[str, ColumnGenerator] = { + "pressured": SlowModelBoundCellGenerator( + config=_expr_config("pressured"), + resource_provider=provider, + provider_name="provider-a", + model_id="model-a", + ), + "open": SlowModelBoundCellGenerator( + config=_expr_config("open"), + resource_provider=provider, + provider_name="provider-b", + model_id="model-b", + ), + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 1)]) + pressured_key = RequestResourceKey("provider-a", "model-a", RequestDomain.CHAT) + open_key = RequestResourceKey("provider-b", "model-b", RequestDomain.CHAT) + pressure = _StaticRequestPressureProvider( + { + pressured_key: _pressure_snapshot(pressured_key, current_limit=1, in_flight=1, waiters=1), + open_key: _pressure_snapshot(open_key, current_limit=1, in_flight=0, waiters=0), + } + ) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=[(0, 1)], + request_pressure_provider=pressure, + request_pressure_advisory=True, + scheduler_event_sink=(sink := InMemoryAdmissionEventSink()), + ) + scheduler._rg_states[0] = SimpleNamespace(size=1, pre_batch_done=True) + pressured = scheduler._schedulable_task(Task(column="pressured", row_group=0, row_index=0, task_type="cell")) + open_task = scheduler._schedulable_task(Task(column="open", row_group=0, row_index=0, task_type="cell")) + scheduler._fair_queue.enqueue((pressured, open_task)) + + selection = scheduler._fair_queue.select_next(scheduler._is_dispatch_eligible) + + assert selection is not None + assert selection.item.payload.column == "open" + skip = next(event for event in sink.scheduler_events if event.event_kind == "request_pressure_advisory_skipped") + assert skip.diagnostics["request_resource"] == "provider-a/model-a/chat" + assert skip.diagnostics["pressure_reason"] == "waiters" + assert skip.diagnostics["open_peer_column"] == "open" + assert skip.diagnostics["open_peer_request_resource"] == "provider-b/model-b/chat" + + +def test_scheduler_request_pressure_advisory_preserves_liveness_when_all_candidates_pressured() -> None: + provider = _mock_provider() + configs = [LLMTextColumnConfig(name="pressured", prompt="A", model_alias=MODEL_ALIAS)] + strategies = {"pressured": GenerationStrategy.CELL_BY_CELL} + generators: dict[str, ColumnGenerator] = { + "pressured": SlowModelBoundCellGenerator( + config=_expr_config("pressured"), + resource_provider=provider, + provider_name="provider-a", + model_id="model-a", + ), + } + graph = ExecutionGraph.create(configs, strategies) + tracker = CompletionTracker.with_graph(graph, [(0, 1)]) + pressured_key = RequestResourceKey("provider-a", "model-a", RequestDomain.CHAT) + pressure = _StaticRequestPressureProvider( + {pressured_key: _pressure_snapshot(pressured_key, current_limit=1, in_flight=1, waiters=1)} + ) + scheduler = AsyncTaskScheduler( + generators=generators, + graph=graph, + tracker=tracker, + row_groups=[(0, 1)], + request_pressure_provider=pressure, + request_pressure_advisory=True, + ) + scheduler._rg_states[0] = SimpleNamespace(size=1, pre_batch_done=True) + pressured = scheduler._schedulable_task(Task(column="pressured", row_group=0, row_index=0, task_type="cell")) + scheduler._fair_queue.enqueue((pressured,)) + + selection = scheduler._fair_queue.select_next(scheduler._is_dispatch_eligible) + + assert selection is not None + assert selection.item.payload.column == "pressured" + + # -- Skip / conditional generation tests (async engine) ----------------------- diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py index dfd219fd5..6a5b31a51 100644 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py +++ b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_execution_graph.py @@ -20,9 +20,9 @@ from data_designer.config.utils.code_lang import CodeLang from data_designer.config.validator_params import CodeValidatorParams from data_designer.engine.dataset_builders.multi_column_configs import SamplerMultiColumnConfig +from data_designer.engine.dataset_builders.scheduling.task_model import SliceRef from data_designer.engine.dataset_builders.utils.errors import ConfigCompilationError, DAGCircularDependencyError from data_designer.engine.dataset_builders.utils.execution_graph import ExecutionGraph -from data_designer.engine.dataset_builders.utils.task_model import SliceRef MODEL_ALIAS = "stub-model-alias" diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_fair_task_queue.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_fair_task_queue.py deleted file mode 100644 index b929bce4f..000000000 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_fair_task_queue.py +++ /dev/null @@ -1,219 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from collections import Counter - -from data_designer.engine.dataset_builders.utils.fair_task_queue import ( - FairTaskQueue, - TaskGroupKey, - TaskGroupSpec, -) -from data_designer.engine.dataset_builders.utils.task_model import Task - - -def _task(column: str, row_index: int) -> Task: - return Task(column=column, row_group=0, row_index=row_index, task_type="cell") - - -def _group(name: str, *, weight: float = 1.0, admitted_limit: int | None = None) -> TaskGroupSpec: - return TaskGroupSpec( - key=TaskGroupKey(kind="local", identity=(name,)), - weight=weight, - admitted_limit=admitted_limit, - ) - - -def _enqueue(queue: FairTaskQueue, items: list[tuple[Task, TaskGroupSpec]]) -> None: - for task, group in items: - queue.enqueue(task, group) - - -def test_fair_task_queue_equal_groups_round_robins() -> None: - queue = FairTaskQueue() - _enqueue( - queue, - [ - (task, _group(task.column)) - for task in [ - _task("a", 0), - _task("a", 1), - _task("b", 0), - _task("b", 1), - _task("c", 0), - _task("c", 1), - ] - ], - ) - - selected = [queue.admit_next() for _ in range(6)] - - assert [selection.task.column for selection in selected if selection is not None] == ["a", "b", "c", "a", "b", "c"] - - -def test_fair_task_queue_weighted_groups() -> None: - queue = FairTaskQueue() - _enqueue( - queue, - [ - (task, _group(task.column, weight=2 if task.column == "a" else 1)) - for task in [_task("a", i) for i in range(6)] - ] - + [(_task("b", i), _group("b", weight=1)) for i in range(6)], - ) - - selected = [queue.admit_next() for _ in range(6)] - counts = Counter(selection.task.column for selection in selected if selection is not None) - - assert counts == {"a": 4, "b": 2} - - -def test_fair_task_queue_discards_queued_tasks() -> None: - queue = FairTaskQueue() - stale = _task("a", 0) - fresh = _task("a", 1) - - _enqueue(queue, [(stale, _group("a")), (fresh, _group("a"))]) - queue.discard(stale) - - selected = queue.admit_next() - - assert selected is not None - assert selected.task == fresh - assert queue.admit_next() is None - - -def test_fair_task_queue_admitted_cap_skips_saturated_group_with_waiting_peer() -> None: - queue = FairTaskQueue() - capped = _group("a", admitted_limit=1, weight=1_000) - peer = _group("b") - _enqueue( - queue, - [ - (_task("a", 0), capped), - (_task("a", 1), capped), - (_task("b", 0), peer), - (_task("b", 1), peer), - ], - ) - - first = queue.admit_next() - peer_first = queue.admit_next() - selected = queue.admit_next() - - assert first is not None - assert first.task.column == "a" - assert peer_first is not None - assert peer_first.task.column == "b" - assert selected is not None - assert selected.task.column == "b" - - -def test_fair_task_queue_solo_group_can_exceed_admitted_cap() -> None: - queue = FairTaskQueue() - group = _group("a", admitted_limit=1) - first_task = _task("a", 0) - second_task = _task("a", 1) - queue.enqueue(first_task, group) - queue.enqueue(second_task, group) - - first = queue.admit_next() - - assert first is not None - assert first.task == first_task - second = queue.admit_next() - assert second is not None - assert second.task == second_task - assert queue.has_queued_tasks is False - - -def test_fair_task_queue_over_cap_group_yields_to_queued_peer() -> None: - queue = FairTaskQueue() - capped = _group("a", admitted_limit=1) - peer = _group("b") - _enqueue(queue, [(_task("a", i), capped) for i in range(5)]) - - solo_selected = [queue.admit_next() for _ in range(3)] - _enqueue(queue, [(_task("b", i), peer) for i in range(2)]) - peer_selected = [queue.admit_next() for _ in range(2)] - - assert [selection.task.column for selection in solo_selected if selection is not None] == ["a", "a", "a"] - assert [selection.task.column for selection in peer_selected if selection is not None] == ["b", "b"] - - -def test_fair_task_queue_returns_none_when_all_competing_groups_capped() -> None: - queue = FairTaskQueue() - group_a = _group("a", admitted_limit=1) - group_b = _group("b", admitted_limit=1) - _enqueue( - queue, - [ - (_task("a", 0), group_a), - (_task("a", 1), group_a), - (_task("b", 0), group_b), - (_task("b", 1), group_b), - ], - ) - - selected = [queue.admit_next() for _ in range(2)] - - assert [selection.task.column for selection in selected if selection is not None] == ["a", "b"] - assert queue.admit_next() is None - assert queue.has_queued_tasks is True - - -def test_fair_task_queue_release_reopens_saturated_group() -> None: - queue = FairTaskQueue() - group_a = _group("a", admitted_limit=1) - group_b = _group("b", admitted_limit=1) - _enqueue( - queue, - [ - (_task("a", 0), group_a), - (_task("a", 1), group_a), - (_task("b", 0), group_b), - (_task("b", 1), group_b), - ], - ) - first = queue.admit_next() - second = queue.admit_next() - - assert first is not None - assert first.task.column == "a" - assert second is not None - assert second.task.column == "b" - assert queue.admit_next() is None - - queue.release(first.task) - reopened = queue.admit_next() - - assert reopened is not None - assert reopened.task == _task("a", 1) - - -def test_fair_task_queue_no_duplicate_on_repeated_enqueue() -> None: - queue = FairTaskQueue() - task = _task("a", 0) - - queue.enqueue(task, _group("a")) - queue.enqueue(task, _group("a")) - first = queue.admit_next() - - assert first is not None - assert first.task == task - assert queue.admit_next() is None - - -def test_fair_task_queue_discard_where_removes_matching_tasks() -> None: - queue = FairTaskQueue() - _enqueue( - queue, - [(_task(column, i), _group(column)) for column in ["a", "b"] for i in range(2)], - ) - - queue.discard_where(lambda task: task.column == "a") - selected = [queue.admit_next() for _ in range(2)] - - assert [selection.task.column for selection in selected if selection is not None] == ["b", "b"] - assert queue.admit_next() is None diff --git a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_scheduling_hints.py b/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_scheduling_hints.py deleted file mode 100644 index 4e46c07b0..000000000 --- a/packages/data-designer-engine/tests/engine/dataset_builders/utils/test_scheduling_hints.py +++ /dev/null @@ -1,150 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from data_designer.config.column_configs import ( - CustomColumnConfig, - ExpressionColumnConfig, - GenerationStrategy, - LLMTextColumnConfig, -) -from data_designer.config.custom_column import custom_column_generator -from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig -from data_designer.engine.column_generators.generators.base import ColumnGenerator -from data_designer.engine.column_generators.generators.custom import CustomColumnGenerator -from data_designer.engine.dataset_builders.utils.scheduling_hints import SchedulingHint, SchedulingHintResolver -from data_designer.engine.resources.resource_provider import ResourceProvider - -MODEL_ALIAS = "stub" - - -def _expr_config(name: str = "test") -> ExpressionColumnConfig: - return ExpressionColumnConfig(name=name, expr="{{ x }}", dtype="str") - - -def _provider_with_model_configs(configs: dict[str, ModelConfig]) -> MagicMock: - provider = MagicMock(spec=ResourceProvider) - provider.model_registry = MagicMock() - provider.model_registry.get_model_config.side_effect = lambda model_alias: configs[model_alias] - provider.model_registry.get_model_provider.return_value = SimpleNamespace(name="mock-provider") - return provider - - -class LocalCellGenerator(ColumnGenerator[ExpressionColumnConfig]): - @staticmethod - def get_generation_strategy() -> GenerationStrategy: - return GenerationStrategy.CELL_BY_CELL - - def generate(self, data: dict) -> dict: - data[self.config.name] = "local" - return data - - -class ModelCellGenerator(ColumnGenerator[LLMTextColumnConfig]): - @property - def is_llm_bound(self) -> bool: - return True - - @staticmethod - def get_generation_strategy() -> GenerationStrategy: - return GenerationStrategy.CELL_BY_CELL - - def generate(self, data: dict) -> dict: - data[self.config.name] = "model" - return data - - def get_model_config(self, model_alias: str) -> ModelConfig: - return self.resource_provider.model_registry.get_model_config(model_alias=model_alias) - - def get_model_provider_name(self, model_alias: str) -> str: - provider = self.resource_provider.model_registry.get_model_provider(model_alias=model_alias) - return str(provider.name) - - -def test_scheduling_hint_resolver_local_hint_does_not_touch_model_registry() -> None: - provider = MagicMock(spec=ResourceProvider) - provider.model_registry = MagicMock() - generator = LocalCellGenerator(config=_expr_config("local_col"), resource_provider=provider) - - resolver = SchedulingHintResolver({"local_col": generator}) - - assert resolver.hint_for(generator) == SchedulingHint(group_kind="local") - provider.model_registry.get_model_config.assert_not_called() - provider.model_registry.get_model_provider.assert_not_called() - - -def test_scheduling_hint_resolver_resolves_primary_model_once_per_generator() -> None: - model_config = ModelConfig( - alias=MODEL_ALIAS, - model="model-text", - inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=3), - provider="mock-provider", - ) - provider = _provider_with_model_configs({MODEL_ALIAS: model_config}) - column_config = LLMTextColumnConfig(name="answer", prompt="hello", model_alias=MODEL_ALIAS) - generator = ModelCellGenerator(config=column_config, resource_provider=provider) - - resolver = SchedulingHintResolver({"answer": generator, "answer_again": generator}) - hint = resolver.hint_for(generator) - - assert hint.group_kind == "model" - assert hint.identity_prefix[:2] == ("mock-provider", "model-text") - assert hint.weight == 3 - assert provider.model_registry.get_model_config.call_count == 1 - assert provider.model_registry.get_model_provider.call_count == 1 - - -def test_scheduling_hint_resolver_falls_back_to_custom_model_hint_with_debug( - caplog: pytest.LogCaptureFixture, -) -> None: - provider = MagicMock(spec=ResourceProvider) - provider.model_registry = MagicMock() - provider.model_registry.get_model_config.side_effect = RuntimeError("registry unavailable") - provider.model_registry.get_model_provider.return_value = SimpleNamespace(name="mock-provider") - column_config = LLMTextColumnConfig(name="answer", prompt="hello", model_alias=MODEL_ALIAS) - generator = ModelCellGenerator(config=column_config, resource_provider=provider) - - with caplog.at_level("DEBUG", logger="data_designer.engine.dataset_builders.utils.scheduling_hints"): - resolver = SchedulingHintResolver({"answer": generator}) - - hint = resolver.hint_for(generator) - - assert hint == SchedulingHint(group_kind="custom_model", identity_suffix=(MODEL_ALIAS,), weight=1) - fallback_records = [ - record for record in caplog.records if "Falling back to custom-model scheduling group" in record.getMessage() - ] - assert len(fallback_records) == 1 - assert "answer" in fallback_records[0].getMessage() - assert MODEL_ALIAS in fallback_records[0].getMessage() - assert fallback_records[0].exc_info is not None - - -def test_scheduling_hint_resolver_partial_alias_fallback_preserves_resolved_weight() -> None: - @custom_column_generator(model_aliases=["resolved", "missing"]) - def gen_with_models(row: dict, generator_params: None, models: dict) -> dict: - row["custom_llm"] = "value" - return row - - provider = _provider_with_model_configs( - { - "resolved": ModelConfig( - alias="resolved", - model="model-resolved", - inference_parameters=ChatCompletionInferenceParams(max_parallel_requests=7), - provider="mock-provider", - ) - } - ) - config = CustomColumnConfig(name="custom_llm", generator_function=gen_with_models) - generator = CustomColumnGenerator(config=config, resource_provider=provider) - - resolver = SchedulingHintResolver({"custom_llm": generator}) - hint = resolver.hint_for(generator) - - assert hint == SchedulingHint(group_kind="custom_model", identity_suffix=("missing", "resolved"), weight=7) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_factory.py b/packages/data-designer-engine/tests/engine/models/clients/test_factory.py index ffdad291f..f809db8be 100644 --- a/packages/data-designer-engine/tests/engine/models/clients/test_factory.py +++ b/packages/data-designer-engine/tests/engine/models/clients/test_factory.py @@ -18,9 +18,9 @@ from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient from data_designer.engine.models.clients.factory import create_model_client +from data_designer.engine.models.clients.model_request_executor import ModelRequestExecutor from data_designer.engine.models.clients.retry import RetryConfig -from data_designer.engine.models.clients.throttle_manager import ThrottleManager -from data_designer.engine.models.clients.throttled import ThrottledModelClient +from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController from data_designer.engine.secret_resolver import SecretResolver @@ -178,40 +178,51 @@ def test_concurrency_mode_defaults_to_sync( assert client.concurrency_mode == ClientConcurrencyMode.SYNC -# --- Throttle manager wrapping --- +# --- Request admission wrapping --- -def test_throttle_manager_wraps_openai_client( +def test_request_admission_wraps_openai_client( openai_model_config: ModelConfig, secret_resolver: SecretResolver, openai_registry: ModelProviderRegistry, ) -> None: - tm = ThrottleManager() + controller = AdaptiveRequestAdmissionController() + retry_config = RetryConfig(max_retries=5) client = create_model_client( - openai_model_config, secret_resolver, openai_registry, retry_config=RetryConfig(), throttle_manager=tm + openai_model_config, + secret_resolver, + openai_registry, + retry_config=retry_config, + request_admission=controller, ) - assert isinstance(client, ThrottledModelClient) + assert isinstance(client, ModelRequestExecutor) assert isinstance(client._inner, OpenAICompatibleClient) + assert client._retry_config is retry_config + assert client._inner._retry_config.max_retries == 0 -def test_throttle_manager_wraps_anthropic_client( +def test_request_admission_wraps_anthropic_client( anthropic_model_config: ModelConfig, secret_resolver: SecretResolver, anthropic_registry: ModelProviderRegistry, ) -> None: - tm = ThrottleManager() + controller = AdaptiveRequestAdmissionController() client = create_model_client( - anthropic_model_config, secret_resolver, anthropic_registry, retry_config=RetryConfig(), throttle_manager=tm + anthropic_model_config, + secret_resolver, + anthropic_registry, + retry_config=RetryConfig(), + request_admission=controller, ) - assert isinstance(client, ThrottledModelClient) + assert isinstance(client, ModelRequestExecutor) assert isinstance(client._inner, AnthropicClient) -def test_no_throttle_manager_returns_inner_client_directly( +def test_no_request_admission_returns_inner_client_directly( openai_model_config: ModelConfig, secret_resolver: SecretResolver, openai_registry: ModelProviderRegistry, ) -> None: client = create_model_client(openai_model_config, secret_resolver, openai_registry) assert isinstance(client, OpenAICompatibleClient) - assert not isinstance(client, ThrottledModelClient) + assert not isinstance(client, ModelRequestExecutor) diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py b/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py new file mode 100644 index 000000000..2806ae569 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py @@ -0,0 +1,381 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging + +import pytest + +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind +from data_designer.engine.models.clients.model_request_executor import ModelRequestExecutor +from data_designer.engine.models.clients.retry import RetryConfig +from data_designer.engine.models.clients.types import ( + AssistantMessage, + ChatCompletionRequest, + ChatCompletionResponse, + EmbeddingRequest, + EmbeddingResponse, + ImageGenerationRequest, + ImageGenerationResponse, + ImagePayload, +) +from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController +from data_designer.engine.models.request_admission.resources import RequestDomain +from data_designer.engine.observability import InMemoryAdmissionEventSink + + +class _Client: + provider_name = "nvidia" + + def __init__(self) -> None: + self.error: Exception | None = None + + def supports_chat_completion(self) -> bool: + return True + + def supports_embeddings(self) -> bool: + return True + + def supports_image_generation(self) -> bool: + return True + + def close(self) -> None: + return None + + async def aclose(self) -> None: + return None + + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + if self.error is not None: + raise self.error + return ChatCompletionResponse(AssistantMessage(content="ok")) + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + if self.error is not None: + raise self.error + return ChatCompletionResponse(AssistantMessage(content="ok")) + + def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + return EmbeddingResponse(vectors=[[1.0]]) + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + return EmbeddingResponse(vectors=[[1.0]]) + + def generate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + return ImageGenerationResponse(images=[ImagePayload("abc")]) + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + return ImageGenerationResponse(images=[ImagePayload("abc")]) + + +class _BrokenSink: + def emit_request_event(self, _event: object) -> None: + raise RuntimeError("sink boom") + + +class _GatedAsyncClient(_Client): + def __init__(self) -> None: + super().__init__() + self.chat_started = asyncio.Event() + self.embedding_started = asyncio.Event() + self.image_started = asyncio.Event() + self.release_chat = asyncio.Event() + self.release_embedding = asyncio.Event() + self.release_image = asyncio.Event() + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + self.chat_started.set() + await self.release_chat.wait() + return ChatCompletionResponse(AssistantMessage(content="chat")) + + async def aembeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + self.embedding_started.set() + await self.release_embedding.wait() + return EmbeddingResponse(vectors=[[1.0]]) + + async def agenerate_image(self, request: ImageGenerationRequest) -> ImageGenerationResponse: + self.image_started.set() + await self.release_image.wait() + return ImageGenerationResponse(images=[ImagePayload("image")]) + + +class _FlakyClient(_Client): + def __init__( + self, + *, + failures: int, + kind: ProviderErrorKind = ProviderErrorKind.INTERNAL_SERVER, + status_code: int | None = 503, + ) -> None: + super().__init__() + self.failures = failures + self.calls = 0 + self.kind = kind + self.status_code = status_code + + def _maybe_fail(self) -> None: + self.calls += 1 + if self.calls <= self.failures: + raise ProviderError( + kind=self.kind, + message="temporarily unavailable", + status_code=self.status_code, + provider_name="nvidia", + model_name="nemotron", + ) + + def completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + self._maybe_fail() + return ChatCompletionResponse(AssistantMessage(content="ok")) + + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + self._maybe_fail() + return ChatCompletionResponse(AssistantMessage(content="ok")) + + +def _executor() -> tuple[ModelRequestExecutor, AdaptiveRequestAdmissionController, _Client]: + controller = AdaptiveRequestAdmissionController() + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + client = _Client() + return ModelRequestExecutor(client, controller, "nvidia", "nemotron"), controller, client + + +def test_model_request_executor_releases_successful_request() -> None: + executor, controller, _client = _executor() + + response = executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert response.message.content == "ok" + snapshot = controller.pressure.snapshot(next(iter(controller.pressure.snapshots()))) + assert snapshot is not None + assert snapshot.active_lease_count == 0 + assert snapshot.last_outcome == "success" + + +def test_model_request_executor_classifies_rate_limit() -> None: + executor, controller, client = _executor() + client.error = ProviderError( + kind=ProviderErrorKind.RATE_LIMIT, + message="rate limited", + provider_name="nvidia", + model_name="nemotron", + retry_after=1.0, + ) + + with pytest.raises(ProviderError): + executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + snapshot = controller.pressure.snapshot(next(iter(controller.pressure.snapshots()))) + assert snapshot is not None + assert snapshot.last_outcome == "rate_limited" + assert snapshot.cooldown_remaining_seconds > 0 + + +def test_model_request_executor_retries_provider_503_with_fresh_leases() -> None: + sink = InMemoryAdmissionEventSink() + controller = AdaptiveRequestAdmissionController(event_sink=sink) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + client = _FlakyClient(failures=1) + executor = ModelRequestExecutor( + client, + controller, + "nvidia", + "nemotron", + event_sink=sink, + retry_config=RetryConfig(max_retries=1, backoff_factor=0.0), + ) + + response = executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert response.message.content == "ok" + assert client.calls == 2 + acquired = [event for event in sink.request_events if event.event_kind == "request_lease_acquired"] + released = [event for event in sink.request_events if event.event_kind == "request_lease_released"] + assert len(acquired) == 2 + assert len(released) == 2 + assert {event.request_lease_id for event in acquired} == {event.request_lease_id for event in released} + + +def test_model_request_executor_does_not_retry_provider_timeout_without_status() -> None: + controller = AdaptiveRequestAdmissionController() + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + client = _FlakyClient(failures=2, kind=ProviderErrorKind.TIMEOUT, status_code=None) + executor = ModelRequestExecutor( + client, + controller, + "nvidia", + "nemotron", + retry_config=RetryConfig(max_retries=2, backoff_factor=0.0), + ) + + with pytest.raises(ProviderError) as exc_info: + executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert exc_info.value.kind == ProviderErrorKind.TIMEOUT + assert client.calls == 1 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_model_request_executor_retries_async_provider_503_with_fresh_leases() -> None: + sink = InMemoryAdmissionEventSink() + controller = AdaptiveRequestAdmissionController(event_sink=sink) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + client = _FlakyClient(failures=1) + executor = ModelRequestExecutor( + client, + controller, + "nvidia", + "nemotron", + event_sink=sink, + retry_config=RetryConfig(max_retries=1, backoff_factor=0.0), + ) + + response = await executor.acompletion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert response.message.content == "ok" + assert client.calls == 2 + acquired = [event for event in sink.request_events if event.event_kind == "request_lease_acquired"] + released = [event for event in sink.request_events if event.event_kind == "request_lease_released"] + assert len(acquired) == 2 + assert len(released) == 2 + assert {event.request_lease_id for event in acquired} == {event.request_lease_id for event in released} + + +@pytest.mark.asyncio(loop_scope="session") +async def test_model_request_executor_releases_async_cancellation() -> None: + class _SlowClient(_Client): + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + await asyncio.sleep(30) + return ChatCompletionResponse(AssistantMessage(content="late")) + + controller = AdaptiveRequestAdmissionController() + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + executor = ModelRequestExecutor(_SlowClient(), controller, "nvidia", "nemotron") + + task = asyncio.create_task(executor.acompletion(ChatCompletionRequest(model="nemotron", messages=[]))) + await asyncio.sleep(0) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + snapshot = controller.pressure.snapshot(next(iter(controller.pressure.snapshots()))) + assert snapshot is not None + assert snapshot.active_lease_count == 0 + assert snapshot.last_outcome == "local_cancelled" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_model_request_executor_classifies_async_keyboard_interrupt_as_cancelled() -> None: + class _InterruptingClient(_Client): + async def acompletion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + raise KeyboardInterrupt + + sink = InMemoryAdmissionEventSink() + controller = AdaptiveRequestAdmissionController(event_sink=sink) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + executor = ModelRequestExecutor(_InterruptingClient(), controller, "nvidia", "nemotron", event_sink=sink) + + with pytest.raises(KeyboardInterrupt): + await executor.acompletion(ChatCompletionRequest(model="nemotron", messages=[])) + + snapshot = controller.pressure.snapshot(next(iter(controller.pressure.snapshots()))) + assert snapshot is not None + assert snapshot.active_lease_count == 0 + assert snapshot.last_outcome == "local_cancelled" + completed = [event for event in sink.request_events if event.event_kind == "model_request_completed"] + assert completed[-1].diagnostics["outcome"] == "local_cancelled" + + +def test_model_request_executor_maps_image_chat_domain() -> None: + executor, controller, _client = _executor() + + executor.generate_image(ImageGenerationRequest(model="nemotron", prompt="p", messages=[])) + + resources = controller.pressure.snapshots() + assert any(resource.domain == RequestDomain.CHAT for resource in resources) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_model_request_executor_shares_provider_model_cap_across_async_domains() -> None: + controller = AdaptiveRequestAdmissionController() + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + client = _GatedAsyncClient() + executor = ModelRequestExecutor(client, controller, "nvidia", "nemotron") + + chat_task = asyncio.create_task(executor.acompletion(ChatCompletionRequest(model="nemotron", messages=[]))) + await asyncio.wait_for(client.chat_started.wait(), timeout=1.0) + embedding_task = asyncio.create_task(executor.aembeddings(EmbeddingRequest(model="nemotron", inputs=["x"]))) + image_task = asyncio.create_task(executor.agenerate_image(ImageGenerationRequest(model="nemotron", prompt="image"))) + await _wait_for_request_waiters(controller, expected=2) + + global_snapshot = controller.pressure.global_snapshot("nvidia", "nemotron") + assert global_snapshot is not None + assert global_snapshot.aggregate_in_flight == 1 + assert not client.embedding_started.is_set() + assert not client.image_started.is_set() + + client.release_chat.set() + await asyncio.wait_for(client.embedding_started.wait(), timeout=1.0) + assert not client.image_started.is_set() + assert (await chat_task).message.content == "chat" + + global_snapshot = controller.pressure.global_snapshot("nvidia", "nemotron") + assert global_snapshot is not None + assert global_snapshot.aggregate_in_flight == 1 + client.release_embedding.set() + await asyncio.wait_for(client.image_started.wait(), timeout=1.0) + assert (await embedding_task).vectors == [[1.0]] + + client.release_image.set() + assert (await image_task).images[0].b64_data == "image" + global_snapshot = controller.pressure.global_snapshot("nvidia", "nemotron") + assert global_snapshot is not None + assert global_snapshot.aggregate_in_flight == 0 + + +async def _wait_for_request_waiters(controller: AdaptiveRequestAdmissionController, *, expected: int) -> None: + for _ in range(50): + waiters = sum(snapshot.waiters for snapshot in controller.pressure.snapshots().values()) + if waiters == expected: + return + await asyncio.sleep(0) + raise AssertionError(f"expected {expected} request waiters") + + +def test_model_request_executor_emits_attempt_events_with_correlation_fields() -> None: + sink = InMemoryAdmissionEventSink() + controller = AdaptiveRequestAdmissionController(event_sink=sink) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + executor = ModelRequestExecutor(_Client(), controller, "nvidia", "nemotron", event_sink=sink) + + executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + kinds = [event.event_kind for event in sink.request_events] + assert "request_wait_started" in kinds + assert "request_lease_acquired" in kinds + assert "model_request_started" in kinds + assert "model_request_completed" in kinds + assert "request_lease_released" in kinds + attempts = {event.request_attempt_id for event in sink.request_events if event.request_attempt_id is not None} + assert len(attempts) == 1 + assert all(event.request_resource_key is not None for event in sink.request_events) + assert all(event.pressure_snapshot is not None for event in sink.request_events) + attempt_events = [event for event in sink.request_events if event.request_attempt_id is not None] + assert attempt_events + assert all(event.request_group_key is not None for event in attempt_events) + for event in attempt_events: + assert isinstance(event.pressure_snapshot, dict) + assert event.pressure_snapshot["resource"] == event.request_resource_key + + +def test_model_request_executor_logs_sink_failures(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.WARNING, logger="data_designer.engine.models.clients.model_request_executor") + controller = AdaptiveRequestAdmissionController() + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + executor = ModelRequestExecutor(_Client(), controller, "nvidia", "nemotron", event_sink=_BrokenSink()) + + executor.completion(ChatCompletionRequest(model="nemotron", messages=[])) + + assert "Model request event sink raised; dropping event." in caplog.text diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_throttle_manager.py b/packages/data-designer-engine/tests/engine/models/clients/test_throttle_manager.py deleted file mode 100644 index 44a69d549..000000000 --- a/packages/data-designer-engine/tests/engine/models/clients/test_throttle_manager.py +++ /dev/null @@ -1,720 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import asyncio -import threading -import time - -import pytest - -from data_designer.config.run_config import ThrottleConfig -from data_designer.engine.models.clients.throttle_manager import ( - CAPACITY_POLL_INTERVAL, - ThrottleDomain, - ThrottleManager, -) - -PROVIDER = "test-provider" -MODEL = "gpt-test" -DOMAIN = ThrottleDomain.CHAT - - -@pytest.fixture -def manager() -> ThrottleManager: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) - return tm - - -# --- try_acquire --- - - -def test_acquire_under_limit_returns_zero(manager: ThrottleManager) -> None: - wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert wait == 0.0 - - -def test_acquire_at_capacity_returns_short_poll_interval(manager: ThrottleManager) -> None: - for _ in range(4): - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert wait == pytest.approx(CAPACITY_POLL_INTERVAL) - - -def test_acquire_respects_blocked_until(manager: ThrottleManager) -> None: - """Rate-limit cooldown returns remaining block duration (not the short capacity poll).""" - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, retry_after=5.0, now=1.0) - wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=2.0) - assert wait == pytest.approx(4.0, abs=0.01) - - -def test_acquire_without_registration_raises() -> None: - tm = ThrottleManager() - with pytest.raises(RuntimeError, match="register"): - tm.try_acquire(provider_name="unknown", model_id="m", domain=DOMAIN, now=0.0) - - -# --- startup ramp --- - - -def test_startup_ramp_starts_at_one_and_reaches_effective_max_linearly() -> None: - tm = ThrottleManager(ThrottleConfig(rampup_seconds=10.0)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=5) - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) == 0.0 - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 1 - assert state.rampup_active is True - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) == pytest.approx( - CAPACITY_POLL_INTERVAL - ) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.success_streak == 0 - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=5.0) == 0.0 - assert state.current_limit == 3 - assert state.rampup_active is True - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=5.0) - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) == 0.0 - assert state.current_limit == 5 - assert state.rampup_active is False - - -def test_rate_limit_aborts_startup_ramp_and_continues_with_aimd() -> None: - tm = ThrottleManager(ThrottleConfig(reduce_factor=0.5, success_window=1, rampup_seconds=100.0)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=9) - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=50.0) == 0.0 - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 5 - assert state.rampup_active is True - - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=50.0) - assert state.rampup_active is False - assert state.current_limit == 2 - assert state.rate_limit_ceiling == 5 - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=60.0) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=60.0) - assert state.current_limit == 3 - - -def test_rate_limit_at_start_of_ramp_does_not_pin_recovery_to_minimum_ceiling() -> None: - tm = ThrottleManager( - ThrottleConfig(reduce_factor=0.5, success_window=1, ceiling_overshoot=0.0, rampup_seconds=100.0) - ) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) == 0.0 - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 1 - - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.rampup_active is False - assert state.current_limit == 1 - assert state.rate_limit_ceiling == 0 - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - assert state.current_limit == 2 - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=11.0) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=11.0) - assert state.current_limit == 3 - - -def test_startup_ramp_skipped_when_effective_max_is_one() -> None: - tm = ThrottleManager(ThrottleConfig(rampup_seconds=10.0)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) == 0.0 - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 1 - assert state.rampup_active is False - - -def test_startup_ramp_completes_on_first_call_after_elapsed_time() -> None: - tm = ThrottleManager(ThrottleConfig(rampup_seconds=10.0)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=5) - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) == 0.0 - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 1 - assert state.rampup_active is True - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=11.0) == 0.0 - assert state.current_limit == 5 - assert state.rampup_active is False - - -def test_release_failure_preserves_startup_ramp_and_progress() -> None: - tm = ThrottleManager(ThrottleConfig(rampup_seconds=10.0)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=5) - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) == 0.0 - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.rampup_active is True - tm.release_failure(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.rampup_active is True - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=5.0) == 0.0 - assert state.current_limit == 3 - assert state.rampup_active is True - - -def test_non_ramp_rate_limit_at_minimum_does_not_pin_recovery_to_soft_ceiling() -> None: - tm = ThrottleManager(ThrottleConfig(reduce_factor=0.5, success_window=1, ceiling_overshoot=0.0)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) == 0.0 - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - state.current_limit = 1 - - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.current_limit == 1 - assert state.rate_limit_ceiling == 0 - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - assert state.current_limit == 2 - - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=11.0) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=11.0) - assert state.current_limit == 3 - - -# --- release_success --- - - -def test_release_success_frees_slot(manager: ThrottleManager) -> None: - for _ in range(4): - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - wait = manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert wait == 0.0 - - -def test_additive_increase_after_success_window() -> None: - tm = ThrottleManager(ThrottleConfig(success_window=5)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - limit_after_drop = state.current_limit - - for i in range(5): - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) - - assert state.current_limit == limit_after_drop + 1 - - -def test_additive_increase_uses_configured_step() -> None: - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=3)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=20) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - limit_after_drop = state.current_limit - - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - - assert state.current_limit == limit_after_drop + 3 - - -def test_current_limit_never_exceeds_effective_max() -> None: - tm = ThrottleManager(ThrottleConfig(success_window=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=2) - for i in range(20): - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=float(i)) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit <= 2 - - -def test_additive_increase_clamped_to_effective_max() -> None: - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=100)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=5) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 5 - - -# --- release_rate_limited --- - - -def test_rate_limited_reduces_current_limit(manager: ThrottleManager) -> None: - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 3 # floor(4 * 0.75) - - -def test_rate_limited_never_drops_below_one() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit >= 1 - - -def test_rate_limited_resets_success_streak(manager: ThrottleManager) -> None: - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.success_streak == 0 - - -def test_rate_limited_uses_retry_after_for_blocked_until(manager: ThrottleManager) -> None: - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, retry_after=7.0, now=10.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.blocked_until == pytest.approx(17.0, abs=0.01) - - -def test_rate_limited_uses_default_block_when_no_retry_after(manager: ThrottleManager) -> None: - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.blocked_until == pytest.approx(10.0 + ThrottleConfig.DEFAULT_COOLDOWN_SECONDS, abs=0.01) - - -# --- release_failure --- - - -def test_failure_releases_slot_without_limit_change(manager: ThrottleManager) -> None: - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - limit_before = state.current_limit - manager.release_failure(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.current_limit == limit_before - assert state.in_flight == 0 - - -def test_failure_does_not_reset_cascade_while_burst_in_flight(manager: ThrottleManager) -> None: - """Mixed-response burst (429 → 500 → 429 with multiple slots in-flight) must reduce only once. - - With a real burst of in-flight requests, an interleaved non-rate-limit - failure should NOT break the cascade - otherwise the next 429 from the - same wave would be treated as a new cascade and double-reduce the limit - even though the provider hasn't recovered between the two 429s. - """ - # Saturate to limit (4 concurrent slots). - for _ in range(4): - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.in_flight == 4 - limit_before = state.current_limit - - # First 429 from the burst: limit reduced once. - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - limit_after_first_429 = state.current_limit - assert limit_after_first_429 < limit_before - assert state.consecutive_429s == 1 - assert state.in_flight == 3 - - # Second response from the same burst: 500. With the regression, this - # would reset the cascade to 0; with the fix, in_flight > 0 keeps it at 1. - manager.release_failure(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.consecutive_429s == 1, "cascade must not reset while the prior burst is still in-flight" - assert state.in_flight == 2 - - # Third response from the same burst: another 429. With the regression - # this would be treated as a new cascade and reduce the limit again. - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.current_limit == limit_after_first_429, "limit must not double-reduce within the same burst" - assert state.in_flight == 1 - - -def test_failure_resets_cascade_after_burst_drains(manager: ThrottleManager) -> None: - """Once the burst fully drains (in_flight == 0), the next non-RL failure breaks the cascade. - - This preserves the original PR intent for the sequential 429 → 500 → 429 - case: provider rate-limited, settled, then rate-limited again. - """ - # Saturate, then drain: one 429 then one 500 with no concurrency. - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = manager.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.consecutive_429s == 1 - assert state.in_flight == 0 - - # New request after the burst drained. release_failure sees in_flight 1 → 0. - manager.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - manager.release_failure(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - assert state.consecutive_429s == 0 - assert state.in_flight == 0 - - -# --- Global cap --- - - -def test_two_aliases_effective_max_is_minimum() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a2", max_parallel_requests=3) - assert tm.get_effective_max(PROVIDER, MODEL) == 3 - - -def test_domain_clamped_when_new_alias_lowers_cap() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=10) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 10 - - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a2", max_parallel_requests=3) - assert state.current_limit == 3 - - -# --- Domain isolation --- - - -def test_chat_and_embedding_throttle_independently() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=2) - - for _ in range(2): - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=ThrottleDomain.CHAT, now=0.0) - wait_chat = tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=ThrottleDomain.CHAT, now=0.0) - assert wait_chat > 0.0 - - wait_emb = tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=ThrottleDomain.EMBEDDING, now=0.0) - assert wait_emb == 0.0 - - -# --- 429 lifecycle scenario --- - - -def test_rate_limit_lifecycle_acquire_backoff_recover() -> None: - """End-to-end AIMD lifecycle: steady-state → 429 → backoff → cooldown → recovery. - - Uses the ``now`` parameter to simulate time without real sleeps. - Config: success_window=3, additive_increase=1, max_parallel=4, reduce_factor=0.75. - """ - tm = ThrottleManager(ThrottleConfig(success_window=3, additive_increase=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) - t = 0.0 - - # Phase 1 — Steady state (t=0): all 4 slots acquired and released successfully. - for _ in range(4): - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - for _ in range(4): - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.current_limit == 4 - - # Phase 2 — 429 hits (t=10): reduce_factor=0.75 → floor(4*0.75)=3. - # Domain is blocked until t=10+5=15. - t = 10.0 - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, retry_after=5.0, now=t) - assert state.current_limit == 3 - assert state.blocked_until == 15.0 - - # Phase 3 — During cooldown (t=12): acquire returns positive wait since 12 < 15. - wait = tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=12.0) - assert wait > 0.0 - - # Phase 4 — Cooldown expires (t=16): acquire succeeds, start accumulating successes. - # Need 3 successes (success_window=3) to bump limit 3 → 4. - t = 16.0 - for _ in range(3): - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - t += 1.0 - - assert state.current_limit == 4 - - -# --- Ceiling stabilization --- - - -def test_ceiling_stabilization_with_overshoot() -> None: - """After a 429, AIMD increase stops at ceiling + overshoot instead of effective_max. - - Config: effective_max=1000, success_window=1, ceiling_overshoot=0.10. - Scenario: 429 at limit 40 → floor(40*0.75)=30 → ceiling=40 → soft cap = 40 + 4 = 44. - Recovery should stop at 44, not climb to 1000. - """ - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1000) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - state.current_limit = 40 - - # 429 at limit 40 → floor(40*0.75)=30, ceiling recorded as 40. - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - assert state.current_limit == 30 - assert state.rate_limit_ceiling == 40 - - # Pump success windows to climb back up. soft_cap = 40 + floor(40*0.1) = 44. - t = 20.0 - for _ in range(20): - t += 1.0 - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - - assert state.current_limit == 44, f"Expected stabilization at 44, got {state.current_limit}" - - # Further successes should not increase beyond the soft ceiling. - for _ in range(10): - t += 1.0 - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - - assert state.current_limit == 44, f"Limit crept past soft ceiling: {state.current_limit}" - - -def test_ceiling_lowers_on_repeated_429_after_recovery() -> None: - """A 429 after partial recovery lowers the ceiling, tightening the soft cap. - - Scenario: first 429 at 40 → floor(40*0.75)=30, ceiling=40. - Recovery: set limit to 30, one success bumps to 31 (success_window=1). - Second 429 at 31 → floor(31*0.75)=23, ceiling = min(40, 31) = 31. - Soft cap = 31 + max(1, floor(31*0.1)) = 31 + 3 = 34. - """ - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1000) - - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - state.current_limit = 40 - - # First 429 at 40 → floor(40*0.75)=30, ceiling=40. - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - assert state.rate_limit_ceiling == 40 - assert state.current_limit == 30 - - # Recovery: one success bumps 30 → 31. - t = 20.0 - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - assert state.current_limit == 31 - - # Second 429 at 31 → floor(31*0.75)=23, ceiling = min(40, 31) = 31. - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t + 1) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t + 1) - assert state.rate_limit_ceiling == 31 - assert state.current_limit == 23 - - # Soft cap = 31 + max(1, floor(31*0.1)) = 34. Climb should stop there. - t = 40.0 - for _ in range(15): - t += 1.0 - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - - assert state.current_limit == 34, f"Expected soft cap at 34, got {state.current_limit}" - - -def test_cascade_only_first_429_reduces_limit() -> None: - """Only the first 429 in a cascade reduces the limit; subsequent ones just release permits.""" - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=100) - - for _ in range(4): - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.in_flight == 4 - - # First 429: limit 100 → 75, ceiling set to 100. - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - assert state.current_limit == 75 - assert state.rate_limit_ceiling == 100 - assert state.in_flight == 3 - - # Subsequent cascade 429s: limit stays at 75, only in_flight decrements. - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - assert state.current_limit == 75 - assert state.rate_limit_ceiling == 100 - assert state.in_flight == 2 - - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - assert state.current_limit == 75 - assert state.in_flight == 1 - - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=1.0) - assert state.current_limit == 75 - assert state.in_flight == 0 - - -def test_ceiling_does_not_restrict_when_at_effective_max() -> None: - """When effective_max is small (e.g. 4), the ceiling + overshoot should not - prevent recovery to effective_max. - """ - tm = ThrottleManager(ThrottleConfig(success_window=1, additive_increase=1)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) - - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=0.0) - tm.release_rate_limited(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=10.0) - - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - # floor(4 * 0.75) = 3; ceiling=4, soft_cap = min(4 + max(1, floor(4*0.1)), 4) = 4 - assert state.current_limit == 3 - - t = 20.0 - for _ in range(5): - t += 1.0 - assert tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) == 0.0 - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, now=t) - - assert state.current_limit == 4, f"Should recover to effective_max=4, got {state.current_limit}" - - -# --- Acquire timeout --- - - -def test_acquire_sync_raises_timeout_when_at_capacity() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - # Saturate the single slot so try_acquire returns a positive wait. - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - with pytest.raises(TimeoutError, match="timed out"): - tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, timeout=0.0) - - -def test_acquire_sync_does_not_overshoot_timeout() -> None: - """When wait > remaining budget, raise immediately instead of sleeping the full wait.""" - tm = ThrottleManager(ThrottleConfig(cooldown_seconds=5.0)) - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - # Timeout of 0.5s is less than the 5s block wait — should raise fast, not sleep 5s. - start = time.monotonic() - with pytest.raises(TimeoutError, match="timed out"): - tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, timeout=0.5) - elapsed = time.monotonic() - start - assert elapsed < 2.0, f"acquire_sync overshot timeout: elapsed {elapsed:.1f}s (expected <2s)" - - -@pytest.mark.asyncio -async def test_acquire_async_raises_timeout_when_at_capacity() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - with pytest.raises(TimeoutError, match="timed out"): - await tm.acquire_async(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN, timeout=0.0) - - -@pytest.mark.asyncio -async def test_acquire_async_default_no_deadline_waits_for_release() -> None: - """``timeout=None`` (the default) waits for the permit instead of raising. - - Issue #551: the previous 300s default produced spurious ``ModelTimeoutError`` - cascades on slow endpoints with deep queues; now queue waits scale with - provider speed and only the HTTP timeout deadlines actual work. The - ``timeout=0.0`` case is covered by ``test_acquire_async_raises_timeout_when_at_capacity``. - """ - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - async def release_after(delay: float) -> None: - await asyncio.sleep(delay) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - # Hold a strong reference to the task so the loop's weak-ref bookkeeping - # can't GC it before the inner await observes the release. - release_task = asyncio.create_task(release_after(0.05)) - try: - # asyncio.wait_for caps the test runtime; the inner acquire_async passes None. - await asyncio.wait_for( - tm.acquire_async(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN), - timeout=2.0, - ) - finally: - await release_task - - -def test_acquire_sync_default_no_deadline_waits_for_release() -> None: - """Sync counterpart: ``timeout=None`` default blocks until release.""" - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=1) - tm.try_acquire(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - def release_after(delay: float) -> None: - time.sleep(delay) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - - threading.Thread(target=release_after, args=(0.05,), daemon=True).start() - start = time.monotonic() - tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - elapsed = time.monotonic() - start - assert 0.04 < elapsed < 2.0, f"expected ~0.05s wait, got {elapsed:.3f}s" - - -# --- Thread safety --- - - -def test_concurrent_acquire_release_no_errors() -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL, alias="a1", max_parallel_requests=4) - errors: list[Exception] = [] - - def worker() -> None: - try: - for _ in range(50): - tm.acquire_sync(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - tm.release_success(provider_name=PROVIDER, model_id=MODEL, domain=DOMAIN) - except Exception as exc: - errors.append(exc) - - threads = [threading.Thread(target=worker) for _ in range(8)] - for t in threads: - t.start() - for t in threads: - t.join(timeout=10) - assert not errors, f"Thread errors: {errors}" - - state = tm.get_domain_state(PROVIDER, MODEL, DOMAIN) - assert state is not None - assert state.in_flight == 0 diff --git a/packages/data-designer-engine/tests/engine/models/clients/test_throttled_model_client.py b/packages/data-designer-engine/tests/engine/models/clients/test_throttled_model_client.py deleted file mode 100644 index bad6cb9b1..000000000 --- a/packages/data-designer-engine/tests/engine/models/clients/test_throttled_model_client.py +++ /dev/null @@ -1,673 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -import asyncio -import math -import time -from dataclasses import dataclass -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest - -from data_designer.config.run_config import ThrottleConfig -from data_designer.engine.models.clients.adapters.http_model_client import ClientConcurrencyMode -from data_designer.engine.models.clients.adapters.openai_compatible import OpenAICompatibleClient -from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind -from data_designer.engine.models.clients.throttle_manager import DomainThrottleState, ThrottleDomain, ThrottleManager -from data_designer.engine.models.clients.throttled import ThrottledModelClient -from data_designer.engine.models.clients.types import ( - AssistantMessage, - ChatCompletionRequest, - ChatCompletionResponse, - EmbeddingRequest, - EmbeddingResponse, - ImageGenerationRequest, - ImageGenerationResponse, - Usage, -) -from tests.engine.models.clients.conftest import mock_httpx_response - -PROVIDER = "test-provider" -MODEL_ID = "test-model" -ENDPOINT = "https://api.example.com/v1" - - -@dataclass(frozen=True) -class ColdServerSample: - elapsed: float - in_flight: int - allowed: int - status_code: int - - -class ColdServerAsyncHTTPClient: - def __init__( - self, - *, - max_parallel: int, - server_ramp_seconds: float, - service_seconds: float, - retry_after: float, - ) -> None: - self._max_parallel = max_parallel - self._server_ramp_seconds = server_ramp_seconds - self._service_seconds = service_seconds - self._retry_after = retry_after - self._lock = asyncio.Lock() - self._started_at: float | None = None - self._in_flight = 0 - self.samples: list[ColdServerSample] = [] - - @property - def rate_limits(self) -> int: - return sum(1 for sample in self.samples if sample.status_code == 429) - - @property - def peak_in_flight(self) -> int: - if not self.samples: - return 0 - return max(sample.in_flight for sample in self.samples) - - @property - def peak_allowed(self) -> int: - if not self.samples: - return 0 - return max(sample.allowed for sample in self.samples) - - async def post(self, *_args: object, **_kwargs: object) -> MagicMock: - async with self._lock: - self._in_flight += 1 - now = time.monotonic() - if self._started_at is None: - self._started_at = now - elapsed = now - self._started_at - allowed = self._allowed(elapsed) - status_code = 200 if self._in_flight <= allowed else 429 - self.samples.append( - ColdServerSample( - elapsed=elapsed, - in_flight=self._in_flight, - allowed=allowed, - status_code=status_code, - ) - ) - try: - if status_code == 429: - response = mock_httpx_response({"error": {"message": "synthetic rate limit"}}, status_code=429) - response.headers = {"Retry-After": str(self._retry_after)} - return response - await asyncio.sleep(self._service_seconds) - return mock_httpx_response( - { - "choices": [ - {"index": 0, "message": {"role": "assistant", "content": "ok"}, "finish_reason": "stop"} - ], - "usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}, - } - ) - finally: - async with self._lock: - self._in_flight -= 1 - - async def aclose(self) -> None: - return None - - def _allowed(self, elapsed: float) -> int: - if self._server_ramp_seconds <= 0: - return self._max_parallel - fraction = min(1.0, max(0.0, elapsed / self._server_ramp_seconds)) - ramp_slots = math.floor((self._max_parallel - 1) * fraction) - return max(1, min(self._max_parallel, 1 + ramp_slots)) - - -@pytest.fixture -def throttle_manager() -> ThrottleManager: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL_ID, alias="alias", max_parallel_requests=10) - return tm - - -@pytest.fixture -def inner_client() -> MagicMock: - client = MagicMock() - client.provider_name = PROVIDER - client.supports_chat_completion.return_value = True - client.supports_embeddings.return_value = True - client.supports_image_generation.return_value = True - client.completion.return_value = ChatCompletionResponse(message=AssistantMessage(content="ok"), usage=Usage()) - client.acompletion = AsyncMock( - return_value=ChatCompletionResponse(message=AssistantMessage(content="ok"), usage=Usage()) - ) - client.embeddings.return_value = EmbeddingResponse(vectors=[[0.1]], usage=Usage()) - client.aembeddings = AsyncMock(return_value=EmbeddingResponse(vectors=[[0.1]], usage=Usage())) - client.generate_image.return_value = ImageGenerationResponse(images=[]) - client.agenerate_image = AsyncMock(return_value=ImageGenerationResponse(images=[])) - client.close.return_value = None - client.aclose = AsyncMock() - return client - - -@pytest.fixture -def throttled_client(inner_client: MagicMock, throttle_manager: ThrottleManager) -> ThrottledModelClient: - return ThrottledModelClient( - inner=inner_client, - throttle_manager=throttle_manager, - provider_name=PROVIDER, - model_id=MODEL_ID, - ) - - -# --- Protocol delegation --- - - -def test_provider_name_delegates(throttled_client: ThrottledModelClient) -> None: - assert throttled_client.provider_name == PROVIDER - - -def test_supports_methods_delegate(throttled_client: ThrottledModelClient) -> None: - assert throttled_client.supports_chat_completion() is True - assert throttled_client.supports_embeddings() is True - assert throttled_client.supports_image_generation() is True - - -def test_close_delegates(throttled_client: ThrottledModelClient, inner_client: MagicMock) -> None: - throttled_client.close() - inner_client.close.assert_called_once() - - -@pytest.mark.asyncio(loop_scope="session") -async def test_aclose_delegates(throttled_client: ThrottledModelClient, inner_client: MagicMock) -> None: - await throttled_client.aclose() - inner_client.aclose.assert_awaited_once() - - -# --- Sync: acquire/release on success --- - - -def test_completion_success_releases_success( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ChatCompletionRequest(model=MODEL_ID, messages=[]) - result = throttled_client.completion(request) - assert result.message.content == "ok" - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - assert state.success_streak == 1 - - -def test_embeddings_success_releases_success( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = EmbeddingRequest(model=MODEL_ID, inputs=["hello"]) - result = throttled_client.embeddings(request) - assert result.vectors == [[0.1]] - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.EMBEDDING) - assert state is not None - assert state.in_flight == 0 - assert state.success_streak == 1 - - -def test_generate_image_diffusion_uses_image_domain( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ImageGenerationRequest(model=MODEL_ID, prompt="a cat", messages=None) - throttled_client.generate_image(request) - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.IMAGE) - assert state is not None - assert state.success_streak == 1 - - -def test_generate_image_chat_backed_uses_chat_domain( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ImageGenerationRequest(model=MODEL_ID, prompt="a cat", messages=[{"role": "user", "content": "draw"}]) - throttled_client.generate_image(request) - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.success_streak == 1 - - -# --- Async: acquire/release on success --- - - -@pytest.mark.asyncio(loop_scope="session") -async def test_acompletion_success_releases_success( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ChatCompletionRequest(model=MODEL_ID, messages=[]) - result = await throttled_client.acompletion(request) - assert result.message.content == "ok" - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - assert state.success_streak == 1 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_aembeddings_success_releases_success( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = EmbeddingRequest(model=MODEL_ID, inputs=["hello"]) - result = await throttled_client.aembeddings(request) - assert result.vectors == [[0.1]] - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.EMBEDDING) - assert state is not None - assert state.in_flight == 0 - assert state.success_streak == 1 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_agenerate_image_diffusion_uses_image_domain( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ImageGenerationRequest(model=MODEL_ID, prompt="a cat", messages=None) - await throttled_client.agenerate_image(request) - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.IMAGE) - assert state is not None - assert state.success_streak == 1 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_agenerate_image_chat_backed_uses_chat_domain( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - request = ImageGenerationRequest(model=MODEL_ID, prompt="a cat", messages=[{"role": "user", "content": "draw"}]) - await throttled_client.agenerate_image(request) - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.success_streak == 1 - - -# --- Rate-limit error: release_rate_limited with retry_after --- - - -def test_completion_rate_limit_calls_release_rate_limited( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - throttled_client._inner.completion.side_effect = ProviderError( - kind=ProviderErrorKind.RATE_LIMIT, - message="429", - status_code=429, - retry_after=5.0, - ) - with pytest.raises(ProviderError, match="429"): - throttled_client.completion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - assert state.blocked_until > 0 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_acompletion_rate_limit_calls_release_rate_limited( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - throttled_client._inner.acompletion = AsyncMock( - side_effect=ProviderError( - kind=ProviderErrorKind.RATE_LIMIT, - message="429", - status_code=429, - retry_after=3.0, - ) - ) - with pytest.raises(ProviderError, match="429"): - await throttled_client.acompletion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - assert state.blocked_until > 0 - - -# --- Non-rate-limit ProviderError: release_failure --- - - -def test_completion_non_rate_limit_error_calls_release_failure( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - throttled_client._inner.completion.side_effect = ProviderError( - kind=ProviderErrorKind.INTERNAL_SERVER, - message="500", - status_code=500, - ) - with pytest.raises(ProviderError, match="500"): - throttled_client.completion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - assert state.success_streak == 0 - - -# --- Non-ProviderError exception: release_failure --- - - -def test_completion_generic_exception_calls_release_failure( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - throttled_client._inner.completion.side_effect = RuntimeError("boom") - with pytest.raises(RuntimeError, match="boom"): - throttled_client.completion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_acompletion_generic_exception_calls_release_failure( - throttled_client: ThrottledModelClient, throttle_manager: ThrottleManager -) -> None: - throttled_client._inner.acompletion = AsyncMock(side_effect=RuntimeError("boom")) - with pytest.raises(RuntimeError, match="boom"): - await throttled_client.acompletion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 - - -# --- Acquire timeout: normalized to ProviderError(kind=TIMEOUT), no release --- - - -def test_sync_acquire_timeout_normalized_to_provider_error(inner_client: MagicMock) -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL_ID, alias="alias", max_parallel_requests=1) - client = ThrottledModelClient(inner=inner_client, throttle_manager=tm, provider_name=PROVIDER, model_id=MODEL_ID) - - with patch.object(tm, "acquire_sync", side_effect=TimeoutError("timed out")): - with pytest.raises(ProviderError) as exc_info: - client.completion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - assert exc_info.value.kind == ProviderErrorKind.TIMEOUT - - inner_client.completion.assert_not_called() - - -@pytest.mark.asyncio(loop_scope="session") -async def test_async_acquire_timeout_normalized_to_provider_error(inner_client: MagicMock) -> None: - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL_ID, alias="alias", max_parallel_requests=1) - client = ThrottledModelClient(inner=inner_client, throttle_manager=tm, provider_name=PROVIDER, model_id=MODEL_ID) - - with patch.object(tm, "acquire_async", side_effect=TimeoutError("timed out")): - with pytest.raises(ProviderError) as exc_info: - await client.acompletion(ChatCompletionRequest(model=MODEL_ID, messages=[])) - assert exc_info.value.kind == ProviderErrorKind.TIMEOUT - - inner_client.acompletion.assert_not_awaited() - - -# --- Cancellation: release_failure on CancelledError --- - - -@pytest.mark.asyncio(loop_scope="session") -async def test_acompletion_cancelled_releases_permit(throttle_manager: ThrottleManager) -> None: - """CancelledError during an in-flight async request releases the throttle permit.""" - blocked = asyncio.Event() - - async def slow_acompletion(_request: ChatCompletionRequest) -> ChatCompletionResponse: - blocked.set() - await asyncio.sleep(60) - return ChatCompletionResponse(message=AssistantMessage(content="ok"), usage=Usage()) - - inner = MagicMock() - inner.provider_name = PROVIDER - inner.acompletion = slow_acompletion - - client = ThrottledModelClient( - inner=inner, throttle_manager=throttle_manager, provider_name=PROVIDER, model_id=MODEL_ID - ) - request = ChatCompletionRequest(model=MODEL_ID, messages=[]) - - task = asyncio.create_task(client.acompletion(request)) - await blocked.wait() - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 1 - - task.cancel() - with pytest.raises(asyncio.CancelledError): - await task - - assert state.in_flight == 0 - - -# --- E2E: full AIMD feedback loop --- - - -async def _complete_with_rate_limit_retries( - client: ThrottledModelClient, - request: ChatCompletionRequest, - *, - max_attempts: int = 50, -) -> int: - for attempt in range(1, max_attempts + 1): - try: - await client.acompletion(request) - return attempt - except ProviderError as exc: - if exc.kind != ProviderErrorKind.RATE_LIMIT: - raise - await asyncio.sleep(exc.retry_after or 0.01) - raise AssertionError(f"request did not complete after {max_attempts} attempts") - - -async def _run_cold_server_scenario( - *, - throttle_ramp_seconds: float, - server_ramp_seconds: float, - max_parallel: int = 4, - tasks: int = 12, -) -> tuple[ColdServerAsyncHTTPClient, ThrottleManager, ThrottledModelClient]: - async_http_client = ColdServerAsyncHTTPClient( - max_parallel=max_parallel, - server_ramp_seconds=server_ramp_seconds, - service_seconds=0.04, - retry_after=0.02, - ) - inner = OpenAICompatibleClient( - provider_name=PROVIDER, - endpoint=ENDPOINT, - api_key="sk-test-key", - concurrency_mode=ClientConcurrencyMode.ASYNC, - async_client=async_http_client, - ) - throttle_manager = ThrottleManager( - ThrottleConfig( - reduce_factor=0.5, - additive_increase=1, - success_window=2, - cooldown_seconds=0.02, - rampup_seconds=throttle_ramp_seconds, - ) - ) - throttle_manager.register( - provider_name=PROVIDER, - model_id=MODEL_ID, - alias="cold-server", - max_parallel_requests=max_parallel, - ) - client = ThrottledModelClient( - inner=inner, - throttle_manager=throttle_manager, - provider_name=PROVIDER, - model_id=MODEL_ID, - ) - requests = [ - ChatCompletionRequest(model=MODEL_ID, messages=[{"role": "user", "content": f"request {i}"}]) - for i in range(tasks) - ] - await asyncio.gather(*(_complete_with_rate_limit_retries(client, request) for request in requests)) - return async_http_client, throttle_manager, client - - -@pytest.mark.asyncio(loop_scope="session") -async def test_startup_ramp_integration_eases_into_cold_server_without_429s() -> None: - throttle_ramp_seconds = 0.3 - no_ramp_client, _, _ = await _run_cold_server_scenario( - throttle_ramp_seconds=0.0, - server_ramp_seconds=0.3, - ) - ramped_client, throttle_manager, throttled_client = await _run_cold_server_scenario( - throttle_ramp_seconds=throttle_ramp_seconds, - server_ramp_seconds=0.1, - ) - - assert no_ramp_client.rate_limits > 0 - assert no_ramp_client.peak_in_flight > 1 - assert ramped_client.rate_limits == 0 - assert ramped_client.peak_allowed > 1 - assert ramped_client.peak_in_flight > 1 - - await asyncio.sleep(throttle_ramp_seconds + 0.2) - await _complete_with_rate_limit_retries( - throttled_client, - ChatCompletionRequest(model=MODEL_ID, messages=[{"role": "user", "content": "final ramp probe"}]), - ) - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.current_limit == 4 - assert state.rampup_active is False - - -@pytest.mark.asyncio(loop_scope="session") -async def test_startup_ramp_integration_overaggressive_ramp_aborts_to_aimd() -> None: - cold_client, throttle_manager, _ = await _run_cold_server_scenario( - throttle_ramp_seconds=0.05, - server_ramp_seconds=0.5, - ) - - assert cold_client.rate_limits > 0 - - state = throttle_manager.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.rampup_active is False - - -@pytest.mark.asyncio(loop_scope="session") -async def test_aimd_feedback_loop_rate_limit_reduces_then_successes_recover() -> None: - """Verify the full AIMD cycle: success -> rate-limit halves limit -> successes recover. - - Uses a real ThrottleManager with aggressive tuning (success_window=2, - additive_increase=1) so the test can drive a full decrease+increase cycle - with a small number of calls. - - Sequence: - 1. Register model with max_parallel_requests=4. - 2. Make 1 successful async completion -> limit stays 4, streak=1. - 3. Hit a 429 with retry_after=0.01s -> limit halves to 2, cooldown applied. - 4. Wait for cooldown to expire. - 5. Make 2 more successes -> streak reaches success_window=2, limit increases to 3. - 6. Make 2 more successes -> limit increases to 4 (full recovery). - """ - tm = ThrottleManager( - ThrottleConfig( - reduce_factor=0.5, - additive_increase=1, - success_window=2, - cooldown_seconds=0.01, - ) - ) - tm.register(provider_name=PROVIDER, model_id=MODEL_ID, alias="a", max_parallel_requests=4) - - call_count = 0 - rate_limit_on_call = 2 - - async def mock_acompletion(request: ChatCompletionRequest) -> ChatCompletionResponse: - nonlocal call_count - call_count += 1 - if call_count == rate_limit_on_call: - raise ProviderError( - kind=ProviderErrorKind.RATE_LIMIT, - message="429 Too Many Requests", - status_code=429, - retry_after=0.01, - ) - return ChatCompletionResponse(message=AssistantMessage(content="ok"), usage=Usage()) - - inner = MagicMock() - inner.provider_name = PROVIDER - inner.acompletion = mock_acompletion - - client = ThrottledModelClient(inner=inner, throttle_manager=tm, provider_name=PROVIDER, model_id=MODEL_ID) - request = ChatCompletionRequest(model=MODEL_ID, messages=[]) - - def get_state() -> DomainThrottleState: - s = tm.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert s is not None - return s - - # Step 1: first success - await client.acompletion(request) - assert get_state().current_limit == 4 - assert get_state().success_streak == 1 - - # Step 2: 429 -> AIMD decrease - with pytest.raises(ProviderError): - await client.acompletion(request) - assert get_state().current_limit == 2 - assert get_state().success_streak == 0 - assert get_state().in_flight == 0 - - # Step 3: wait for cooldown - await asyncio.sleep(0.02) - - # Step 4: two successes -> additive increase (limit 2 -> 3) - await client.acompletion(request) - assert get_state().success_streak == 1 - await client.acompletion(request) - assert get_state().current_limit == 3 - assert get_state().success_streak == 0 - - # Step 5: two more successes -> additive increase (limit 3 -> 4, full recovery) - await client.acompletion(request) - await client.acompletion(request) - assert get_state().current_limit == 4 - assert get_state().success_streak == 0 - - -@pytest.mark.asyncio(loop_scope="session") -async def test_concurrent_requests_bounded_by_throttle_limit() -> None: - """Concurrent async requests are bounded by the throttle limit. - - Registers a model with max_parallel_requests=2, fires 5 concurrent - acompletion calls that each sleep briefly, and verifies that the - ThrottleManager never had more than 2 in-flight at once. - """ - tm = ThrottleManager() - tm.register(provider_name=PROVIDER, model_id=MODEL_ID, alias="a", max_parallel_requests=2) - - peak_in_flight = 0 - lock = asyncio.Lock() - - async def mock_acompletion(request: ChatCompletionRequest) -> ChatCompletionResponse: - nonlocal peak_in_flight - state = tm.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - if state is not None: - async with lock: - peak_in_flight = max(peak_in_flight, state.in_flight) - await asyncio.sleep(0.02) - return ChatCompletionResponse(message=AssistantMessage(content="ok"), usage=Usage()) - - inner = MagicMock() - inner.provider_name = PROVIDER - inner.acompletion = mock_acompletion - - client = ThrottledModelClient(inner=inner, throttle_manager=tm, provider_name=PROVIDER, model_id=MODEL_ID) - request = ChatCompletionRequest(model=MODEL_ID, messages=[]) - - tasks = [asyncio.create_task(client.acompletion(request)) for _ in range(5)] - results = await asyncio.gather(*tasks, return_exceptions=True) - - successes = [r for r in results if not isinstance(r, Exception)] - assert len(successes) == 5 - assert peak_in_flight <= 2 - - state = tm.get_domain_state(PROVIDER, MODEL_ID, ThrottleDomain.CHAT) - assert state is not None - assert state.in_flight == 0 diff --git a/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py b/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py new file mode 100644 index 000000000..af77f8c40 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/models/request_admission/test_controller.py @@ -0,0 +1,505 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import asyncio +import logging +import threading +import time + +import pytest + +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.controller import ( + RELEASED_LEASE_HISTORY_LIMIT, + AdaptiveRequestAdmissionController, + RequestAdmissionDenied, + RequestAdmissionError, + RequestAdmissionLease, +) +from data_designer.engine.models.request_admission.outcomes import RequestReleaseOutcome +from data_designer.engine.models.request_admission.resources import ( + RequestAdmissionItem, + RequestDomain, + RequestGroupSpec, + RequestResourceKey, +) +from data_designer.engine.observability import InMemoryAdmissionEventSink + + +def _item(domain: RequestDomain = RequestDomain.CHAT, timeout: float | None = None) -> RequestAdmissionItem: + resource = RequestResourceKey("nvidia", "nemotron", domain) + return RequestAdmissionItem( + resource=resource, + group=RequestGroupSpec(resource), + queue_wait_timeout_seconds=timeout, + ) + + +def _controller(cap: int = 2, config: RequestAdmissionConfig | None = None) -> AdaptiveRequestAdmissionController: + controller = AdaptiveRequestAdmissionController(config) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=cap) + return controller + + +class _BrokenRequestSink: + def emit_request_event(self, _event: object) -> None: + raise RuntimeError("sink boom") + + +def test_request_admission_acquires_and_releases_lease() -> None: + controller = _controller(cap=1) + item = _item() + + decision = controller.try_acquire(item) + + assert isinstance(decision, RequestAdmissionLease) + assert controller.pressure.snapshot(item.resource).in_flight_count == 1 # type: ignore[union-attr] + result = controller.release(decision, RequestReleaseOutcome(kind="success")) + assert result.released is True + assert controller.pressure.snapshot(item.resource).in_flight_count == 0 # type: ignore[union-attr] + + +def test_request_admission_enforces_provider_model_aggregate_cap() -> None: + controller = _controller(cap=1) + chat = _item(RequestDomain.CHAT) + embedding = _item(RequestDomain.EMBEDDING) + lease = controller.try_acquire(chat) + assert isinstance(lease, RequestAdmissionLease) + + denied = controller.try_acquire(embedding) + + assert isinstance(denied, RequestAdmissionDenied) + assert denied.reason == "no_capacity" + + +def test_request_admission_duplicate_release_does_not_corrupt_counts() -> None: + controller = _controller(cap=1) + item = _item() + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + + first = controller.release(lease, RequestReleaseOutcome(kind="success")) + second = controller.release(lease, RequestReleaseOutcome(kind="success")) + + assert first.released is True + assert second.released is False + assert second.reason == "duplicate" + assert controller.pressure.snapshot(item.resource).active_lease_count == 0 # type: ignore[union-attr] + + +def test_request_admission_stale_release_requires_exact_lease() -> None: + controller = _controller(cap=1) + item = _item() + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + stale = RequestAdmissionLease( + lease_id=lease.lease_id, + item=lease.item, + acquired_at=lease.acquired_at, + current_adaptive_limit=lease.current_adaptive_limit + 1, + effective_max=lease.effective_max, + controller_generation=lease.controller_generation, + ) + + stale_result = controller.release(stale, RequestReleaseOutcome(kind="provider_failure")) + snapshot = controller.pressure.snapshot(item.resource) + + assert stale_result.released is False + assert stale_result.reason == "stale_lease" + assert snapshot is not None + assert snapshot.in_flight_count == 1 + assert snapshot.active_lease_count == 1 + + released = controller.release(lease, RequestReleaseOutcome(kind="success")) + + assert released.released is True + assert controller.pressure.snapshot(item.resource).active_lease_count == 0 # type: ignore[union-attr] + + +def test_request_admission_rate_limit_decreases_and_sets_cooldown() -> None: + controller = _controller( + cap=4, + config=RequestAdmissionConfig( + multiplicative_decrease_factor=0.5, + cooldown_seconds=10, + ), + ) + item = _item() + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + + controller.release(lease, RequestReleaseOutcome(kind="rate_limited", retry_after_seconds=1.0)) + denied = controller.try_acquire(item) + snapshot = controller.pressure.snapshot(item.resource) + + assert isinstance(denied, RequestAdmissionDenied) + assert denied.reason == "cooldown" + assert snapshot is not None + assert snapshot.current_limit == 2 + assert snapshot.cooldown_remaining_seconds > 0 + + +def test_request_admission_rate_limit_burst_decreases_once_per_cascade() -> None: + controller = _controller( + cap=8, + config=RequestAdmissionConfig( + multiplicative_decrease_factor=0.5, + cooldown_seconds=10, + ), + ) + item = _item() + leases = [controller.try_acquire(item) for _ in range(8)] + assert all(isinstance(lease, RequestAdmissionLease) for lease in leases) + + for lease in leases: + controller.release(lease, RequestReleaseOutcome(kind="rate_limited")) + snapshot = controller.pressure.snapshot(item.resource) + + assert snapshot is not None + assert snapshot.current_limit == 4 + assert snapshot.rate_limit_ceiling == 8 + assert snapshot.consecutive_rate_limits == 8 + + +def test_request_admission_fresh_rate_limit_after_burst_decreases_again() -> None: + controller = _controller( + cap=8, + config=RequestAdmissionConfig( + multiplicative_decrease_factor=0.5, + cooldown_seconds=0, + ), + ) + item = _item() + leases = [controller.try_acquire(item) for _ in range(8)] + assert all(isinstance(lease, RequestAdmissionLease) for lease in leases) + + for lease in leases: + controller.release(lease, RequestReleaseOutcome(kind="rate_limited")) + snapshot = controller.pressure.snapshot(item.resource) + assert snapshot is not None + assert snapshot.current_limit == 4 + assert snapshot.rate_limit_ceiling == 8 + + fresh_lease = controller.try_acquire(item) + assert isinstance(fresh_lease, RequestAdmissionLease) + assert fresh_lease.current_adaptive_limit == 4 + + controller.release(fresh_lease, RequestReleaseOutcome(kind="rate_limited")) + snapshot = controller.pressure.snapshot(item.resource) + + assert snapshot is not None + assert snapshot.current_limit == 2 + assert snapshot.rate_limit_ceiling == 8 + assert snapshot.consecutive_rate_limits == 9 + + +def test_request_admission_additive_recovery_after_successes() -> None: + item = _item() + controller = _controller( + cap=3, + config=RequestAdmissionConfig( + initial_limits={item.resource: 1}, + successes_until_increase=1, + additive_increase_step=1, + ), + ) + + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + controller.release(lease, RequestReleaseOutcome(kind="success")) + + assert controller.pressure.snapshot(item.resource).current_limit == 2 # type: ignore[union-attr] + + +def test_request_admission_startup_ramp_starts_at_one_and_progresses_to_cap( + monkeypatch: pytest.MonkeyPatch, +) -> None: + now = 100.0 + monkeypatch.setattr("data_designer.engine.models.request_admission.controller.time.monotonic", lambda: now) + controller = _controller(cap=4, config=RequestAdmissionConfig(startup_ramp_seconds=10.0)) + item = _item() + + first = controller.try_acquire(item) + assert isinstance(first, RequestAdmissionLease) + second = controller.try_acquire(item) + assert isinstance(second, RequestAdmissionDenied) + assert second.reason == "no_capacity" + assert controller.pressure.snapshot(item.resource).current_limit == 1 # type: ignore[union-attr] + controller.release(first, RequestReleaseOutcome(kind="success")) + + now = 105.0 + halfway_leases = [controller.try_acquire(item) for _ in range(2)] + assert all(isinstance(lease, RequestAdmissionLease) for lease in halfway_leases) + denied = controller.try_acquire(item) + assert isinstance(denied, RequestAdmissionDenied) + assert denied.reason == "no_capacity" + assert controller.pressure.snapshot(item.resource).current_limit == 2 # type: ignore[union-attr] + for lease in halfway_leases: + assert isinstance(lease, RequestAdmissionLease) + controller.release(lease, RequestReleaseOutcome(kind="success")) + + now = 110.0 + full_ramp_leases = [controller.try_acquire(item) for _ in range(4)] + assert all(isinstance(lease, RequestAdmissionLease) for lease in full_ramp_leases) + assert controller.pressure.snapshot(item.resource).current_limit == 4 # type: ignore[union-attr] + for lease in full_ramp_leases: + assert isinstance(lease, RequestAdmissionLease) + controller.release(lease, RequestReleaseOutcome(kind="success")) + + +def test_request_admission_rate_limit_aborts_startup_ramp(monkeypatch: pytest.MonkeyPatch) -> None: + now = 100.0 + monkeypatch.setattr("data_designer.engine.models.request_admission.controller.time.monotonic", lambda: now) + controller = _controller( + cap=4, + config=RequestAdmissionConfig( + cooldown_seconds=0.0, + multiplicative_decrease_factor=0.5, + startup_ramp_seconds=10.0, + ), + ) + item = _item() + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + + controller.release(lease, RequestReleaseOutcome(kind="rate_limited")) + now = 110.0 + + assert controller.pressure.snapshot(item.resource).current_limit == 1 # type: ignore[union-attr] + next_lease = controller.try_acquire(item) + assert isinstance(next_lease, RequestAdmissionLease) + denied = controller.try_acquire(item) + assert isinstance(denied, RequestAdmissionDenied) + assert denied.reason == "no_capacity" + controller.release(next_lease, RequestReleaseOutcome(kind="success")) + + +def test_request_admission_blocking_timeout_raises_typed_error() -> None: + controller = _controller(cap=1) + first = _item() + second = _item(timeout=0.01) + lease = controller.try_acquire(first) + assert isinstance(lease, RequestAdmissionLease) + + with pytest.raises(RequestAdmissionError) as exc_info: + controller.acquire_sync(second) + + assert exc_info.value.decision.reason == "queue_timeout" + + +def test_request_admission_zero_sync_timeout_is_immediate() -> None: + controller = _controller(cap=1) + lease = controller.try_acquire(_item()) + assert isinstance(lease, RequestAdmissionLease) + + with pytest.raises(RequestAdmissionError) as exc_info: + controller.acquire_sync(_item(RequestDomain.EMBEDDING, timeout=0.0)) + + assert exc_info.value.decision.reason == "queue_timeout" + snapshot = controller.pressure.snapshot(RequestResourceKey("nvidia", "nemotron", RequestDomain.EMBEDDING)) + assert snapshot is not None + assert snapshot.waiters == 0 + controller.release(lease, RequestReleaseOutcome(kind="success")) + + +def test_request_admission_sync_unregistered_provider_raises_hard_denial() -> None: + controller = AdaptiveRequestAdmissionController() + + with pytest.raises(RequestAdmissionError) as exc_info: + controller.acquire_sync(_item()) + + assert exc_info.value.decision.reason == "hard_policy_denial" + snapshot = controller.pressure.snapshot(_item().resource) + assert snapshot is not None + assert snapshot.waiters == 0 + + +def test_request_admission_logs_sink_failures(caplog: pytest.LogCaptureFixture) -> None: + caplog.set_level(logging.WARNING, logger="data_designer.engine.models.request_admission.controller") + controller = AdaptiveRequestAdmissionController(event_sink=_BrokenRequestSink()) + + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + + assert "Request admission event sink raised; dropping event." in caplog.text + + +def test_request_lease_released_event_records_release_outcome() -> None: + sink = InMemoryAdmissionEventSink() + controller = AdaptiveRequestAdmissionController(event_sink=sink) + controller.register(provider_name="nvidia", model_id="nemotron", alias="default", max_parallel_requests=1) + item = _item() + lease = controller.try_acquire(item) + assert isinstance(lease, RequestAdmissionLease) + + controller.release(lease, RequestReleaseOutcome(kind="provider_failure")) + + release_events = [event for event in sink.request_events if event.event_kind == "request_lease_released"] + assert release_events + assert release_events[-1].reason_or_outcome == "provider_failure" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_acquire_sync_rejects_running_event_loop() -> None: + controller = _controller(cap=1) + + with pytest.raises(RuntimeError, match="would block the running event loop"): + controller.acquire_sync(_item()) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_try_acquire_does_not_bypass_queued_waiter_for_same_provider_model() -> None: + controller = _controller(cap=1) + first = _item(RequestDomain.CHAT) + queued = _item(RequestDomain.EMBEDDING, timeout=2) + incoming = _item(RequestDomain.IMAGE) + lease = controller.try_acquire(first) + assert isinstance(lease, RequestAdmissionLease) + + queued_task = asyncio.create_task(controller.acquire_async(queued)) + await asyncio.sleep(0) + + denied = controller.try_acquire(incoming) + + assert isinstance(denied, RequestAdmissionDenied) + assert denied.reason == "no_capacity" + snapshot = controller.pressure.snapshot(queued.resource) + assert snapshot is not None + assert snapshot.waiters == 1 + controller.release(lease, RequestReleaseOutcome(kind="success")) + queued_lease = await queued_task + controller.release(queued_lease, RequestReleaseOutcome(kind="success")) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_request_admission_zero_async_timeout_is_immediate() -> None: + controller = _controller(cap=1) + lease = controller.try_acquire(_item()) + assert isinstance(lease, RequestAdmissionLease) + + with pytest.raises(RequestAdmissionError) as exc_info: + await controller.acquire_async(_item(RequestDomain.EMBEDDING, timeout=0.0)) + + assert exc_info.value.decision.reason == "queue_timeout" + snapshot = controller.pressure.snapshot(RequestResourceKey("nvidia", "nemotron", RequestDomain.EMBEDDING)) + assert snapshot is not None + assert snapshot.waiters == 0 + controller.release(lease, RequestReleaseOutcome(kind="success")) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_acquire_async_does_not_assign_expired_waiter_after_release( + monkeypatch: pytest.MonkeyPatch, +) -> None: + controller = _controller(cap=1) + monkeypatch.setattr(controller, "_wait_seconds_locked", lambda _item, _now, _deadline: 10.0) + lease = controller.try_acquire(_item(RequestDomain.CHAT)) + assert isinstance(lease, RequestAdmissionLease) + queued = _item(RequestDomain.EMBEDDING, timeout=0.01) + + queued_task = asyncio.create_task(controller.acquire_async(queued)) + for _ in range(20): + snapshot = controller.pressure.snapshot(queued.resource) + if snapshot is not None and snapshot.waiters == 1: + break + await asyncio.sleep(0) + else: + raise AssertionError("async waiter did not enqueue") + + def release_after_deadline() -> None: + time.sleep(0.03) + controller.release(lease, RequestReleaseOutcome(kind="success")) + + release_thread = threading.Thread(target=release_after_deadline) + release_thread.start() + try: + time.sleep(0.06) + with pytest.raises(RequestAdmissionError) as exc_info: + await asyncio.wait_for(queued_task, timeout=0.5) + finally: + release_thread.join() + + assert exc_info.value.decision.reason == "queue_timeout" + snapshot = controller.pressure.snapshot(queued.resource) + assert snapshot is not None + assert snapshot.waiters == 0 + assert snapshot.active_lease_count == 0 + assert snapshot.in_flight_count == 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_acquire_async_wakes_when_release_assigns_lease(monkeypatch: pytest.MonkeyPatch) -> None: + controller = _controller(cap=1) + monkeypatch.setattr(controller, "_wait_seconds_locked", lambda _item, _now, _deadline: 10.0) + lease = controller.try_acquire(_item(RequestDomain.CHAT)) + assert isinstance(lease, RequestAdmissionLease) + queued = _item(RequestDomain.EMBEDDING, timeout=30.0) + + queued_task = asyncio.create_task(controller.acquire_async(queued)) + for _ in range(20): + snapshot = controller.pressure.snapshot(queued.resource) + if snapshot is not None and snapshot.waiters == 1: + break + await asyncio.sleep(0) + else: + raise AssertionError("async waiter did not enqueue") + + controller.release(lease, RequestReleaseOutcome(kind="success")) + queued_lease = await asyncio.wait_for(queued_task, timeout=0.5) + + controller.release(queued_lease, RequestReleaseOutcome(kind="success")) + + +@pytest.mark.asyncio(loop_scope="session") +async def test_acquire_async_unregistered_provider_raises_hard_denial(monkeypatch: pytest.MonkeyPatch) -> None: + controller = AdaptiveRequestAdmissionController() + monkeypatch.setattr(controller, "_wait_seconds_locked", lambda _item, _now, _deadline: 10.0) + queued = _item(RequestDomain.CHAT, timeout=30.0) + + with pytest.raises(RequestAdmissionError) as exc_info: + await asyncio.wait_for(controller.acquire_async(queued), timeout=0.5) + + assert exc_info.value.decision.reason == "hard_policy_denial" + snapshot = controller.pressure.snapshot(queued.resource) + assert snapshot is not None + assert snapshot.waiters == 0 + + +def test_request_admission_released_history_is_bounded() -> None: + controller = _controller(cap=1) + first_lease: RequestAdmissionLease | None = None + for _ in range(RELEASED_LEASE_HISTORY_LIMIT + 5): + lease = controller.try_acquire(_item()) + assert isinstance(lease, RequestAdmissionLease) + first_lease = first_lease or lease + controller.release(lease, RequestReleaseOutcome(kind="success")) + + assert len(controller._released) == RELEASED_LEASE_HISTORY_LIMIT + assert len(controller._released_order) == RELEASED_LEASE_HISTORY_LIMIT + assert controller._released_order.maxlen == RELEASED_LEASE_HISTORY_LIMIT + assert first_lease is not None + assert controller.release(first_lease, RequestReleaseOutcome(kind="success")).reason == "unknown_lease" + + +@pytest.mark.asyncio(loop_scope="session") +async def test_async_cancellation_after_waiter_assignment_releases_lease() -> None: + controller = _controller(cap=1) + first = _item(RequestDomain.CHAT) + queued = _item(RequestDomain.EMBEDDING, timeout=1.0) + lease = controller.try_acquire(first) + assert isinstance(lease, RequestAdmissionLease) + + queued_task = asyncio.create_task(controller.acquire_async(queued)) + await asyncio.sleep(0) + controller.release(lease, RequestReleaseOutcome(kind="success")) + queued_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await queued_task + + snapshot = controller.pressure.snapshot(queued.resource) + assert snapshot is not None + assert snapshot.active_lease_count == 0 + assert snapshot.in_flight_count == 0 + assert snapshot.waiters == 0 diff --git a/packages/data-designer-engine/tests/engine/models/test_facade.py b/packages/data-designer-engine/tests/engine/models/test_facade.py index 208010500..0be33bd02 100644 --- a/packages/data-designer-engine/tests/engine/models/test_facade.py +++ b/packages/data-designer-engine/tests/engine/models/test_facade.py @@ -9,6 +9,7 @@ import pytest from data_designer.engine.mcp.errors import MCPConfigurationError, MCPToolError +from data_designer.engine.models.clients.errors import ProviderError, ProviderErrorKind from data_designer.engine.models.clients.types import ( AssistantMessage, ChatCompletionRequest, @@ -19,7 +20,11 @@ ToolCall, Usage, ) -from data_designer.engine.models.errors import ImageGenerationError, ModelGenerationValidationFailureError +from data_designer.engine.models.errors import ( + ImageGenerationError, + ModelGenerationValidationFailureError, + ModelTimeoutError, +) from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.parsers.errors import ParserException from data_designer.engine.models.usage import TokenCountSource @@ -277,6 +282,35 @@ def test_generate_strips_response_content( assert result == expected +def test_generate_maps_statusless_provider_timeout_to_model_timeout(stub_model_facade: ModelFacade) -> None: + stub_model_facade._client.completion.side_effect = ProviderError( + kind=ProviderErrorKind.TIMEOUT, + message="request timed out", + status_code=None, + provider_name="stub", + model_name=stub_model_facade.model_name, + ) + + with pytest.raises(ModelTimeoutError, match="timed out"): + stub_model_facade.generate(prompt="test", parser=lambda value: value) + + +@pytest.mark.asyncio +async def test_agenerate_maps_statusless_provider_timeout_to_model_timeout(stub_model_facade: ModelFacade) -> None: + stub_model_facade._client.acompletion = AsyncMock( + side_effect=ProviderError( + kind=ProviderErrorKind.TIMEOUT, + message="request timed out", + status_code=None, + provider_name="stub", + model_name=stub_model_facade.model_name, + ) + ) + + with pytest.raises(ModelTimeoutError, match="timed out"): + await stub_model_facade.agenerate(prompt="test", parser=lambda value: value) + + def test_model_alias_property(stub_model_facade: ModelFacade, stub_model_configs: list[Any]) -> None: assert stub_model_facade.model_alias == stub_model_configs[0].alias diff --git a/packages/data-designer-engine/tests/engine/models/test_model_registry.py b/packages/data-designer-engine/tests/engine/models/test_model_registry.py index 86b16864f..fc944d07d 100644 --- a/packages/data-designer-engine/tests/engine/models/test_model_registry.py +++ b/packages/data-designer-engine/tests/engine/models/test_model_registry.py @@ -6,6 +6,7 @@ import pytest from data_designer.config.models import ChatCompletionInferenceParams, ModelConfig +from data_designer.config.run_config import RequestAdmissionTuningConfig, RunConfig from data_designer.engine.models.errors import ModelAuthenticationError from data_designer.engine.models.facade import ModelFacade from data_designer.engine.models.factory import create_model_registry @@ -54,6 +55,35 @@ def test_create_model_registry( assert isinstance(model_registry, ModelRegistry) +def test_create_model_registry_maps_request_admission_tuning_config( + stub_model_configs: list[ModelConfig], + stub_secrets_resolver: object, + stub_model_provider_registry: object, +) -> None: + model_registry = create_model_registry( + model_configs=stub_model_configs, + secret_resolver=stub_secrets_resolver, + model_provider_registry=stub_model_provider_registry, + run_config=RunConfig( + request_admission=RequestAdmissionTuningConfig( + multiplicative_decrease_factor=0.5, + additive_increase_step=2, + successes_until_increase=7, + cooldown_seconds=1.5, + startup_ramp_seconds=30.0, + ) + ), + ) + + assert model_registry.request_admission is not None + request_config = model_registry.request_admission.config + assert request_config.multiplicative_decrease_factor == 0.5 + assert request_config.additive_increase_step == 2 + assert request_config.successes_until_increase == 7 + assert request_config.cooldown_seconds == 1.5 + assert request_config.startup_ramp_seconds == 30.0 + + def test_public_props(stub_model_configs, stub_model_registry): assert stub_model_registry.model_configs == { model_config.alias: model_config for model_config in stub_model_configs diff --git a/packages/data-designer-engine/tests/engine/test_capacity.py b/packages/data-designer-engine/tests/engine/test_capacity.py new file mode 100644 index 000000000..856aeba09 --- /dev/null +++ b/packages/data-designer-engine/tests/engine/test_capacity.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from data_designer.engine.capacity import ( + AsyncCapacityConfigured, + AsyncCapacityObservedMaxima, + AsyncCapacityPlan, + AsyncCapacityRuntimeSnapshot, + CapacityValue, + RequestAdmissionConfigSnapshot, + RowGroupAdmission, +) +from data_designer.engine.models.request_admission.config import RequestAdmissionConfig +from data_designer.engine.models.request_admission.resources import RequestDomain, RequestResourceKey +from data_designer.engine.models.resources import ProviderModelKey, ProviderModelStaticCap + + +def test_request_admission_config_snapshot_records_resources() -> None: + resource = RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) + config = RequestAdmissionConfig( + initial_limits={resource: 2}, + max_limit_clamps={resource: 4}, + startup_ramp_seconds=30.0, + ) + + snapshot = RequestAdmissionConfigSnapshot.from_config(config) + + assert snapshot.resources == (resource,) + assert snapshot.initial_limits[resource] == 2 + assert snapshot.max_limit_clamps[resource] == 4 + assert snapshot.startup_ramp_seconds == 30.0 + + +def test_async_capacity_plan_records_configured_runtime_and_maxima() -> None: + resource = RequestResourceKey("nvidia", "nemotron", RequestDomain.CHAT) + provider_model = ProviderModelKey("nvidia", "nemotron") + static_cap = ProviderModelStaticCap(cap=4, aliases=("default",), raw_caps={"default": 4}) + + plan = AsyncCapacityPlan( + configured=AsyncCapacityConfigured( + buffer_size=CapacityValue(value=16, source="run_config"), + row_group_admission=RowGroupAdmission( + row_group_concurrency=CapacityValue(value=2, source="dataset_builder"), + observed_in_flight=1, + ), + submission_capacity=CapacityValue(value=8, source="engine_internal_config"), + task_resource_limits=CapacityValue(value={"submission": 8, "llm_wait": 4}, source="engine_internal_config"), + request_resources=CapacityValue(value=(resource,), source="runtime_snapshot"), + provider_model_static_caps=CapacityValue(value={provider_model: static_cap}, source="model_metadata"), + request_domain_initial_limits=CapacityValue(value={resource: 2}, source="engine_internal_config"), + request_admission_config=CapacityValue( + value=RequestAdmissionConfigSnapshot.from_config(RequestAdmissionConfig(initial_limits={resource: 2})), + source="engine_internal_config", + ), + transport_pool_limits=CapacityValue(value={provider_model: 8}, source="adapter_config"), + ), + runtime_snapshot=AsyncCapacityRuntimeSnapshot( + request_domain_current_limits={resource: 2}, + request_domain_effective_max={resource: 4}, + request_domain_blocked_until={resource: None}, + provider_model_aggregate_in_flight={provider_model: 0}, + ), + observed_maxima=AsyncCapacityObservedMaxima( + row_groups_in_flight=1, + request_in_flight_by_resource={resource: 2}, + provider_model_aggregate_in_flight={provider_model: 2}, + ), + ) + + assert plan.configured.provider_model_static_caps.value[provider_model].merge_rule == "min_same_endpoint" + assert plan.runtime_snapshot.request_domain_current_limits[resource] == 2 + assert plan.observed_maxima.provider_model_aggregate_in_flight[provider_model] == 2 diff --git a/packages/data-designer-engine/tests/engine/test_observability.py b/packages/data-designer-engine/tests/engine/test_observability.py new file mode 100644 index 000000000..e7d9ce21b --- /dev/null +++ b/packages/data-designer-engine/tests/engine/test_observability.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass +from enum import Enum + +from data_designer.engine.observability import ( + CorrelatedRuntimeView, + InMemoryAdmissionEventSink, + RequestAdmissionEvent, + RuntimeCorrelation, + RuntimeCorrelationProvider, + SchedulerAdmissionEvent, +) + + +class _DiagnosticMode(Enum): + TEST = "test" + + +@dataclass(frozen=True) +class _DiagnosticPayload: + label: str + mode: _DiagnosticMode + + +def _correlation() -> RuntimeCorrelation: + return RuntimeCorrelation( + run_id="run", + row_group=1, + task_column="answer", + task_type="cell", + scheduling_group_kind="model", + scheduling_group_identity_hash="hash", + task_execution_id="task-exec", + ) + + +def test_runtime_correlation_provider_sets_and_resets_context() -> None: + provider = RuntimeCorrelationProvider() + correlation = _correlation() + + token = provider.set(correlation) + assert provider.current() == correlation + + provider.reset(token) + assert provider.current() is None + + +def test_admission_events_capture_correlation_and_diagnostics() -> None: + correlation = _correlation() + + scheduler_event = SchedulerAdmissionEvent.capture( + "task_lease_acquired", + sequence=1, + correlation=correlation, + task_id="task-1", + task_lease_id="lease-1", + diagnostics={"resource": "submission"}, + ) + request_event = RequestAdmissionEvent.capture( + "request_lease_acquired", + sequence=2, + correlation=correlation, + request_attempt_id="request-1", + request_lease_id="lease-2", + diagnostics={"resource": "chat"}, + ) + + assert scheduler_event.captured_correlation == asdict(correlation) + assert scheduler_event.task_id == "task-1" + assert scheduler_event.diagnostics == {"resource": "submission"} + assert request_event.captured_correlation == asdict(correlation) + assert request_event.request_attempt_id == "request-1" + assert request_event.diagnostics == {"resource": "chat"} + + +def test_admission_events_are_json_safe_at_construction() -> None: + correlation = _correlation() + payload = _DiagnosticPayload(label="payload", mode=_DiagnosticMode.TEST) + + scheduler_event = SchedulerAdmissionEvent.capture( + "admission_blocked", + sequence=1, + correlation=correlation, + snapshot=payload, + diagnostics={"payload": payload, "values": {"b", "a"}, "pair": ("x", _DiagnosticMode.TEST)}, + ) + request_event = RequestAdmissionEvent.capture( + "request_wait_started", + sequence=2, + correlation=correlation, + request_resource_key=payload, + request_group_key=("group", _DiagnosticMode.TEST), + pressure_snapshot={"payload": payload}, + diagnostics={"payload": payload}, + ) + + json.dumps(asdict(scheduler_event), sort_keys=True) + json.dumps(asdict(request_event), sort_keys=True) + assert scheduler_event.snapshot == {"label": "payload", "mode": "test"} + assert scheduler_event.diagnostics["values"] == ["a", "b"] + assert request_event.request_resource_key == {"label": "payload", "mode": "test"} + + +def test_in_memory_admission_event_sink_collects_scheduler_and_request_events() -> None: + sink = InMemoryAdmissionEventSink() + scheduler_event = SchedulerAdmissionEvent.capture("selected", sequence=1) + request_event = RequestAdmissionEvent.capture("request_wait_started", sequence=2) + + sink.emit_scheduler_event(scheduler_event) + sink.emit_request_event(request_event) + + assert sink.scheduler_events == [scheduler_event] + assert sink.request_events == [request_event] + + +def test_correlated_runtime_view_timeline_sorts_events() -> None: + scheduler_event = SchedulerAdmissionEvent(event_kind="selected", captured_at_monotonic=2.0, sequence=1) + first_request_event = RequestAdmissionEvent( + event_kind="request_wait_started", + captured_at_monotonic=1.0, + sequence=3, + ) + second_request_event = RequestAdmissionEvent( + event_kind="request_lease_acquired", + captured_at_monotonic=2.0, + sequence=0, + ) + view = CorrelatedRuntimeView( + scheduler_events=(scheduler_event,), + request_events=(first_request_event, second_request_event), + ) + + assert view.timeline == (first_request_event, second_request_event, scheduler_event) diff --git a/packages/data-designer/src/data_designer/interface/data_designer.py b/packages/data-designer/src/data_designer/interface/data_designer.py index ed42d6e0b..c1a4d0b9c 100644 --- a/packages/data-designer/src/data_designer/interface/data_designer.py +++ b/packages/data-designer/src/data_designer/interface/data_designer.py @@ -77,8 +77,8 @@ from data_designer.plugins.registry import PluginRegistry if TYPE_CHECKING: - from data_designer.engine.models.clients.throttle_manager import ThrottleManager from data_designer.engine.models.facade import ModelFacade + from data_designer.engine.models.request_admission.controller import AdaptiveRequestAdmissionController logger = logging.getLogger(__name__) @@ -156,7 +156,7 @@ def __init__( self._secret_resolver = secret_resolver or DEFAULT_SECRET_RESOLVER self._artifact_path = Path(artifact_path) if artifact_path is not None else Path.cwd() / "artifacts" self._run_config = RunConfig() - self._throttle_manager: ThrottleManager = self._create_throttle_manager() + self._request_admission: AdaptiveRequestAdmissionController = self._create_request_admission_controller() self._managed_assets_path = Path(managed_assets_path or MANAGED_ASSETS_PATH) self._person_reader = person_reader # Only consult the YAML's `default:` key when we are also falling back to @@ -613,7 +613,7 @@ def set_run_config(self, run_config: RunConfig) -> None: due to error-rate thresholds. Errors are still tracked for reporting. """ self._run_config = run_config - self._throttle_manager = self._create_throttle_manager() + self._request_admission = self._create_request_admission_controller() def get_models(self, model_aliases: list[str]) -> dict[str, ModelFacade]: """Get a dict of ModelFacade instances for custom column development. @@ -702,13 +702,13 @@ def _create_resource_provider( mcp_providers=self._mcp_providers, tool_configs=config_builder.tool_configs, client_concurrency_mode=self._resolve_client_concurrency_mode(config_builder), - throttle_manager=self._throttle_manager, + request_admission=self._request_admission, ) - def _create_throttle_manager(self) -> ThrottleManager: - from data_designer.engine.models.clients.throttle_manager import ThrottleManager + def _create_request_admission_controller(self) -> AdaptiveRequestAdmissionController: + from data_designer.engine.models.factory import create_request_admission_controller - return ThrottleManager(self._run_config.throttle) + return create_request_admission_controller(self._run_config) @staticmethod def _resolve_client_concurrency_mode(config_builder: DataDesignerConfigBuilder) -> ClientConcurrencyMode: diff --git a/packages/data-designer/src/data_designer/interface/results.py b/packages/data-designer/src/data_designer/interface/results.py index a7038866c..6c5b076a3 100644 --- a/packages/data-designer/src/data_designer/interface/results.py +++ b/packages/data-designer/src/data_designer/interface/results.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: import pandas as pd - from data_designer.engine.dataset_builders.utils.task_model import TaskTrace + from data_designer.engine.dataset_builders.scheduling.task_model import TaskTrace ExportFormat = Literal["jsonl", "csv", "parquet"] SUPPORTED_EXPORT_FORMATS: tuple[str, ...] = get_args(ExportFormat) diff --git a/packages/data-designer/tests/interface/test_acreate.py b/packages/data-designer/tests/interface/test_acreate.py index b87d8eeba..30c452fb2 100644 --- a/packages/data-designer/tests/interface/test_acreate.py +++ b/packages/data-designer/tests/interface/test_acreate.py @@ -119,7 +119,7 @@ def fake_create( assert data_designer.create.call_count == 2 -def test_data_designer_reuses_throttle_manager_across_create_calls( +def test_data_designer_reuses_request_admission_across_create_calls( tmp_path: Path, stub_model_providers: list[ModelProvider], stub_model_configs: list[ModelConfig], @@ -133,7 +133,8 @@ def test_data_designer_reuses_throttle_manager_across_create_calls( assert left_provider.model_registry is not None assert right_provider.model_registry is not None - assert left_provider.model_registry.throttle_manager is right_provider.model_registry.throttle_manager + assert left_provider.model_registry.request_admission is right_provider.model_registry.request_admission + assert left_provider.model_registry.request_admission is data_designer._request_admission @pytest.mark.asyncio diff --git a/packages/data-designer/tests/interface/test_data_designer.py b/packages/data-designer/tests/interface/test_data_designer.py index 0c5f6d1fe..fe88509d6 100644 --- a/packages/data-designer/tests/interface/test_data_designer.py +++ b/packages/data-designer/tests/interface/test_data_designer.py @@ -23,7 +23,7 @@ from data_designer.config.errors import InvalidConfigError from data_designer.config.models import ModelProvider from data_designer.config.processors import DropColumnsProcessorConfig -from data_designer.config.run_config import JinjaRenderingEngine, RunConfig, ThrottleConfig +from data_designer.config.run_config import JinjaRenderingEngine, RequestAdmissionTuningConfig, RunConfig from data_designer.config.sampler_params import CategorySamplerParams, DatetimeSamplerParams, SamplerType from data_designer.config.seed import IndexRange, PartitionBlock, SamplingStrategy from data_designer.config.seed_source import ( @@ -712,7 +712,7 @@ def test_init_no_user_providers_no_yaml_default_stays_quiet( def test_run_config_setting_persists(stub_artifact_path, stub_model_providers): """Test that run config setting persists across multiple calls.""" data_designer = DataDesigner(artifact_path=stub_artifact_path, model_providers=stub_model_providers) - original_throttle_manager = data_designer._throttle_manager + original_request_admission = data_designer._request_admission # Test default values assert data_designer.run_config.disable_early_shutdown is False @@ -731,7 +731,7 @@ def test_run_config_setting_persists(stub_artifact_path, stub_model_providers): buffer_size=500, max_conversation_restarts=7, max_conversation_correction_steps=2, - throttle=ThrottleConfig(success_window=7), + request_admission=RequestAdmissionTuningConfig(successes_until_increase=7), ) ) assert data_designer.run_config.disable_early_shutdown is True @@ -740,8 +740,8 @@ def test_run_config_setting_persists(stub_artifact_path, stub_model_providers): assert data_designer.run_config.buffer_size == 500 assert data_designer.run_config.max_conversation_restarts == 7 assert data_designer.run_config.max_conversation_correction_steps == 2 - assert data_designer._throttle_manager is not original_throttle_manager - assert data_designer._throttle_manager._success_window == 7 + assert data_designer._request_admission is not original_request_admission + assert data_designer._request_admission.config.successes_until_increase == 7 # Test updating values data_designer.set_run_config( diff --git a/plans/645/AsyncSchedulingEpicComponent.png b/plans/645/AsyncSchedulingEpicComponent.png new file mode 100644 index 000000000..4eab0b4e3 Binary files /dev/null and b/plans/645/AsyncSchedulingEpicComponent.png differ diff --git a/plans/645/AsyncSchedulingEpicIssueMap.png b/plans/645/AsyncSchedulingEpicIssueMap.png new file mode 100644 index 000000000..383107f72 Binary files /dev/null and b/plans/645/AsyncSchedulingEpicIssueMap.png differ diff --git a/plans/645/AsyncSchedulingEpicRuntimeSequence.png b/plans/645/AsyncSchedulingEpicRuntimeSequence.png new file mode 100644 index 000000000..f5f24491e Binary files /dev/null and b/plans/645/AsyncSchedulingEpicRuntimeSequence.png differ diff --git a/plans/645/AsyncSchedulingRequestAdmissionClassModel.png b/plans/645/AsyncSchedulingRequestAdmissionClassModel.png new file mode 100644 index 000000000..34bdb2da2 Binary files /dev/null and b/plans/645/AsyncSchedulingRequestAdmissionClassModel.png differ diff --git a/plans/645/AsyncSchedulingSupportContractsClassModel.png b/plans/645/AsyncSchedulingSupportContractsClassModel.png new file mode 100644 index 000000000..1975d24f7 Binary files /dev/null and b/plans/645/AsyncSchedulingSupportContractsClassModel.png differ diff --git a/plans/645/AsyncSchedulingTaskAdmissionClassModel.png b/plans/645/AsyncSchedulingTaskAdmissionClassModel.png new file mode 100644 index 000000000..7fe1a5b06 Binary files /dev/null and b/plans/645/AsyncSchedulingTaskAdmissionClassModel.png differ diff --git a/plans/645/README.md b/plans/645/README.md new file mode 100644 index 000000000..0d95081fb --- /dev/null +++ b/plans/645/README.md @@ -0,0 +1,92 @@ +# Async Scheduling Architecture Plan + +Source-of-truth architecture plan for the async scheduling epic tracked by issue 645. The UML file is the visual index; the Markdown files in this directory are the durable spec. GitHub issues should point back here and focus on implementation sequencing, quality gates, tests, and evidence. + +If an issue body and this plan disagree, update this plan first, then adjust the issue to reference the corrected section. + +This directory is the maintainer source of truth while the epic is active. Issue 660 promotes the stabilized V1 content into current user/operator architecture docs and marks older pre-epic scheduling descriptions as historical or removes them. + +## Spec + +- [Architecture](architecture.md): target system shape, ownership boundaries, invariants, and non-goals. +- [Contracts](contracts.md): durable DTO, protocol, event, and config names. +- [Module ownership](module-ownership.md): final repository/module homes, import rules, audience boundaries, tests, and benchmark ownership. +- [Capacity model](capacity-model.md): layered capacity vocabulary and ownership. +- [Task admission](task-admission.md): scheduler-owned ready selection, task leases, policy hooks, bounded borrowing, and resource-vector direction. +- [Request admission](request-admission.md): model-call admission, AIMD controller shape, dynamic request semantics, and replacement of pre-epic request-control names. +- [Observability](observability.md): scheduler events, request events, runtime correlation, snapshots, and cardinality rules. +- [Benchmark plan](benchmark-plan.md): scenarios, metrics, A/B baselines, and required artifacts. +- [Migration and cleanup](migration-and-cleanup.md): legacy-name removal, grep gates, and no-shim rules. +- [Issue map](issue-map.md): how the GitHub issues map to this source-of-truth plan. + +## Read This First + +Recommended reading paths: + +- Implementers: [Architecture](architecture.md), [Contracts](contracts.md), [Module ownership](module-ownership.md), then the topic file for the issue being implemented. +- Plugin documentation authors: [Contracts](contracts.md#metadata-contracts), [Architecture](architecture.md#audience-and-api-boundaries), and [Migration and cleanup](migration-and-cleanup.md#documentation-cleanup). +- Operators and performance reviewers: [Capacity model](capacity-model.md), [Observability](observability.md), and [Benchmark plan](benchmark-plan.md). +- Issue owners: [Issue map](issue-map.md), then the linked source sections for the issue. + +## Source + +- [async-scheduling-epic.puml](async-scheduling-epic.puml): PlantUML source for every diagram on this page. + +The PNG files in this directory are generated review artifacts. The PlantUML file is authoritative for diagram source. Any PR that changes the UML should regenerate the PNGs and include them in the same diff, or explicitly state why rendering was unavailable. + +## Component View + +![Component view](AsyncSchedulingEpicComponent.png) + +## Task Admission Contracts + +![Task admission class model](AsyncSchedulingTaskAdmissionClassModel.png) + +## Request Admission Contracts + +![Request admission class model](AsyncSchedulingRequestAdmissionClassModel.png) + +## Capacity, Telemetry, and Evidence Contracts + +![Support contracts class model](AsyncSchedulingSupportContractsClassModel.png) + +## Runtime Sequence + +![Runtime sequence](AsyncSchedulingEpicRuntimeSequence.png) + +## Issue Dependency Map + +![Issue dependency map](AsyncSchedulingEpicIssueMap.png) + +## Render + +```bash +plantuml plans/645/async-scheduling-epic.puml +``` + +The expected runtime control owner is `AsyncTaskScheduler`: + +```text +ColumnGenerator.get_scheduling_metadata() + -> SchedulingMetadata + -> TaskSchedulingResolver + -> ResolvedTaskScheduling + -> SchedulableTask inputs + +AsyncTaskScheduler + -> CompletionTracker.ready_frontier() + -> FairTaskQueue.enqueue(...) + -> FairTaskQueue.select_next(scheduler-owned eligibility callback) + -> TaskAdmissionController.try_acquire(selection.item, selection.queue_view) + -> FairTaskQueue.commit(...) + -> execute admitted task/generator code + +Admitted task/generator code + -> model facade/provider boundary + -> ModelRequestExecutor.execute_attempt(...) per concrete request attempt + -> RequestAdmissionController.acquire_async(...) + -> provider/model endpoint + -> RequestAdmissionController.release(lease, outcome) +``` + +Task admission and request admission each have explicit controller, queue, policy, and lease/state boundaries where applicable. Telemetry observes scheduler admission and request admission separately, then issue 648 correlates the two timelines through the runtime correlation provider. diff --git a/plans/645/architecture.md b/plans/645/architecture.md new file mode 100644 index 000000000..58f756f07 --- /dev/null +++ b/plans/645/architecture.md @@ -0,0 +1,139 @@ +# Async Scheduling Architecture + +This plan moves Data Designer's async engine from implicit scheduling behavior to explicit, layered admission control. The target architecture separates static generator resource metadata, dependency readiness, ready-work ordering, scheduler-level task admission, concrete model-request admission, capacity diagnostics, and runtime observability. + +The guiding rule is: each layer owns one question and speaks through typed boundaries. + +## Source Of Truth + +The Markdown files in `plans/645` are the source of truth for this epic. The UML in [async-scheduling-epic.puml](async-scheduling-epic.puml) is the visual index and must be kept aligned with these files. GitHub issues should reference this plan and own implementation sequencing, validation commands, acceptance gates, and PR-level evidence. + +## Target Shape + +The durable data-preparation flow is: + +```text +ColumnGenerator / plugin + -> ColumnGenerator.get_scheduling_metadata() + -> SchedulingMetadata + -> TaskSchedulingResolver + -> ResolvedTaskScheduling + -> SchedulableTask inputs +``` + +The durable runtime control flow is: + +```text +AsyncTaskScheduler + -> CompletionTracker.ready_frontier() + -> FairTaskQueue.enqueue(...) + -> FairTaskQueue.select_next(scheduler-owned eligibility callback) + -> TaskAdmissionController.try_acquire(selection.item, selection.queue_view) + -> FairTaskQueue.commit(selection) + -> execute admitted task/generator code + +admitted task/generator code + -> model facade/provider boundary + -> ModelRequestExecutor.execute_attempt(...) + -> RequestAdmissionController.acquire_async(RequestAdmissionItem) + -> provider/model endpoint + -> RequestAdmissionController.release(lease, outcome) +``` + +This is not a passive pipeline where `CompletionTracker`, `FairTaskQueue`, or `TaskAdmissionController` pushes work into the scheduler. `AsyncTaskScheduler` is the execution owner. It asks the readiness tracker for work, enqueues ready tasks, asks the queue to select a candidate through an admission eligibility callback, asks the task admission controller for a lease, commits the queue selection, executes the admitted task, and releases the lease. + +`ModelRequestExecutor` is not a scheduler task wrapper. It is reached only when admitted task/generator code makes a concrete model call through the model facade/provider boundary. A task may make zero, one, or many concrete calls; each call attempt receives request admission independently. + +## Layer Responsibilities + +`SchedulingMetadata` is a generator-facing static resource declaration. It describes the resource shape a generator expects, such as local work or model-backed work. It does not expose queue internals, admitted limits, request domains, AIMD state, or runtime pressure. + +`TaskSchedulingResolver` is the internal bridge from generator metadata to scheduler inputs. It produces `ResolvedTaskScheduling`, including `TaskGroupSpec` and `SchedulerResourceRequest`, and appends scheduler-owned flow identity such as output columns. It is the only scheduler grouping bridge once the legacy resolver is removed. + +`CompletionTracker` owns dependency readiness. It reports the ready frontier and completion state to `AsyncTaskScheduler`. It does not enqueue into the ready queue, order ready work, admit resources, or inspect provider/model pressure. + +`FairTaskQueue` owns ready-work membership and ordering. Its selection operation is non-mutating and takes an eligibility callback supplied by scheduler admission. It does not own dependency readiness, admitted counts, provider metadata, request admission, or policy state. + +`TaskAdmissionController` owns scheduler-level task leases and resource accounting. `TaskAdmissionPolicy` decides whether a queued task is eligible under the current queue and admission views. The controller consumes resolved scheduler inputs and its engine-internal `TaskAdmissionConfig`; it must not inspect generators, user config layout, model registries, or provider registries directly. + +`AsyncTaskScheduler` owns runtime control flow. It wires readiness, queue selection, task admission, worker spawn, task execution, salvage/retry behavior, shutdown, and lease release. + +`ModelRequestExecutor` is the durable model-call boundary. It maps each concrete provider/model/domain call attempt to a `RequestAdmissionItem`, acquires a request lease, calls the provider, records request timing, and releases that exact lease with a classified outcome on success, rate limit, failure, cancellation, timeout, or unexpected exception. + +`RequestAdmissionController` owns request-level provider/model/domain admission. `AdaptiveRequestAdmissionController` is the V1 AIMD-backed implementation. Internal `RequestFairQueue`, `RequestAdmissionPolicy`, and `AdaptiveRequestLimitState` are implementation components of this controller, not a second public layer. + +`SchedulerAdmissionEventSink` and `RequestAdmissionEventSink` observe their own layers separately. `RuntimeCorrelationProvider` supplies shared runtime context, and `CorrelatedRuntimeView` joins timelines without collapsing the two telemetry systems. + +## Audience And API Boundaries + +The plan uses several contract categories. Keeping them separate prevents internal scheduling mechanics from becoming accidental plugin API. + +Durable engine vocabulary is maintainer-facing unless this plan explicitly marks it plugin-facing or operator-facing. See [Module ownership](module-ownership.md) for final module homes, import rules, and test/benchmark ownership. + +| Audience | Durable surface | Must not expose | +| --- | --- | --- | +| Plugin authors | `ColumnGenerator.get_scheduling_metadata()` and `SchedulingMetadata` | queue state, task leases, request domains, AIMD state, runtime pressure | +| Users/operators | documented run config fields, `AsyncCapacityPlan`, benchmark and telemetry artifacts | internal queue/policy classes, per-lease mutation APIs | +| Engine implementers | scheduler/request admission protocols, DTOs, policies, snapshots, events | config-layer imports from engine runtime | +| Diagnostics and benchmarks | event DTOs, snapshots, correlation view, capacity plan | prompts, completions, row data, secrets, unbounded IDs as metric labels | + +Package ownership follows Data Designer's structural layering: + +| Package | Owns | +| --- | --- | +| `data-designer-config` | public configuration DTOs and generator-facing metadata, including `SchedulingMetadata`, metadata validation errors, and future stable config surfaces only after an issue explicitly promotes them to public API | +| `data-designer-engine` | scheduler runtime DTOs, task/request admission controllers, queues, policies, leases, snapshots, events, capacity plan construction, and benchmark harness internals | +| `data-designer` | public interface wiring, CLI/operator presentation, and integration docs; it may consume engine/config contracts but must not make engine internals plugin API | + +When a contract is shared across packages, the lower package owns the data definition and the higher package owns presentation or orchestration. Engine code may import config contracts; config code must not import engine runtime protocols. + +Target repository ownership is part of the architecture, not an implementation detail. The epic does not include compatibility aliases, shim modules, transitional reexports, or duplicate old/new paths for scheduler or request-admission names. + +## Module Ownership + +The final target module layout is defined in [Module ownership](module-ownership.md). In summary: + +- generator-facing metadata lives in `data_designer.config.scheduling` +- scheduler task models, readiness, queues, task admission, and task policies live under `data_designer.engine.dataset_builders.scheduling` +- `AsyncTaskScheduler` remains in `data_designer.engine.dataset_builders.async_scheduler` as the runtime coordinator only +- shared provider/model identity lives in `data_designer.engine.models.resources` +- concrete request admission lives under `data_designer.engine.models.request_admission` +- `ModelRequestExecutor` stays under `data_designer.engine.models.clients` because it wraps model clients at the acquire/call/release boundary +- capacity diagnostics live in `data_designer.engine.capacity` +- runtime admission events and correlation live in `data_designer.engine.observability` +- product/provider usage telemetry remains separate in `data_designer.engine.models.telemetry` + +Implementation PRs should use these final homes directly and must not leave production compatibility names, compatibility modules, old-path reexports, or durable tests for replaced names. + +## Two-Stage Admission + +Task admission controls when ready dataset work may become a running worker. Request admission controls concrete provider/model/domain calls at the moment they are made. + +The split is required because arbitrary custom Python can make zero, one, or many model calls dynamically. A task's metadata may help group and schedule the task, but it is not a promise of exact request count and must not reserve every future model call up front. + +Task admission may consume request pressure snapshots as read-only policy input. The current branch includes a narrow request-pressure advisory selection path that can prefer an eligible, unpressured peer over a request-pressured candidate. It must not pre-acquire request permits, emulate AIMD, mutate request admission, or wrap provider/model/domain request admission. Broader provider/resource-aware scheduling remains #651 scope. + +In V1, a task waiting inside request admission keeps its scheduler task lease until the task reaches a terminal outcome. This makes request wait visible without adding yield/reacquire complexity to the lease boundary. The cross-provider optimization target, where tasks blocked on one cooled-down provider do not occupy every scheduler slot while another provider has ready work, belongs to #651's provider/resource-aware task policy or an explicit later yield/reacquire design. + +## Core Invariants + +- Scheduler-level work is not spawned until `TaskAdmissionController` returns a `TaskAdmissionLease`. +- `FairTaskQueue.select_next(...)` does not remove work or mutate virtual-time state. `commit(selection)` is the only queue operation that removes the selected task. +- `select_next(...)`, `try_acquire(...)`, and `commit(selection)` are coordinated by `AsyncTaskScheduler` under a single dispatch critical section or an equivalent versioned-selection protocol. +- If `try_acquire(...)` succeeds but `commit(selection)` fails, the scheduler releases the task lease before retrying. +- Every task lease and request lease is released exactly once in all success, failure, retry, cancellation, shutdown, and salvage paths. +- Root/from-scratch work uses the same ready queue and task-admission path as downstream work. +- Request admission happens only at concrete model-call time through `ModelRequestExecutor`. +- Provider retries are visible to request admission: each outbound attempt either re-enters `ModelRequestExecutor` or is owned by a retry loop inside it that acquires/releases per attempt. +- Scheduler telemetry and request telemetry remain independently useful when the other subsystem is disabled. +- Capacity and benchmark artifacts must distinguish dependency readiness, ready ordering, scheduler admission wait, request admission wait, provider execution, cooldown/rate-limit behavior, and task completion. + +## Non-Goals + +- Do not collapse task admission and request admission into one subsystem. +- Do not expose scheduler internals as plugin API. +- Do not put provider retry, cooldown, or AIMD behavior into `AsyncTaskScheduler` or `TaskAdmissionController`. +- Do not put DAG readiness, row-group lifecycle, or task ordering into `RequestAdmissionController`. +- Do not configure OpenTelemetry SDKs or exporters in core runtime. +- Do not add public capacity knobs before benchmark evidence and docs justify them. +- Do not keep durable compatibility shims or aliases for replaced scheduler/request-admission names at epic completion. diff --git a/plans/645/async-scheduling-epic.puml b/plans/645/async-scheduling-epic.puml new file mode 100644 index 000000000..e581eab8f --- /dev/null +++ b/plans/645/async-scheduling-epic.puml @@ -0,0 +1,974 @@ +@startuml AsyncSchedulingEpicComponent +title Data Designer Async Scheduling Epic - Component View + +left to right direction +skinparam componentStyle rectangle +skinparam shadowing false +skinparam packageStyle rectangle +skinparam linetype ortho +skinparam nodesep 40 +skinparam ranksep 45 + +legend right + Epic issue 645 target shape. + Issue tags show primary ownership. + Solid arrows are runtime calls or data flow. + Dashed arrows are observability or policy inputs. +endlegend + +actor "Generator / Plugin Author" as Author +actor "Operator / Integrator" as Operator + +component "Plugin Metadata Surface\nColumnGenerator.get_scheduling_metadata\nSchedulingMetadata\nissues 641, 652" as MetadataStage +component "Scheduler Metadata Resolver\nTaskSchedulingResolver\nResolvedTaskScheduling\nissues 646, 653" as ResolverStage +component "Dependency Readiness\nCompletionTracker" as ReadinessStage +component "Ready Work Ordering\nFairTaskQueue\nQueueView\nQueueSelection" as TaskQueueStage +component "Task Admission\nTaskAdmissionController\nTaskAdmissionPolicy\nTaskAdmissionLease\nissue 644" as TaskAdmissionStage +component "AsyncTaskScheduler\nexecution owner" as SchedulerStage +component "Model Request Boundary\nModel facade/provider boundary\nModelRequestExecutor\nper concrete request attempt" as ModelBoundaryStage +component "Request Admission\nRequestAdmissionController interface\nAdaptiveRequestAdmissionController\nRequestFairQueue / Policy / LimitState\nissue 657" as RequestAdmissionStage +component "Provider / Model Endpoint" as ProviderStage + +component "Capacity Plan\nAsyncCapacityPlan\nTaskAdmissionConfig\nRequestAdmissionConfig\nissue 654" as CapacityStage +component "Telemetry and Correlation\nSchedulerAdmissionEventSink\nRequestAdmissionEventSink\nRuntimeCorrelationProvider\nCorrelated Runtime View\nissues 635, 647, 648" as TelemetryStage +component "Benchmark and Future Design\nbenchmark harness\nbounded-borrow policy\nresource-vector design\nissues 649, 650, 651" as EvidenceStage + +Author --> MetadataStage +MetadataStage -right-> ResolverStage : static metadata +ResolverStage -right-> ReadinessStage : SchedulableTask inputs +SchedulerStage -left-> ReadinessStage : ready_frontier() +SchedulerStage -left-> TaskQueueStage : enqueue / select_next / commit +SchedulerStage -down-> TaskAdmissionStage : eligibility / try_acquire +TaskAdmissionStage -up-> SchedulerStage : TaskAdmissionLease +SchedulerStage -right-> ModelBoundaryStage : admitted task makes model call +ModelBoundaryStage -right-> RequestAdmissionStage : model request item +RequestAdmissionStage -right-> ModelBoundaryStage : request lease +ModelBoundaryStage -right-> ProviderStage : admitted model call + +CapacityStage ..> TaskAdmissionStage +CapacityStage ..> RequestAdmissionStage +TaskAdmissionStage ..> TelemetryStage +RequestAdmissionStage ..> TelemetryStage +SchedulerStage ..> TelemetryStage +TelemetryStage --> Operator +EvidenceStage ..> CapacityStage +EvidenceStage ..> TaskAdmissionStage +EvidenceStage ..> RequestAdmissionStage + +note bottom of TaskAdmissionStage + Controller owns scheduler leases. + Policy decides admissibility. + FairTaskQueue owns only ready-work ordering. +end note + +note bottom of MetadataStage + SchedulingMetadata is static declaration only. + It does not encode queue depth, AIMD state, + admitted limits, RequestDomain, or runtime pressure. +end note + +@enduml + +@startuml AsyncSchedulingTaskAdmissionClassModel +title Data Designer Async Scheduling Epic - Task Admission Contracts + +left to right direction +skinparam classAttributeIconSize 0 +skinparam shadowing false +skinparam packageStyle rectangle +skinparam linetype ortho + +package "Metadata contracts" { + class ColumnGenerator { + +get_scheduling_metadata() + } + + class SchedulingMetadata { + +kind + +identity + +weight + } + + class SchedulingMetadataError { + +code + +message + +fallback + +diagnostics + } + + class TaskSchedulingResolver { + +resolve(generator) + } + + class ResolvedTaskScheduling { + +group + +resource_request + } +} + +package "Task scheduling contracts" { + class CompletionTracker { + +ready_frontier() + +mark_enqueued(task_ids) + +mark_complete(task) + } + + class TaskGroupSpec { + +key + +weight + } + + class SchedulerResourceRequest { + +amounts + } + + class SchedulerResourceKey { + +kind + +identity + } + + class SchedulableTask { + +task_id + +task + +group + +resource_request + } + + class FairTaskQueue { + +enqueue(items) -> Sequence[task_id] + +select_next(is_eligible) -> QueueSelection | None + +commit(selection) -> SchedulableTask | None + +view() -> QueueView + } + + class QueueSelection { + +item + +queue_view + +sequence_version + } + + class QueueView { + +queued_total + +queued_by_group + +queued_demand_by_group_resource + +first_candidate_resource_by_group + +queued_peer_demand_by_resource + } + + interface TaskAdmissionPolicy { + +evaluate(item, queue_view, admission_view) + +on_acquire(lease, decision) + +on_release(lease) + } + + class TaskAdmissionPolicyDecision { + +allowed + +reason + +available_after + +diagnostics + } + + class PolicyStateDelta { + +borrow_debt_delta + +diagnostics + } + + class StrictFairTaskAdmissionPolicy { + +enforce hard per-group limits + } + + class BoundedBorrowTaskAdmissionPolicy { + +borrow_debt(group, resource) + } + + class TaskAdmissionController { + +is_eligible(item, queue_view) -> bool + +try_acquire(item, queue_view) -> TaskAdmissionDecision + +release(lease) -> ReleaseResult + +view() -> TaskAdmissionView + +explain_blocked(queue_view) -> TaskAdmissionBlockSummary + } + + class TaskAdmissionView { + +limits_by_resource + +available_by_resource + +leased_by_resource + +leased_by_group_resource + +running_by_group_resource + +policy_debt_by_group_resource + } + + class TaskAdmissionBlockSummary { + +queued_count + +dominant_reasons + +available_after + +diagnostics + } + + class TaskAdmissionDecision <> { + TaskAdmissionLease | TaskAdmissionDenied + } + + class TaskAdmissionDenied { + +item + +reason + +available_after + +snapshot + } + + class TaskAdmissionLease { + +lease_id + +item + +resources + +acquired_at + +controller_generation + } + + class ReleaseResult { + +released + +reason + +diagnostics + } + + class AsyncTaskScheduler <> { + +coordinate_ready_work() + +execute_admitted_task() + } +} + +package "Task config" { + class TaskAdmissionConfig { + +submission_capacity + +resource_limits + +policy_config + } +} + +ColumnGenerator --> SchedulingMetadata +SchedulingMetadataError --> SchedulingMetadata : optional fallback +TaskSchedulingResolver --> SchedulingMetadata +TaskSchedulingResolver --> ResolvedTaskScheduling +ResolvedTaskScheduling --> TaskGroupSpec +ResolvedTaskScheduling --> SchedulerResourceRequest +SchedulerResourceRequest --> SchedulerResourceKey +ResolvedTaskScheduling --> SchedulableTask +AsyncTaskScheduler --> CompletionTracker : ready_frontier / mark_enqueued +FairTaskQueue --> QueueSelection : returns +FairTaskQueue --> QueueView : view() +TaskAdmissionController --> TaskAdmissionPolicy +TaskAdmissionPolicy --> TaskAdmissionPolicyDecision +TaskAdmissionPolicy --> PolicyStateDelta +TaskAdmissionController --> TaskAdmissionBlockSummary +StrictFairTaskAdmissionPolicy ..|> TaskAdmissionPolicy +BoundedBorrowTaskAdmissionPolicy ..|> TaskAdmissionPolicy +TaskAdmissionController --> TaskAdmissionView +TaskAdmissionController --> TaskAdmissionDecision +TaskAdmissionController --> ReleaseResult +TaskAdmissionDecision --> TaskAdmissionLease +TaskAdmissionDecision --> TaskAdmissionDenied +TaskAdmissionLease --> SchedulableTask +AsyncTaskScheduler --> CompletionTracker +AsyncTaskScheduler --> FairTaskQueue : enqueue / select_next / commit +AsyncTaskScheduler --> TaskAdmissionController : eligibility callback / try_acquire +TaskAdmissionController ..> TaskAdmissionConfig : reads config + +note right of SchedulingMetadata + Static declaration only. + Must not encode queue depth, + AIMD state, admitted limits, + RequestDomain, or runtime pressure. +end note + +note bottom of SchedulerResourceRequest + Replaces legacy hidden-wait booleans + once the resource-vector design lands. +end note + +note bottom of FairTaskQueue + select_next receives an opaque scheduler-owned + eligibility predicate. The queue owns ordering only; + it does not depend on TaskAdmissionController. +end note + +@enduml + +@startuml AsyncSchedulingRequestAdmissionClassModel +title Data Designer Async Scheduling Epic - Request Admission Contracts + +left to right direction +skinparam classAttributeIconSize 0 +skinparam shadowing false +skinparam packageStyle rectangle +skinparam linetype ortho + +package "Model request boundary" { + class ModelRequestExecutor <> { + +execute_attempt(request) + } + + class ProviderModelEndpoint <> { + +request() + } +} + +package "Request admission contracts" { + class RequestResourceKey { + +provider_name + +model_id + +domain + } + + class RequestResourceResolver { + +resolve_provider_model(...) + +resolve_request_resource(...) + } + + class ProviderModelKey { + +provider_name + +model_id + } + + enum RequestDomain { + chat + embedding + image + healthcheck + } + + class RequestGroupSpec { + +key + +weight + } + + class RequestAdmissionItem { + +resource + +group + +timeout + +event_context + } + + class RequestWaiter { + +waiter_id + +item + +enqueued_at + +deadline + +completion_handle + } + + class RequestEventContext { + +captured_correlation + +task_execution_id + +request_attempt_id + } + + class RequestFairQueue <> { + +enqueue(waiter) + +select_next(is_eligible) -> RequestQueueSelection | None + +commit(selection) -> RequestWaiter | None + +remove(waiter_id) + +view() -> RequestQueueView + } + + class RequestQueueSelection { + +waiter + +item + +waiter_id + +queue_view + +sequence_version + } + + class RequestQueueView { + +queued_total + +queued_by_group + +queued_demand_by_resource + +aggregate_waiters + } + + interface RequestAdmissionPolicy <> { + +is_eligible(item, queue_view, limits) + +on_release(lease, outcome) + } + + class AdaptiveRequestLimitState { + +current_limit(domain_resource) + +effective_max(domain_resource) + +aggregate_in_flight(provider, model) + +record_acquire(resource) + +record_outcome(resource, outcome) + } + + interface RequestAdmissionController { + +try_acquire(item) -> RequestAdmissionDecision + +acquire_sync(item) -> RequestAdmissionLease + +acquire_async(item) -> RequestAdmissionLease + +release(lease, outcome) -> ReleaseResult + +pressure -> RequestPressureSnapshotProvider + } + + class AdaptiveRequestAdmissionController { + +AIMD reduce/increase behavior + } + + class RequestAdmissionDecision <> { + RequestAdmissionLease | RequestAdmissionDenied + } + + class RequestAdmissionLease { + +lease_id + +item + +acquired_at + +current_limit + +effective_max + +controller_generation + } + + class RequestAdmissionDenied { + +item + +reason + +retry_after_seconds + +available_after_monotonic + +snapshot + } + + class RequestAdmissionError { + +denial + } + + class RequestReleaseOutcome { + +kind + +retry_after_seconds + +safe_status_metadata + } + + class ReleaseResult { + +released + +reason + +diagnostics + } + + class RequestPressureSnapshotProvider { + +snapshot(resource) + +snapshots() + +global_snapshot(provider, model) + +global_snapshots() + } +} + +package "Request config" { + class RequestAdmissionConfig { + +resources + +initial_limit + +max_limit_clamp + +cooldown_seconds + +multiplicative_decrease_factor + +additive_increase_step + +successes_until_increase + +startup_ramp_seconds + +default_queue_wait_timeout + } +} + +ModelRequestExecutor --> RequestAdmissionController +ModelRequestExecutor --> RequestResourceResolver +RequestResourceResolver --> ProviderModelKey +RequestResourceResolver --> RequestResourceKey +RequestAdmissionPolicy --> AdaptiveRequestLimitState +AdaptiveRequestAdmissionController ..|> RequestAdmissionController +AdaptiveRequestAdmissionController --> RequestFairQueue +AdaptiveRequestAdmissionController --> RequestAdmissionPolicy +AdaptiveRequestAdmissionController --> AdaptiveRequestLimitState +RequestAdmissionController --> RequestAdmissionDecision +RequestAdmissionController --> RequestReleaseOutcome +RequestAdmissionController --> ReleaseResult +RequestAdmissionDecision --> RequestAdmissionLease +RequestAdmissionDecision --> RequestAdmissionDenied +RequestAdmissionError --> RequestAdmissionDenied +RequestAdmissionController --> RequestPressureSnapshotProvider +AdaptiveRequestAdmissionController ..> RequestAdmissionConfig : reads config +RequestAdmissionItem --> RequestResourceKey +RequestAdmissionItem --> RequestEventContext +RequestWaiter --> RequestAdmissionItem +RequestResourceKey --> ProviderModelKey +RequestResourceKey --> RequestDomain +RequestAdmissionItem --> RequestGroupSpec +RequestFairQueue --> RequestQueueSelection +RequestFairQueue --> RequestQueueView +RequestQueueSelection --> RequestWaiter +RequestAdmissionLease --> RequestAdmissionItem +RequestAdmissionDenied --> RequestAdmissionItem +ModelRequestExecutor --> ProviderModelEndpoint + +note bottom of RequestFairQueue + Internal waiter ordering. + This is not a second public wrapper around + RequestAdmissionController. +end note + +note right of RequestAdmissionController + Owns concrete provider/model/domain + request admission at model-call time. + Does not own DAG readiness or task + scheduler admission. + No durable legacy request-control or hidden-wait + compatibility concepts. +end note + +@enduml + +@startuml AsyncSchedulingSupportContractsClassModel +title Data Designer Async Scheduling Epic - Capacity, Telemetry, and Evidence Contracts + +left to right direction +skinparam classAttributeIconSize 0 +skinparam shadowing false +skinparam packageStyle rectangle +skinparam linetype ortho + +package "Capacity planning (issue 654)" { + class RunConfigRuntimeArgs <> { + +env + +runtime_args + +run_config + } + + class AsyncCapacityPlan { + +configured + +runtime_snapshot + +observed_maxima + } + + class CapacityValue { + +value + +source + +fallback_from + +missing_reason + } + + class CapacityObservedMaxima { + +row_groups_in_flight + +queued_tasks_by_group + +task_leases_by_resource + +request_waiters_by_resource + +request_in_flight_by_resource + +provider_model_aggregate_in_flight + +request_domain_current_limits + +transport_pool_utilization + } + + class RowGroupAdmission { + +row_group_concurrency + +observed_in_flight + } + + class TransportPoolConfig { + +pool_limits + } + + class ProviderModelStaticCap { + +cap + +aliases + +raw_caps + +merge_rule + } + + class TaskAdmissionConfig { + +submission_capacity + +resource_limits + +policy_config + } + + class RequestAdmissionConfig { + +resources + +initial_limit + +max_limit_clamp + +cooldown_seconds + +multiplicative_decrease_factor + +additive_increase_step + +successes_until_increase + +startup_ramp_seconds + +default_queue_wait_timeout + } +} + +package "Telemetry and correlation (issues 635, 647, 648)" { + class RuntimeCorrelation { + +run_id + +row_group + +task_column + +task_type + +scheduling_group_kind + +scheduling_group_identity_hash + +task_execution_id + } + + class RuntimeCorrelationProvider { + +set(context) + +reset(token) + +current() + } + + class SchedulerAdmissionEventSink { + +emit(event) + } + + class RequestAdmissionEventSink { + +emit(event) + } + + class SchedulerAdmissionEvent { + +event_kind + +captured_at_monotonic + +sequence + +captured_correlation + +task_id + +task_execution_id + +task_lease_id + +scheduler_resource_key + +reason_or_result + +snapshot + +diagnostics + } + + class RequestAdmissionEvent { + +event_kind + +captured_at_monotonic + +sequence + +captured_correlation + +request_attempt_id + +request_lease_id + +request_resource + +request_group_key + +reason_or_outcome + +snapshot + +diagnostics + } + + class CorrelatedRuntimeView { + +join scheduler and request timelines + } +} + +package "Evidence and future design (issues 649, 650, 651)" { + class AsyncSchedulingBenchmarkHarness { + +run_ab(baseline_ref, candidate_ref) + +emit_json_csv_artifacts() + +verify_final_snapshots() + } + + class SchedulerResourceVectorDesign { + +static task resources + +read-only request pressure + } + + class BoundedBorrowTaskAdmissionPolicy { + +borrow_debt(group, resource) + } +} + +RunConfigRuntimeArgs --> AsyncCapacityPlan +AsyncCapacityPlan --> TaskAdmissionConfig +AsyncCapacityPlan --> RequestAdmissionConfig +AsyncCapacityPlan --> CapacityValue +AsyncCapacityPlan --> CapacityObservedMaxima +AsyncCapacityPlan --> RowGroupAdmission +AsyncCapacityPlan --> TransportPoolConfig +AsyncCapacityPlan --> ProviderModelStaticCap +SchedulerAdmissionEvent --> RuntimeCorrelation : captured at construction +RequestAdmissionEvent --> RuntimeCorrelation : captured at construction +SchedulerAdmissionEventSink --> SchedulerAdmissionEvent +RequestAdmissionEventSink --> RequestAdmissionEvent +RuntimeCorrelationProvider --> RuntimeCorrelation +CorrelatedRuntimeView --> SchedulerAdmissionEvent : consumes +CorrelatedRuntimeView --> RequestAdmissionEvent : consumes +AsyncSchedulingBenchmarkHarness ..> AsyncCapacityPlan +AsyncSchedulingBenchmarkHarness ..> SchedulerAdmissionEventSink +AsyncSchedulingBenchmarkHarness ..> RequestAdmissionEventSink +AsyncSchedulingBenchmarkHarness ..> BoundedBorrowTaskAdmissionPolicy +SchedulerResourceVectorDesign ..> BoundedBorrowTaskAdmissionPolicy + +@enduml + +@startuml AsyncSchedulingEpicRuntimeSequence +title Data Designer Async Scheduling Epic - Runtime Sequence + +skinparam shadowing false +skinparam sequenceMessageAlign center + +actor User +participant "ColumnGenerator" as Gen +participant "TaskSchedulingResolver\nissue 646" as Resolver +participant "CompletionTracker" as Tracker +participant "AsyncTaskScheduler" as Scheduler +participant "FairTaskQueue" as Queue +participant "TaskAdmissionController\nissue 644" as TaskAdmit +participant "TaskAdmissionPolicy" as TaskPolicy +participant "Admitted Task /\nGenerator Code" as TaskCode +participant "ModelRequestExecutor" as Executor +participant "RequestAdmissionController\nissue 657" as ReqAdmit +participant "RequestFairQueue" as ReqQueue +participant "RequestAdmissionPolicy" as ReqPolicy +participant "AdaptiveRequestLimitState" as ReqLimits +participant "Provider / Model" as Provider +participant "SchedulerAdmissionEventSink\nissue 647" as SchedEvents +participant "RequestAdmissionEventSink\nissue 635" as ReqEvents +participant "RuntimeCorrelationProvider\nissue 648" as Correlation + +User -> Gen : declare columns / plugins +Resolver -> Gen : get_scheduling_metadata() +alt valid metadata + Gen --> Resolver : SchedulingMetadata +else generator omitted override + Resolver -> Resolver : use documented default metadata +else recoverable metadata error + Gen --> Resolver : SchedulingMetadataError(fallback) + Resolver -> Resolver : emit metadata diagnostic +else fatal metadata error + Gen --> Resolver : SchedulingMetadataError(no fallback) + Resolver --> User : fail before scheduling +end +Resolver -> Resolver : resolve TaskGroupSpec\nand SchedulerResourceRequest +Resolver -> Tracker : provide SchedulableTask inputs + +Scheduler -> Tracker : ready_frontier() +Tracker --> Scheduler : un-enqueued ready SchedulableTasks +Scheduler -> SchedEvents : dependency_ready\n(correlation captured now) +Scheduler -> Queue : enqueue ready SchedulableTasks +Queue --> Scheduler : accepted task_ids +Scheduler -> Tracker : mark_enqueued(accepted task_ids) +Scheduler -> SchedEvents : ready_enqueued\n(correlation captured now) + +loop dispatch while capacity and ready work may exist + Scheduler -> Queue : select_next(scheduler eligibility predicate) + Queue -> Scheduler : is_eligible(candidate, queue_view) + Scheduler -> TaskAdmit : is_eligible(candidate, queue_view) + TaskAdmit -> TaskPolicy : evaluate candidate against queue/admission views + TaskPolicy --> TaskAdmit : eligible or denied + TaskAdmit --> Scheduler : eligibility result + + alt no eligible selection + Queue --> Scheduler : None + Scheduler -> TaskAdmit : explain_blocked(queue.view()) + TaskAdmit --> Scheduler : TaskAdmissionBlockSummary + Scheduler -> SchedEvents : queue_empty / admission_blocked\n(correlation captured now) + Scheduler -> Scheduler : wait_for_wake_or_deadline() + else selection returned + Queue --> Scheduler : QueueSelection(item, queue_view, sequence_version) + Scheduler -> SchedEvents : selected\n(correlation captured now) + Scheduler -> TaskAdmit : try_acquire(selection.item, selection.queue_view) + + alt admission denied + TaskAdmit --> Scheduler : TaskAdmissionDenied + Scheduler -> SchedEvents : admission_denied\n(correlation captured now) + Scheduler -> Scheduler : wake_dispatch_loop() + else lease acquired + TaskAdmit -> TaskPolicy : on_acquire(lease, decision) + TaskAdmit --> Scheduler : TaskAdmissionLease + Scheduler -> Queue : commit(selection) + + alt stale selection + Queue --> Scheduler : None + Scheduler -> TaskAdmit : release(lease) + TaskAdmit -> TaskPolicy : on_release(lease) + Scheduler -> SchedEvents : stale_selection / task_lease_released\n(correlation captured now) + Scheduler -> Scheduler : wake_dispatch_loop() + else committed + Queue --> Scheduler : SchedulableTask + Scheduler -> Correlation : set RuntimeCorrelation + Scheduler -> SchedEvents : task_lease_acquired\n(correlation captured now) + + alt worker spawn failed + Scheduler -> TaskAdmit : release(lease) + TaskAdmit -> TaskPolicy : on_release(lease) + Scheduler -> SchedEvents : worker_spawn_failed / task_lease_released\n(correlation captured now) + Scheduler -> Correlation : reset RuntimeCorrelation + else worker spawned + Scheduler -> SchedEvents : worker_spawned\n(correlation captured now) + Scheduler -> TaskCode : execute admitted task + + loop zero, one, or many concrete model calls + TaskCode -> Executor : model call attempt + Executor -> Correlation : current() + Executor -> ReqAdmit : acquire_async(RequestAdmissionItem\nwith RequestEventContext) + + ReqAdmit -> ReqEvents : request_wait_started\n(correlation captured now) + + alt immediate eligible and no queued waiter selected first + ReqAdmit -> ReqPolicy : evaluate resource/group against limits + ReqPolicy -> ReqLimits : read domain limit\nand aggregate in-flight + ReqLimits --> ReqPolicy : limit snapshot + ReqPolicy --> ReqAdmit : eligible + else queued waiter path + ReqAdmit -> ReqQueue : enqueue waiter + + loop until admitted, timeout, cancel, or shutdown + ReqAdmit -> ReqQueue : select_next(ReqPolicy.is_eligible) + ReqQueue --> ReqAdmit : RequestQueueSelection | None + ReqAdmit -> ReqPolicy : evaluate resource/group against limits + ReqPolicy -> ReqLimits : read domain limit\nand aggregate in-flight + ReqLimits --> ReqPolicy : limit snapshot + ReqPolicy --> ReqAdmit : eligible or denied + ReqAdmit -> ReqAdmit : timed wait to next\nblocked_until / timeout + end + end + + alt no lease: timeout, shutdown, or hard denial + ReqAdmit -> ReqQueue : remove waiter if queued + ReqAdmit -> ReqEvents : request_wait_timeout / request_acquire_denied\n(correlation captured now) + ReqAdmit --> Executor : raise RequestAdmissionError + else async cancellation before lease + ReqAdmit -> ReqQueue : remove waiter + ReqAdmit -> ReqEvents : request_wait_cancelled\n(correlation captured now) + ReqAdmit --> Executor : re-raise cancellation + else request lease acquired + ReqAdmit -> ReqQueue : commit(selection) if queued\nand fulfill selected waiter + ReqAdmit -> ReqLimits : record_acquire(resource) + note right of ReqAdmit + Once record_acquire succeeds, + cancellation either delivers the lease + for caller cleanup or internally releases + it as local_cancelled. + end note + ReqAdmit -> ReqEvents : request_wait_completed\n(correlation captured now) + ReqAdmit -> ReqEvents : request_lease_acquired\n(correlation captured now) + ReqAdmit --> Executor : RequestAdmissionLease + Executor -> ReqEvents : model_request_started\n(correlation captured now) + Executor -> Provider : concrete provider/model attempt + Provider --> Executor : response / rate limit / failure + Executor -> ReqEvents : model_request_completed\n(correlation captured now) + + alt success + Executor -> ReqAdmit : release(lease, success) + ReqAdmit -> ReqPolicy : on_release(lease, success) + ReqPolicy -> ReqLimits : record_outcome(resource, success) + else rate limit + Executor -> ReqAdmit : release(lease, rate_limited) + ReqAdmit -> ReqPolicy : on_release(lease, rate_limited) + ReqPolicy -> ReqLimits : record_outcome(resource, rate_limited) + ReqAdmit -> ReqEvents : request_rate_limited / request_limit_decreased\n(correlation captured now) + else provider failure or timeout + Executor -> ReqAdmit : release(lease, provider_failure/provider_timeout) + ReqAdmit -> ReqPolicy : on_release(lease, provider outcome) + ReqPolicy -> ReqLimits : record_outcome(resource, provider outcome) + else local_cancelled or local_timeout + Executor -> ReqAdmit : release(lease, local_cancelled/local_timeout) + ReqAdmit -> ReqPolicy : on_release(lease, local outcome) + ReqPolicy -> ReqLimits : record_outcome(resource, local outcome) + else unexpected exception + Executor -> ReqAdmit : release(lease, unexpected_exception) + ReqAdmit -> ReqPolicy : on_release(lease, unexpected_exception) + ReqPolicy -> ReqLimits : record_outcome(resource, unexpected_exception) + end + + ReqAdmit -> ReqQueue : wake/select next waiter + ReqAdmit -> ReqEvents : request_lease_released\n(+ request_limit_increased if changed) + end + end + + TaskCode --> Scheduler : generated value / terminal outcome + alt success + Scheduler -> Tracker : mark_complete(task) + Scheduler -> SchedEvents : task_completed\n(correlation captured now) + else retryable failure + Scheduler -> SchedEvents : retry_deferred\n(correlation captured now) + Scheduler -> Scheduler : record retry requested + else non-retryable failure or cancellation + Scheduler -> Tracker : mark_complete(task) + Scheduler -> SchedEvents : non_retryable_dropped / cancelled\n(correlation captured now) + else salvage redispatch + Scheduler -> SchedEvents : salvage_redispatched\n(correlation captured now) + Scheduler -> Scheduler : record salvage requested + end + + Scheduler -> TaskAdmit : release TaskAdmissionLease + TaskAdmit -> TaskPolicy : on_release(lease) + Scheduler -> SchedEvents : task_lease_released\n(correlation captured now) + Scheduler -> Correlation : reset RuntimeCorrelation + + alt retry or salvage replacement requested + Scheduler -> Tracker : record replacement work + Scheduler -> Scheduler : replacement re-enters ready_frontier path + else no replacement + Scheduler -> Scheduler : terminal accounting complete + end + end + end + end + end +end + +@enduml + +@startuml AsyncSchedulingEpicIssueMap +title Data Designer Async Scheduling Epic - Issue Dependency Map + +skinparam componentStyle rectangle +skinparam shadowing false +skinparam linetype ortho + +component "issue 641\nSchedulingMetadata" as I641 +component "issue 646\nTaskSchedulingResolver" as I646 +component "issue 653\nRemove legacy hints" as I653 +component "issue 652\nDocs" as I652 +component "issue 644\nTask admission" as I644 +component "issue 654\nAsyncCapacityPlan" as I654 +component "issue 657\nRequestAdmissionController" as I657 +component "issue 635\nRequest admission telemetry" as I635 +component "issue 647\nScheduler admission telemetry" as I647 +component "issue 648\nCorrelation" as I648 +component "issue 649\nBenchmark harness" as I649 +component "issue 660\nArchitecture docs" as I660 +component "issue 650\nBounded borrow task policy" as I650 +component "issue 651\nResource-vector design" as I651 + +I641 --> I646 +I641 --> I653 +I641 --> I652 +I646 --> I653 +I646 --> I652 +I653 --> I652 +I652 --> I644 +I646 --> I644 +I641 --> I654 +I646 --> I654 +I644 --> I654 +I654 --> I657 +I644 --> I657 +I657 --> I635 +I654 ..> I635 : metric vocabulary +I644 --> I647 +I635 --> I647 +I657 --> I648 +I635 --> I648 +I647 --> I648 +I644 --> I649 +I654 --> I649 +I657 --> I649 +I635 --> I649 +I647 --> I649 +I648 --> I649 +I652 --> I660 +I654 --> I660 +I657 --> I660 +I648 --> I660 +I649 --> I660 +I660 --> I650 +I649 --> I650 +I644 --> I650 +I641 --> I651 +I646 --> I651 +I644 --> I651 +I649 --> I651 +I650 --> I651 +I660 --> I651 + +note bottom + Native GitHub subissue order: + issue 641 -> issue 646 -> issue 653 -> issue 652 -> issue 644 + -> issue 654 -> issue 657 -> issue 635 -> issue 647 + -> issue 648 -> issue 649 -> issue 660 -> issue 650 -> issue 651 + + Issues before 649 that need evidence emit provisional artifacts. + Issue 649 normalizes and reruns representative scenarios. +end note + +@enduml diff --git a/plans/645/benchmark-plan.md b/plans/645/benchmark-plan.md new file mode 100644 index 000000000..7048bc645 --- /dev/null +++ b/plans/645/benchmark-plan.md @@ -0,0 +1,247 @@ +# Benchmark Plan + +The benchmark harness turns architecture claims into reusable evidence. It prevents each implementation PR from inventing one-off scripts and makes fairness/throughput tradeoffs explicit. + +Until issue #649 closes, implementation PRs that need scheduling evidence must emit the provisional artifact schema in this file. A minimal deterministic smoke entrypoint and artifact writer should exist before the risky task/request admission implementation slices rely on it; issue #649 formalizes the reusable harness and reruns the provisional evidence against the accepted implementation chain before issue #645 closes. This prevents task/request admission PRs from landing without evidence while still allowing the harness to mature after capacity and telemetry contracts stabilize. + +## Harness Requirements + +Provide a repo-local benchmark entrypoint that can compare two refs or checkouts. + +Required inputs: + +- baseline ref +- candidate ref +- scenario +- record count +- buffer size +- row-group concurrency +- task admission capacity +- request latency knobs +- warmups +- measured iterations +- output directory +- seed +- scenario version +- harness version +- mock provider transcript or scripted provider behavior when live providers are not used +- monotonic clock/retry schedule when deterministic replay is claimed + +Required artifacts: + +- JSON and CSV outputs +- concise Markdown summary +- baseline and candidate commit SHAs +- command lines +- machine/runtime information +- environment knobs +- `AsyncCapacityPlan` +- per-layer observed maxima +- final task admission snapshot +- final request admission snapshot, or explicit `not_available_until_issue` marker before #657 lands +- completion timeline +- ready-idle/utilization timeline +- deterministic output hashes where applicable + +Final snapshots must prove zero active task leases, zero request leases, zero request waiters, and no resource-specific permit leaks after all terminal paths complete. Before #657 lands, request snapshot fields remain present but can carry `not_available_until_issue: 657` rather than fabricated zeros. + +The sync path can be used as a correctness/hash oracle, not as the timing baseline for async scheduling policy. + +## Artifact Schema + +The provisional and final JSON artifacts use monotonic seconds for timeline fields and stable scenario ids for comparison: + +```text +scenario_id +artifact_schema_version +scenario_version +harness_version +baseline_sha +candidate_sha +inputs +provider_script +clock_script +capacity_plan +iterations[] + wall_time_seconds + timeline[] + event_kind + captured_at_monotonic + stream + sequence + captured_correlation + run_id + row_group + task_column + task_type + scheduling_group_kind + scheduling_group_identity_hash + task_execution_id + task_id + task_execution_id + task_lease_id + request_attempt_id + request_lease_id + scheduler_resource_key + request_resource_key + reason_or_outcome + final_task_snapshot + final_request_snapshot + output_hashes +derived_metrics +``` + +Derived metrics: + +- `ready_queue_wait = selected_at - ready_enqueued_at` +- `task_admission_wait = task_lease_acquired_at - selected_at` +- `ready_to_lease_gap = task_lease_acquired_at - ready_enqueued_at` +- `ready_idle_gap` is derived from intervals where dependency-ready work exists, scheduler task capacity is available, and no task lease is acquired. Per-task `selected_at -> task_lease_acquired_at` is task admission overhead, not the starvation metric. +- `active_capacity_integral = integral(active_leases / configured_capacity) over wall time` +- `root_over_admission_debt = admitted root work above strict fair share after first downstream-ready timestamp` +- `hidden_scheduler_resource_waiters` is the count of spawned workers waiting for scheduler-level resources that should have been acquired before spawn. After the task-admission lease boundary lands, the event stream should prove this is zero by showing no worker-spawned event before the corresponding task-lease-acquired event and no pre-epic scheduler-resource wait event for the task. +- deterministic hashes include generated output values and stable ordering metadata, not timing or event ids + +## Scenario Matrix + +### Queue And Admission Microbench + +Compare old `admit_next + release` behavior with new `select_next + try_acquire + commit + release` behavior. + +Matrix: + +- task counts: 1k, 10k, 100k +- group counts: 1, 8, 64, 256 +- resource mixes: local only, resource-bound, mixed local/resource-bound, stateful/exclusive if included + +Metrics: + +- p50/p95 admission cycle cost +- enqueue/select/acquire/commit/release breakdown +- total CPU time +- peak memory +- scaling by group count + +### Heavy-Root Downstream Benchmark + +Required shape: + +```text +true_from_scratch_root_slow -> downstream_fast +``` + +Optional secondary shape: + +```text +seed -> root_slow -> downstream_fast +``` + +Required metrics: + +- first downstream-ready time +- first downstream-dispatch time +- ready-but-not-running gap +- root over-admission debt after first downstream-ready timestamp +- time to first completed record +- time to 50 percent completed records +- p95 row completion time +- final wall time +- max ready-idle gap by group/resource +- active-capacity integral + +This scenario must exercise true root/from-scratch dispatch, not only downstream slow tasks. + +### Hidden-Waiter Proof + +After task admission lands: + +```text +max(hidden_scheduler_resource_waiters) == 0 +``` + +Required monotonic timeline fields: + +- dependency_ready_at +- ready_enqueued_at +- selected_at +- task_lease_acquired_at +- worker_spawned_at +- request_wait_started_at +- request_wait_completed_at +- request_lease_acquired_at +- model_request_started_at +- model_request_completed_at +- request_lease_released_at +- task_completed_at +- task_lease_released_at + +Scheduler events own selected/task-lease/spawn/task-completion/task-release. Request/model instrumentation owns request wait, request lease, model request start/complete, and request release. + +Immediate request acquisition records `request_wait_started_at == request_wait_completed_at` so the timeline can distinguish a zero wait from missing instrumentation. + +### Idle And Utilization Proxy + +Use mock endpoint pools with request start/end events. + +Metrics: + +- active-capacity integral +- max ready-idle gap while work is available +- initial idle gap after first downstream-ready task + +### End-To-End A/B Timing + +Run paired A/B trials with warmup and at least five measured iterations for: + +- narrow serial workflow +- wide independent roots +- dual model generate-to-judge workflow +- heavy-root workflow +- dynamic request-count custom generator workflow +- cross-provider cooldown workflow where provider A is rate-limited or cooling down while provider B has ready independent work + +### Request Dynamic-Call Benchmark + +Use custom generators that make zero, one, and many model calls per task, including branch-dependent request counts. + +Metrics: + +- request admission acquire/release overhead +- queue wait +- event emission overhead +- emitted event count +- CPU time +- memory +- end-to-end timing + +## Baselines + +| Consumer | Baseline | Candidate | +| --- | --- | --- | +| Task admission | `origin/main` or the implementation PR merge-base before `TaskAdmissionController` | task-admission PR | +| Bounded borrow | accepted lease-only task-admission SHA | bounded-borrow PR | +| Resource vector | accepted bounded-borrow SHA or named policy baseline | resource-vector policy PR | + +## Evidence Thresholds + +All timing gates use paired same-machine runs with at least five measured iterations unless the scenario explicitly raises that count. Reports include mean, p50, p95, min, max, standard deviation, and a noise-floor note. If standard deviation is large enough to make a threshold ambiguous, the PR must either add iterations or treat the timing claim as inconclusive. + +Neutral scenarios should be no worse than 5 percent mean wall time unless the PR explicitly justifies a fairness/utilization tradeoff. + +Heavy-root scenarios should show reduced downstream ready-to-dispatch lag versus the named baseline when the candidate claims to improve heavy-root behavior. + +Every run must show no permit leaks and deterministic output equality where applicable. + +Scenario-specific gates: + +- Queue/admission microbench: p95 admission cycle cost must not regress more than 10 percent unless the PR documents a fairness or correctness tradeoff. +- Heavy-root benchmark: p95 ready-to-dispatch gap for downstream work must improve versus the named baseline when the candidate claims heavy-root fairness; root over-admission debt must be bounded by the configured policy. +- Hidden-waiter proof: `max(hidden_scheduler_resource_waiters) == 0` across success, failure, cancellation, and salvage paths after task admission lands. +- Idle/utilization proxy: ready-idle gaps while eligible work and capacity are available must be zero except for documented event-loop scheduling granularity. +- Dynamic request benchmark: zero/one/many request tasks must produce matching output hashes, request lease counts must equal concrete outbound attempts, and request wait/execute/release timelines must be monotonic. +- Cross-provider cooldown benchmark: provider B ready work must continue to receive scheduler task leases while provider A is blocked by request cooldown once the provider-aware policy in #651 claims that optimization. +- Variance: measured iterations must report mean, p50, p95, min, max, and standard deviation. Any acceptance claim based on timing should remain directionally true after removing the fastest and slowest measured iteration. + +## CI Smoke + +The harness should have a small deterministic smoke mode using mock endpoints. It writes machine-readable artifacts and does not require live providers. diff --git a/plans/645/capacity-model.md b/plans/645/capacity-model.md new file mode 100644 index 000000000..811546c13 --- /dev/null +++ b/plans/645/capacity-model.md @@ -0,0 +1,119 @@ +# Capacity Model + +The async engine uses layered capacity. Each layer has a different owner and meaning. The epic goal is to make these layers visible, non-overlapping, and traceable in runtime artifacts. + +## Layer Vocabulary + +| Layer | Owner | Meaning | +| --- | --- | --- | +| Engine selection | Dataset builder / interface | Selects async or sync execution. Not a capacity control. | +| Record window | Dataset builder | Controls row grouping, checkpoint granularity, and memory shape. | +| Row-group admission | Async scheduler | Bounds row groups in flight. | +| Task-stage admission | `TaskAdmissionController` | Bounds scheduler-spawned work and scheduler-level resource pressure. | +| Request-stage admission | `RequestAdmissionController` | Bounds concrete provider/model/domain requests when they are made. | +| Static provider cap | Model config / metadata | User-declared provider/model upper bound and scheduling weight source. | +| Adaptive request-domain limit | `AdaptiveRequestAdmissionController` | Runtime AIMD limit for one provider/model/domain resource under the static provider/model cap. | +| Transport pool | HTTP/model client adapter | Socket/session pool sizing. Not scheduling or fairness policy. | + +## AsyncCapacityPlan + +`AsyncCapacityPlan` is the run-level explanation of capacity. It should record: + +- `buffer_size` +- row-group concurrency +- task admission capacity +- task resource limits +- request admission resources +- provider/model aggregate static caps +- provider/model aggregate in-flight maxima +- static provider/model caps used by the workflow +- adaptive request-admission config snapshot +- request-domain adaptive initial/current/effective limits when captured +- transport/session pool values if they remain distinct +- source of each value, such as default constant, model metadata, run config, request admission state, or environment selection + +The plan is emitted for diagnostics, traces, benchmarks, and operator documentation. It does not admit work by itself. + +`AsyncCapacityPlan` uses three sections: + +```text +configured: values computed before or at run start +runtime_snapshot: point-in-time controller snapshots, nullable until the owning issue lands +observed_maxima: maxima collected during execution or benchmark replay +``` + +Each configured value is a `CapacityValue` with `value`, `source`, `fallback_from`, and `missing_reason`. Fields that depend on request-admission runtime state may be present with `value = None` and `missing_reason` in #654 before #657 lands. + +`CapacityValue.source` uses the durable source vocabulary from [contracts.md](contracts.md#capacity-contracts), including `dataset_builder`, `engine_internal_config`, and `adapter_config` for values that do not come from public run config or model metadata. + +Source precedence is per-field, not global: + +| Field | V1 precedence | +| --- | --- | +| `buffer_size` | explicit run config, then documented default | +| row-group concurrency | existing dataset-builder/runtime setting if present, then documented default | +| task admission limits | benchmark override for benchmark runs, then engine default | +| provider/model static cap | canonical model/provider metadata; request-admission config may lower but not raise it | +| request-domain initial/adaptive settings | public `RunConfig.request_admission` tuning where supported, benchmark override for non-public harness values, then engine default, all clamped under provider/model static cap | +| transport pool | adapter/client config, then documented default | + +If a value is missing, the capacity plan records the missing source and fallback used. If no safe fallback exists, construction fails with a typed configuration/metadata error before work is scheduled. + +## Ownership Rules + +Task admission capacity is scheduler-level capacity. It controls when a ready task can become a running worker. + +Request admission capacity is provider/model/domain request capacity. It controls when a concrete model call can execute. + +`max_parallel_requests` remains the user-facing static provider/model cap and scheduling metadata weight source. `AdaptiveRequestAdmissionController.current_limit` is the runtime adaptive request cap for a request domain. + +The provider/model static cap is an aggregate in-flight upper bound across all domains for that provider/model in V1. Domain adaptive limits operate under that aggregate cap. V1 intentionally does not define an aggregate cross-domain AIMD policy; adding one requires a later design that specifies fairness, telemetry, and benchmarks. + +HTTP transport pools may be larger than the static provider cap. They are transport sizing, not effective request concurrency. + +`DATA_DESIGNER_ASYNC_ENGINE` is an execution path selector. It is not a capacity knob. + +`RunConfig.buffer_size` shapes record windows and row groups. It is not a request-concurrency knob. + +## Row Groups And Record Windows + +`buffer_size` defines the record-window shape used by the dataset builder. Row groups are the concrete execution partitions produced from that windowing behavior. + +Row-group admission remains scheduler-owned and is separate from the V1 task-admission lease boundary. The normal dataset-builder wiring uses fixed row-group admission: `max_concurrent_row_groups` is the hard in-flight cap, and task admission leases then control ready task dispatch inside admitted row groups. + +The current branch also contains an internal adaptive row-group admission mode for direct scheduler use. That mode is additive-only: it starts from an initial target and can raise the soft in-flight row-group target up to the semaphore hard cap when no local scheduler-pressure reason blocks growth. It does not decrease the target, so docs and telemetry must not describe it as AIMD. It remains off by default unless a later issue explicitly promotes it to a durable scheduler policy. + +`AsyncCapacityPlan.configured.row_group_admission` records the mode, configured row-group concurrency, current/observed adaptive target when applicable, observed row groups in flight, optional max-admitted-row guardrail, and blocked reasons. Preview, resume, and checkpoint behavior use the existing dataset-builder partitioning rules. `AsyncCapacityPlan` reports the row-group values that the current engine used rather than redefining those rules. + +## Transitional Values + +Any hidden task-stage capacity concept left from the pre-epic design is transitional. At epic completion those names must be gone or represented by explicit scheduler-resource terminology in `TaskAdmissionConfig` and `AsyncCapacityPlan`. + +If a distinct task-stage backpressure resource remains for model-producing work, it must be derived from actually used resolved `SchedulingMetadata`, not every registered model alias. It must be described as scheduler task-stage pressure, not provider request concurrency. + +## Alias And Provider Semantics + +Scheduling metadata may use model aliases to derive static resource identity and weight. Alias metadata should deduplicate aliases that resolve to the same provider/model/generation resource before deriving effective weight. The startup health-check hook `get_model_aliases()` remains separate from the scheduler metadata hook; a multi-endpoint alias set reported for health checks must not be forced into one provider/model resource. + +Request admission resources are provider/model/domain scoped. A provider/model may have a global effective static cap while each request domain has its own adaptive state. The capacity plan must make that distinction visible. + +V1 does not define a cross-domain aggregate AIMD provider cap beyond the documented provider/model effective static cap unless a later issue explicitly adds that policy. The request controller still enforces the static aggregate cap by checking provider/model aggregate in-flight counts before admitting a domain request. + +Alias-derived provider/model caps deduplicate aliases that resolve to the same concrete provider/model endpoint. If aliases for the same endpoint specify different `max_parallel_requests` values, V1 uses the minimum as the effective static cap and records every contributing alias and raw cap in `AsyncCapacityPlan`. This min-merge is not a metadata error. If a default registry-backed generator sees multiple aliases that resolve to different concrete endpoints, it should use deterministic `custom_model` task-stage metadata unless the plugin overrides `get_scheduling_metadata()` with a sharper declaration. If the provider treats generation type as a distinct endpoint, the canonical model id includes that distinction before cap merging. + +## Observability Requirements + +Operators should be able to answer: + +- Which capacity values were used for this run? +- Was progress limited by dependency readiness, queue ordering, task admission, request admission, provider cooldown, or provider execution? +- What static provider caps and adaptive request limits were active? +- Were transport pools distinct from request caps? + +Benchmarks and traces must include `AsyncCapacityPlan` plus per-layer observed maxima. + +Required per-layer maxima include row groups in flight, queued tasks by group/resource, task leases by resource, request waiters by resource, domain in-flight counts, provider/model aggregate in-flight counts, adaptive current limits, and transport pool utilization when available. + +## Public Knob Rule + +Do not add new public capacity knobs beyond the documented model `max_parallel_requests`, `buffer_size`, and advanced `RunConfig.request_admission` tuning fields until benchmark evidence shows a specific need and the docs explain the layer. Prefer clear defaults, internal configs, and diagnostics first. diff --git a/plans/645/contracts.md b/plans/645/contracts.md new file mode 100644 index 000000000..0c7a51a4f --- /dev/null +++ b/plans/645/contracts.md @@ -0,0 +1,527 @@ +# Contracts + +This file records the durable names and semantics used by the async scheduling architecture. Implementation details inside the owning target modules can evolve, but these names are the normative spec vocabulary for the epic. Topic files may explain behavior, but should not redefine fields or return shapes in ways that conflict with this file. + +Durable names in this file are not public API by default. Publicness and final module homes are defined in [Module ownership](module-ownership.md). + +## Package Ownership + +| Contract family | Owning package | Notes | +| --- | --- | --- | +| Generator metadata and public config DTOs | `data-designer-config` | `SchedulingMetadata`, metadata validation errors, and exposed run-config fields live here when they are public/user-facing. | +| Scheduler/request runtime protocols | `data-designer-engine` | queues, controllers, policies, leases, runtime snapshots, event DTOs, capacity plan construction, and benchmark internals live here. | +| User interface and operator presentation | `data-designer` | consumes config and engine contracts for the public `DataDesigner` interface, CLI, and integrations. | + +Config-layer contracts must not import engine runtime protocols. Engine contracts may consume config-layer DTOs. + +The final repository layout is specified in [Module ownership](module-ownership.md). Runtime contracts must live in their owning target modules; do not preserve old engine module paths through aliases, shim files, or broad package reexports. Public config may keep explicitly deprecated compatibility DTOs when they translate into the new durable config surface and warn users. + +## Config Surface Status + +| Contract | V1 status | Owner | +| --- | --- | --- | +| `SchedulingMetadata` | public plugin-facing DTO | `data-designer-config` | +| `RequestAdmissionTuningConfig` | public advanced `RunConfig.request_admission` DTO for supported AIMD tuning only | `data-designer-config` | +| `ThrottleConfig` | deprecated compatibility DTO translated into `RequestAdmissionTuningConfig` by `RunConfig.throttle` | `data-designer-config` | +| `TaskAdmissionConfig` | engine-internal config and benchmark injection surface; not a public `RunConfig` knob in V1 | `data-designer-engine` | +| `RequestAdmissionConfig` | engine-internal config and benchmark injection surface in V1 | `data-designer-engine` | +| `RunConfig.request_admission` | public advanced request-admission tuning surface backed by `RequestAdmissionTuningConfig`; does not expose internal controller APIs | `data-designer-config` | +| `RunConfig.throttle` | deprecated compatibility input translated into `RunConfig.request_admission` with a `DeprecationWarning` | `data-designer-config` | +| `AsyncCapacityPlan` | diagnostic/reporting DTO, emitted to explain a run | `data-designer-engine` | + +Public request-admission tuning is limited to the supported fields on `RequestAdmissionTuningConfig`: multiplicative decrease factor, additive increase step, successes before increase, fallback cooldown, and startup ramp seconds. Benchmarks may still inject lower-level capacity values through harness-only configuration without committing those values to public API. + +`ThrottleConfig` is retained only as a migration shim for existing configs that pass `RunConfig(throttle=...)`. The shim maps `reduce_factor`, `additive_increase`, `success_window`, and `cooldown_seconds` into the corresponding `RequestAdmissionTuningConfig` fields and emits a `DeprecationWarning`. `ceiling_overshoot` is accepted for DTO compatibility but is not forwarded because request admission does not expose an overshoot knob. + +## Metadata Contracts + +`ColumnGenerator.get_scheduling_metadata()` returns generator-facing scheduling metadata. It is additive and non-abstract so existing generators keep working. + +`SchedulingMetadata` is a static declaration with: + +- `kind`: initial values are `local`, `model`, and `custom_model`. +- `identity`: deterministic tuple of broad-to-specific resource identity values. +- `weight`: positive static capacity hint. + +`SchedulingMetadataError` is the typed failure path for metadata resolution. It can carry fallback metadata when partial resolution is safe. The documented default metadata for generators that do not override `get_scheduling_metadata()` is a normal resolver path, not an error fallback. + +Rules: + +- Metadata identity is resource identity, not a queue key. +- Metadata cannot encode queue depth, admitted limits, runtime pressure, request domains, AIMD state, or provider cooldown. +- Alias-derived model metadata deduplicates aliases that resolve to the same provider/model/generation resource before deriving effective weight. +- Alias ordering is canonicalized so equivalent configs produce equivalent metadata. +- Generators that do not override `get_scheduling_metadata()` receive a documented default metadata value. The default must preserve current behavior and must not infer provider/model pressure dynamically. +- Invalid `kind`, non-deterministic `identity`, or non-positive `weight` raises `SchedulingMetadataError`. +- Differing `max_parallel_requests` values for aliases that resolve to the same concrete provider/model endpoint are not, by themselves, ambiguous. They merge through the static-cap min rule in the capacity model. +- A default implementation must not turn `get_model_aliases()` into a requirement that every health-check alias collapse to one scheduler endpoint. If a registry-backed generator reports multiple aliases that resolve to different concrete endpoints and does not override `get_scheduling_metadata()`, the default metadata should fall back to deterministic `custom_model` metadata with diagnostics. Plugins that need sharper endpoint-aware scheduler grouping should override `get_scheduling_metadata()` directly. +- Fallback metadata is safe only when it preserves current scheduling behavior and the resolver can explain the fallback in diagnostics. Invalid metadata shape or invalid weights are fatal. + +Normative V1 metadata shapes: + +| Kind | Identity tuple | Weight source | Default/fallback behavior | +| --- | --- | --- | --- | +| `local` | `("local", resource_name)` where `resource_name` defaults to `"default"` | positive integer, default `1` | the default for generators that do not override `get_scheduling_metadata()` is `SchedulingMetadata(kind="local", identity=("local", "default"), weight=1)` | +| `model` | `("model", provider_name, canonical_model_id, generation_kind)` after alias resolution | effective static provider/model capacity hint, normally derived from the model config's `max_parallel_requests` and clamped to at least `1` | safe fallback is allowed only when the resolver can identify the same canonical provider/model resource as the current implementation | +| `custom_model` | `("custom_model", plugin_namespace, resource_name, version)` with deterministic plugin-provided values | positive plugin-provided capacity hint, defaulting to `1` if omitted | used for explicit plugin metadata and for compatibility-preserving defaults when model aliases cannot be represented as one concrete provider/model/generation resource | + +`SchedulingMetadataError` contains: + +- `code` +- `message` +- optional `fallback: SchedulingMetadata` +- sanitized `diagnostics` + +If `fallback` is present, the resolver may continue and must emit diagnostics. If `fallback` is absent, metadata resolution is fatal before scheduler inputs are created. + +## Scheduler Input Contracts + +`TaskSchedulingResolver` consumes `SchedulingMetadata` and produces scheduler-internal inputs. It owns per-run metadata caching and scheduler flow-identity composition. + +`ResolvedTaskScheduling` contains: + +- `group: TaskGroupSpec` +- `resource_request: SchedulerResourceRequest` + +`TaskGroupSpec` contains a scheduler-internal task group key and static weight. + +`SchedulerResourceRequest` contains scheduler-level task-stage resources: + +```text +amounts: Mapping[SchedulerResourceKey, int] +``` + +`SchedulerResourceKey` identifies a scheduler-owned task-stage resource such as `submission`, `local`, or a future internal resource-vector key. It is not a provider request-domain key. + +The first implementation models scheduler task-stage pressure with explicit scheduler resources. Concrete provider/model/domain request pressure belongs to `RequestResourceKey` and request admission. Future resource-vector work may add local, GPU, or other scheduler resources, but those remain scheduler-internal unless a later design explicitly changes the public contract. + +`SchedulableTask` contains: + +- stable `task_id` +- task payload +- task group +- scheduler resource request + +`CompletionTracker` owns readiness state: + +```text +ready_frontier() -> Sequence[SchedulableTask] +mark_enqueued(task_ids) +mark_complete(task) +``` + +`ready_frontier()` returns dependency-ready tasks that have not yet been acknowledged as enqueued. After `FairTaskQueue.enqueue(...)` accepts a task, `AsyncTaskScheduler` calls `mark_enqueued(...)` with exactly the accepted task ids. `FairTaskQueue.enqueue(...)` is also idempotent by `task_id`, so duplicate frontier reads cannot create duplicate ready membership. If enqueue fails before acceptance, the task remains unacknowledged and appears in a later frontier read. + +## Queue Contracts + +`FairTaskQueue` owns ready-task membership and ready ordering: + +```text +enqueue(items) -> Sequence[task_id] +select_next(is_eligible) -> QueueSelection | None +commit(selection) -> SchedulableTask | None +view() -> QueueView +``` + +`QueueSelection` returns from `FairTaskQueue` to `AsyncTaskScheduler`. It is not delivered to `TaskAdmissionController`. + +`QueueSelection` contains the selected item, the queue view used during selection, and an opaque `sequence_version` used by `commit(selection)` to detect stale selections. + +`QueueView` is read-only policy input. It exposes: + +- queued totals +- queued counts by group +- queued resource demand by group and `SchedulerResourceKey` +- first-candidate resource request by group where available +- queued peer demand by resource + +`QueueView` is produced by `FairTaskQueue`; policies must not traverse queue internals directly. It contains raw queued membership and demand facts only. `TaskAdmissionPolicy` computes eligibility and resource-aware peer pressure from `QueueView` plus `TaskAdmissionView`. + +`FairTaskQueue` must not invoke scheduler-supplied eligibility predicates while holding internal queue locks that can be needed by enqueue, commit, release wakeups, or diagnostics. The scheduler dispatch critical section owns the cross-component coordination; queue internals remain local to queue mutation. + +## Task Admission Contracts + +`TaskAdmissionController` owns task-stage resource accounting and leases: + +```text +is_eligible(item, queue_view) -> bool +try_acquire(item, queue_view) -> TaskAdmissionDecision +release(lease) -> ReleaseResult +view() -> TaskAdmissionView +explain_blocked(queue_view) -> TaskAdmissionBlockSummary +``` + +`TaskAdmissionPolicy` owns the decision rule: + +```text +evaluate(item, queue_view, admission_view) -> TaskAdmissionPolicyDecision +on_acquire(lease, decision) -> PolicyStateDelta +on_release(lease) -> PolicyStateDelta +``` + +`evaluate(...)` is side-effect-free. It may be called while scanning queue candidates and must not mutate borrow debt, counters, timers, diagnostics, or resource ledgers. Only controller-mediated acquire/release paths apply `PolicyStateDelta` values. + +`TaskAdmissionPolicyDecision` contains: + +- `allowed` +- optional denial `reason`, such as no capacity, group cap, borrow debt, shutdown, or policy denial +- optional `available_after` +- sanitized diagnostic fields + +`PolicyStateDelta` contains policy-owned state changes such as borrow-debt increment, repayment, or diagnostic counters. The controller applies the delta in the same transaction as lease acquire/release and exposes the resulting policy counters through `TaskAdmissionView`. Bounded-borrow debt affects eligibility, but it is not part of the hard resource ledger and never changes resource availability counters directly. + +`AsyncTaskScheduler` supplies the boolean eligibility callback used by `FairTaskQueue`; that callback delegates to `TaskAdmissionController.is_eligible(...)`. The controller may call `TaskAdmissionPolicy.evaluate(...)` internally, but denial details are surfaced through `try_acquire(...)`, `explain_blocked(...)`, events, and tests rather than through the queue callback. + +When `FairTaskQueue.select_next(...)` returns no selection while queued work exists, `AsyncTaskScheduler` calls `TaskAdmissionController.explain_blocked(queue_view)` before sleeping. `TaskAdmissionBlockSummary` contains queued count, dominant denial reasons, optional earliest `available_after`, and sanitized diagnostics. This is the source for `admission_blocked`, `group_capped`, and timed wakeups when no candidate can currently be admitted. + +`TaskAdmissionLease` contains: + +- `lease_id` +- `item` +- `resources` +- `acquired_at` +- controller identity or generation token sufficient to reject stale/wrong-controller releases + +`TaskAdmissionView` exposes a consistent read-only snapshot: + +- task resource limits by `SchedulerResourceKey` +- task resources available by `SchedulerResourceKey` +- leased resources by `SchedulerResourceKey` +- leased resources by group and `SchedulerResourceKey` +- running counts by group and resource where tracked +- policy-only debt by group/resource if the active policy uses bounded borrow + +`TaskAdmissionDecision` is a union of `TaskAdmissionLease` and `TaskAdmissionDenied`. + +`TaskAdmissionDenied` contains: + +- item +- reason, such as no capacity, group cap, borrow debt, shutdown, or policy denial +- optional available-after timing +- optional `TaskAdmissionView` snapshot + +Implementations may provide a local convenience helper that converts `TaskAdmissionDecision` to an optional lease, but telemetry, tests, and benchmark artifacts use the typed decision vocabulary. + +`TaskAdmissionConfig` is engine-internal in V1 and contains scheduler task-stage capacity values such as `submission_capacity`, resource limits, and optional policy-specific config. Bounded-borrow policy config, when enabled by #650, includes borrow ceiling by group/resource, strict-share rounding mode, and repayment behavior. The default V1 lease-boundary policy is behavior-preserving unless #650 explicitly enables bounded borrow. + +`ReleaseResult` contains: + +- `released: bool` +- `reason`, such as released, duplicate, stale lease, wrong controller generation, or unknown lease +- sanitized diagnostics + +Terminal `finally` paths must not raise from release. Duplicate, stale, or wrong-controller releases return `ReleaseResult` and emit diagnostic events without increasing capacity. + +## Request Admission Contracts + +`ModelRequestExecutor` maps concrete model-call attempts into request-admission items and owns exact lease release around provider execution: + +```text +execute_attempt(request) -> provider response +``` + +`RequestResourceResolver` is the canonical request-resource identity factory. It maps provider alias, model alias, model id, generation kind, endpoint metadata, and `RequestDomain` into `ProviderModelKey` and `RequestResourceKey`. `TaskSchedulingResolver`, `ModelRequestExecutor`, `AsyncCapacityPlan`, and request admission all use the same provider/model canonicalization rules so alias merging, metadata weight, and request caps cannot drift. + +`RequestResourceKey` identifies a concrete provider/model/domain request resource: + +- `provider_name`, the canonical resolved provider name, not an alias +- `model_id`, the canonical resolved provider/model endpoint id, not a user alias +- `domain` + +Aliases are recorded in capacity plans and pressure snapshots for diagnostics, but request admission keys use canonical resolved provider/model identity so aliases cannot bypass aggregate caps. + +`ProviderModelKey` is the aggregate request-capacity key: + +- canonical provider name +- canonical model endpoint id + +`RequestResourceKey` is `ProviderModelKey + RequestDomain`. + +`RequestDomain` is the durable domain vocabulary for request admission. V1 includes `chat`, `embedding`, `image`, and `healthcheck`; adding new domains requires updating this plan and the request-admission docs. + +`RequestGroupSpec` contains the request fairness group key and static weight. In V1 the group key is the `RequestResourceKey`; a later design may split fairness group from resource key, but must specify the mapping before doing so. + +`RequestAdmissionItem` contains: + +- request resource +- request group +- optional queue-wait timeout +- optional `RequestEventContext` + +`RequestEventContext` is constructed by `ModelRequestExecutor` when it maps a model call attempt into a request item. It contains primitive, telemetry-only context: + +- captured `RuntimeCorrelation | None` +- `task_execution_id` +- `request_attempt_id` + +The request controller treats this as opaque event context. It does not import scheduler task types or mutate scheduler state. + +`RequestFairQueue` owns waiter ordering inside `AdaptiveRequestAdmissionController`: + +```text +enqueue(waiter) +select_next(is_eligible) -> RequestQueueSelection | None +commit(selection) -> RequestWaiter | None +remove(waiter_id) +view() -> RequestQueueView +``` + +`RequestWaiter` contains waiter id, item, enqueue timestamp, deadline/cancellation state, and the waiter completion handle used by the blocking acquire path. + +`RequestQueueSelection` contains the selected waiter, item, waiter id, queue view, and opaque `sequence_version` for stale-selection detection. + +`RequestQueueView` exposes queued totals, queued counts by request group, queued demand by request resource, and aggregate provider/model waiters. It does not inspect adaptive limit state. + +`try_acquire(...)` is non-blocking. It may immediately acquire only when the request is eligible and no queued eligible waiter for the same request resource or provider/model aggregate cap would be selected before the incoming item by `RequestFairQueue`'s weighted ordering. Otherwise it returns `RequestAdmissionDenied` with reason `queued_waiters_ahead` or another specific denial reason. + +`RequestAdmissionController` owns request-level admission: + +```text +try_acquire(item) -> RequestAdmissionDecision +acquire_sync(item) -> RequestAdmissionLease +acquire_async(item) -> RequestAdmissionLease +release(lease, outcome) -> ReleaseResult +pressure -> RequestPressureSnapshotProvider +``` + +`acquire_sync(...)` and `acquire_async(...)` wait until a lease is available or a terminal no-lease condition occurs. Timeout, shutdown, or hard denial before a lease is acquired must remove the waiter and raise a typed project error that carries the corresponding `RequestAdmissionDenied` decision. They must not return `None`. + +`RequestAdmissionError` is the typed no-lease exception raised by blocking acquire paths. It wraps `RequestAdmissionDenied` and must not be raised after a lease has been returned; post-lease provider outcomes are represented by `RequestReleaseOutcome`. + +`acquire_async(...)` must preserve cooperative cancellation. If the awaiting task is cancelled before a lease is acquired, the controller removes the waiter, emits a cancellation/denial event, and re-raises the cancellation exception instead of converting it to `RequestAdmissionError`. + +Once a waiter is selected and in-flight counts are incremented, cancellation cannot orphan the lease. The controller either delivers the lease to that waiter's acquire call so caller cleanup can release it, or internally releases the admitted waiter as `local_cancelled` before completing cancellation. A caller's `acquire_async(item)` may only return the lease for its own waiter; if the controller admits another waiter while this caller is awake, it fulfills that other waiter's completion handle and this caller continues waiting. + +`RequestAdmissionDecision` is a union of `RequestAdmissionLease` and `RequestAdmissionDenied`. + +`RequestAdmissionLease` contains: + +- `lease_id` +- item +- acquired timestamp +- current adaptive limit +- effective max +- controller identity or generation token sufficient to reject stale/wrong-controller releases + +`RequestAdmissionDenied` contains: + +- item +- reason, such as no capacity, cooldown, queue timeout, queued waiters ahead, cancellation, shutdown, or hard policy denial +- `retry_after_seconds` when supplied by the provider or policy +- `available_after_monotonic` when the controller can compute an unblock deadline +- optional snapshot + +`RequestReleaseOutcome` contains: + +- `kind`: one of `success`, `rate_limited`, `provider_failure`, `provider_timeout`, `local_cancelled`, `local_timeout`, or `unexpected_exception` +- `retry_after_seconds` when rate limited +- provider/status metadata safe for telemetry + +Only provider rate-limit outcomes drive multiplicative decrease/cooldown. Provider failures may affect diagnostic counters. Local cancellation and local timeout release capacity and wake waiters but must not be treated as provider pressure unless a later policy explicitly defines that behavior. + +`provider_timeout` is a timeout or timeout-shaped transport/provider failure after a lease has been acquired and an outbound provider attempt has started. `local_timeout` is a caller, queue-wait, or controller deadline that is not evidence of provider pressure. Cancellation after lease acquisition is classified as `local_cancelled`; release diagnostics must not mask the original cancellation and the cancellation is re-raised after accounting. + +`AdaptiveRequestAdmissionController` is the V1 concrete request controller. It owns AIMD behavior through internal `RequestFairQueue`, `RequestAdmissionPolicy`, and `AdaptiveRequestLimitState`. + +Request admission acquires under one controller lock/condition. An admitted lease increments domain in-flight counts and provider/model aggregate in-flight counts before the lease is returned. Release decrements those counts exactly once and wakes eligible waiters. + +Cross-domain arbitration under a provider/model aggregate cap uses `RequestFairQueue` ordering by `RequestGroupSpec` weight. V1 uses weighted fair ordering across request groups sharing the aggregate cap; if weights are equal, older waiters are selected first. + +V1 AIMD semantics: + +- `effective_max = min(provider_model_static_cap, request_config.max_limit_clamp_for_resource_if_present)` +- instantaneous aggregate availability is checked separately as `provider_model_aggregate_in_flight < provider_model_static_cap` +- `initial_limit` is clamped to `[1, effective_max]` +- `current_limit` starts at `initial_limit` +- on `rate_limited`, `current_limit = max(1, floor(current_limit * multiplicative_decrease_factor))`, `blocked_until_monotonic` is set from provider `retry_after_seconds` when supplied or the configured cooldown otherwise, and rate-limit counters increment +- on success outside cooldown, successful releases accumulate; after `successes_until_increase` successes, `current_limit = min(effective_max, current_limit + additive_increase_step)` +- `request_soft_ceiling_recovered` fires when `current_limit` rises above the last rate-limit ceiling +- `request_fully_recovered` fires when `current_limit == effective_max` and cooldown has cleared +- all timing uses a monotonic clock +- waiters use timed waits to the earliest relevant monotonic deadline: queue-wait timeout, cancellation, `available_after_monotonic`, or `blocked_until_monotonic`. Cooldown expiry must wake queued waiters even when no in-flight request releases. + +`RequestPressureSnapshotProvider` exposes read-only request pressure: + +```text +snapshot(resource) +snapshots() +global_snapshot(provider, model) +global_snapshots() +``` + +It has no mutation or admission methods. + +Snapshots are immutable and internally consistent for their capture point. Domain snapshots include `captured_at`, monotonic `sequence`, resource, effective max, current limit, in-flight count, active lease count, waiters, blocked-until timing, cooldown remaining, rate-limit ceiling, consecutive rate limits, last outcome summary, and leak diagnostic counters. Global provider/model snapshots include aggregate static cap, aggregate in-flight count across domains, aggregate active lease count, aliases contributing to the cap, and per-domain limit summaries. + +`RequestAdmissionConfig` is the durable engine-internal request-admission tuning/config vocabulary for V1. It includes request resources, per-resource `initial_limit`, optional `max_limit_clamp`, configured cooldown, `multiplicative_decrease_factor`, `additive_increase_step`, `successes_until_increase`, `startup_ramp_seconds`, and default queue-wait timeout. Legacy request-control config names are not durable names. `RunConfig.throttle` may translate a deprecated `ThrottleConfig` into `RequestAdmissionTuningConfig`, but the new request-admission DTOs store and document scheduler-era names. + +## Telemetry And Correlation Contracts + +`SchedulerAdmissionEventSink` emits scheduler admission events. + +`RequestAdmissionEventSink` emits request admission events. + +`RuntimeCorrelation` contains primitive runtime context: + +- run id +- row group +- task column +- task type +- scheduling group kind +- scheduling group identity hash +- task execution id + +`RuntimeCorrelationProvider` owns set/reset/current behavior, likely through context variables. It must not require request admission protocols to import scheduler types. + +Scheduler/request events capture correlation values when event DTOs are constructed and normalize correlation, keys, snapshots, and diagnostics to JSON-compatible values. Event sinks must not rely on reading mutable ambient context later, because deferred emission could attach the wrong task context. + +Canonical scheduler `event_kind` values are snake_case and versioned as part of the benchmark artifact schema: + +```text +scheduler_job_started +scheduler_job_completed +scheduler_health_snapshot +dependency_ready +ready_enqueued +row_group_admitted +row_group_admission_blocked +row_group_admission_target_changed +row_group_checkpointed +selected +queue_empty +admission_blocked +group_capped +request_pressure_advisory_skipped +task_lease_acquired +admission_denied +worker_spawned +worker_spawn_failed +stale_selection +retry_deferred +non_retryable_dropped +cancelled +salvage_redispatched +queue_drained +task_completed +task_lease_released +release_diagnostic +``` + +`SchedulerAdmissionEvent` contains: + +- `event_kind` +- `captured_at_monotonic` +- monotonic `sequence` +- captured correlation as JSON-compatible values +- task id +- task execution id when a worker execution exists +- task lease id when available +- scheduler resource key when applicable +- decision reason or release result when applicable +- optional JSON-compatible scheduler snapshot +- JSON-compatible diagnostics + +Canonical request `event_kind` values are snake_case and versioned as part of the benchmark artifact schema: + +```text +request_resource_registered +request_effective_cap_changed +request_queue_formed +request_wait_started +request_wait_completed +request_wait_timeout +request_wait_cancelled +request_acquire_denied +request_lease_acquired +model_request_started +model_request_completed +request_queue_drained +request_rate_limited +request_limit_decreased +request_limit_increased +request_soft_ceiling_recovered +request_fully_recovered +request_lease_released +request_release_diagnostic +``` + +`RequestAdmissionEvent` contains: + +- `event_kind` +- `captured_at_monotonic` +- monotonic `sequence` +- captured correlation as JSON-compatible values +- request attempt id when the event belongs to one concrete model-call attempt +- request lease id when available +- canonical request resource as JSON-compatible values when the event is resource-specific +- request group key as JSON-compatible values when the event is queue/admission specific +- denial reason or release outcome when applicable +- optional JSON-compatible request pressure snapshot +- JSON-compatible diagnostics + +Lease ids, task ids, request attempt ids, and raw model ids are trace/artifact fields only; they are not metric labels. Metric exporters use bounded labels such as `metric_model_label`, model family, or allowlisted model label. The OTel bridge must reject raw model ids as metric labels. + +`CorrelatedRuntimeView` joins scheduler and request timelines for diagnostics, benchmarks, and future operator views. + +## Capacity Contracts + +`AsyncCapacityPlan` records computed per-run capacity values: + +```text +CapacityValue[T]: + value: T | None + source: default | run_config | dataset_builder | model_metadata | engine_internal_config | adapter_config | environment | runtime_snapshot | benchmark_override + fallback_from: str | None + missing_reason: str | None + +RowGroupAdmission: + row_group_concurrency: CapacityValue[int] + observed_in_flight: int | None + +ProviderModelStaticCap: + cap: int + aliases: Sequence[str] + raw_caps: Mapping[str, int | None] + merge_rule: min_same_endpoint + +RequestAdmissionConfigSnapshot: + resources: Sequence[RequestResourceKey] + initial_limits: Mapping[RequestResourceKey, int] + max_limit_clamps: Mapping[RequestResourceKey, int | None] + cooldown_seconds: float + multiplicative_decrease_factor: float + additive_increase_step: int + successes_until_increase: int + startup_ramp_seconds: float + default_queue_wait_timeout_seconds: float | None + +AsyncCapacityPlan: + configured: + buffer_size: CapacityValue[int] + row_group_admission: RowGroupAdmission + submission_capacity: CapacityValue[int] + task_resource_limits: CapacityValue[Mapping[SchedulerResourceKey, int]] + request_resources: CapacityValue[Sequence[RequestResourceKey]] + provider_model_static_caps: CapacityValue[Mapping[ProviderModelKey, ProviderModelStaticCap]] + request_domain_initial_limits: CapacityValue[Mapping[RequestResourceKey, int]] + request_admission_config: CapacityValue[RequestAdmissionConfigSnapshot] + transport_pool_limits: CapacityValue[Mapping[ProviderModelKey, int]] + runtime_snapshot: + request_domain_current_limits: Mapping[RequestResourceKey, int] | None + request_domain_effective_max: Mapping[RequestResourceKey, int] | None + request_domain_blocked_until: Mapping[RequestResourceKey, float | None] | None + provider_model_aggregate_in_flight: Mapping[ProviderModelKey, int] | None + observed_maxima: + row_groups_in_flight: int + queued_tasks_by_group: Mapping[str, int] + task_leases_by_resource: Mapping[SchedulerResourceKey, int] + request_waiters_by_resource: Mapping[RequestResourceKey, int] + request_in_flight_by_resource: Mapping[RequestResourceKey, int] + provider_model_aggregate_in_flight: Mapping[ProviderModelKey, int] + request_domain_current_limits: Mapping[RequestResourceKey, int] + transport_pool_utilization: Mapping[ProviderModelKey, int] | None +``` + +Fields that depend on request-admission runtime state may be `None` in #654 before #657 lands, but the capacity plan and benchmark artifact must still include the field with `missing_reason` or an equivalent `not_available_until_issue` marker. + +The capacity plan explains observed runtime behavior. It is not itself a policy engine. diff --git a/plans/645/issue-map.md b/plans/645/issue-map.md new file mode 100644 index 000000000..14ee92989 --- /dev/null +++ b/plans/645/issue-map.md @@ -0,0 +1,75 @@ +# Issue Map + +This file maps GitHub issues to the source-of-truth plan sections. Issues should reference these files and own implementation/quality gates rather than restating the full architecture. + +## Source-Of-Truth Rule + +Use this directory for architecture, contracts, naming, invariants, and cross-cutting decisions. + +Use GitHub issues for: + +- implementation slices +- dependencies and target branch +- acceptance criteria for that slice +- tests and validation commands +- benchmark evidence required by that slice +- PR-specific cleanup gates + +## Issues + +| Issue | Implementation Focus | Source Sections | +| --- | --- | --- | +| #641 | Add `SchedulingMetadata` and generator override contract | [architecture.md](architecture.md), [contracts.md](contracts.md), [module-ownership.md](module-ownership.md), [task-admission.md](task-admission.md) | +| #646 | Ingest metadata into scheduler grouping through `TaskSchedulingResolver` | [architecture.md](architecture.md), [contracts.md](contracts.md), [module-ownership.md](module-ownership.md), [task-admission.md](task-admission.md) | +| #653 | Remove legacy hint resolver path | [migration-and-cleanup.md](migration-and-cleanup.md), [module-ownership.md](module-ownership.md), [contracts.md](contracts.md) | +| #652 | Document plugin-facing metadata behavior | [architecture.md](architecture.md), [contracts.md](contracts.md), [migration-and-cleanup.md](migration-and-cleanup.md) | +| #644 | Implement task admission lease boundary | [task-admission.md](task-admission.md), [contracts.md](contracts.md), [module-ownership.md](module-ownership.md), [benchmark-plan.md](benchmark-plan.md) | +| #654 | Implement and document capacity vocabulary and snapshots | [capacity-model.md](capacity-model.md), [observability.md](observability.md), [benchmark-plan.md](benchmark-plan.md) | +| #657 | Refactor model-call request control into request admission | [request-admission.md](request-admission.md), [contracts.md](contracts.md), [module-ownership.md](module-ownership.md), [migration-and-cleanup.md](migration-and-cleanup.md), [benchmark-plan.md](benchmark-plan.md) | +| #635 | Instrument request admission state | [observability.md](observability.md), [request-admission.md](request-admission.md), [contracts.md](contracts.md), [benchmark-plan.md](benchmark-plan.md) | +| #647 | Instrument scheduler admission state | [observability.md](observability.md), [task-admission.md](task-admission.md), [contracts.md](contracts.md), [benchmark-plan.md](benchmark-plan.md) | +| #648 | Correlate scheduler and request observability | [observability.md](observability.md), [architecture.md](architecture.md), [benchmark-plan.md](benchmark-plan.md) | +| #649 | Build reusable benchmark harness and normalize provisional evidence | [benchmark-plan.md](benchmark-plan.md), [capacity-model.md](capacity-model.md), [observability.md](observability.md), [task-admission.md](task-admission.md), [request-admission.md](request-admission.md) | +| #660 | Produce final user/operator docs | [architecture.md](architecture.md), [contracts.md](contracts.md), [module-ownership.md](module-ownership.md), [request-admission.md](request-admission.md), [capacity-model.md](capacity-model.md), [observability.md](observability.md), [benchmark-plan.md](benchmark-plan.md), [migration-and-cleanup.md](migration-and-cleanup.md) | +| #650 | Implement bounded-borrow task policy | [task-admission.md](task-admission.md), [benchmark-plan.md](benchmark-plan.md), [capacity-model.md](capacity-model.md) | +| #651 | Design resource-vector/provider-aware policy | [task-admission.md](task-admission.md), [request-admission.md](request-admission.md), [capacity-model.md](capacity-model.md), [observability.md](observability.md), [benchmark-plan.md](benchmark-plan.md) | + +## Dependency Order + +The implementation order remains: + +```text +#641 -> #646 -> #653 -> #652 -> #644 -> #654 -> #657 +-> #635 -> #647 -> #648 -> #649 -> #660 -> #650 -> #651 +``` + +#644 cannot close until task admission uses the final scheduler module homes and the accepted metadata contract. The accepted end state is `SchedulingMetadata` feeding task admission through `TaskSchedulingResolver`; old resolver paths, compatibility adapters, and duplicate module homes are not part of the target architecture. + +#660 promotes the stabilized V1 admission/capacity/telemetry docs. #650 and #651 are follow-on policy/design issues; if they change behavior or public/operator guidance, they must update this source-of-truth plan and any promoted docs as part of their own acceptance gates. #651 is design-first unless its issue body explicitly promotes an implementation slice. The request-pressure advisory selection path currently in PR #661 is a narrow implementation slice ahead of the broader #651 design; it must remain read-only with respect to request admission until #651 defines a durable provider/resource-aware policy. + +## Evidence Phasing + +The native issue order keeps #649 after capacity, request admission, telemetry, and correlation because the reusable harness consumes those contracts. That does not waive evidence for earlier implementation PRs. + +Before #649 closes, issues #644, #654, #657, #635, #647, and #648 must produce provisional benchmark/evidence artifacts using the schema in [benchmark-plan.md](benchmark-plan.md). A minimal deterministic smoke writer should exist before those slices rely on one-off evidence. Issue #649 then converts those provisional artifacts into the reusable harness, reruns representative scenarios, and becomes the gate for #660, #650, and #651. + +## Issue Body Cleanup Pattern + +When revising issue bodies, keep: + +- priority +- dependency metadata +- target branch +- short problem statement +- links to the relevant plan sections +- implementation checklist +- tests and validation commands +- evidence requirements +- acceptance criteria specific to the slice + +Remove or shorten: + +- duplicated contract definitions +- duplicated architecture diagrams +- broad cross-cutting non-goals already captured here +- stale naming decisions superseded by this plan diff --git a/plans/645/migration-and-cleanup.md b/plans/645/migration-and-cleanup.md new file mode 100644 index 000000000..2469a1593 --- /dev/null +++ b/plans/645/migration-and-cleanup.md @@ -0,0 +1,140 @@ +# Migration And Cleanup + +The epic is not complete until replaced names, old module paths, and compatibility paths are removed from production code, current docs, and this source-of-truth plan. The target architecture is defined by [Module ownership](module-ownership.md); implementation PRs should move directly to those final homes. + +## Scheduling Metadata Cleanup + +The durable scheduling metadata path is: + +```text +ColumnGenerator.get_scheduling_metadata() +-> SchedulingMetadata +-> TaskSchedulingResolver +``` + +Remove the legacy resolver types and any independent scheduler-side model/provider introspection path. All model/provider inference must live behind `ColumnGenerator.get_scheduling_metadata()`, `SchedulingMetadata`, and typed `SchedulingMetadataError` fallback behavior. + +Unacceptable end states: + +- a parallel fallback that independently introspects generators, configs, model registries, aliases, or admitted policy data under the old resolver contract +- a compatibility adapter, alias, or reexport that preserves the old resolver vocabulary as a durable production path + +Final legacy-name search gate should have no production/current-doc matches for these historical strings except the explicitly deprecated config compatibility shim for `ThrottleConfig` / `RunConfig.throttle`: + +```text +SchedulingHintResolver +SchedulingHint +_model_aliases_for_generator +``` + +Independent scheduler-side model-bound fallback logic is also migration-only and should be folded behind metadata/resource requests by epic completion. + +## Request Admission Cleanup + +The durable request-admission names are: + +- `ModelRequestExecutor` +- `RequestAdmissionController` +- `AdaptiveRequestAdmissionController` +- `RequestAdmissionConfig` +- `RequestDomain` + +Final legacy-name search gate should have no production/current-doc matches for these historical strings: + +```text +ThrottleManager +ThrottleDomain +ThrottleConfig +RunConfig.throttle +throttle_manager.py +ThrottledModelClient +throttled_model_client +``` + +Historical changelog text may remain only if it is clearly marked as historical and not presented as current API. + +## Task-Stage Wait Cleanup + +The durable architecture does not include: + +```text +needs_llm_wait +held_llm_wait +max_llm_wait_tasks +``` + +If a scheduler-level resource remains for LLM-bound work, it must be represented through `SchedulerResourceRequest`, `TaskAdmissionConfig`, and `AsyncCapacityPlan`, with names that describe scheduler task-stage pressure rather than request concurrency. + +## Module Ownership Cleanup + +The target scheduler package is: + +```text +data_designer.engine.dataset_builders.scheduling +``` + +The durable architecture does not keep scheduler-owned task models, readiness tracking, queues, task admission, or task policies in `dataset_builders.utils`. + +The target request-admission package is: + +```text +data_designer.engine.models.request_admission +``` + +The durable architecture does not keep request-admission controllers, queues, waiters, AIMD state, pressure snapshots, or request leases in `models.clients.request_admission`. + +`ModelRequestExecutor` remains under `models.clients` because it wraps concrete model clients. Request admission itself must not be reexported from `models.clients.__init__`. + +## Compatibility Shim Rule + +Do not leave production compatibility aliases, subclasses, adapters, reexports, docs paths, or durable tests for replaced names at epic completion. + +Do not introduce shim modules or deprecation adapters under replaced names. Historical names may appear only in explicit cleanup/search-gate sections like this one or in clearly marked historical changelog/dev-note text. + +## Gate Semantics + +Before the migration issues land, stale-name matches can exist as current-state evidence. + +By #653 close, legacy scheduling-hint production paths are gone and tests have moved to metadata/resolver coverage. + +By #657 close, request-admission code has no production `Throttle*` aliases, exports, modules, or durable tests. + +By #645 close, production code lives in the target modules from [Module ownership](module-ownership.md), package `__init__.py` files do not reexport internal queues/controllers/leases, and public/current docs plus `plans/645` use only the durable architecture vocabulary except for this cleanup file's explicit legacy-name search lists. Historical changelog or dev-note text can remain only when explicitly marked historical. + +## Documentation Cleanup + +Current maintainer architecture docs should use durable internal names when they discuss implementation internals: + +- `SchedulingMetadata` +- `TaskSchedulingResolver` +- `TaskAdmissionController` +- `TaskAdmissionPolicy` +- `TaskAdmissionLease` +- `ModelRequestExecutor` +- `RequestAdmissionController` +- `AdaptiveRequestAdmissionController` +- `RequestAdmissionConfig` +- `RuntimeCorrelationProvider` + +User/operator docs should expose public run config fields, including `RequestAdmissionTuningConfig`, `AsyncCapacityPlan`, benchmark artifacts, telemetry views, and high-level layer names. They must not present `TaskAdmissionConfig`, `RequestAdmissionConfig`, policies, leases, queues, pressure snapshots, or controller mutation APIs as public user knobs. Plugin-facing docs should describe metadata only, then link to architecture docs for maintainers/operators. + +Current architecture docs, diagrams, generated assets, and plan files must be checked as part of final cleanup. Existing historical dev notes may retain old names only when the text clearly says the name is historical and no longer current API. + +Current user/operator architecture docs must also remove or mark as historical capacity-control descriptions that imply the pre-epic architecture. This includes old model-client request-capacity names and scheduler-slot handoff explanations. + +## Validation Commands + +Adjust paths as files move, but final PRs should include searches equivalent to: + +```bash +rg "SchedulingHintResolver|SchedulingHint|_model_aliases_for_generator|is_llm_bound" packages docs fern architecture plans/645 +rg "ThrottleManager|ThrottleDomain|ThrottleConfig|RunConfig\\.throttle|throttle_manager\\.py|ThrottledModelClient|throttled_model_client" packages docs fern architecture plans/645 +rg "_submission_semaphore|_llm_wait_semaphore|get_semaphore_permits|TrackingSemaphore" packages docs fern architecture plans/645 +rg "throttl(e|ed|ing)|semaphore" docs fern architecture plans/645 +rg "needs_llm_wait|held_llm_wait|max_llm_wait_tasks" packages docs fern architecture plans/645 +rg "dataset_builders\\.utils\\.(task_model|completion_tracker|task_scheduling|fair_task_queue|task_admission)" packages docs fern architecture plans/645 +rg "models\\.clients\\.request_admission|from data_designer\\.engine\\.models\\.clients import .*Request" packages docs fern architecture plans/645 +rg "SchedulingMetadata|TaskSchedulingResolver|FairTaskQueue|TaskAdmissionController|TaskAdmissionLease|ModelRequestExecutor|RequestAdmissionController|AdaptiveRequestAdmissionController|AsyncCapacityPlan|SchedulerResourceRequest|RequestResourceKey" docs fern architecture plans/645 +``` + +Any remaining hit must be intentionally historical, not a current implementation or docs path. Allowed plan hits are limited to explicit cleanup/search-gate sections that name the legacy strings so reviewers know what to remove. The task-stage wait-specific search distinguishes obsolete scheduler-slot handoff primitives from unrelated internal synchronization primitives that may remain after review. diff --git a/plans/645/module-ownership.md b/plans/645/module-ownership.md new file mode 100644 index 000000000..8ed86e09c --- /dev/null +++ b/plans/645/module-ownership.md @@ -0,0 +1,217 @@ +# Module Ownership + +This page defines the target repository and module ownership for the async scheduling epic. It is an end-state design, not a migration plan. Implementation PRs should move directly toward these homes and must not introduce compatibility aliases, shim modules, transitional reexports, or duplicate old/new module paths. + +Durable engine names in this plan are maintainer contracts. They are not public import promises unless this page explicitly marks them plugin-facing or operator-facing. + +## Package Ownership + +| Package | Owns | Must not own | +| --- | --- | --- | +| `data-designer-config` | public configuration DTOs and generator-facing metadata | engine runtime protocols, queues, admission leases, request domains, AIMD state, runtime pressure | +| `data-designer-engine` | scheduler runtime, task admission, request admission, capacity diagnostics, runtime observability, benchmark internals | public interface orchestration, user-facing docs presentation | +| `data-designer` | public `DataDesigner` interface wiring, CLI presentation, integrations | scheduler internals, plugin-facing scheduling metadata definitions | + +Config code must not import engine runtime code. Engine code may import config DTOs. + +## Target Module Layout + +```text +packages/data-designer-config/src/data_designer/config/ + scheduling.py + SchedulingMetadata + SchedulingMetadataError + +packages/data-designer-engine/src/data_designer/engine/ + dataset_builders/ + async_scheduler.py + AsyncTaskScheduler + + scheduling/ + task_model.py + Task + SliceRef + TaskTrace + + completion.py + CompletionTracker + FrontierDelta + + resources.py + TaskGroupKey + TaskGroupSpec + SchedulerResourceKey + SchedulerResourceRequest + SchedulableTask + stable_task_id + + resolver.py + TaskSchedulingResolver + ResolvedTaskScheduling + + queue.py + FairTaskQueue + QueueView + QueueSelection + + task_admission.py + TaskAdmissionController + TaskAdmissionConfig + TaskAdmissionLease + TaskAdmissionDenied + TaskAdmissionDecision + TaskAdmissionView + TaskAdmissionBlockSummary + ReleaseResult + + task_policies.py + TaskAdmissionPolicy + TaskAdmissionPolicyDecision + PolicyStateDelta + StrictFairTaskAdmissionPolicy + BoundedBorrowTaskAdmissionPolicy + BoundedBorrowTaskAdmissionPolicyConfig + + models/ + resources.py + ProviderModelKey + ProviderModelStaticCap + provider/model alias canonicalization helpers + + clients/ + model_request_executor.py + ModelRequestExecutor + + request_admission/ + resources.py + RequestDomain + RequestResourceKey + RequestGroupSpec + RequestEventContext + RequestAdmissionItem + + resolver.py + RequestResourceResolver + ResolvedRequestResource + + config.py + RequestAdmissionConfig + + queue.py + RequestFairQueue + RequestWaiter + RequestQueueView + RequestQueueSelection + + limits.py + AdaptiveRequestLimitState + provider/model aggregate limit state + + pressure.py + RequestPressureSnapshotProvider + RequestPressureSnapshot + ProviderModelPressureSnapshot + + outcomes.py + RequestReleaseOutcome + ReleaseResult + + controller.py + RequestAdmissionController + AdaptiveRequestAdmissionController + RequestAdmissionLease + RequestAdmissionDenied + RequestAdmissionDecision + RequestAdmissionError + + capacity.py + CapacityValue + AsyncCapacityPlan + AsyncCapacityConfigured + AsyncCapacityRuntimeSnapshot + AsyncCapacityObservedMaxima + RequestAdmissionConfigSnapshot + + observability.py + RuntimeCorrelation + RuntimeCorrelationProvider + runtime_correlation_provider + SchedulerAdmissionEvent + SchedulerAdmissionEventSink + RequestAdmissionEvent + RequestAdmissionEventSink + InMemoryAdmissionEventSink + CorrelatedRuntimeView + + models/telemetry.py + product/provider usage telemetry only +``` + +`AsyncTaskScheduler` is the runtime coordinator only. It owns ready-frontier polling, queue selection, task-lease acquire/release orchestration, worker lifecycle, salvage/retry coordination, shutdown, and row-group lifecycle integration. It does not own queue policy, task admission ledgers, request admission, provider cooldown, AIMD behavior, or model-client wrapping. + +`ModelRequestExecutor` remains under `models/clients` because it implements the model-client boundary and wraps concrete provider clients. Request admission itself lives under `models/request_admission` and must not import `ModelClient` or provider adapter classes. + +`models/resources.py` owns provider/model identity that is shared across metadata resolution, request admission, and capacity diagnostics. Request admission owns request-domain resources. Capacity consumes both as read-only diagnostic inputs; it does not own admission policy or controller state transitions. + +`observability.py` is the cross-layer runtime-observability home. It owns scheduler and request admission event DTOs, primitive runtime correlation, in-memory test/diagnostic sinks, and correlated runtime views. Product/provider usage telemetry remains separate in `models/telemetry.py`. + +## Current Module Targets + +| Current or legacy module/concept | Target direction | +| --- | --- | +| `dataset_builders/async_scheduler.py` | keep as coordinator; remove durable queue, task-policy, and request-admission ownership | +| `dataset_builders/utils/task_model.py` | move scheduler task DTOs to `dataset_builders/scheduling/task_model.py` | +| `dataset_builders/utils/completion_tracker.py` | move readiness tracking to `dataset_builders/scheduling/completion.py` | +| `dataset_builders/utils/task_scheduling.py` | split scheduler resources into `scheduling/resources.py` and metadata resolution into `scheduling/resolver.py` | +| `dataset_builders/utils/fair_task_queue.py` | move to `dataset_builders/scheduling/queue.py`; keep ready ordering only | +| `dataset_builders/utils/task_admission.py` | split controller/lease DTOs into `scheduling/task_admission.py` and policies into `scheduling/task_policies.py` | +| `models/clients/model_request_executor.py` | keep as the concrete model-client acquire/call/release wrapper | +| `models/clients/request_admission.py` | split into the `models/request_admission/` package | +| `models/clients/__init__.py` request-admission reexports | remove; request-admission internals are imported from their owning modules only | +| `models/telemetry.py` | keep product/provider usage telemetry separate from admission event DTOs | +| `capacity.py` | keep as cross-cutting capacity diagnostic/reporting code that consumes read-only scheduler/request DTOs and snapshots | +| `SchedulingHintResolver`, `SchedulingHint`, and scheduler-side model-bound fallbacks | remove; `SchedulingMetadata` plus `TaskSchedulingResolver` are the only durable path | +| `ThrottleManager`, `ThrottleDomain`, `ThrottledModelClient`, and `throttled_model_client` | remove; request admission and `ModelRequestExecutor` are the only durable request-control path | +| `ThrottleConfig` and `RunConfig.throttle` | keep only as deprecated public config compatibility shims that translate to `RequestAdmissionTuningConfig` and emit `DeprecationWarning`; not durable engine architecture | + +## Import Rules + +- `data_designer.config.*` must not import `data_designer.engine.*` or `data_designer.interface.*`. +- Engine modules may import `SchedulingMetadata` and `SchedulingMetadataError` from config. +- `dataset_builders/scheduling/*` may import config scheduling metadata, dataset-builder task/readiness concepts, primitive runtime observability, and neutral provider/model identity helpers from `engine.models.resources`. +- `dataset_builders/scheduling/*` must not import model clients, request-admission controllers, request queues, AIMD state, provider adapters, or request leases. +- `models/request_admission/*` may import neutral provider/model identity helpers and primitive observability, but must not import dataset-builder scheduler types or `ModelClient`. +- `models/clients/model_request_executor.py` is the production bridge that imports both model-client types and request-admission protocols. It is the only model-client layer that acquires and releases request leases. +- `capacity.py` may import read-only resource DTOs, config snapshots, pressure snapshots, and event snapshots. It must not call controller mutation APIs or become a controller registry. +- `observability.py` must not import concrete controllers, queues, model clients, provider adapters, or dataset-builder schedulers. +- `data-designer` interface and CLI code may consume engine diagnostics for presentation, but must not reexport scheduler/request internals as plugin API. +- Package `__init__.py` files must not reexport internal queues, policies, leases, waiters, or controllers as broad public-looking APIs. + +## Audience Boundaries + +| Audience | Exposed surface | Not exposed | +| --- | --- | --- | +| Plugin authors | `ColumnGenerator.get_scheduling_metadata()`, `SchedulingMetadata`, `SchedulingMetadataError` | queues, task groups, scheduler resources, task leases, request domains, pressure snapshots, AIMD state | +| Users/operators | documented public run/model config fields, `AsyncCapacityPlan`, benchmark artifacts, telemetry/event artifacts, correlated runtime views | controller mutation APIs, queues, policies, leases, waiters, internal config objects | +| Engine maintainers | scheduler/request admission modules, DTOs, protocols, policies, snapshots, events, capacity diagnostics | config-layer reverse imports, compatibility aliases, duplicate old/new module paths | +| Tests and benchmarks | local fakes, deterministic model clients, event sinks, benchmark override config | production `engine.testing` helpers, test-module imports from benchmark code, benchmark-module imports from unit tests | + +`TaskAdmissionConfig` and `RequestAdmissionConfig` are engine-internal in V1. They may appear inside capacity and benchmark artifacts as explanatory snapshots, but they are not public `RunConfig` knobs. Public request-admission tuning is exposed only through `RequestAdmissionTuningConfig` on `RunConfig.request_admission` and is translated into the engine-internal config at the engine boundary. + +## Tests And Benchmarks + +Tests mirror target module ownership: + +| Area | Target test home | +| --- | --- | +| config metadata | `packages/data-designer-config/tests/config/test_scheduling.py` | +| scheduler task resources/resolver | `packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_resources.py` and `test_resolver.py` | +| fair task queue | `packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_queue.py` | +| task admission and policies | `packages/data-designer-engine/tests/engine/dataset_builders/scheduling/test_task_admission.py` and `test_task_policies.py` | +| scheduler integration | `packages/data-designer-engine/tests/engine/dataset_builders/test_async_scheduler.py` | +| request admission | `packages/data-designer-engine/tests/engine/models/request_admission/` | +| model request executor | `packages/data-designer-engine/tests/engine/models/clients/test_model_request_executor.py` | +| capacity diagnostics | `packages/data-designer-engine/tests/engine/test_capacity.py` | +| runtime observability | `packages/data-designer-engine/tests/engine/test_observability.py` | + +Test fakes live under tests near their consumers. Benchmark fakes and reusable scenarios live under `scripts/benchmarks/async_scheduling/`, with `scripts/benchmarks/benchmark_async_scheduling.py` as the runnable entrypoint. Production `data_designer.engine.testing` helpers are not part of the target architecture. diff --git a/plans/645/observability.md b/plans/645/observability.md new file mode 100644 index 000000000..2b09b5260 --- /dev/null +++ b/plans/645/observability.md @@ -0,0 +1,212 @@ +# Observability + +Observability must explain which layer is limiting progress without collapsing scheduler admission and request admission into one subsystem. + +## Separate Event Streams + +`SchedulerAdmissionEventSink` emits scheduler-owned task admission events. + +`RequestAdmissionEventSink` emits provider/model/domain request admission events. + +Both sinks are generic first. OpenTelemetry, structured logs, dashboards, benchmarks, and debug tools are adapters or consumers. + +Sink failures must never interrupt generation. Event data can be collected under locks, but event emission should happen after locks are released. + +Event DTOs capture primitive correlation fields at construction time. Sinks receive already-captured event data; they must not read ambient context later to discover which task/request an event belongs to. + +All event DTOs include `captured_at_monotonic` and a monotonic per-stream `sequence`. Scheduler events include task id, task execution id when a worker execution exists, task lease id when available, scheduler resource key when applicable, denial/release reason when applicable, optional snapshot, and diagnostics. Request events include request attempt id when they belong to one concrete attempt, request lease id when available, canonical request resource when resource-specific, request group key when queue/admission-specific, denial/release outcome when applicable, optional pressure snapshot, and diagnostics. Event construction normalizes correlation, keys, snapshots, and diagnostics to JSON-compatible values so structured sinks do not need to understand internal dataclass or enum types. + +## Scheduler Admission Events + +Scheduler events describe dependency-ready work moving through ready ordering, task admission, worker spawn, and task lease release. + +Canonical scheduler event kinds: + +- `scheduler_job_started` +- `scheduler_job_completed` +- `scheduler_health_snapshot` +- `dependency_ready` +- `ready_enqueued` +- `row_group_admitted` +- `row_group_admission_blocked` +- `row_group_admission_target_changed` +- `row_group_checkpointed` +- `selected` +- `queue_empty` +- `admission_blocked` +- `group_capped` +- `request_pressure_advisory_skipped` +- `task_lease_acquired` +- `admission_denied` +- `worker_spawned` +- `worker_spawn_failed` +- `stale_selection` +- `retry_deferred` +- `non_retryable_dropped` +- `cancelled` +- `salvage_redispatched` +- `queue_drained` +- `task_completed` +- `task_lease_released` +- `release_diagnostic` + +Scheduler snapshots include: + +- queued total +- queued by group +- queued demand by group/resource +- admitted/running by group +- resource limits by scheduler resource +- scheduler resources available by resource +- leased resources by group/resource +- active task lease count by resource +- release diagnostic counters +- bounded-borrow debt by group/resource when applicable + +Scheduler events must make hidden scheduler-resource waiters derivable and zero after the task-admission lease boundary lands. + +## Request Admission Events + +Request events describe provider/model/domain request admission and AIMD behavior. + +Canonical request event kinds: + +- `request_resource_registered` +- `request_effective_cap_changed` +- `request_queue_formed` +- `request_wait_started` +- `request_wait_completed` +- `request_wait_timeout` +- `request_wait_cancelled` +- `request_acquire_denied` +- `request_lease_acquired` +- `model_request_started` +- `model_request_completed` +- `request_queue_drained` +- `request_rate_limited` +- `request_limit_decreased` +- `request_limit_increased` +- `request_soft_ceiling_recovered` +- `request_fully_recovered` +- `request_lease_released` +- `request_release_diagnostic` + +Request snapshots include: + +- captured timestamp +- monotonic sequence +- request resource +- effective max +- current limit +- in-flight count +- active lease count +- waiters +- blocked-until timing +- cooldown remaining +- rate-limit ceiling +- consecutive rate limits +- last release outcome +- leak diagnostic counters + +Global provider/model snapshots capture effective static caps, aggregate in-flight counts across domains, aliases, and per-domain summaries. + +## Runtime Correlation + +`RuntimeCorrelationProvider` carries current task context while the scheduler executes a task. The likely implementation is a context variable with set/reset/current behavior. + +`RuntimeCorrelation` contains primitive values only: + +- run id +- row group +- task column +- task type +- scheduling group kind +- scheduling group identity hash +- task execution id + +`ModelRequestExecutor` reads the current correlation context when constructing `RequestEventContext` for each concrete request attempt. Scheduler event DTOs capture the scheduler's current task/run identity directly. `AdaptiveRequestAdmissionController` remains keyed by provider/model/domain resources and does not import scheduler task types; it may attach the opaque primitive request event context to request events. + +Correlation must propagate through child asyncio tasks created as part of admitted task execution. If execution crosses threads, callbacks, or background tasks that cannot preserve context variables, the caller must pass primitive `RuntimeCorrelation` explicitly or mark the event as intentionally uncorrelated. Late/background provider calls after the scheduler has reset task context are not considered part of the admitted task unless they carry explicit correlation. + +`CorrelatedRuntimeView` joins the timelines for diagnostics and benchmarks. + +## Joined Timeline + +The joined timeline should distinguish: + +```text +dependency readiness +ready enqueued +selected by fair queue +task lease acquired +worker spawned +request admission wait started +request admission wait completed +request lease acquired +model request started +model request completed +request lease released +task completed +task lease released +``` + +Runs should be diagnosable as limited by dependency readiness, ready-queue fairness, scheduler capacity, request-admission wait, provider cooldown/rate-limit behavior, transport/provider execution, or downstream completion. + +Benchmark-required monotonic timeline fields are derived from these events: + +- `dependency_ready_at` +- `ready_enqueued_at` +- `selected_at` +- `task_lease_acquired_at` +- `worker_spawned_at` +- `request_wait_started_at` +- `request_wait_completed_at` +- `request_lease_acquired_at` +- `model_request_started_at` +- `model_request_completed_at` +- `request_lease_released_at` +- `task_completed_at` +- `task_lease_released_at` + +## Cardinality And Safety + +Metric-safe dimensions: + +- event kind +- scheduler resource kind +- request admission event kind +- provider name +- bounded model label, model family, or allowlisted model label +- metric model label +- request domain +- algorithm + +Trace-only or sampled fields: + +- run id +- row group +- task column +- task type +- scheduling group identity hash +- raw model id +- task id +- task execution id +- task lease id +- request attempt id +- request lease id +- queued maps by group + +Never emit: + +- prompts +- completions +- row values +- dataset records +- secrets +- raw provider response bodies +- raw exception payloads +- unbounded request IDs as metric labels + +## OpenTelemetry Rule + +Core runtime may provide an OTel bridge that depends on API-level primitives, but it must not configure OTel SDKs, exporters, or collectors. Applications embedding Data Designer own exporter configuration. diff --git a/plans/645/request-admission.md b/plans/645/request-admission.md new file mode 100644 index 000000000..b4204c3c3 --- /dev/null +++ b/plans/645/request-admission.md @@ -0,0 +1,196 @@ +# Request Admission + +Request admission controls concrete provider/model/domain calls at the moment they are made. It is separate from task admission because task-level scheduling cannot predict every model call inside arbitrary generator Python. + +## Runtime Shape + +```text +ModelRequestExecutor + -> ModelRequestExecutor.execute_attempt(request) + -> RequestAdmissionController.acquire_async(RequestAdmissionItem) + -> RequestAdmissionLease + -> provider/model endpoint + -> RequestAdmissionController.release(lease, RequestReleaseOutcome) +``` + +`ModelRequestExecutor` is the durable model-call boundary. It maps each concrete call attempt to a request resource, acquires a lease, calls the provider, records timing, and releases the exact lease. + +The boundary is per outbound attempt. Provider retry behavior must either live inside `ModelRequestExecutor` and acquire/release a lease for each attempt, or call back through `ModelRequestExecutor` for each attempt. HTTP/provider-client retries that hide multiple outbound attempts under one request lease are not compatible with the target architecture because rate limits and provider timing would be invisible to request admission. + +After a lease is acquired, `ModelRequestExecutor` owns release in a non-cancellable cleanup path. Cancellation after lease acquisition is classified as `local_cancelled`, the exact lease is released before the cancellation is re-raised, and release diagnostics must not mask the original cancellation. + +## Dynamic Requests + +Custom generators may make zero, one, or many model requests depending on row data, branches, retries, validation failures, tool calls, or helper functions. `SchedulingMetadata` can describe static resource shape for task grouping, but it is not an exact request-count promise. + +Therefore: + +- task admission must not pre-acquire request permits +- request admission happens at concrete model-call time +- each acquired request lease is released exactly once +- each retry attempt is admitted and released independently +- request-level wait and provider execution timing remain visible separately + +## Durable Names + +The durable interface name is `RequestAdmissionController`. + +The durable V1 implementation name is `AdaptiveRequestAdmissionController`. + +The durable model-call boundary name is `ModelRequestExecutor`. + +The durable internal config vocabulary is `RequestAdmissionConfig`. The public `RunConfig.request_admission` surface uses `RequestAdmissionTuningConfig`, a constrained advanced DTO for supported AIMD tuning only. Public config is translated into engine-internal `RequestAdmissionConfig` at the engine boundary; users do not receive controller, queue, lease, pressure snapshot, per-resource initial-limit, max-clamp, or queue-timeout mutation APIs. + +Do not keep production aliases, shims, subclasses, adapters, exports, docs paths, or durable tests for the replaced request-control vocabulary. [Migration and cleanup](migration-and-cleanup.md#request-admission-cleanup) lists the exact search terms. + +## Request Resource Model + +`RequestResourceKey` identifies canonical resolved request identity: + +- provider name, after alias resolution +- model id, after alias resolution +- request domain + +Aliases are diagnostic-only after request-key construction. They are recorded in `AsyncCapacityPlan` and snapshots, but request admission must not key aggregate caps by user alias. + +`RequestResourceResolver` is the single canonicalization contract for request admission. It resolves provider alias, model alias, model id, generation kind, endpoint metadata, and `RequestDomain` into `ProviderModelKey` and `RequestResourceKey`. Metadata resolution and capacity planning use the same provider/model canonicalization rules; generation kind is folded into the canonical model id only when the provider treats it as a distinct endpoint. + +`RequestDomain` V1 values are `chat`, `embedding`, `image`, and `healthcheck`. Additions require updating this plan. + +`RequestAdmissionItem` contains resource, group, optional queue-wait timeout, and `RequestEventContext`. `RequestGroupSpec` contains a fairness group key and weight. In V1 the fairness group key is the `RequestResourceKey`; a future policy may split resource identity from fairness identity only after updating this plan. + +`RequestEventContext` is created by `ModelRequestExecutor` from the current primitive runtime correlation plus a request-attempt id. It is telemetry context, not scheduler state. + +`RequestAdmissionDecision` is `RequestAdmissionLease | RequestAdmissionDenied`. + +`RequestAdmissionLease` records a unique lease id, item, acquired timestamp, current adaptive limit, effective max, and controller generation token. + +`RequestAdmissionDenied` records item, reason, retry timing, availability timing, and optional snapshot. + +`RequestAdmissionController.pressure` exposes the read-only `RequestPressureSnapshotProvider`. + +`acquire_sync(...)` and `acquire_async(...)` block until a lease is available or a terminal no-lease condition occurs. Queue-wait timeout, shutdown, or hard denial removes the waiter and raises `RequestAdmissionError`, a typed Data Designer error carrying `RequestAdmissionDenied`. These methods never return `None`. `try_acquire(...)` is the non-blocking path that returns the full decision union. + +`acquire_async(...)` preserves cooperative cancellation: if the awaiting task is cancelled before a lease is acquired, the controller removes the waiter, emits a cancellation event, and re-raises the cancellation exception. + +Once the controller selects a waiter and increments in-flight counts, cancellation cannot orphan the lease. The selected waiter's acquire call receives the lease for caller cleanup, or the controller internally releases it as `local_cancelled` before completing cancellation. A blocking acquire call may only return a lease for its own waiter; if a wakeup admits a different waiter, that other waiter is fulfilled and the current caller continues waiting. + +## Request Queue Semantics + +`AdaptiveRequestAdmissionController` owns an internal `RequestFairQueue`. The queue is protected by the controller lock/condition and exposes the same transaction shape as task admission: + +```text +enqueue(waiter) +select_next(is_eligible) -> RequestQueueSelection | None +commit(selection) -> RequestWaiter | None +remove(waiter_id) +view() -> RequestQueueView +``` + +`RequestWaiter` carries waiter id, item, enqueue timestamp, deadline/cancellation state, and the completion handle for the blocking acquire path. `RequestQueueSelection` carries waiter, item, waiter id, queue view, and a `sequence_version`. `commit(selection)` is the only operation that removes an admitted waiter. + +Wakeups occur when a request lease releases, cooldown expires, adaptive limit increases, shutdown/cancellation removes waiters, or provider/model aggregate capacity becomes available. Waiters use monotonic timed waits to the earliest queue timeout, `available_after_monotonic`, or `blocked_until_monotonic`; cooldown expiry cannot depend on a later provider release to wake the queue. + +Every concrete request attempt emits request-wait timeline events. Immediate acquisition emits `request_wait_started` and `request_wait_completed` as a zero-duration wait before `request_lease_acquired`; queued acquisition emits those events around actual queue wait. + +`try_acquire(...)` must not bypass queued work. It may return an immediate lease only when the item is eligible and no queued eligible waiter for the same request resource or provider/model aggregate cap would be selected first by `RequestFairQueue` weighted ordering. Otherwise it returns a typed denial, usually `queued_waiters_ahead`, `cooldown`, or `no_capacity`. + +## AdaptiveRequestAdmissionController + +`AdaptiveRequestAdmissionController` is the AIMD-backed request controller. It owns: + +- request fair queueing +- request admission policy +- adaptive request limit state +- provider/model/domain in-flight counts plus provider/model aggregate in-flight counts +- waiters +- cooldown state +- rate-limit cascades +- additive increase and multiplicative decrease +- request pressure snapshots + +Internal `RequestFairQueue`, `RequestAdmissionPolicy`, and `AdaptiveRequestLimitState` are part of the single canonical request-admission implementation. They are not a second public wrapper around request admission. + +An admitted request increments domain in-flight count and provider/model aggregate in-flight count before the lease is returned. Release decrements those counts exactly once before waking waiters. + +Weighted fairness applies across `RequestGroupSpec` groups that share a provider/model aggregate cap. Equal weights fall back to oldest waiter first. + +V1 AIMD contract: + +- all timing uses a monotonic clock +- `effective_max = min(provider_model_static_cap, request_config.max_limit_clamp_for_resource_if_present)` +- instantaneous aggregate availability is enforced separately by `provider_model_aggregate_in_flight < provider_model_static_cap` +- `initial_limit` is clamped to `[1, effective_max]` +- `current_limit` starts at `initial_limit`, unless `startup_ramp_seconds > 0`, in which case it starts at `1` and ramps linearly to `initial_limit` +- a provider rate-limit during startup ramp aborts the ramp and switches the resource to normal AIMD recovery +- provider rate limits apply multiplicative decrease and set `blocked_until_monotonic` +- success outside cooldown contributes to additive recovery +- `request_limit_increased`, `request_soft_ceiling_recovered`, and `request_fully_recovered` events are emitted from state transitions, not inferred later by sinks + +## Release Classification + +`ModelRequestExecutor` releases the exact acquired lease through the canonical release call: + +```text +release(lease, RequestReleaseOutcome) +``` + +Required outcome kinds: + +- `success` +- `rate_limited`, with `retry_after_seconds` when available +- `provider_failure` +- `provider_timeout` +- `local_cancelled` +- `local_timeout` +- `unexpected_exception` + +The release path is responsible for exactly-once accounting. Key-only release paths are not durable. + +Rate-limit outcomes drive AIMD decrease, cooldown, and waiter wake behavior. Provider failures may drive diagnostic counters but do not automatically imply provider pressure. Local cancellation and local timeout release capacity and wake waiters but must not be treated as rate limits or provider failures. + +`provider_timeout` means a timeout or timeout-shaped transport/provider error after a lease has been acquired and an outbound provider attempt has started. `local_timeout` means a caller, queue-wait, or controller deadline that is not evidence of provider pressure. + +Release returns `ReleaseResult` and must not raise from terminal cleanup paths. Duplicate release, stale release, or release against the wrong controller generation must return a diagnostic result and emit an error event without corrupting counters. + +## Request Pressure Snapshots + +`RequestPressureSnapshotProvider` exposes read-only state to diagnostics, benchmarks, telemetry, and future task policies. + +Domain snapshots include: + +- captured timestamp +- monotonic sequence/version +- request resource +- effective max +- current limit +- in-flight count +- active lease count +- waiters +- blocked-until timing +- cooldown remaining +- rate-limit ceiling +- consecutive rate limits +- last release outcome summary +- leak diagnostic counters + +Global snapshots include provider/model effective static caps, aggregate in-flight count across domains, aggregate active lease count, aliases contributing to the cap, and per-domain limit summaries. + +Task admission may read these snapshots as advisory input in later policy work. It must not mutate request state or emulate request admission. + +## Static And Adaptive Cap Semantics + +`max_parallel_requests` remains the provider/model static cap when available. In V1, that cap is enforced as an aggregate upper bound across all request domains for the provider/model. Domain-specific adaptive limits decide how each domain is admitted beneath the aggregate cap; there is no cross-domain aggregate AIMD state beyond the static aggregate cap unless a later issue adds one. + +Effective admission for a request must satisfy both: + +- the provider/model aggregate static cap has available in-flight capacity +- the request domain's adaptive limit and cooldown state admits the item + +## Non-Goals + +- Do not make request admission aware of DAG dependencies. +- Do not make request admission own row-group lifecycle or ready-work ordering. +- Do not replace AIMD with token-bucket or leaky-bucket behavior in V1. +- Do not require static prediction of all model calls. +- Do not make task-level `TaskAdmissionController` responsible for provider retry, cooldown, or AIMD updates. diff --git a/plans/645/task-admission.md b/plans/645/task-admission.md new file mode 100644 index 000000000..1abd3e99b --- /dev/null +++ b/plans/645/task-admission.md @@ -0,0 +1,192 @@ +# Task Admission + +Task admission controls when dependency-ready dataset work may become a running worker. It is scheduler-level admission, not provider/model request admission. + +## Control Owner + +`AsyncTaskScheduler` is the control owner. Its dispatch loop follows this shape: + +```python +selection = queue.select_next(lambda item, view: admission.is_eligible(item, view)) +if selection is None: + block_summary = admission.explain_blocked(queue.view()) + emit_queue_empty_or_blocked(block_summary) + wait_for_wake_or_deadline(block_summary.available_after) + return + +decision = admission.try_acquire(selection.item, selection.queue_view) +if isinstance(decision, TaskAdmissionDenied): + emit_admission_denied(decision) + wake_dispatch_loop() + return +lease = decision + +committed = queue.commit(selection) +if committed is None: + admission.release(lease) + emit_stale_selection(selection, lease) + wake_dispatch_loop() + return + +try: + spawn_worker(committed, lease) +except Exception: + admission.release(lease) + emit_worker_spawn_failed(committed, lease) + raise +``` + +`FairTaskQueue` selects candidates. `TaskAdmissionController` leases scheduler resources. The scheduler coordinates both. + +V1 requires a scheduler dispatch mutex around `select_next -> try_acquire -> commit`. No concurrent dispatch iteration may acquire resources for the same selected task. `QueueSelection` still carries a queue version so `commit(selection)` can detect stale selections defensively. If `commit(selection)` fails because the selection is stale, the scheduler releases the exact task lease before retrying and emits a stale-selection event. + +Wakeups are required when ready work is enqueued, task admission capacity is released, policy state changes from denied to eligible, an `available_after` deadline expires, shutdown/cancellation is requested, or a stale selection is detected. The implementation must avoid lost wakeups: a sleeper cannot remain asleep while queued work is eligible and task capacity is available. + +Lock ordering is part of the contract: the scheduler dispatch mutex coordinates the sequence, but `FairTaskQueue` must not hold queue-internal locks while invoking the scheduler eligibility predicate, and event sinks must not be called while queue or controller locks are held. + +## Queue Semantics + +`FairTaskQueue` owns ready-work ordering only. + +Rules: + +- `select_next(...)` is non-mutating. +- `select_next(...)` calls the eligibility callback with candidates and queue view. +- `QueueSelection` returns to `AsyncTaskScheduler`. +- `QueueSelection` carries the queue view/version used to evaluate the candidate. +- `enqueue(...)` returns the accepted task ids; duplicate task ids are accepted idempotently and do not create duplicate queue entries. +- `commit(selection)` removes the selected task and advances queue state. +- The queue does not track admitted/running counts after this epic. +- The queue does not inspect model registries, provider pressure, or request-admission state. +- The queue may scan ready candidates to find the next eligible task, but eligibility is computed only through the scheduler-supplied predicate. + +`QueueView` must be strong enough for current strict fairness and future bounded borrow without policy traversal of queue internals. It includes queued counts by group, queued demand by group/resource, and first-candidate resources by group. It does not report admission-aware eligibility. `TaskAdmissionPolicy` computes whether a peer is eligible for a currently available resource by combining `QueueView` with `TaskAdmissionView`. + +## Admission Semantics + +`TaskAdmissionController` owns: + +- scheduler-resource availability +- task-stage leases +- admitted/running resource counts +- per-group accounting used by policy +- release on every worker terminal path +- rollback of acquired resources when the scheduler reports stale queue commit +- the authoritative hard resource ledger; policy debt is stored separately and affects eligibility without changing resource availability counters directly + +`TaskAdmissionPolicy` owns: + +- eligibility decisions +- acquisition/release policy callbacks +- strict fair admission +- bounded-borrow behavior +- future resource-vector policy decisions + +`TaskAdmissionPolicy.evaluate(...)` is a pure decision function. It can be called repeatedly while the queue scans candidates and must not mutate debt, counters, timers, or diagnostics. `on_acquire(...)` and `on_release(...)` return deterministic policy state deltas. They must not directly mutate the controller's authoritative lease/resource ledger. If a policy needs borrow debt or similar mutable state, the controller applies the state transition as part of the same acquire/release transaction and exposes the resulting policy state in `TaskAdmissionView`. + +Policy decisions are typed. A denied decision carries the reason used by scheduler telemetry and tests. Bounded-borrow policies return `PolicyStateDelta` values for borrow-debt increments and repayments; the controller applies those deltas atomically with the lease acquire/release path. + +`TaskAdmissionController` consumes `SchedulableTask`, `SchedulerResourceRequest`, `QueueView`, and `TaskAdmissionView`. It must not inspect `ColumnGenerator`, config layout, model registry, or provider registry directly. + +## V1 Lease Boundary + +The first task-admission implementation is lease-only and behavior-preserving. It centralizes resource ownership without changing fairness policy beyond what is required to eliminate hidden waiters and make root work visible. + +V1 includes: + +- submission capacity for scheduler-spawned work +- explicit scheduler-resource leases for any task-stage backpressure that remains after request admission is separated +- current per-group admitted/running cap behavior +- typed `TaskAdmissionDecision` denial reasons for telemetry, tests, and benchmarks +- unique task lease identities so duplicate, stale, or wrong-controller releases are rejected or diagnosed + +V1 request waits remain inside admitted task execution and the task lease is retained until worker completion. That preserves the lease boundary and makes request waits visible, but it does not by itself solve cross-provider utilization when tasks for a cooled-down provider occupy all scheduler task slots. Issue #651 must address provider/resource-aware task admission or an explicit yield/reacquire design before the epic claims cross-provider scheduling optimization as complete. + +The current branch includes a narrow request-pressure advisory inside scheduler selection: when request-admission pressure is visible for one candidate and another eligible peer is not pressured, the scheduler may skip the pressured candidate for that selection pass. This consumes request-pressure snapshots as read-only input and does not mutate request-admission state or duplicate provider/model/domain AIMD. Treat broader provider/resource-aware scheduling as #651 scope. + +V1 excludes: + +- row-group admission +- concrete provider/model/domain request admission +- public runtime knobs +- distributed scheduling +- token budgets +- provider retry and AIMD behavior + +## Root And From-Scratch Work + +Root/from-scratch tasks must become `SchedulableTask`s and enter the same `FairTaskQueue` as downstream ready tasks. They must acquire scheduler-level leases through `TaskAdmissionController`. + +Initial root materialization is owned by `AsyncTaskScheduler`. `CompletionTracker.ready_frontier()` reports dependency-ready root tasks to the scheduler; the scheduler enqueues them into `FairTaskQueue` through the same path used for downstream work. `CompletionTracker` must not enqueue directly into `FairTaskQueue`. + +Readiness handoff is idempotent: + +- every `SchedulableTask` has a stable `task_id` +- `ready_frontier()` returns tasks that are ready and not yet acknowledged as enqueued +- `FairTaskQueue.enqueue(...)` is idempotent by `task_id` +- after enqueue succeeds, the scheduler calls `CompletionTracker.mark_enqueued(task_ids)` +- `CompletionTracker.mark_complete(task)` closes the task only after the scheduler records the terminal outcome + +No root dispatch path should bypass: + +- ready queue membership +- queue selection +- task admission +- lease release accounting +- scheduler admission telemetry + +This is required for heavy-root live-traffic evidence and later bounded-borrow policy. + +## Resource Handoff + +Resource-bound work must not become a spawned worker that waits for scheduler-level resources. The lease is acquired before spawn. + +Non-resource-bound work holds the relevant scheduler lease until worker completion. Resource-bound work holds the scheduler resource lease that represents the V1 task-stage resource request. Legacy hidden-wait booleans are not part of the target architecture. + +## Lease Lifecycle + +Every admitted task has one `TaskAdmissionLease` with a unique lease id. The scheduler releases that exact lease in a terminal `finally` path for success, retryable failure, non-retryable failure, cancellation, shutdown, salvage redispatch, and worker-spawn failure. + +Release rules: + +- release returns `ReleaseResult` and must not raise from terminal `finally` paths +- duplicate release must not increment capacity +- releasing a stale lease or a lease from another controller generation returns a diagnostic release result and emits an error event +- stale queue commit releases the task lease before any worker is spawned +- salvage/retry may make replacement work visible only after the original lease terminal path is accounted for. Replacement work is recorded through `CompletionTracker` or an explicit retry tracker, then re-enters the normal `ready_frontier() -> enqueue -> mark_enqueued` handoff; it must not be inserted directly into `FairTaskQueue` while the original lease is active. +- task release wakes the dispatch loop if queued work may now be eligible + +## Bounded Borrow Policy + +`BoundedBorrowTaskAdmissionPolicy` is the first behavior-changing follow-up after the lease boundary. It limits how far one group may borrow ahead while no peer group is queued. + +Policy inputs: + +- `QueueView`: queued counts and queued resource demand. +- `TaskAdmissionView`: resource limits, availability, leased/running counts, and policy debt by group/resource. +- `TaskGroupSpec`: group key and weight. +- candidate `SchedulerResourceRequest`. +- engine-internal `BoundedBorrowTaskAdmissionPolicyConfig` when enabled by #650, including borrow ceiling by group/resource, strict-share rounding mode, and repayment behavior. + +Policy constraints: + +- Single-group workloads remain live. +- Borrow debt is measured in admitted scheduler-resource units above strict fair share for a group/resource. Strict share is computed from scheduler-known competing groups and their weights; #650 owns the exact rounding rule and benchmark evidence. +- A group may borrow beyond strict share only up to its configured ceiling while no eligible peer can use the resource. +- When peer queue pressure exists and a group has borrow debt, that group receives no further admissions for the borrowed resource while an eligible peer has queued work and the required resource is available. +- Debt repayment happens when peer pressure exists and the indebted group is withheld, or when policy-defined repayment work completes. Runtime debt is tracked by task group and scheduler resource; any completed lease in the same task group repays debt for the resources it releases. Repayment changes policy debt counters only, not hard resource availability. +- The policy must not traverse the DAG inside `FairTaskQueue`. +- No public knob is added until benchmark evidence supports it. + +## Resource-Vector Direction + +Future policy work may use `SchedulerResourceKey` and `SchedulerResourceRequest` for multi-resource admission. Candidate resources include submission, local resources, GPU slots if reliable metadata exists, or scheduler-owned task-stage resources derived from `SchedulingMetadata`. Provider/model/domain request resources remain owned by request admission. + +Resource-vector policy must: + +- remain scheduler-internal unless a later design explicitly changes public metadata fields +- consume resolved metadata from `TaskSchedulingResolver` +- avoid duplicating provider/model/domain AIMD request admission +- use `RequestPressureSnapshotProvider` only as read-only pressure input +- preserve single-resource and single-group liveness +- produce benchmark evidence through the benchmark harness diff --git a/tests_e2e/tests/test_mcp_demo.py b/tests_e2e/tests/test_mcp_demo.py index 163e904cb..d7bca2e9e 100644 --- a/tests_e2e/tests/test_mcp_demo.py +++ b/tests_e2e/tests/test_mcp_demo.py @@ -101,25 +101,25 @@ def test_mcp_server_tool_usage_with_nvidia_text(tmp_path: Path) -> None: assert tool_call_messages tool_calls: list[dict[str, object]] = [] - tool_call_indices: dict[str, int] = {} + tool_call_positions: dict[str, tuple[int, int]] = {} for msg_index, msg in enumerate(trace): if not isinstance(msg, dict): continue if msg.get("role") != "assistant": continue - for tool_call in msg.get("tool_calls") or []: + for tool_call_index, tool_call in enumerate(msg.get("tool_calls") or []): if not isinstance(tool_call, dict): continue tool_calls.append(tool_call) function = tool_call.get("function") or {} if isinstance(function, dict): name = function.get("name") - if isinstance(name, str) and name not in tool_call_indices: - tool_call_indices[name] = msg_index + if isinstance(name, str) and name not in tool_call_positions: + tool_call_positions[name] = (msg_index, tool_call_index) - assert tool_call_indices.get("get_fact") is not None - assert tool_call_indices.get("add_numbers") is not None - assert tool_call_indices["get_fact"] < tool_call_indices["add_numbers"] + assert tool_call_positions.get("get_fact") is not None + assert tool_call_positions.get("add_numbers") is not None + assert tool_call_positions["get_fact"] < tool_call_positions["add_numbers"] def _tool_call_to_name_args(tool_call: dict[str, object]) -> tuple[str | None, dict[str, object]]: function = tool_call.get("function")