Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 121 additions & 16 deletions model-engine/tests/unit/api/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,41 @@
from model_engine_server.api.dependencies import _get_external_interfaces
from model_engine_server.common.config import HostedModelInferenceServiceConfig
from model_engine_server.infra.gateways import (
ABSFileStorageGateway,
ABSFilesystemGateway,
ABSLLMArtifactGateway,
ASBInferenceAutoscalingMetricsGateway,
GCSFileStorageGateway,
GCSFilesystemGateway,
GCSLLMArtifactGateway,
RedisInferenceAutoscalingMetricsGateway,
S3FilesystemGateway,
S3LLMArtifactGateway,
)
from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import (
ASBQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import (
GcpPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.onprem_queue_endpoint_resource_delegate import (
OnPremQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import (
SQSQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.s3_file_storage_gateway import S3FileStorageGateway
from model_engine_server.infra.repositories import (
ABSFileLLMFineTuneEventsRepository,
ABSFileLLMFineTuneRepository,
ACRDockerRepository,
ECRDockerRepository,
GARDockerRepository,
GCSFileLLMFineTuneEventsRepository,
GCSFileLLMFineTuneRepository,
OnPremDockerRepository,
S3FileLLMFineTuneEventsRepository,
S3FileLLMFineTuneRepository,
)


Expand Down Expand Up @@ -86,40 +111,120 @@ def test_default_task_queue_selection_when_celery_broker_type_redis_disabled():
assert external_interfaces.inference_task_queue_gateway == sqs_gateway


def test_gcp_provider_selects_gcp_implementations():
"""Test that cloud_provider='gcp' wires the correct GCP implementations."""
# Expected concrete backend class per cloud_provider. docker_repository is keyed on
# registry_type (not cloud_provider): aws/azure/gcp drive it via a realistic prefix, onprem
# via docker_registry_type since no prefix infers "onprem".
_PROVIDER_CASES = [
pytest.param(
"aws",
"000000000000.dkr.ecr.us-east-1.amazonaws.com/my-repo",
None,
{
"queue_delegate": SQSQueueEndpointResourceDelegate,
"filesystem_gateway": S3FilesystemGateway,
"llm_artifact_gateway": S3LLMArtifactGateway,
"file_storage_gateway": S3FileStorageGateway,
"docker_repository": ECRDockerRepository,
"llm_fine_tune_repository": S3FileLLMFineTuneRepository,
"llm_fine_tune_events_repository": S3FileLLMFineTuneEventsRepository,
"inference_autoscaling_metrics_gateway": RedisInferenceAutoscalingMetricsGateway,
},
id="aws",
),
pytest.param(
"azure",
"myregistry.azurecr.io/my-repo",
None,
{
"queue_delegate": ASBQueueEndpointResourceDelegate,
"filesystem_gateway": ABSFilesystemGateway,
"llm_artifact_gateway": ABSLLMArtifactGateway,
"file_storage_gateway": ABSFileStorageGateway,
"docker_repository": ACRDockerRepository,
"llm_fine_tune_repository": ABSFileLLMFineTuneRepository,
"llm_fine_tune_events_repository": ABSFileLLMFineTuneEventsRepository,
"inference_autoscaling_metrics_gateway": ASBInferenceAutoscalingMetricsGateway,
},
id="azure",
),
pytest.param(
"gcp",
"us-docker.pkg.dev/my-project/my-repo",
None,
{
"queue_delegate": GcpPubSubQueueEndpointResourceDelegate,
"filesystem_gateway": GCSFilesystemGateway,
"llm_artifact_gateway": GCSLLMArtifactGateway,
"file_storage_gateway": GCSFileStorageGateway,
"docker_repository": GARDockerRepository,
"llm_fine_tune_repository": GCSFileLLMFineTuneRepository,
"llm_fine_tune_events_repository": GCSFileLLMFineTuneEventsRepository,
"inference_autoscaling_metrics_gateway": RedisInferenceAutoscalingMetricsGateway,
},
id="gcp",
),
pytest.param(
"onprem",
"registry.internal/my-repo",
"onprem",
{
"queue_delegate": OnPremQueueEndpointResourceDelegate,
"filesystem_gateway": S3FilesystemGateway,
"llm_artifact_gateway": S3LLMArtifactGateway,
"file_storage_gateway": S3FileStorageGateway,
"docker_repository": OnPremDockerRepository,
"llm_fine_tune_repository": S3FileLLMFineTuneRepository,
"llm_fine_tune_events_repository": S3FileLLMFineTuneEventsRepository,
"inference_autoscaling_metrics_gateway": RedisInferenceAutoscalingMetricsGateway,
},
id="onprem",
),
]


@pytest.mark.parametrize(
"cloud_provider, docker_repo_prefix, docker_registry_type, expected",
_PROVIDER_CASES,
)
def test_cloud_provider_selects_expected_backends(
cloud_provider, docker_repo_prefix, docker_registry_type, expected
):
"""Pin the concrete backend class selected for each cloud_provider."""
with (
patch("model_engine_server.api.dependencies.infra_config") as mock_config,
patch("model_engine_server.api.dependencies.CIRCLECI", False),
patch("model_engine_server.api.dependencies.CeleryTaskQueueGateway"),
patch("model_engine_server.api.dependencies.get_tracing_gateway"),
patch("model_engine_server.api.dependencies.aioredis"),
patch("model_engine_server.api.dependencies.get_or_create_aioredis_pool"),
patch("model_engine_server.api.dependencies.ASBInferenceAutoscalingMetricsGateway"),
patch("model_engine_server.api.dependencies.get_monitoring_metrics_gateway"),
):
mock_config_instance = MagicMock()
mock_config_instance.cloud_provider = "gcp"
mock_config_instance.cloud_provider = cloud_provider
mock_config_instance.docker_repo_prefix = docker_repo_prefix
mock_config_instance.docker_registry_type = docker_registry_type
mock_config_instance.celery_broker_type_redis = None
mock_config_instance.docker_repo_prefix = "us-docker.pkg.dev/my-project/my-repo"
mock_config_instance.docker_registry_type = None
mock_config_instance.gcp_project_id = "test-project"
mock_config.return_value = mock_config_instance

mock_session = MagicMock()
external_interfaces = _get_external_interfaces(read_only=False, session=mock_session)
ei = _get_external_interfaces(read_only=False, session=MagicMock())

assert isinstance(external_interfaces.filesystem_gateway, GCSFilesystemGateway)
assert isinstance(external_interfaces.llm_artifact_gateway, GCSLLMArtifactGateway)
assert isinstance(external_interfaces.file_storage_gateway, GCSFileStorageGateway)
assert isinstance(external_interfaces.docker_repository, GARDockerRepository)
assert isinstance(ei.resource_gateway.queue_delegate, expected["queue_delegate"])
assert isinstance(ei.filesystem_gateway, expected["filesystem_gateway"])
assert isinstance(ei.llm_artifact_gateway, expected["llm_artifact_gateway"])
assert isinstance(ei.file_storage_gateway, expected["file_storage_gateway"])
assert isinstance(ei.docker_repository, expected["docker_repository"])
assert isinstance(
ei.llm_fine_tuning_service.llm_fine_tune_repository,
expected["llm_fine_tune_repository"],
)
assert isinstance(
external_interfaces.llm_fine_tune_events_repository,
GCSFileLLMFineTuneEventsRepository,
ei.llm_fine_tune_events_repository,
expected["llm_fine_tune_events_repository"],
)
assert isinstance(
external_interfaces.resource_gateway.queue_delegate,
GcpPubSubQueueEndpointResourceDelegate,
ei.resource_gateway.inference_autoscaling_metrics_gateway,
expected["inference_autoscaling_metrics_gateway"],
)


Expand Down