From 7eb354676c6370ea6d70418ce4a58b885e8430df Mon Sep 17 00:00:00 2001 From: Lorenzo Norcini Date: Tue, 23 Jun 2026 01:16:44 +0000 Subject: [PATCH] [MLI-7219] feat(k8s-cache): log + metric on cacher Redis write failure Cacher write failures were observable nowhere: no log, no metric. Entries then expire and the Gateway reports endpoint status as `unknown`. Wrap the write loop in execute() to emit a logger.exception + a new scale_launch.k8s_cache.redis_write_failure counter, then re-raise. Co-Authored-By: Claude Opus 4.8 (1M context) --- .../gateways/monitoring_metrics_gateway.py | 10 +++++ .../entrypoints/k8s_cache.py | 5 ++- .../datadog_monitoring_metrics_gateway.py | 3 ++ .../fake_monitoring_metrics_gateway.py | 5 +++ .../services/model_endpoint_cache_service.py | 39 ++++++++++++------ .../test_model_endpoint_cache_service.py | 41 ++++++++++++++++++- 6 files changed, 89 insertions(+), 14 deletions(-) diff --git a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py index dcad95d5c..813a630c7 100644 --- a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py @@ -90,3 +90,13 @@ def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int): fast enough, or we're out of capacity, or the upstream svc is unhealthy) """ pass + + @abstractmethod + def emit_cache_write_failure_metric(self): + """ + K8s cacher Redis-write failure metric. Emitted when the cacher's periodic + write loop fails to write endpoint info to Redis (e.g. bad auth, network + partition, expired credentials). An early-warning signal that the cache is + going stale before the Gateway starts reporting endpoint status as `unknown`. + """ + pass diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index a64d55e97..d19890b50 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -17,6 +17,7 @@ from model_engine_server.core.config import infer_registry_type, infra_config from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.db.base import get_session_async_null_pool +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway from model_engine_server.domain.repositories import DockerRepository from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( ASBQueueEndpointResourceDelegate, @@ -83,6 +84,7 @@ async def loop_iteration( endpoint_record_repo: ModelEndpointRecordRepository, image_cache_gateway: ImageCacheGateway, docker_repository: DockerRepository, + monitoring_metrics_gateway: MonitoringMetricsGateway, ttl_seconds: float, ): image_cache_service = ImageCacheService( @@ -91,7 +93,7 @@ async def loop_iteration( docker_repository=docker_repository, ) cache_write_service = ModelEndpointCacheWriteService( - cache_repo, k8s_resource_manager, image_cache_service + cache_repo, k8s_resource_manager, image_cache_service, monitoring_metrics_gateway ) await cache_write_service.execute(ttl_seconds=ttl_seconds) @@ -166,6 +168,7 @@ async def main(args: Any): endpoint_record_repo, image_cache_gateway, docker_repo, + monitoring_metrics_gateway, args.ttl_seconds, ) loop_end = time.time() diff --git a/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py index 93a739709..433fc6140 100644 --- a/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py @@ -90,3 +90,6 @@ def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int): tags = self.tags tags.extend([f"endpoint_name:{endpoint_name}", f"error_code:{error_code}"]) statsd.increment(f"{self.prefix}.upstream_sync_error", tags=tags) + + def emit_cache_write_failure_metric(self): + statsd.increment("scale_launch.k8s_cache.redis_write_failure", tags=self.tags) diff --git a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py index 25bf45fa1..8ba2bcd57 100644 --- a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py @@ -23,6 +23,7 @@ def __init__(self): self.token_count = 0 self.total_tokens_per_second = 0 self.sync_call_timeout = defaultdict(int) + self.cache_write_failure = 0 def reset(self): self.attempted_build = 0 @@ -39,6 +40,7 @@ def reset(self): self.token_count = 0 self.total_tokens_per_second = 0 self.sync_call_timeout = defaultdict(int) + self.cache_write_failure = 0 def emit_attempted_build_metric(self): self.attempted_build += 1 @@ -79,3 +81,6 @@ def emit_token_count_metrics(self, token_usage: TokenUsage, _metadata: MetricMet def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int): self.sync_call_timeout[(endpoint_name, error_code)] += 1 + + def emit_cache_write_failure_metric(self): + self.cache_write_failure += 1 diff --git a/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py b/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py index 7e193027d..1ed21664c 100644 --- a/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py +++ b/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py @@ -1,6 +1,8 @@ from typing import Dict, Tuple +from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ModelEndpointInfraState +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) @@ -9,6 +11,8 @@ ) from model_engine_server.infra.services.image_cache_service import ImageCacheService +logger = make_logger(logger_name()) + class ModelEndpointCacheWriteService: """ @@ -20,27 +24,38 @@ def __init__( model_endpoint_cache_repository: ModelEndpointCacheRepository, resource_gateway: EndpointResourceGateway, image_cache_service: ImageCacheService, + monitoring_metrics_gateway: MonitoringMetricsGateway, ): self.model_endpoint_cache_repository = model_endpoint_cache_repository self.resource_gateway = resource_gateway self.image_cache_service = image_cache_service + self.monitoring_metrics_gateway = monitoring_metrics_gateway async def execute(self, ttl_seconds: float): endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpointInfraState]] = ( await self.resource_gateway.get_all_resources() ) - for key, (is_key_an_endpoint_id, state) in endpoint_infra_states.items(): - if is_key_an_endpoint_id: - await self.model_endpoint_cache_repository.write_endpoint_info( - endpoint_id=key, endpoint_info=state, ttl_seconds=ttl_seconds - ) - else: - # TODO: Once we've backfilled all k8s resources to have an endpoint_id label, then - # we can get rid of this branch (also in the write_endpoint_info method, as well as - # simplifying the return type of get_all_resources() to not require the bool). - await self.model_endpoint_cache_repository.write_endpoint_info( - endpoint_id="", endpoint_info=state, ttl_seconds=ttl_seconds - ) + try: + for key, (is_key_an_endpoint_id, state) in endpoint_infra_states.items(): + if is_key_an_endpoint_id: + await self.model_endpoint_cache_repository.write_endpoint_info( + endpoint_id=key, endpoint_info=state, ttl_seconds=ttl_seconds + ) + else: + # TODO: Once we've backfilled all k8s resources to have an endpoint_id label, + # then we can get rid of this branch (also in the write_endpoint_info method, as + # well as simplifying the return type of get_all_resources() to not require the + # bool). + await self.model_endpoint_cache_repository.write_endpoint_info( + endpoint_id="", endpoint_info=state, ttl_seconds=ttl_seconds + ) + except Exception: + # A silent Redis-write failure here lets cache entries expire, after which the Gateway + # reports endpoint status as `unknown`. Surface the cause (log + metric) before + # re-raising so the failure is observable rather than deceptive. + logger.exception("Failed to write endpoint info to Redis cache") + self.monitoring_metrics_gateway.emit_cache_write_failure_metric() + raise await self.image_cache_service.execute(endpoint_infra_states=endpoint_infra_states) diff --git a/model-engine/tests/unit/infra/services/test_model_endpoint_cache_service.py b/model-engine/tests/unit/infra/services/test_model_endpoint_cache_service.py index fc3661b24..e1342b54f 100644 --- a/model-engine/tests/unit/infra/services/test_model_endpoint_cache_service.py +++ b/model-engine/tests/unit/infra/services/test_model_endpoint_cache_service.py @@ -2,6 +2,7 @@ from model_engine_server.infra.services.model_endpoint_cache_service import ( ModelEndpointCacheWriteService, ) +from tests.unit.conftest import FakeModelEndpointCacheRepository @pytest.mark.asyncio @@ -9,6 +10,7 @@ async def test_model_endpoint_write_success( fake_model_endpoint_cache_repository, fake_resource_gateway, fake_image_cache_service, + fake_monitoring_metrics_gateway, model_endpoint_1, model_endpoint_2, ): @@ -18,7 +20,10 @@ async def test_model_endpoint_write_success( ) cache_write_service = ModelEndpointCacheWriteService( - fake_model_endpoint_cache_repository, fake_resource_gateway, fake_image_cache_service + fake_model_endpoint_cache_repository, + fake_resource_gateway, + fake_image_cache_service, + fake_monitoring_metrics_gateway, ) await cache_write_service.execute(42) infra_state = await fake_model_endpoint_cache_repository.read_endpoint_info( @@ -64,3 +69,37 @@ async def test_model_endpoint_write_success( deployment_name=model_endpoint_1.infra_state.deployment_name, ) assert infra_state is None + + # Happy path emits no write-failure metric. + assert fake_monitoring_metrics_gateway.cache_write_failure == 0 + + +class _RaisingCacheRepository(FakeModelEndpointCacheRepository): + """Simulates Redis being unwritable (e.g. bad auth / network partition).""" + + async def write_endpoint_info(self, endpoint_id, endpoint_info, ttl_seconds): + raise ConnectionError("Error connecting to Redis") + + +@pytest.mark.asyncio +async def test_model_endpoint_write_failure_emits_metric_and_reraises( + fake_resource_gateway, + fake_image_cache_service, + fake_monitoring_metrics_gateway, + model_endpoint_1, +): + fake_resource_gateway.add_resource( + endpoint_id=model_endpoint_1.record.id, infra_state=model_endpoint_1.infra_state + ) + + cache_write_service = ModelEndpointCacheWriteService( + _RaisingCacheRepository(), + fake_resource_gateway, + fake_image_cache_service, + fake_monitoring_metrics_gateway, + ) + + with pytest.raises(ConnectionError): + await cache_write_service.execute(42) + + assert fake_monitoring_metrics_gateway.cache_write_failure == 1