Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,13 @@ def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int):
fast enough, or we're out of capacity, or the upstream svc is unhealthy)
"""
pass

@abstractmethod
def emit_cache_write_failure_metric(self):
"""
K8s cacher Redis-write failure metric. Emitted when the cacher's periodic
write loop fails to write endpoint info to Redis (e.g. bad auth, network
partition, expired credentials). An early-warning signal that the cache is
going stale before the Gateway starts reporting endpoint status as `unknown`.
"""
pass
5 changes: 4 additions & 1 deletion model-engine/model_engine_server/entrypoints/k8s_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from model_engine_server.core.config import infer_registry_type, infra_config
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.db.base import get_session_async_null_pool
from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway
from model_engine_server.domain.repositories import DockerRepository
from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import (
ASBQueueEndpointResourceDelegate,
Expand Down Expand Up @@ -83,6 +84,7 @@ async def loop_iteration(
endpoint_record_repo: ModelEndpointRecordRepository,
image_cache_gateway: ImageCacheGateway,
docker_repository: DockerRepository,
monitoring_metrics_gateway: MonitoringMetricsGateway,
ttl_seconds: float,
):
image_cache_service = ImageCacheService(
Expand All @@ -91,7 +93,7 @@ async def loop_iteration(
docker_repository=docker_repository,
)
cache_write_service = ModelEndpointCacheWriteService(
cache_repo, k8s_resource_manager, image_cache_service
cache_repo, k8s_resource_manager, image_cache_service, monitoring_metrics_gateway
)
await cache_write_service.execute(ttl_seconds=ttl_seconds)

Expand Down Expand Up @@ -166,6 +168,7 @@ async def main(args: Any):
endpoint_record_repo,
image_cache_gateway,
docker_repo,
monitoring_metrics_gateway,
args.ttl_seconds,
)
loop_end = time.time()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,6 @@ def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int):
tags = self.tags
tags.extend([f"endpoint_name:{endpoint_name}", f"error_code:{error_code}"])
statsd.increment(f"{self.prefix}.upstream_sync_error", tags=tags)

def emit_cache_write_failure_metric(self):
statsd.increment("scale_launch.k8s_cache.redis_write_failure", tags=self.tags)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self):
self.token_count = 0
self.total_tokens_per_second = 0
self.sync_call_timeout = defaultdict(int)
self.cache_write_failure = 0

def reset(self):
self.attempted_build = 0
Expand All @@ -39,6 +40,7 @@ def reset(self):
self.token_count = 0
self.total_tokens_per_second = 0
self.sync_call_timeout = defaultdict(int)
self.cache_write_failure = 0

def emit_attempted_build_metric(self):
self.attempted_build += 1
Expand Down Expand Up @@ -79,3 +81,6 @@ def emit_token_count_metrics(self, token_usage: TokenUsage, _metadata: MetricMet

def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int):
self.sync_call_timeout[(endpoint_name, error_code)] += 1

def emit_cache_write_failure_metric(self):
self.cache_write_failure += 1
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Dict, Tuple

from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.entities import ModelEndpointInfraState
from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway
from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import (
EndpointResourceGateway,
)
Expand All @@ -9,6 +11,8 @@
)
from model_engine_server.infra.services.image_cache_service import ImageCacheService

logger = make_logger(logger_name())


class ModelEndpointCacheWriteService:
"""
Expand All @@ -20,27 +24,38 @@ def __init__(
model_endpoint_cache_repository: ModelEndpointCacheRepository,
resource_gateway: EndpointResourceGateway,
image_cache_service: ImageCacheService,
monitoring_metrics_gateway: MonitoringMetricsGateway,
):
self.model_endpoint_cache_repository = model_endpoint_cache_repository
self.resource_gateway = resource_gateway
self.image_cache_service = image_cache_service
self.monitoring_metrics_gateway = monitoring_metrics_gateway

async def execute(self, ttl_seconds: float):
endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpointInfraState]] = (
await self.resource_gateway.get_all_resources()
)

for key, (is_key_an_endpoint_id, state) in endpoint_infra_states.items():
if is_key_an_endpoint_id:
await self.model_endpoint_cache_repository.write_endpoint_info(
endpoint_id=key, endpoint_info=state, ttl_seconds=ttl_seconds
)
else:
# TODO: Once we've backfilled all k8s resources to have an endpoint_id label, then
# we can get rid of this branch (also in the write_endpoint_info method, as well as
# simplifying the return type of get_all_resources() to not require the bool).
await self.model_endpoint_cache_repository.write_endpoint_info(
endpoint_id="", endpoint_info=state, ttl_seconds=ttl_seconds
)
try:
for key, (is_key_an_endpoint_id, state) in endpoint_infra_states.items():
if is_key_an_endpoint_id:
await self.model_endpoint_cache_repository.write_endpoint_info(
endpoint_id=key, endpoint_info=state, ttl_seconds=ttl_seconds
)
else:
# TODO: Once we've backfilled all k8s resources to have an endpoint_id label,
# then we can get rid of this branch (also in the write_endpoint_info method, as
# well as simplifying the return type of get_all_resources() to not require the
# bool).
await self.model_endpoint_cache_repository.write_endpoint_info(
endpoint_id="", endpoint_info=state, ttl_seconds=ttl_seconds
)
except Exception:
# A silent Redis-write failure here lets cache entries expire, after which the Gateway
# reports endpoint status as `unknown`. Surface the cause (log + metric) before
# re-raising so the failure is observable rather than deceptive.
logger.exception("Failed to write endpoint info to Redis cache")
self.monitoring_metrics_gateway.emit_cache_write_failure_metric()
raise

await self.image_cache_service.execute(endpoint_infra_states=endpoint_infra_states)
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from model_engine_server.infra.services.model_endpoint_cache_service import (
ModelEndpointCacheWriteService,
)
from tests.unit.conftest import FakeModelEndpointCacheRepository


@pytest.mark.asyncio
async def test_model_endpoint_write_success(
fake_model_endpoint_cache_repository,
fake_resource_gateway,
fake_image_cache_service,
fake_monitoring_metrics_gateway,
model_endpoint_1,
model_endpoint_2,
):
Expand All @@ -18,7 +20,10 @@ async def test_model_endpoint_write_success(
)

cache_write_service = ModelEndpointCacheWriteService(
fake_model_endpoint_cache_repository, fake_resource_gateway, fake_image_cache_service
fake_model_endpoint_cache_repository,
fake_resource_gateway,
fake_image_cache_service,
fake_monitoring_metrics_gateway,
)
await cache_write_service.execute(42)
infra_state = await fake_model_endpoint_cache_repository.read_endpoint_info(
Expand Down Expand Up @@ -64,3 +69,37 @@ async def test_model_endpoint_write_success(
deployment_name=model_endpoint_1.infra_state.deployment_name,
)
assert infra_state is None

# Happy path emits no write-failure metric.
assert fake_monitoring_metrics_gateway.cache_write_failure == 0


class _RaisingCacheRepository(FakeModelEndpointCacheRepository):
"""Simulates Redis being unwritable (e.g. bad auth / network partition)."""

async def write_endpoint_info(self, endpoint_id, endpoint_info, ttl_seconds):
raise ConnectionError("Error connecting to Redis")


@pytest.mark.asyncio
async def test_model_endpoint_write_failure_emits_metric_and_reraises(
fake_resource_gateway,
fake_image_cache_service,
fake_monitoring_metrics_gateway,
model_endpoint_1,
):
fake_resource_gateway.add_resource(
endpoint_id=model_endpoint_1.record.id, infra_state=model_endpoint_1.infra_state
)

cache_write_service = ModelEndpointCacheWriteService(
_RaisingCacheRepository(),
fake_resource_gateway,
fake_image_cache_service,
fake_monitoring_metrics_gateway,
)

with pytest.raises(ConnectionError):
await cache_write_service.execute(42)

assert fake_monitoring_metrics_gateway.cache_write_failure == 1