From 93da557ef147af1d16fa45f0b5b7b57664457338 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:09:19 -0700 Subject: [PATCH 1/2] ModelBuilder LoRA support through ModelPackage input --- .../src/sagemaker/serve/model_builder.py | 17 ++ .../unit/test_model_package_peft_detection.py | 222 ++++++++++++++++++ 2 files changed, 239 insertions(+) create mode 100644 sagemaker-serve/tests/unit/test_model_package_peft_detection.py diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder.py b/sagemaker-serve/src/sagemaker/serve/model_builder.py index 7c7af2defc..a48edea31a 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder.py @@ -2411,6 +2411,12 @@ def _build_single_modelbuilder( f"{self.model._latest_training_job.model_artifacts.s3_model_artifacts}" "/checkpoints/hf/" ) + elif isinstance(self.model, ModelPackage): + self._adapter_s3_uri = ( + model_package.inference_specification.containers[ + 0 + ].model_data_source.s3_data_source.s3_uri + ) else: # Non-LORA: Model points at training output self.s3_upload_path = model_package.inference_specification.containers[ @@ -2418,6 +2424,7 @@ def _build_single_modelbuilder( ].model_data_source.s3_data_source.s3_uri container_def = ContainerDefinition( image=self.image_uri, + environment=self.env_vars, model_data_source={ "s3_data_source": { "s3_uri": self.s3_upload_path.rstrip("/") + "/", @@ -4554,6 +4561,16 @@ def _fetch_peft(self) -> Optional[str]: training_job = self.model elif isinstance(self.model, ModelTrainer): training_job = self.model._latest_training_job + elif isinstance(self.model, ModelPackage): + try: + recipe_name = ( + self.model.inference_specification.containers[0].base_model.recipe_name + ) + if recipe_name and "lora" in recipe_name.lower(): + return "LORA" + except (AttributeError, IndexError): + pass + return None else: return None diff --git a/sagemaker-serve/tests/unit/test_model_package_peft_detection.py b/sagemaker-serve/tests/unit/test_model_package_peft_detection.py new file mode 100644 index 0000000000..3fc18d07b6 --- /dev/null +++ b/sagemaker-serve/tests/unit/test_model_package_peft_detection.py @@ -0,0 +1,222 @@ +"""Unit tests for ModelPackage LoRA detection in _fetch_peft() and related paths. + +Tests verify that: +1. _fetch_peft() returns "LORA" for ModelPackage with lora recipe name +2. _fetch_peft() returns None for ModelPackage with non-lora recipe name +3. _fetch_peft() returns None for ModelPackage with no recipe name +4. _adapter_s3_uri is correctly set from ModelPackage container S3 URI +5. env vars are applied in the non-LoRA ContainerDefinition path +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.core.resources import ModelPackage, TrainingJob + + +class TestModelPackagePeftDetection: + """Test _fetch_peft() behavior with ModelPackage input.""" + + def _create_model_package_mock(self, recipe_name=None): + """Helper to create a mock ModelPackage with a given recipe name.""" + mock_package = Mock(spec=ModelPackage) + mock_container = Mock() + mock_container.base_model = Mock() + mock_container.base_model.recipe_name = recipe_name + mock_package.inference_specification = Mock() + mock_package.inference_specification.containers = [mock_container] + return mock_package + + def test_fetch_peft_returns_lora_for_lora_recipe(self): + """_fetch_peft() returns 'LORA' when recipe name contains 'lora'.""" + mock_package = self._create_model_package_mock( + recipe_name="verl-grpo-rlvr-qwen-3-32b-lora" + ) + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() == "LORA" + + def test_fetch_peft_returns_lora_case_insensitive(self): + """_fetch_peft() matches 'lora' case-insensitively.""" + mock_package = self._create_model_package_mock( + recipe_name="some-model-LoRA-adapter" + ) + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() == "LORA" + + def test_fetch_peft_returns_none_for_fft_recipe(self): + """_fetch_peft() returns None when recipe name does not contain 'lora'.""" + mock_package = self._create_model_package_mock( + recipe_name="verl-grpo-rlvr-qwen-3-32b-fft" + ) + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() is None + + def test_fetch_peft_returns_none_for_no_recipe_name(self): + """_fetch_peft() returns None when recipe name is None.""" + mock_package = self._create_model_package_mock(recipe_name=None) + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() is None + + def test_fetch_peft_returns_none_when_base_model_missing(self): + """_fetch_peft() returns None when base_model attribute is missing.""" + mock_package = Mock(spec=ModelPackage) + mock_container = Mock() + mock_container.base_model = None + mock_package.inference_specification = Mock() + mock_package.inference_specification.containers = [mock_container] + + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() is None + + def test_fetch_peft_returns_none_when_containers_empty(self): + """_fetch_peft() returns None when containers list is empty.""" + mock_package = Mock(spec=ModelPackage) + mock_package.inference_specification = Mock() + mock_package.inference_specification.containers = [] + + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + assert builder._fetch_peft() is None + + +class TestModelPackageAdapterS3Uri: + """Test _adapter_s3_uri is correctly set from ModelPackage.""" + + @patch.object(ModelBuilder, "_fetch_model_package_arn") + @patch.object(ModelBuilder, "_fetch_model_package") + @patch.object(ModelBuilder, "_fetch_peft") + @patch.object(ModelBuilder, "_fetch_hub_document_for_custom_model") + @patch.object(ModelBuilder, "_fetch_and_cache_recipe_config") + @patch.object(ModelBuilder, "_is_nova_model", return_value=False) + @patch.object(ModelBuilder, "_is_model_customization") + @patch("sagemaker.core.resources.Model.create") + def test_adapter_s3_uri_set_from_model_package( + self, + mock_model_create, + mock_is_customization, + mock_is_nova_model, + mock_fetch_and_cache_recipe, + mock_fetch_hub, + mock_fetch_peft, + mock_fetch_package, + mock_fetch_package_arn, + ): + """_adapter_s3_uri is set from ModelPackage container S3 URI for LORA.""" + mock_is_customization.return_value = True + mock_fetch_peft.return_value = "LORA" + + expected_adapter_uri = "s3://bucket/adapter-weights/" + + mock_package = Mock(spec=ModelPackage) + mock_package.model_package_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:model-package/test-package" + ) + mock_container = Mock() + mock_container.base_model = Mock() + mock_container.base_model.recipe_name = "verl-grpo-rlvr-qwen-3-32b-lora" + mock_container.model_data_source = Mock() + mock_container.model_data_source.s3_data_source = Mock() + mock_container.model_data_source.s3_data_source.s3_uri = expected_adapter_uri + mock_package.inference_specification = Mock() + mock_package.inference_specification.containers = [mock_container] + mock_fetch_package.return_value = mock_package + mock_fetch_package_arn.return_value = mock_package.model_package_arn + + mock_fetch_hub.return_value = { + "HostingArtifactUri": "s3://jumpstart-bucket/base-model-artifacts/" + } + + mock_model_create.return_value = Mock() + + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + ) + builder.accept_eula = True + builder._build_single_modelbuilder() + + assert builder._adapter_s3_uri == expected_adapter_uri + + +class TestNonLoraEnvVars: + """Test env vars are applied in the non-LoRA ContainerDefinition path.""" + + @patch.object(ModelBuilder, "_fetch_model_package_arn") + @patch.object(ModelBuilder, "_fetch_model_package") + @patch.object(ModelBuilder, "_fetch_peft") + @patch.object(ModelBuilder, "_fetch_and_cache_recipe_config") + @patch.object(ModelBuilder, "_is_nova_model", return_value=False) + @patch.object(ModelBuilder, "_is_model_customization") + @patch("sagemaker.core.resources.Model.create") + def test_env_vars_passed_to_non_lora_container_def( + self, + mock_model_create, + mock_is_customization, + mock_is_nova_model, + mock_fetch_and_cache_recipe, + mock_fetch_peft, + mock_fetch_package, + mock_fetch_package_arn, + ): + """Non-LoRA ContainerDefinition includes environment vars.""" + mock_is_customization.return_value = True + mock_fetch_peft.return_value = None # Not LORA + + mock_package = Mock(spec=ModelPackage) + mock_package.model_package_arn = ( + "arn:aws:sagemaker:us-west-2:123456789012:model-package/test-package" + ) + mock_container = Mock() + mock_container.base_model = Mock() + mock_container.base_model.recipe_name = "verl-grpo-rlvr-qwen-3-32b-fft" + mock_container.model_data_source = Mock() + mock_container.model_data_source.s3_data_source = Mock() + mock_container.model_data_source.s3_data_source.s3_uri = "s3://bucket/model/" + mock_package.inference_specification = Mock() + mock_package.inference_specification.containers = [mock_container] + mock_fetch_package.return_value = mock_package + mock_fetch_package_arn.return_value = mock_package.model_package_arn + + mock_model_create.return_value = Mock() + + expected_env = {"SM_MODEL_ID": "test-model", "CUSTOM_VAR": "value"} + + builder = ModelBuilder( + model=mock_package, + role_arn="arn:aws:iam::123456789012:role/SageMakerRole", + instance_type="ml.g5.12xlarge", + env_vars=expected_env, + ) + builder._build_single_modelbuilder() + + # Verify Model.create was called and the container has environment set + assert mock_model_create.called + create_call = mock_model_create.call_args + containers = create_call[1].get("containers", []) + assert len(containers) == 1 + container_def = containers[0] + assert container_def.environment == expected_env From 5815eebc296e8c4cf62627efb969439c3292be15 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:37:05 -0700 Subject: [PATCH 2/2] Adding integration tests for model package lora detection --- .../test_model_package_lora_detection.py | 142 ++++++++++++++++++ 1 file changed, 142 insertions(+) create mode 100644 sagemaker-serve/tests/integ/test_model_package_lora_detection.py diff --git a/sagemaker-serve/tests/integ/test_model_package_lora_detection.py b/sagemaker-serve/tests/integ/test_model_package_lora_detection.py new file mode 100644 index 0000000000..559fa6f52c --- /dev/null +++ b/sagemaker-serve/tests/integ/test_model_package_lora_detection.py @@ -0,0 +1,142 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Integration tests for ModelPackage LoRA detection in ModelBuilder. + +Build-only tests that verify _fetch_peft() correctly detects LoRA from +ModelPackage recipe names and that the build path sets the right attributes. +No deployment or GPU instances required. +""" +from __future__ import absolute_import + +import os +import time +import random +import pytest + + +# LoRA model package (recipe name contains "lora") +LORA_MODEL_PACKAGE_ARN = ( + "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/1" +) + +# Non-LoRA model package (DPO recipe, no "lora" in name) +NON_LORA_MODEL_PACKAGE_ARN = ( + "arn:aws:sagemaker:us-west-2:729646638167:model-package/sdk-test-finetuned-models/264" +) + + +@pytest.fixture(scope="session", autouse=True) +def set_region(): + """Ensure us-west-2 region for all tests.""" + original = os.environ.get("SAGEMAKER_REGION") + os.environ["SAGEMAKER_REGION"] = "us-west-2" + yield + if original: + os.environ["SAGEMAKER_REGION"] = original + elif "SAGEMAKER_REGION" in os.environ: + del os.environ["SAGEMAKER_REGION"] + + +class TestModelPackageLoraDetection: + """Test _fetch_peft() LoRA detection from real ModelPackage resources.""" + + def test_fetch_peft_returns_lora_for_lora_model_package(self): + """_fetch_peft() returns 'LORA' for a ModelPackage with a LoRA recipe name.""" + from sagemaker.core.resources import ModelPackage + from sagemaker.serve import ModelBuilder + + model_package = ModelPackage.get(model_package_name=LORA_MODEL_PACKAGE_ARN) + model_builder = ModelBuilder(model=model_package) + + peft_type = model_builder._fetch_peft() + assert peft_type == "LORA", ( + f"Expected 'LORA' but got '{peft_type}' for {LORA_MODEL_PACKAGE_ARN}" + ) + + def test_fetch_peft_returns_none_for_non_lora_model_package(self): + """_fetch_peft() returns None for a ModelPackage without LoRA recipe name.""" + from sagemaker.core.resources import ModelPackage + from sagemaker.serve import ModelBuilder + + model_package = ModelPackage.get(model_package_name=NON_LORA_MODEL_PACKAGE_ARN) + model_builder = ModelBuilder(model=model_package) + + peft_type = model_builder._fetch_peft() + assert peft_type is None, ( + f"Expected None but got '{peft_type}' for {NON_LORA_MODEL_PACKAGE_ARN}" + ) + + +class TestModelPackageLoraBuild: + """Test that build() from a LoRA ModelPackage sets the right attributes.""" + + def test_build_lora_model_package_sets_adapter_s3_uri(self): + """build() from a LoRA ModelPackage sets _adapter_s3_uri.""" + from sagemaker.core.resources import ModelPackage + from sagemaker.serve import ModelBuilder + + model_package = ModelPackage.get(model_package_name=LORA_MODEL_PACKAGE_ARN) + model_builder = ModelBuilder(model=model_package) + model_builder.accept_eula = True + + model_name = f"integ-lora-test-{int(time.time())}-{random.randint(100, 10000)}" + model = model_builder.build(model_name=model_name) + + try: + # Verify model was created + assert model is not None + assert model.model_arn is not None + + # Verify LoRA-specific attributes + assert hasattr(model_builder, "_adapter_s3_uri"), ( + "_adapter_s3_uri should be set after build() for LoRA ModelPackage" + ) + assert model_builder._adapter_s3_uri is not None + assert model_builder._adapter_s3_uri.startswith("s3://"), ( + f"_adapter_s3_uri should be an S3 URI, got: {model_builder._adapter_s3_uri}" + ) + finally: + # Cleanup: delete the created model + try: + model.delete() + except Exception: + pass + + def test_build_non_lora_model_package_no_adapter_uri(self): + """build() from a non-LoRA ModelPackage does NOT set _adapter_s3_uri.""" + from sagemaker.core.resources import ModelPackage + from sagemaker.serve import ModelBuilder + + model_package = ModelPackage.get(model_package_name=NON_LORA_MODEL_PACKAGE_ARN) + model_builder = ModelBuilder(model=model_package) + model_builder.accept_eula = True + + model_name = f"integ-nonlora-test-{int(time.time())}-{random.randint(100, 10000)}" + model = model_builder.build(model_name=model_name) + + try: + # Verify model was created + assert model is not None + assert model.model_arn is not None + + # Verify no adapter URI is set (non-LoRA path) + adapter_uri = getattr(model_builder, "_adapter_s3_uri", None) + assert adapter_uri is None, ( + f"_adapter_s3_uri should not be set for non-LoRA ModelPackage, got: {adapter_uri}" + ) + finally: + # Cleanup: delete the created model + try: + model.delete() + except Exception: + pass