Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions engine/src/agent_control_engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,33 @@

logger = logging.getLogger(__name__)


def _env_positive_int(*names: str, default: int) -> int:
"""Read a positive integer from the first configured environment variable."""
for name in names:
value = os.environ.get(name)
if value is None or value.strip() == "":
continue
try:
parsed = int(value)
except ValueError as exc:
raise RuntimeError(f"{name}={value!r} must be an integer.") from exc
if parsed < 1:
raise RuntimeError(f"{name}={value!r} must be greater than or equal to 1.")
return parsed
return default


# Default timeout for evaluator execution (seconds)
DEFAULT_EVALUATOR_TIMEOUT = float(os.environ.get("EVALUATOR_TIMEOUT_SECONDS", "30"))

# Max concurrent evaluations (limits task spawning overhead for large policies)
MAX_CONCURRENT_EVALUATIONS = int(os.environ.get("MAX_CONCURRENT_EVALUATIONS", "3"))
# Max concurrent evaluations (limits task spawning overhead for large policies).
# Prefer the namespaced env var; MAX_CONCURRENT_EVALUATIONS is kept for compatibility.
MAX_CONCURRENT_EVALUATIONS = _env_positive_int(
"AGENT_CONTROL_MAX_CONCURRENT_EVALUATIONS",
"MAX_CONCURRENT_EVALUATIONS",
default=3,
)

SELECTED_DATA_PREVIEW_MAX_CHARS = int(
os.environ.get("AGENT_CONTROL_SELECTED_DATA_PREVIEW_MAX_CHARS", "500")
Expand Down
51 changes: 51 additions & 0 deletions engine/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,57 @@ async def test_timeout_does_not_affect_fast_evaluators(self):
class TestConcurrencyLimit:
"""Tests for semaphore-based concurrency limiting."""

def test_max_concurrency_env_prefers_agent_control_name(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""The canonical Agent Control env var overrides the legacy short name."""
import agent_control_engine.core as core_module

monkeypatch.setenv("AGENT_CONTROL_MAX_CONCURRENT_EVALUATIONS", "7")
monkeypatch.setenv("MAX_CONCURRENT_EVALUATIONS", "2")

assert (
core_module._env_positive_int(
"AGENT_CONTROL_MAX_CONCURRENT_EVALUATIONS",
"MAX_CONCURRENT_EVALUATIONS",
default=3,
)
== 7
)

def test_max_concurrency_env_reads_legacy_name(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""The existing env var remains supported for compatibility."""
import agent_control_engine.core as core_module

monkeypatch.delenv("AGENT_CONTROL_MAX_CONCURRENT_EVALUATIONS", raising=False)
monkeypatch.setenv("MAX_CONCURRENT_EVALUATIONS", "5")

assert (
core_module._env_positive_int(
"AGENT_CONTROL_MAX_CONCURRENT_EVALUATIONS",
"MAX_CONCURRENT_EVALUATIONS",
default=3,
)
== 5
)

def test_max_concurrency_env_rejects_non_positive_values(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""The concurrency cap must always allow at least one evaluator."""
import agent_control_engine.core as core_module

monkeypatch.setenv("AGENT_CONTROL_MAX_CONCURRENT_EVALUATIONS", "0")

with pytest.raises(RuntimeError, match="greater than or equal to 1"):
core_module._env_positive_int(
"AGENT_CONTROL_MAX_CONCURRENT_EVALUATIONS",
"MAX_CONCURRENT_EVALUATIONS",
default=3,
)

@pytest.mark.asyncio
async def test_concurrency_limited_to_max(self, monkeypatch: pytest.MonkeyPatch):
"""Test that concurrent evaluations are limited by semaphore.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import ssl
import warnings
from asyncio import Lock
from base64 import urlsafe_b64encode
from hashlib import sha256
from hmac import new as hmac_new
Expand All @@ -27,6 +28,11 @@
DEFAULT_KEEPALIVE_EXPIRY_SECS = 1.0
DEFAULT_MAX_CONNECTIONS = 100
DEFAULT_MAX_KEEPALIVE_CONNECTIONS = 20
DEFAULT_CLIENT_POOL_SIZE = 1
LUNA_KEEPALIVE_EXPIRY_ENV = "GALILEO_LUNA_KEEPALIVE_EXPIRY_SECONDS"
LUNA_MAX_CONNECTIONS_ENV = "GALILEO_LUNA_MAX_CONNECTIONS"
LUNA_MAX_KEEPALIVE_CONNECTIONS_ENV = "GALILEO_LUNA_MAX_KEEPALIVE_CONNECTIONS"
LUNA_CLIENT_POOL_SIZE_ENV = "GALILEO_LUNA_CLIENT_POOL_SIZE"
PUBLIC_SCORER_INVOKE_PATH = "/scorers/invoke"
INTERNAL_SCORER_INVOKE_PATH = "/internal/scorers/invoke"
AuthMode = Literal["public", "internal"]
Expand Down Expand Up @@ -78,6 +84,54 @@ def _env_auth_mode() -> AuthMode | None:
raise ValueError("GALILEO_LUNA_AUTH_MODE must be either 'public' or 'internal'.")


def _load_float_env(env_name: str, default: float) -> float:
Comment thread
abhinav-galileo marked this conversation as resolved.
raw = os.getenv(env_name)
if raw is None or raw.strip() == "":
return default
try:
return float(raw)
except ValueError as exc:
raise ValueError(f"{env_name}={raw!r} is not a number.") from exc


def _load_int_env(env_name: str, default: int) -> int:
raw = os.getenv(env_name)
if raw is None or raw.strip() == "":
return default
try:
return int(raw)
except ValueError as exc:
raise ValueError(f"{env_name}={raw!r} is not an integer.") from exc


def _validate_connection_config(
*,
keepalive_expiry_seconds: float,
max_connections: int,
max_keepalive_connections: int,
client_pool_size: int,
) -> None:
if keepalive_expiry_seconds < 0:
raise ValueError(
f"{LUNA_KEEPALIVE_EXPIRY_ENV}={keepalive_expiry_seconds} "
"must be greater than or equal to 0."
)
if max_connections <= 0:
raise ValueError(f"{LUNA_MAX_CONNECTIONS_ENV}={max_connections} must be greater than 0.")
if max_keepalive_connections < 0:
raise ValueError(
f"{LUNA_MAX_KEEPALIVE_CONNECTIONS_ENV}={max_keepalive_connections} "
"must be greater than or equal to 0."
)
if max_keepalive_connections > max_connections:
raise ValueError(
f"{LUNA_MAX_KEEPALIVE_CONNECTIONS_ENV}={max_keepalive_connections} "
f"must be less than or equal to {LUNA_MAX_CONNECTIONS_ENV}={max_connections}."
)
if client_pool_size <= 0:
raise ValueError(f"{LUNA_CLIENT_POOL_SIZE_ENV}={client_pool_size} must be greater than 0.")


def _as_float_or_none(value: JSONValue) -> float | None:
if isinstance(value, bool) or value is None:
return None
Expand Down Expand Up @@ -184,6 +238,10 @@ class GalileoLunaClient:
GALILEO_API_URL: Galileo API URL fallback.
GALILEO_LUNA_CA_FILE: CA bundle used to verify the scorer API endpoint, for
deployments whose API serves an internally-issued TLS certificate.
GALILEO_LUNA_KEEPALIVE_EXPIRY_SECONDS: HTTP pooled connection expiry.
GALILEO_LUNA_MAX_CONNECTIONS: Maximum outbound HTTP connections.
GALILEO_LUNA_MAX_KEEPALIVE_CONNECTIONS: Maximum idle pooled HTTP connections.
GALILEO_LUNA_CLIENT_POOL_SIZE: Number of outbound HTTP clients to rotate across.
GALILEO_CONSOLE_URL: Galileo Console URL (optional, defaults to production).
"""

Expand Down Expand Up @@ -235,7 +293,26 @@ def __init__(
self.api_base = self._resolve_api_base(api_url)
self.ca_file = (ca_file or os.getenv("GALILEO_LUNA_CA_FILE") or "").strip() or None
self._ssl_context = self._load_ssl_context(self.ca_file)
self.keepalive_expiry_seconds = _load_float_env(
LUNA_KEEPALIVE_EXPIRY_ENV, DEFAULT_KEEPALIVE_EXPIRY_SECS
)
self.max_connections = _load_int_env(LUNA_MAX_CONNECTIONS_ENV, DEFAULT_MAX_CONNECTIONS)
self.max_keepalive_connections = _load_int_env(
LUNA_MAX_KEEPALIVE_CONNECTIONS_ENV, DEFAULT_MAX_KEEPALIVE_CONNECTIONS
)
self.client_pool_size = _load_int_env(
LUNA_CLIENT_POOL_SIZE_ENV, DEFAULT_CLIENT_POOL_SIZE
)
_validate_connection_config(
keepalive_expiry_seconds=self.keepalive_expiry_seconds,
max_connections=self.max_connections,
max_keepalive_connections=self.max_keepalive_connections,
client_pool_size=self.client_pool_size,
)
self._client: httpx.AsyncClient | None = None
self._clients: list[httpx.AsyncClient] = []
self._next_client_index = 0
self._client_lock = Lock()
logger.info("[GalileoLunaClient] Auth mode selected: %s", self.auth_mode)

def _resolve_api_base(self, api_url: str | None) -> str:
Expand Down Expand Up @@ -316,26 +393,48 @@ def _derive_api_url(self, console_url: str) -> str:
parts._replace(netloc=parts.netloc.replace(host, new_host, 1))
)

def _create_client(self) -> httpx.AsyncClient:
"""Create an HTTP client with the configured auth, TLS, and connection limits."""
headers = {"Content-Type": "application/json"}
if self.auth_mode == "public" and self.api_key is not None:
headers["Galileo-API-Key"] = self.api_key
verify: ssl.SSLContext | bool = (
self._ssl_context if self._ssl_context is not None else True
)
return httpx.AsyncClient(
headers=headers,
timeout=httpx.Timeout(DEFAULT_TIMEOUT_SECS),
limits=httpx.Limits(
max_connections=self.max_connections,
max_keepalive_connections=self.max_keepalive_connections,
keepalive_expiry=self.keepalive_expiry_seconds,
),
verify=verify,
)

def _select_pooled_client(self) -> httpx.AsyncClient:
Comment thread
abhinav-galileo marked this conversation as resolved.
"""Select the next pooled client while holding the client state lock."""
client = self._clients[self._next_client_index % len(self._clients)]
self._next_client_index = (self._next_client_index + 1) % len(self._clients)
return client

async def _get_client(self) -> httpx.AsyncClient:
"""Get or create the HTTP client."""
if self._client is None or self._client.is_closed:
headers = {"Content-Type": "application/json"}
if self.auth_mode == "public" and self.api_key is not None:
headers["Galileo-API-Key"] = self.api_key
verify: ssl.SSLContext | bool = (
self._ssl_context if self._ssl_context is not None else True
)
self._client = httpx.AsyncClient(
headers=headers,
timeout=httpx.Timeout(DEFAULT_TIMEOUT_SECS),
limits=httpx.Limits(
max_connections=DEFAULT_MAX_CONNECTIONS,
max_keepalive_connections=DEFAULT_MAX_KEEPALIVE_CONNECTIONS,
keepalive_expiry=DEFAULT_KEEPALIVE_EXPIRY_SECS,
),
verify=verify,
)
return self._client
"""Get or create the next HTTP client."""
async with self._client_lock:
self._clients = [client for client in self._clients if not client.is_closed]

if self.client_pool_size == 1:
if self._client is not None and not self._client.is_closed:
return self._client
self._client = self._clients[0] if self._clients else self._create_client()
self._clients = [self._client]
return self._client

self._client = None
while len(self._clients) < self.client_pool_size:
self._clients.append(self._create_client())

return self._select_pooled_client()

def _endpoint_and_headers(
self,
Expand Down Expand Up @@ -431,9 +530,23 @@ async def invoke(

async def close(self) -> None:
"""Close the HTTP client and release resources."""
if self._client is not None:
await self._client.aclose()
async with self._client_lock:
clients: list[httpx.AsyncClient] = []
seen_client_ids: set[int] = set()
if self._client is not None:
clients.append(self._client)
seen_client_ids.add(id(self._client))
self._client = None
for client in self._clients:
if id(client) not in seen_client_ids:
clients.append(client)
seen_client_ids.add(id(client))
self._clients = []
self._next_client_index = 0

for client in clients:
if not client.is_closed:
await client.aclose()

async def __aenter__(self) -> GalileoLunaClient:
"""Async context manager entry."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from importlib.metadata import PackageNotFoundError, version
from typing import Any

import httpx
from agent_control_evaluators import Evaluator, EvaluatorMetadata, register_evaluator
from agent_control_models import EvaluatorResult, JSONValue

Expand All @@ -27,6 +28,7 @@ def _resolve_package_version() -> str:

_PACKAGE_VERSION = _resolve_package_version()
LUNA_AVAILABLE = True
_HTTP_ERROR_BODY_LIMIT = 500


def _coerce_payload_text(value: Any) -> str | None:
Expand Down Expand Up @@ -74,6 +76,32 @@ def _confidence_from_score(score: JSONValue) -> float:
return 1.0


def _truncated_http_response_body(body: str) -> tuple[str, bool]:
if len(body) <= _HTTP_ERROR_BODY_LIMIT:
return body, False
return body[:_HTTP_ERROR_BODY_LIMIT], True


def _http_status_error_metadata(error: httpx.HTTPStatusError) -> dict[str, Any]:
metadata: dict[str, Any] = {}

request = error.request
metadata["http_method"] = request.method
metadata["http_endpoint_path"] = request.url.path

response = error.response
metadata["http_status_code"] = response.status_code
metadata["http_response_content_type"] = response.headers.get("content-type")

body = response.text
if body:
metadata["http_response_body"], metadata["http_response_body_truncated"] = (
_truncated_http_response_body(body)
)

return {key: value for key, value in metadata.items() if value is not None}


@register_evaluator
class LunaEvaluator(Evaluator[LunaEvaluatorConfig]):
"""Galileo Luna evaluator using the direct scorer invocation API."""
Expand Down Expand Up @@ -252,16 +280,20 @@ def _handle_error(
error: Exception,
) -> EvaluatorResult:
error_detail = str(error)
metadata: dict[str, Any] = {
"error_type": type(error).__name__,
"scorer_label": self.config.scorer_label,
"scorer_id": self.config.scorer_id,
"scorer_version_id": self.config.scorer_version_id,
}
if isinstance(error, httpx.HTTPStatusError):
metadata.update(_http_status_error_metadata(error))

return EvaluatorResult(
matched=False,
confidence=0.0,
message=f"Luna evaluation error: {error_detail}",
metadata={
"error_type": type(error).__name__,
"scorer_label": self.config.scorer_label,
"scorer_id": self.config.scorer_id,
"scorer_version_id": self.config.scorer_version_id,
},
metadata=metadata,
error=error_detail,
)

Expand Down
Loading
Loading