diff --git a/packages/sdk/server-ai/src/ldai/client.py b/packages/sdk/server-ai/src/ldai/client.py index 9f8616f4..623ff825 100644 --- a/packages/sdk/server-ai/src/ldai/client.py +++ b/packages/sdk/server-ai/src/ldai/client.py @@ -832,10 +832,12 @@ def __evaluate( if 'model' in variation and isinstance(variation['model'], dict): parameters = variation['model'].get('parameters', None) custom = variation['model'].get('custom', None) + region = variation['model'].get('region', None) model = ModelConfig( name=variation['model']['name'], parameters=parameters, - custom=custom + custom=custom, + region=region, ) variation_key = variation.get('_ldMeta', {}).get('variationKey', '') diff --git a/packages/sdk/server-ai/src/ldai/models.py b/packages/sdk/server-ai/src/ldai/models.py index 8e89008c..38e4fdc9 100644 --- a/packages/sdk/server-ai/src/ldai/models.py +++ b/packages/sdk/server-ai/src/ldai/models.py @@ -52,15 +52,23 @@ class ModelConfig: Configuration related to the model. """ - def __init__(self, name: str, parameters: Optional[Dict[str, Any]] = None, custom: Optional[Dict[str, Any]] = None): + def __init__( + self, + name: str, + parameters: Optional[Dict[str, Any]] = None, + custom: Optional[Dict[str, Any]] = None, + region: Optional[str] = None, + ): """ :param name: The name of the model. :param parameters: Additional model-specific parameters. :param custom: Additional customer provided data. + :param region: The region the model is deployed in. """ self._name = name self._parameters = parameters self._custom = custom + self._region = region @property def name(self) -> str: @@ -93,6 +101,13 @@ def get_custom(self, key: str) -> Any: return self._custom.get(key) + @property + def region(self) -> Optional[str]: + """ + The region the model is deployed in. + """ + return self._region + def to_dict(self) -> dict: """ Render the given model config as a dictionary object. @@ -101,6 +116,7 @@ def to_dict(self) -> dict: 'name': self._name, 'parameters': self._parameters, 'custom': self._custom, + 'region': self._region, } diff --git a/packages/sdk/server-ai/tests/test_model_config.py b/packages/sdk/server-ai/tests/test_model_config.py index 838b39d9..6d0f0147 100644 --- a/packages/sdk/server-ai/tests/test_model_config.py +++ b/packages/sdk/server-ai/tests/test_model_config.py @@ -31,6 +31,23 @@ def td() -> TestData: .variation_for_all(0) ) + td.update( + td.flag('model-config-with-region') + .variations( + { + 'model': { + 'name': 'anthropic.claude-opus-4-7', + 'parameters': {}, + 'region': 'us', + }, + 'provider': {'name': 'Bedrock'}, + 'messages': [{'role': 'system', 'content': 'Hello!'}], + '_ldMeta': {'enabled': True, 'variationKey': 'us-variation', 'version': 1}, + }, + ) + .variation_for_all(0) + ) + td.update( td.flag('multiple-messages') .variations( @@ -482,6 +499,36 @@ def test_create_tracker_preserves_config_metadata(): assert 'runId' in track_data +def test_model_config_region(): + model = ModelConfig('fakeModel', region='us') + assert model.region == 'us' + + +def test_model_config_region_defaults_to_none(): + model = ModelConfig('fakeModel') + assert model.region is None + + +def test_model_config_region_from_flag(ldai_client: LDAIClient): + context = Context.create('user-key') + default = AICompletionConfigDefault(enabled=True, model=ModelConfig('fake-model'), messages=[]) + + config = ldai_client.completion_config('model-config-with-region', context, default) + + assert config.model is not None + assert config.model.region == 'us' + + +def test_model_config_no_region_is_none(ldai_client: LDAIClient): + context = Context.create('user-key') + default = AICompletionConfigDefault(enabled=True, model=ModelConfig('fake-model'), messages=[]) + + config = ldai_client.completion_config('model-config', context, default) + + assert config.model is not None + assert config.model.region is None + + def test_create_tracker_each_call_has_different_run_id(): from unittest.mock import Mock