Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 161 additions & 32 deletions sagemaker-serve/tests/integ/test_model_customization_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,24 @@
"""Integration tests for ModelBuilder model customization deployment."""
from __future__ import absolute_import

import os
import json
import boto3
import time
import pytest
import random
import logging
from botocore.config import Config
from datetime import datetime, timezone, timedelta


logger = logging.getLogger(__name__)

from sagemaker.core.helper.session_helper import Session

# This test relies on resources in a specific region
AWS_REGION = "us-west-2"
os.environ.setdefault("AWS_DEFAULT_REGION", AWS_REGION)


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -135,6 +145,38 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name, region=AWS_REGION)
assert adapter_ic is not None

# Invoke verification
time.sleep(10) # brief buffer for IC readiness

invoke_ic_name = adapter_name if peft_type == "LORA" else f"{endpoint_name}-inference-component"

test_payload = {
"inputs": "What is machine learning?",
"parameters": {"max_new_tokens": 32},
}

invoke_response = endpoint.invoke(
body=json.dumps(test_payload),
content_type="application/json",
accept="application/json",
inference_component_name=invoke_ic_name,
)

response_body = json.loads(invoke_response.body.read())

# Validate response structure
assert response_body is not None, f"Empty response from invoke on {invoke_ic_name}"
if isinstance(response_body, list):
assert len(response_body) > 0
assert "generated_text" in response_body[0] or "generation" in response_body[0]
elif isinstance(response_body, dict):
assert (
"generated_text" in response_body
or "generation" in response_body
or "outputs" in response_body
)


def test_fetch_endpoint_names_for_base_model(self, training_job_name, sagemaker_session):
"""Test fetching endpoint names for base model."""
from sagemaker.core.resources import TrainingJob
Expand Down Expand Up @@ -300,9 +342,6 @@ def test_dpo_trainer_build(self, training_job_name, sagemaker_session):
- Improved test assertions to work with new object structures
"""

import json
import time
import pytest
from sagemaker.core.resources import TrainingJob, ModelPackage
from sagemaker.serve.bedrock_model_builder import BedrockModelBuilder

Expand All @@ -316,7 +355,7 @@ def setup_config(self, training_job_name):
from sagemaker.core.helper.session_helper import get_execution_role
return {
"training_job_name": training_job_name,
"region": "us-west-2",
"region": AWS_REGION,
"bucket": "models-sdk-testing-pdx",
"role_arn": get_execution_role()
}
Expand All @@ -336,29 +375,48 @@ def s3_client(self, setup_config):

@pytest.fixture(scope="class")
def bedrock_client(self, setup_config):
"""Create Bedrock client."""
"""Create Bedrock client. Eagerly cleans up test import jobs older than 24h."""

client = boto3.client('bedrock', region_name=setup_config["region"])
# Cleanup existing import jobs

try:
cutoff = datetime.now(timezone.utc) - timedelta(hours=24)
jobs = client.list_model_import_jobs()
for job in jobs.get('modelImportJobSummaries', []):
if job['jobName'].startswith('test-bedrock-'):
if not job['jobName'].startswith('test-bedrock-'):
continue
created = job.get('creationTime') or job.get('lastModifiedTime')
if created and created < cutoff:
try:
client.stop_model_import_job(jobIdentifier=job['jobArn'])
except Exception:
pass
except Exception:
pass
status = job.get('status')
if status in ('InProgress', 'Pending'):
client.stop_model_import_job(jobIdentifier=job['jobArn'])
elif status == 'Completed' and job.get('importedModelArn'):
client.delete_imported_model(
modelIdentifier=job['importedModelArn']
)
except Exception as e:
logger.warning(f"Eager cleanup failed for {job['jobName']}: {e}")
except Exception as e:
logger.warning(f"Failed to list import jobs for eager cleanup: {e}")

return client

@pytest.fixture(scope="class")
def bedrock_runtime(self, setup_config):
"""Create Bedrock runtime client."""
return boto3.client('bedrock-runtime', region_name=setup_config["region"])
# Adding config based on: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html#handle-model-not-ready-exception
config = Config(
retries={
'total_max_attempts': 10,
'mode': 'standard'
}
)
return boto3.client('bedrock-runtime', region_name=setup_config["region"], config=config)

@pytest.fixture(scope="class")
def deployed_model_arn(self, training_job, bedrock_client, s3_client, setup_config):
"""Deploy model and return ARN."""
"""Deploy model and return ARN. Cleans up the imported model after tests."""
self._setup_model_files(training_job, s3_client, setup_config)

job_name = f"test-bedrock-{random.randint(1000, 9999)}-{int(time.time())}"
Expand All @@ -373,21 +431,37 @@ def deployed_model_arn(self, training_job, bedrock_client, s3_client, setup_conf

job_arn = deployment_result['jobArn']

# Wait for completion
while True:
# Wait for completion (max 1 hour wait)
max_wait = 60 * 60 # 60 minutes
start = time.time()
while time.time() - start < max_wait:
response = bedrock_client.get_model_import_job(jobIdentifier=job_arn)
status = response['status']
if status in ['Completed', 'Failed']:
break
time.sleep(30)
else:
pytest.fail(f"Model import job timed out after {max_wait}s")

model_arn = response['importedModelName']
return model_arn
if status == 'Failed':
pytest.fail(
f"Model import job failed: {response.get('failureMessage', 'unknown reason')}")

model_arn = response['importedModelArn']

yield model_arn

# Cleanup: delete the imported model
try:
logger.info(f"Cleaning up imported model: {model_arn}")
bedrock_client.delete_imported_model(modelIdentifier=model_arn)
logger.info(f"Successfully deleted imported model: {model_arn}")
except Exception as e:
logger.warning(f"Failed to delete imported model {model_arn}: {e}")

except Exception as e:
# If there's an issue with the new sagemaker-core integration, provide helpful error info
pytest.fail(
f"Deployment failed with error: {str(e)}.")
f"Bedrock deployment failed with error: {str(e)}.")

def _setup_model_files(self, training_job, s3_client, setup_config):
"""Setup required model files for Bedrock deployment."""
Expand Down Expand Up @@ -504,24 +578,79 @@ def test_bedrock_job_created(self, deployed_model_arn):
"""Test that Bedrock import job was created successfully."""
assert deployed_model_arn is not None

def test_zzz_cleanup_deployed_model(self, bedrock_client):
"""Cleanup deployed model and import jobs (runs last due to zzz prefix)."""
if hasattr(self, 'model_arn_for_cleanup'):
# Note: Below test is flaky and fails due to model not ready exception.
# Documentation recommends retries: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html#handle-model-not-ready-exception.
# TODO: Fix using provisioned throughput or better wait mechanism
@pytest.mark.slow
def test_bedrock_model_invoke(self, deployed_model_arn, bedrock_runtime):
logger.warning(
"This test is known to be flaky due to 'model not ready' exceptions from Bedrock. "
"See: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html"
"#handle-model-not-ready-exception"
)
"""Test invoking the imported Bedrock model to ensure it works end-to-end.

Retries on failure since models can take several minutes
to become ready after import.
"""
max_retries = 2
base_delay = 10

for attempt in range(max_retries):
try:
bedrock_client.delete_imported_model(modelIdentifier=self.model_arn_for_cleanup)
except Exception:
pass
# Cleanup all test import jobs
response = bedrock_runtime.invoke_model(
modelId=deployed_model_arn,
body=json.dumps({
"prompt": "What is the capital of France?",
"max_gen_len": 100,
"temperature": 0.7,
"top_p": 0.9
})
)

result = json.loads(response['body'].read().decode())

# Validate response structure
assert "generation" in result, "Response missing 'generation' field"
assert isinstance(result["generation"], str), "'generation' should be a string"
assert len(result["generation"]) > 0, "'generation' should not be empty"
return # Success

except Exception as e:
if attempt < max_retries - 1:
logger.info(
f"Invoke failed (attempt {attempt + 1}/{max_retries}): {e}. "
f"Retrying in {base_delay}s..."
)
time.sleep(base_delay)
else:
pytest.fail(
f"Invoke failed after {max_retries} attempts. "
f"Last error: {e}"
)


@pytest.fixture(scope="class", autouse=True)
def cleanup_import_jobs(self, bedrock_client):
"""Cleanup any leftover test import jobs after all tests in this class."""
yield
try:
jobs = bedrock_client.list_model_import_jobs()
for job in jobs.get('modelImportJobSummaries', []):
if job['jobName'].startswith('test-bedrock-'):
try:
bedrock_client.stop_model_import_job(jobIdentifier=job['jobArn'])
except Exception:
pass
except Exception:
pass
# Stop in-progress jobs
if job.get('status') in ('InProgress', 'Pending'):
bedrock_client.stop_model_import_job(jobIdentifier=job['jobArn'])
# Delete completed imported models
elif job.get('status') == 'Completed' and job.get('importedModelArn'):
bedrock_client.delete_imported_model(
modelIdentifier=job['importedModelArn']
)
except Exception as e:
logger.warning(f"Cleanup failed for job {job['jobName']}: {e}")
except Exception as e:
logger.warning(f"Failed to list/cleanup import jobs: {e}")


def test_model_customization_workflow(training_job_name):
Expand Down
Loading
Loading