Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 61 additions & 96 deletions sagemaker-serve/tests/integ/test_model_customization_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,22 @@
"""Integration tests for ModelBuilder model customization deployment."""
from __future__ import absolute_import

import boto3
import pytest
import random

from sagemaker.core.helper.session_helper import Session

# This test relies on resources in a specific region
AWS_REGION = "us-west-2"


@pytest.fixture(scope="module")
def sagemaker_session():
"""Create a SageMaker session with explicit region."""
boto_session = boto3.Session(region_name=AWS_REGION)
return Session(boto_session=boto_session)


@pytest.fixture(scope="module")
def training_job_name():
Expand Down Expand Up @@ -45,52 +58,7 @@ def model_package_arn():
def endpoint_name():
"""Generate unique endpoint name."""
import time
return f"e2e-{int(time.time())}-{random.randint(100, 10000)}"


@pytest.fixture(scope="session", autouse=True)
def cleanup_e2e_endpoints():
"""Cleanup e2e endpoints before and after tests."""
import os
from botocore.exceptions import ClientError

# This file's tests use us-west-2 resources. Set SAGEMAKER_REGION so the
# SDK's SageMakerClient creates sessions in the correct region from the start.
# Save/restore to avoid leaking into other test files.
original_sm_region = os.environ.get("SAGEMAKER_REGION")
os.environ["SAGEMAKER_REGION"] = "us-west-2"

from sagemaker.core.resources import Endpoint

# Cleanup before tests
try:
for endpoint in Endpoint.get_all():
try:
if endpoint.endpoint_name.startswith('e2e-'):
endpoint.delete()
except (ClientError, Exception):
pass
except (ClientError, Exception):
pass

yield

# Cleanup after tests
try:
for endpoint in Endpoint.get_all():
try:
if endpoint.endpoint_name.startswith('e2e-'):
endpoint.delete()
except (ClientError, Exception):
pass
except (ClientError, Exception):
pass

# Restore original SAGEMAKER_REGION
if original_sm_region:
os.environ["SAGEMAKER_REGION"] = original_sm_region
elif "SAGEMAKER_REGION" in os.environ:
del os.environ["SAGEMAKER_REGION"]
return f"xe2e-{int(time.time())}-{random.randint(100, 10000)}"


@pytest.fixture(scope="module")
Expand All @@ -102,7 +70,7 @@ def cleanup_endpoints():
for ep_name in endpoints_to_cleanup:
try:
from sagemaker.core.resources import Endpoint
endpoint = Endpoint.get(endpoint_name=ep_name)
endpoint = Endpoint.get(endpoint_name=ep_name, region=AWS_REGION)
endpoint.delete()
except Exception:
pass
Expand All @@ -111,24 +79,23 @@ def cleanup_endpoints():
class TestModelCustomizationFromTrainingJob:
"""Test model customization deployment from TrainingJob."""

def test_build_from_training_job(self, training_job_name):
def test_build_from_training_job(self, training_job_name, sagemaker_session):
"""Test building model from training job."""
from sagemaker.core.resources import TrainingJob
from sagemaker.serve import ModelBuilder
import time

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
model_builder.accept_eula = True
model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}", region=AWS_REGION)

assert model is not None
assert model.model_arn is not None
assert model_builder.image_uri is not None
assert model_builder.instance_type is not None

@pytest.mark.skip(reason="Skipped: parallel cleanup race condition under investigation")
def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanup_endpoints):
def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanup_endpoints, sagemaker_session):
"""Test deploying model from training job.

For LORA models, this verifies the two-step deployment:
Expand All @@ -138,10 +105,10 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu
from sagemaker.serve import ModelBuilder
import time

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge")
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge", sagemaker_session=sagemaker_session)
model_builder.accept_eula = True
model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}", region=AWS_REGION)

peft_type = model_builder._fetch_peft()
adapter_name = f"{endpoint_name}-adapter"
Expand All @@ -160,52 +127,52 @@ def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanu
if peft_type == "LORA":
# Verify base IC was created
base_ic_name = f"{endpoint_name}-inference-component"
base_ic = InferenceComponent.get(inference_component_name=base_ic_name)
base_ic = InferenceComponent.get(inference_component_name=base_ic_name, region=AWS_REGION)
assert base_ic is not None
assert base_ic.inference_component_status == "InService"

# Verify adapter IC was created
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name)
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name, region=AWS_REGION)
assert adapter_ic is not None

def test_fetch_endpoint_names_for_base_model(self, training_job_name):
def test_fetch_endpoint_names_for_base_model(self, training_job_name, sagemaker_session):
"""Test fetching endpoint names for base model."""
from sagemaker.core.resources import TrainingJob
from sagemaker.serve import ModelBuilder

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
endpoint_names = model_builder.fetch_endpoint_names_for_base_model()

assert isinstance(endpoint_names, set)


class TestModelCustomizationFromModelPackage:

def test_build_from_model_package(self, model_package_arn):
def test_build_from_model_package(self, model_package_arn, sagemaker_session):
"""Test building model from model package."""
from sagemaker.core.resources import ModelPackage
from sagemaker.serve import ModelBuilder

model_package = ModelPackage.get(model_package_name=model_package_arn)
model_builder = ModelBuilder(model=model_package)
model_package = ModelPackage.get(model_package_name=model_package_arn, region=AWS_REGION)
model_builder = ModelBuilder(model=model_package, sagemaker_session=sagemaker_session)
model_builder.accept_eula = True
model = model_builder.build()
model = model_builder.build(region=AWS_REGION)

assert model is not None
assert model.model_arn is not None

def test_deploy_from_model_package(self, model_package_arn, cleanup_endpoints):
def test_deploy_from_model_package(self, model_package_arn, cleanup_endpoints, sagemaker_session):
"""Test deploying model from model package."""
from sagemaker.core.resources import ModelPackage
from sagemaker.serve import ModelBuilder
import time

model_package = ModelPackage.get(model_package_name=model_package_arn)
model_package = ModelPackage.get(model_package_name=model_package_arn, region=AWS_REGION)
endpoint_name = f"e2e-{int(time.time())}-{random.randint(100, 10000)}"
model_builder = ModelBuilder(model=model_package)
model_builder = ModelBuilder(model=model_package, sagemaker_session=sagemaker_session)
model_builder.accept_eula = True
model_builder.build()
model_builder.build(region=AWS_REGION)
endpoint = model_builder.deploy(endpoint_name=endpoint_name)

cleanup_endpoints.append(endpoint_name)
Expand All @@ -217,15 +184,15 @@ def test_deploy_from_model_package(self, model_package_arn, cleanup_endpoints):
class TestInstanceTypeAutoDetection:
"""Test automatic instance type detection."""

def test_instance_type_from_recipe(self, training_job_name):
def test_instance_type_from_recipe(self, training_job_name, sagemaker_session):
"""Test instance type auto-detection from recipe."""
from sagemaker.core.resources import TrainingJob
from sagemaker.serve import ModelBuilder

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)
model_builder.accept_eula = True
model_builder.build()
model_builder.build(region=AWS_REGION)

assert model_builder.instance_type is not None
assert "ml." in model_builder.instance_type
Expand All @@ -234,33 +201,33 @@ def test_instance_type_from_recipe(self, training_job_name):
class TestModelCustomizationDetection:
"""Test model customization detection logic."""

def test_is_model_customization_training_job(self, training_job_name):
def test_is_model_customization_training_job(self, training_job_name, sagemaker_session):
"""Test detection from training job."""
from sagemaker.core.resources import TrainingJob
from sagemaker.serve import ModelBuilder

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)

assert model_builder._is_model_customization() is True

def test_is_model_customization_model_package(self, model_package_arn):
def test_is_model_customization_model_package(self, model_package_arn, sagemaker_session):
"""Test detection from model package."""
from sagemaker.core.resources import ModelPackage
from sagemaker.serve import ModelBuilder

model_package = ModelPackage.get(model_package_name=model_package_arn)
model_builder = ModelBuilder(model=model_package)
model_package = ModelPackage.get(model_package_name=model_package_arn, region=AWS_REGION)
model_builder = ModelBuilder(model=model_package, sagemaker_session=sagemaker_session)

assert model_builder._is_model_customization() is True

def test_fetch_model_package_arn(self, training_job_name):
def test_fetch_model_package_arn(self, training_job_name, sagemaker_session):
"""Test fetching model package ARN."""
from sagemaker.core.resources import TrainingJob
from sagemaker.serve import ModelBuilder

training_job = TrainingJob.get(training_job_name=training_job_name)
model_builder = ModelBuilder(model=training_job)
training_job = TrainingJob.get(training_job_name=training_job_name, region=AWS_REGION)
model_builder = ModelBuilder(model=training_job, sagemaker_session=sagemaker_session)

arn = model_builder._fetch_model_package_arn()

Expand All @@ -271,14 +238,14 @@ def test_fetch_model_package_arn(self, training_job_name):
class TestTrainerIntegration:
"""Test ModelBuilder integration with SFTTrainer and DPOTrainer."""

def test_sft_trainer_build(self, training_job_name):
def test_sft_trainer_build(self, training_job_name, sagemaker_session):
"""Test building model from SFTTrainer."""
from sagemaker.core.resources import TrainingJob
from sagemaker.train.sft_trainer import SFTTrainer
from sagemaker.serve import ModelBuilder

training_job = TrainingJob.get(
training_job_name=training_job_name
training_job_name=training_job_name, region=AWS_REGION
)

trainer = SFTTrainer(
Expand All @@ -289,21 +256,21 @@ def test_sft_trainer_build(self, training_job_name):
)
trainer._latest_training_job = training_job

model_builder = ModelBuilder(model=trainer)
model = model_builder.build()
model_builder = ModelBuilder(model=trainer, sagemaker_session=sagemaker_session)
model = model_builder.build(region=AWS_REGION)

assert model is not None
assert model.model_arn is not None

def test_dpo_trainer_build(self, training_job_name):
def test_dpo_trainer_build(self, training_job_name, sagemaker_session):
"""Test building model from DPOTrainer."""
from sagemaker.core.resources import TrainingJob
from sagemaker.train.dpo_trainer import DPOTrainer
from sagemaker.serve import ModelBuilder
from unittest.mock import patch

training_job = TrainingJob.get(
training_job_name=training_job_name
training_job_name=training_job_name, region=AWS_REGION
)

with patch('sagemaker.train.common_utils.finetune_utils._get_fine_tuning_options_and_model_arn',
Expand All @@ -316,8 +283,8 @@ def test_dpo_trainer_build(self, training_job_name):
)
trainer._latest_training_job = training_job

model_builder = ModelBuilder(model=trainer)
model = model_builder.build()
model_builder = ModelBuilder(model=trainer, sagemaker_session=sagemaker_session)
model = model_builder.build(region=AWS_REGION)

assert model is not None
assert model.model_arn is not None
Expand All @@ -335,8 +302,6 @@ def test_dpo_trainer_build(self, training_job_name):

import json
import time
import random
import boto3
import pytest
from sagemaker.core.resources import TrainingJob, ModelPackage
from sagemaker.serve.bedrock_model_builder import BedrockModelBuilder
Expand All @@ -361,6 +326,7 @@ def training_job(self, setup_config):
"""Get the training job."""
return TrainingJob.get(
training_job_name=setup_config["training_job_name"],
region=setup_config["region"],
)

@pytest.fixture(scope="class")
Expand Down Expand Up @@ -432,7 +398,7 @@ def _setup_model_files(self, training_job, s3_client, setup_config):
base_s3_path = training_job.model_artifacts.s3_model_artifacts
elif hasattr(training_job, 'output_model_package_arn'):
# If training job has model package ARN, get artifacts from model package
model_package = ModelPackage.get(training_job.output_model_package_arn)
model_package = ModelPackage.get(training_job.output_model_package_arn, region=AWS_REGION)
if hasattr(model_package,
'inference_specification') and model_package.inference_specification.containers:
container = model_package.inference_specification.containers[0]
Expand Down Expand Up @@ -561,8 +527,7 @@ def test_zzz_cleanup_deployed_model(self, bedrock_client):
def test_model_customization_workflow(training_job_name):
"""Standalone test function for pytest discovery.

Relies on SAGEMAKER_REGION being set by the cleanup_e2e_endpoints
session fixture (us-west-2).
Uses explicit region parameter for all SDK calls.
"""
config = {
"training_job_name": training_job_name,
Expand All @@ -572,7 +537,7 @@ def test_model_customization_workflow(training_job_name):

try:
s3_client = boto3.client('s3', region_name=config["region"])
training_job = TrainingJob.get(training_job_name=config["training_job_name"])
training_job = TrainingJob.get(training_job_name=config["training_job_name"], region=config["region"])

test_class = TestModelCustomizationDeployment()
test_class.test_training_job_exists(training_job)
Expand Down
Loading
Loading