From f48a573f8760129cb91fffe58263568b17e14a35 Mon Sep 17 00:00:00 2001 From: Liam Neal Reilly Date: Fri, 1 May 2026 14:28:47 +0000 Subject: [PATCH] fix: handle unrecognized JumpStart container images in ModelBuilder Collapse endpoint mode routing into a single code path since all JumpStart builders (DJL, TGI, MMS) perform identical operations for endpoint deployment. Unknown container images (e.g. vLLM) now pass through instead of raising ValueError. For local container mode, unrecognized images raise ValueError with a message directing users to use SAGEMAKER_ENDPOINT mode. --- .../sagemaker/serve/model_builder_servers.py | 349 ++++++++-------- .../integ/test_jumpstart_vllm_integration.py | 58 +++ .../servers/test_model_builder_servers.py | 10 +- .../test_model_builder_servers_coverage.py | 371 ++++++++++-------- 4 files changed, 442 insertions(+), 346 deletions(-) create mode 100644 sagemaker-serve/tests/integ/test_jumpstart_vllm_integration.py diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index dfcfa10bc7..64c76f59cf 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -36,9 +36,7 @@ _get_gpu_info_fallback, ) from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf -from sagemaker.serve.detector.image_detector import ( - _get_model_base, _detect_framework_and_version -) +from sagemaker.serve.detector.image_detector import _get_model_base, _detect_framework_and_version from sagemaker.serve.detector.pickler import save_pkl from sagemaker.serve.utils.types import ModelServer @@ -75,14 +73,14 @@ class _ModelBuilderServers(object): def _build_for_model_server(self) -> Model: """Build model using explicit model server configuration. - + This method routes to the appropriate model server builder based on the specified model_server parameter. It validates that the model server is supported and that required parameters are provided. - + Returns: Model: A deployable model object configured for the specified model server. - + Raises: ValueError: If the model server is not supported or required parameters are missing. """ @@ -118,18 +116,17 @@ def _build_for_model_server(self) -> Model: else: raise ValueError(f"Unsupported model server: {self.model_server}") - def _build_for_torchserve(self) -> Model: """Build model for TorchServe deployment. - + Configures the model for deployment using TorchServe model server. Handles both local model objects and HuggingFace model IDs. - + Returns: Model: Configured model ready for TorchServe deployment. """ self.secret_key = "" - + # Save inference spec if we have local artifacts self._save_model_inference_spec() @@ -137,11 +134,11 @@ def _build_for_torchserve(self) -> Model: # Configure HuggingFace model support if not self._is_jumpstart_model_id(): self.env_vars.setdefault("HF_MODEL_ID", self.model) - + # Add HuggingFace token if available if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + # HF models download directly from hub self.s3_upload_path = None @@ -176,19 +173,18 @@ def _build_for_torchserve(self) -> Model: return self._create_model() - def _build_for_tgi(self) -> Model: """Build model for Text Generation Inference (TGI) deployment. - + Configures the model for deployment using Hugging Face's TGI server, optimized for large language model inference with features like tensor parallelism and continuous batching. - + Returns: Model: Configured model ready for TGI deployment. """ self.secret_key = "" - + # Initialize TGI-specific configuration if self.model_server != ModelServer.TGI: messaging = ( @@ -200,7 +196,7 @@ def _build_for_tgi(self) -> Model: self.model_server = ModelServer.TGI self._validate_tgi_serving_sample_data() - + # Use notebook instance type if available nb_instance = _get_nb_instance() if nb_instance and not self._user_provided_instance_type: @@ -208,49 +204,51 @@ def _build_for_tgi(self) -> Model: logger.debug(f"Using detected notebook instance type: {nb_instance}") from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for TGI self.env_vars.setdefault("HF_MODEL_ID", self.model) - + self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + # Apply TGI-specific configurations default_tgi_configurations, _default_max_new_tokens = _get_default_tgi_configurations( self.model, self.hf_model_config, self.schema_builder ) # Filter out None values to avoid pydantic validation errors - filtered_tgi_config = {k: v for k, v in default_tgi_configurations.items() if v is not None} + filtered_tgi_config = { + k: v for k, v in default_tgi_configurations.items() if v is not None + } self.env_vars.update(filtered_tgi_config) - + # Configure schema builder for text generation if "parameters" not in self.schema_builder.sample_input: self.schema_builder.sample_input["parameters"] = {} - self.schema_builder.sample_input["parameters"]["max_new_tokens"] = _default_max_new_tokens - + self.schema_builder.sample_input["parameters"][ + "max_new_tokens" + ] = _default_max_new_tokens + # Set TGI sharding defaults self.env_vars.setdefault("SHARDED", "false") self.env_vars.setdefault("NUM_SHARD", "1") - + # Configure HuggingFace cache and authentication - tgi_env_vars = { - "HF_HOME": "/tmp", - "HUGGINGFACE_HUB_CACHE": "/tmp" - } - + tgi_env_vars = {"HF_HOME": "/tmp", "HUGGINGFACE_HUB_CACHE": "/tmp"} + if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): tgi_env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + self.env_vars.update(tgi_env_vars) - + # TGI downloads models directly from HuggingFace Hub self.s3_upload_path = None self._auto_detect_image_uri() - + if not self._optimizing: if self.mode in LOCAL_MODES: self._prepare_for_mode(should_upload_artifacts=True) @@ -270,48 +268,51 @@ def _build_for_tgi(self) -> Model: raise ValueError( "Instance type must be provided when building for SageMaker Endpoint mode." ) - + try: tot_gpus = _get_gpu_info(self.instance_type, self.sagemaker_session) except Exception: # pylint: disable=W0703 tot_gpus = _get_gpu_info_fallback(self.instance_type) - + default_num_shard = _get_default_tensor_parallel_degree(self.hf_model_config, tot_gpus) - self.env_vars.update({ - "NUM_SHARD": str(default_num_shard), - "SHARDED": "true" if default_num_shard > 1 else "false", - }) - + self.env_vars.update( + { + "NUM_SHARD": str(default_num_shard), + "SHARDED": "true" if default_num_shard > 1 else "false", + } + ) + model = self._create_model() - + if "HF_HUB_OFFLINE" in self.env_vars: self.env_vars.update({"HF_HUB_OFFLINE": "0"}) - + return model def _build_for_djl(self) -> Model: """Build model for DJL Serving deployment. - + Configures the model for deployment using Amazon's Deep Java Library (DJL) Serving, which provides high-performance inference with support for tensor parallelism and model partitioning. - + Returns: Model: Configured model ready for DJL Serving deployment. """ self.secret_key = "" self.model_server = ModelServer.DJL_SERVING - + # Set MODEL_LOADING_TIMEOUT from instance variable if self.model_data_download_timeout: self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(self.model_data_download_timeout)}) - + self._validate_djl_serving_sample_data() # Create DJL directory structure for local artifacts from sagemaker.serve.model_server.djl_serving.prepare import _create_dir_structure + _create_dir_structure(self.model_path) - + # Use notebook instance type if available nb_instance = _get_nb_instance() if nb_instance and not self._user_provided_instance_type: @@ -321,7 +322,7 @@ def _build_for_djl(self) -> Model: if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for DJL self.env_vars.setdefault("HF_MODEL_ID", self.model) - + # Get model configuration for DJL optimization self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") @@ -336,7 +337,9 @@ def _build_for_djl(self) -> Model: # Configure schema builder for text generation if "parameters" not in self.schema_builder.sample_input: self.schema_builder.sample_input["parameters"] = {} - self.schema_builder.sample_input["parameters"]["max_new_tokens"] = _default_max_new_tokens + self.schema_builder.sample_input["parameters"][ + "max_new_tokens" + ] = _default_max_new_tokens # Set DJL serving defaults (only if not already set by user) djl_env_vars = { @@ -360,7 +363,7 @@ def _build_for_djl(self) -> Model: self.s3_upload_path = None self._auto_detect_image_uri() - + if not self._optimizing: if self.mode in LOCAL_MODES: self._prepare_for_mode(should_upload_artifacts=True) @@ -383,34 +386,31 @@ def _build_for_djl(self) -> Model: raise ValueError( "Instance type must be provided when building for SageMaker Endpoint mode." ) - + try: tot_gpus = _get_gpu_info(self.instance_type, self.sagemaker_session) except Exception: # pylint: disable=W0703 tot_gpus = _get_gpu_info_fallback(self.instance_type) - + default_tensor_parallel_degree = _get_default_tensor_parallel_degree( self.hf_model_config, tot_gpus ) - self.env_vars.update({ - "TENSOR_PARALLEL_DEGREE": str(default_tensor_parallel_degree) - }) - + self.env_vars.update({"TENSOR_PARALLEL_DEGREE": str(default_tensor_parallel_degree)}) + model = self._create_model() - + if "HF_HUB_OFFLINE" in self.env_vars: self.env_vars.update({"TRANSFORMERS_OFFLINE": "0"}) - - return model + return model def _build_for_triton(self) -> Model: """Build model for NVIDIA Triton Inference Server deployment. - + Configures the model for deployment using NVIDIA Triton Inference Server, which provides high-performance inference with support for multiple frameworks and advanced features like dynamic batching. - + Returns: Model: Configured model ready for Triton deployment. """ @@ -420,7 +420,7 @@ def _build_for_triton(self) -> Model: if isinstance(self.model, str): self.framework = None self.framework_version = None - + # Configure HuggingFace model for Triton if not self._is_jumpstart_model_id(): # Get model metadata for task detection @@ -430,53 +430,53 @@ def _build_for_triton(self) -> Model: model_task = hf_model_md.get("pipeline_tag") if model_task: self.env_vars.update({"HF_TASK": model_task}) - + # Configure HuggingFace authentication self.env_vars.setdefault("HF_MODEL_ID", self.model) if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + # Triton downloads models directly from HuggingFace Hub self.s3_upload_path = None elif self.model: - fw, self.framework_version = _detect_framework_and_version(str(_get_model_base(self.model))) + fw, self.framework_version = _detect_framework_and_version( + str(_get_model_base(self.model)) + ) self.framework = self._normalize_framework_to_enum(fw) - + # Auto-detect Triton image if not provided if not self.image_uri: self._auto_detect_image_for_triton() - + # Prepare Triton-specific artifacts self._save_inference_spec() self._prepare_for_triton() - + # Prepare deployment artifacts if self.mode in LOCAL_MODES: self._prepare_for_mode() else: # self._prepare_for_mode(should_upload_artifacts=True) self.s3_model_data_url, _ = self._prepare_for_mode(should_upload_artifacts=True) - - return self._create_model() - + return self._create_model() def _build_for_tensorflow_serving(self) -> Model: """Build model for TensorFlow Serving deployment. - + Configures the model for deployment using TensorFlow Serving, Google's high-performance serving system for TensorFlow models. - + Returns: Model: Configured model ready for TensorFlow Serving deployment. - + Raises: ValueError: If image_uri is not provided for TensorFlow Serving. """ self.secret_key = "" if not getattr(self, "_is_mlflow_model", False): raise ValueError("Tensorflow Serving is currently only supported for mlflow models.") - + # Save Schema Builder if not os.path.exists(self.model_path): os.makedirs(self.model_path) @@ -501,23 +501,22 @@ def _build_for_tensorflow_serving(self) -> Model: return self._create_model() - def _build_for_tei(self) -> Model: """Build model for Text Embeddings Inference (TEI) deployment. - + Configures the model for deployment using Hugging Face's TEI server, optimized for embedding model inference with features like pooling strategies and efficient batching. - + Returns: Model: Configured model ready for TEI deployment. """ self.secret_key = "" - + # Set MODEL_LOADING_TIMEOUT from instance variable if self.model_data_download_timeout: self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(self.model_data_download_timeout)}) - + if self.model_server != ModelServer.TEI: messaging = ( "HuggingFace Model ID support on model server: " @@ -534,33 +533,34 @@ def _build_for_tei(self) -> Model: logger.debug(f"Using detected notebook instance type: {nb_instance}") from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for TEI self.env_vars.setdefault("HF_MODEL_ID", self.model) - + self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + # Configure TEI-specific environment variables tei_env_vars = { "HF_HOME": "/tmp", "HUGGINGFACE_HUB_CACHE": "/tmp", "MODEL_LOADING_TIMEOUT": "240", } - + if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): tei_env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + self.env_vars.update(tei_env_vars) - + # TEI downloads models directly from HuggingFace Hub self.s3_upload_path = None self._auto_detect_image_uri() - + if not self._optimizing: if self.mode in LOCAL_MODES: self._prepare_for_mode(should_upload_artifacts=True) @@ -579,27 +579,26 @@ def _build_for_tei(self) -> Model: raise ValueError( "Instance type must be provided when building for SageMaker Endpoint mode." ) - + model = self._create_model() - + if "HF_HUB_OFFLINE" in self.env_vars: self.env_vars.update({"HF_HUB_OFFLINE": "0"}) - - return model + return model def _build_for_smd(self) -> Model: """Build model for SageMaker Distribution (SMD) deployment. - + Configures the model for deployment using SageMaker's distribution container, which provides a comprehensive ML environment with pre-installed frameworks and optimizations. - + Returns: Model: Configured model ready for SMD deployment. """ self.secret_key = "" - + self._save_model_inference_spec() if self.mode != Mode.IN_PROCESS: @@ -625,16 +624,16 @@ def _build_for_smd(self) -> Model: def _build_for_transformers(self) -> Model: """Build model for HuggingFace Transformers deployment. - + Configures the model for deployment using Multi-Model Server (MMS) with HuggingFace Transformers backend for general-purpose model inference. - + Returns: Model: Configured model ready for Transformers deployment. """ self.secret_key = "" self.model_server = ModelServer.MMS - + # Set MODEL_LOADING_TIMEOUT from instance variable if self.model_data_download_timeout: self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(self.model_data_download_timeout)}) @@ -676,6 +675,7 @@ def _build_for_transformers(self) -> Model: logger.debug(f"Using detected notebook instance type: {nb_instance}") from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if not isinstance(self.model, str) or not self._is_jumpstart_model_id(): @@ -700,13 +700,13 @@ def _build_for_transformers(self) -> Model: model_task = hf_model_md.get("pipeline_tag") if model_task: self.env_vars.update({"HF_TASK": model_task}) - + self.env_vars.setdefault("HF_MODEL_ID", self.model) - + # Add HuggingFace token if available if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + # Get HF config for string model IDs if hasattr(self.env_vars, "HF_API_TOKEN"): self.hf_model_config = _get_model_config_properties_from_hf( @@ -723,7 +723,6 @@ def _build_for_transformers(self) -> Model: else: self.s3_model_data_url, _ = self._prepare_for_mode() - # Clean up empty secret key if ( "SAGEMAKER_SERVE_SECRET_KEY" in self.env_vars @@ -739,14 +738,12 @@ def _build_for_transformers(self) -> Model: return self._create_model() - - def _build_for_djl_jumpstart(self, init_kwargs) -> Model: """Build DJL model using JumpStart artifacts. - + Args: init_kwargs: JumpStart initialization parameters. - + Returns: Model: Configured DJL model for JumpStart deployment. """ @@ -754,14 +751,12 @@ def _build_for_djl_jumpstart(self, init_kwargs) -> Model: self.model_server = ModelServer.DJL_SERVING from sagemaker.serve.model_server.djl_serving.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if self.mode in LOCAL_MODES: # Prepare DJL resources for local deployment - ( - self.js_model_config, - self.prepared_for_djl - ) = prepare_djl_js_resources( + (self.js_model_config, self.prepared_for_djl) = prepare_djl_js_resources( model_path=self.model_path, js_id=self.model, dependencies=self.dependencies, @@ -776,13 +771,12 @@ def _build_for_djl_jumpstart(self, init_kwargs) -> Model: self.prepared_for_djl = True return self._create_model() - def _build_for_tgi_jumpstart(self, init_kwargs) -> Model: """Build TGI model using JumpStart artifacts. - + Args: init_kwargs: JumpStart initialization parameters. - + Returns: Model: Configured TGI model for JumpStart deployment. """ @@ -790,6 +784,7 @@ def _build_for_tgi_jumpstart(self, init_kwargs) -> Model: self.model_server = ModelServer.TGI from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if self.mode in LOCAL_MODES: @@ -809,16 +804,15 @@ def _build_for_tgi_jumpstart(self, init_kwargs) -> Model: if not hasattr(self, "s3_upload_path") or not self.s3_upload_path: self.s3_upload_path = init_kwargs.model_data - self.prepared_for_tgi = True + self.prepared_for_tgi = True return self._create_model() - def _build_for_mms_jumpstart(self, init_kwargs) -> Model: """Build MMS model using JumpStart artifacts. - + Args: init_kwargs: JumpStart initialization parameters. - + Returns: Model: Configured MMS model for JumpStart deployment. """ @@ -826,6 +820,7 @@ def _build_for_mms_jumpstart(self, init_kwargs) -> Model: self.model_server = ModelServer.MMS from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if self.mode in LOCAL_MODES: @@ -848,49 +843,50 @@ def _build_for_mms_jumpstart(self, init_kwargs) -> Model: self.prepared_for_mms = True return self._create_model() - - def _build_for_jumpstart(self) -> Model: """Build model for JumpStart deployment. - + Configures the model for deployment using SageMaker JumpStart, which provides pre-trained models with optimized configurations and deployment settings. - + Returns: Model: Configured model ready for JumpStart deployment. - + Raises: ValueError: If the JumpStart image URI is not supported. """ from sagemaker.core.jumpstart.factory.utils import get_init_kwargs self.secret_key = "" - + # Get JumpStart model configuration init_kwargs = get_init_kwargs( model_id=self.model, model_version=self.model_version or "*", region=self.region, instance_type=self.instance_type, - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None), - config_name=getattr(self, 'config_name', None), + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), + config_name=getattr(self, "config_name", None), ) - + # Configure image URI and environment variables self.image_uri = self.image_uri or init_kwargs.image_uri - - if hasattr(init_kwargs, 'env') and init_kwargs.env: + + if hasattr(init_kwargs, "env") and init_kwargs.env: self.env_vars.update(init_kwargs.env) - + # Handle model artifacts for fine-tuned models - if hasattr(init_kwargs, 'model_data') and init_kwargs.model_data: - if isinstance(init_kwargs.model_data, dict) and 'S3DataSource' in init_kwargs.model_data: - self.s3_model_data_url = init_kwargs.model_data['S3DataSource']['S3Uri'] + if hasattr(init_kwargs, "model_data") and init_kwargs.model_data: + if ( + isinstance(init_kwargs.model_data, dict) + and "S3DataSource" in init_kwargs.model_data + ): + self.s3_model_data_url = init_kwargs.model_data["S3DataSource"]["S3Uri"] elif isinstance(init_kwargs.model_data, str): self.s3_model_data_url = init_kwargs.model_data - + # Prepare resources based on mode and model server if self.mode == Mode.LOCAL_CONTAINER: # Route to appropriate model server and prepare resources @@ -907,7 +903,7 @@ def _build_for_jumpstart(self) -> Model: model_data=self.s3_model_data_url, ) return self._build_for_djl_jumpstart(init_kwargs) - + elif "tgi-inference" in self.image_uri: self.model_server = ModelServer.TGI if not hasattr(self, "prepared_for_tgi"): @@ -918,7 +914,7 @@ def _build_for_jumpstart(self) -> Model: model_data=self.s3_model_data_url, ) return self._build_for_tgi_jumpstart(init_kwargs) - + elif "huggingface-pytorch-inference" in self.image_uri: self.model_server = ModelServer.MMS if not hasattr(self, "prepared_for_mms"): @@ -930,27 +926,28 @@ def _build_for_jumpstart(self) -> Model: ) return self._build_for_mms_jumpstart(init_kwargs) else: - raise ValueError(f"Unsupported JumpStart image URI: {self.image_uri}") - + raise ValueError( + f"Local container mode is not yet supported for JumpStart image: {self.image_uri}. " + f"Use Mode.SAGEMAKER_ENDPOINT for deployment." + ) + else: - # SAGEMAKER_ENDPOINT mode - prepare artifacts if needed + # SAGEMAKER_ENDPOINT mode — all JumpStart containers follow the same + # pattern: JumpStart provides the full image URI, env vars, and S3 + # model artifacts. No framework-specific prep is needed. if not self._optimizing: self._prepare_for_mode() - - # Route to appropriate model server based on image URI + if "djl-inference" in self.image_uri: self.model_server = ModelServer.DJL_SERVING - return self._build_for_djl_jumpstart(init_kwargs) elif "tgi-inference" in self.image_uri: self.model_server = ModelServer.TGI - return self._build_for_tgi_jumpstart(init_kwargs) elif "huggingface-pytorch-inference" in self.image_uri: self.model_server = ModelServer.MMS - return self._build_for_mms_jumpstart(init_kwargs) - else: - raise ValueError(f"Unsupported JumpStart image URI: {self.image_uri}") - + if not hasattr(self, "s3_upload_path") or not self.s3_upload_path: + self.s3_upload_path = init_kwargs.model_data + return self._create_model() def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, LocalEndpoint]: """Returns predictor depending on local mode or endpoint mode""" @@ -961,122 +958,120 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, # Handle local deployment modes if self.mode == Mode.IN_PROCESS: return self._deploy_local_endpoint(**kwargs) - + if self.mode == Mode.LOCAL_CONTAINER: return self._deploy_local_endpoint(**kwargs) - + # Remove mode/role from kwargs if present kwargs.pop("mode", None) if "role" in kwargs: self.role_arn = kwargs.pop("role") - + # Set default values if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - + if "initial_instance_count" not in kwargs: kwargs["initial_instance_count"] = 1 - + # Deploy to SageMaker endpoint return self._deploy_core_endpoint(*args, **kwargs) - def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, LocalEndpoint]: """Simplified TGI deploy wrapper - env vars already set during build.""" - + # Handle mode overrides for local deployment if self.mode == Mode.IN_PROCESS: return self._deploy_local_endpoint(**kwargs) - + if self.mode == Mode.LOCAL_CONTAINER: return self._deploy_local_endpoint(**kwargs) - + # Remove mode/role from kwargs if present kwargs.pop("mode", None) if "role" in kwargs: self.role_arn = kwargs.pop("role") - + # Set default values if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - + if "initial_instance_count" not in kwargs: kwargs["initial_instance_count"] = 1 - + # Deploy to SageMaker endpoint return self._deploy_core_endpoint(*args, **kwargs) - def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, LocalEndpoint]: """Simplified TEI deploy wrapper - env vars already set during build.""" - + # Handle local deployment modes if self.mode == Mode.IN_PROCESS: return self._deploy_local_endpoint(**kwargs) - + if self.mode == Mode.LOCAL_CONTAINER: return self._deploy_local_endpoint(**kwargs) - + # Remove mode/role from kwargs if present kwargs.pop("mode", None) if "role" in kwargs: self.role_arn = kwargs.pop("role") - + # Set default values if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - + if "initial_instance_count" not in kwargs: kwargs["initial_instance_count"] = 1 - + # Deploy to SageMaker endpoint return self._deploy_core_endpoint(**kwargs) - - def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, LocalEndpoint]: """Simplified JumpStart deploy wrapper - resource prep already done during build.""" - + # Handle local deployment if self.mode == Mode.LOCAL_CONTAINER: return self._deploy_local_endpoint(**kwargs) - + # Remove mode/role from kwargs if present kwargs.pop("mode", None) if "role" in kwargs: self.role_arn = kwargs.pop("role") - + # Set default values if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - + if hasattr(self, "instance_type"): kwargs.update({"instance_type": self.instance_type}) - + # Deploy to SageMaker endpoint return self._deploy_core_endpoint(**kwargs) - def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, LocalEndpoint]: + def _transformers_model_builder_deploy_wrapper( + self, *args, **kwargs + ) -> Union[Endpoint, LocalEndpoint]: """Simplified Transformers deploy wrapper - env vars already set during build.""" - + # Handle local deployment modes if self.mode == Mode.LOCAL_CONTAINER: return self._deploy_local_endpoint(**kwargs) - + if self.mode == Mode.IN_PROCESS: return self._deploy_local_endpoint(**kwargs) - + # Remove mode/role from kwargs if present kwargs.pop("mode", None) if "role" in kwargs: self.role_arn = kwargs.pop("role") - + # Set default values if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - + if "initial_instance_count" not in kwargs: kwargs["initial_instance_count"] = 1 - + # Deploy to SageMaker endpoint return self._deploy_core_endpoint(**kwargs) diff --git a/sagemaker-serve/tests/integ/test_jumpstart_vllm_integration.py b/sagemaker-serve/tests/integ/test_jumpstart_vllm_integration.py new file mode 100644 index 0000000000..861d9326d9 --- /dev/null +++ b/sagemaker-serve/tests/integ/test_jumpstart_vllm_integration.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import uuid +import pytest +import logging + +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.core.jumpstart.configs import JumpStartConfig +from sagemaker.train.configs import Compute + +logger = logging.getLogger(__name__) + +MODEL_ID = "huggingface-vlm-qwen3-5-4b" +INSTANCE_TYPE = "ml.g6.12xlarge" +MODEL_NAME_PREFIX = "js-vllm-test-model" + + +@pytest.mark.slow_test +def test_jumpstart_vllm_build(): + """Integration test for JumpStart model using vLLM container image. + + Validates that ModelBuilder correctly handles vLLM container images + which do not match the legacy djl-inference/tgi-inference/huggingface-pytorch-inference + patterns in the endpoint mode routing. + """ + logger.info("Starting JumpStart vLLM build test...") + + compute = Compute(instance_type=INSTANCE_TYPE) + jumpstart_config = JumpStartConfig(model_id=MODEL_ID) + model_builder = ModelBuilder.from_jumpstart_config( + jumpstart_config=jumpstart_config, compute=compute + ) + unique_id = str(uuid.uuid4())[:8] + + core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}") + logger.info(f"Model Successfully Created: {core_model.model_name}") + + try: + assert ( + "vllm" in model_builder.image_uri + ), f"Expected vLLM image URI, got: {model_builder.image_uri}" + assert model_builder.s3_upload_path is not None, "s3_upload_path should be set" + logger.info("JumpStart vLLM build test completed successfully") + finally: + core_model.delete() + logger.info("Model Successfully Deleted!") diff --git a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py index 1a6f3b2442..d905b7decf 100644 --- a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py +++ b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py @@ -919,7 +919,7 @@ def test_build_unsupported_image_uri(self, mock_init): with self.assertRaises(ValueError) as ctx: self.builder._build_for_jumpstart() - self.assertIn("Unsupported", str(ctx.exception)) + self.assertIn("Local container mode is not yet supported", str(ctx.exception)) @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") @patch("sagemaker.serve.model_builder_servers.prepare_djl_js_resources") @@ -985,21 +985,21 @@ def test_build_passes_none_config_name_when_not_set( @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") @patch.object(MockModelBuilderServers, "_prepare_for_mode") - @patch.object(MockModelBuilderServers, "_build_for_djl_jumpstart") - def test_build_sagemaker_endpoint_djl(self, mock_djl_build, mock_prepare, mock_init): + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_djl(self, mock_create, mock_prepare, mock_init): """Test building DJL JumpStart for SAGEMAKER_ENDPOINT.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "djl-inference:0.21.0" mock_init_kwargs.env = {} mock_init_kwargs.model_data = "s3://bucket/model.tar.gz" mock_init.return_value = mock_init_kwargs - mock_djl_build.return_value = Mock() + mock_create.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.image_uri = None result = self.builder._build_for_jumpstart() - mock_djl_build.assert_called_once() + mock_create.assert_called_once() class TestDeployWrappers(unittest.TestCase): diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py b/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py index 02b0962feb..3b00264ffa 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py @@ -18,7 +18,7 @@ mock_model_object, MOCK_ROLE_ARN, MOCK_IMAGE_URI, - MOCK_S3_URI + MOCK_S3_URI, ) @@ -35,181 +35,179 @@ def test_build_for_model_server_unsupported_raises(self): model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = "UNSUPPORTED_SERVER" - + with self.assertRaises(ValueError) as context: builder._build_for_model_server() - + self.assertIn("is not supported yet", str(context.exception)) def test_build_for_model_server_without_model_raises(self): """Test that missing model/inference_spec raises error.""" builder = ModelBuilder( - role_arn=MOCK_ROLE_ARN, - sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, image_uri=MOCK_IMAGE_URI ) builder.model = None builder.inference_spec = None builder.model_metadata = None builder.model_server = ModelServer.TORCHSERVE - + with self.assertRaises(ValueError) as context: builder._build_for_model_server() - + self.assertIn("Missing required parameter", str(context.exception)) - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_torchserve') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_torchserve") def test_build_for_model_server_routes_to_torchserve(self, mock_build): """Test routing to TorchServe builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TORCHSERVE - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_djl") def test_build_for_model_server_routes_to_djl(self, mock_build): """Test routing to DJL builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.DJL_SERVING - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_tgi") def test_build_for_model_server_routes_to_tgi(self, mock_build): """Test routing to TGI builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model="gpt2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TGI - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tei') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_tei") def test_build_for_model_server_routes_to_tei(self, mock_build): """Test routing to TEI builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model="sentence-transformers/all-MiniLM-L6-v2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TEI - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_triton') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_triton") def test_build_for_model_server_routes_to_triton(self, mock_build): """Test routing to Triton builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TRITON - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tensorflow_serving') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_tensorflow_serving") def test_build_for_model_server_routes_to_tensorflow(self, mock_build): """Test routing to TensorFlow Serving builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TENSORFLOW_SERVING - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_transformers') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_transformers") def test_build_for_model_server_routes_to_mms(self, mock_build): """Test routing to MMS/Transformers builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.MMS - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_smd') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_smd") def test_build_for_model_server_routes_to_smd(self, mock_build): """Test routing to SMD builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.SMD - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() @@ -225,86 +223,93 @@ def setUp(self): def tearDown(self): """Clean up temp directory.""" import shutil + if os.path.exists(self.temp_dir): shutil.rmtree(self.temp_dir) - @patch('sagemaker.serve.model_builder.save_pkl') - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder._detect_framework_and_version') - @patch('sagemaker.serve.model_builder._get_model_base') - def test_build_for_torchserve_with_model_object(self, mock_get_base, mock_detect_fw, mock_prepare, mock_create, mock_save_pkl): + @patch("sagemaker.serve.model_builder.save_pkl") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + @patch("sagemaker.serve.model_builder._detect_framework_and_version") + @patch("sagemaker.serve.model_builder._get_model_base") + def test_build_for_torchserve_with_model_object( + self, mock_get_base, mock_detect_fw, mock_prepare, mock_create, mock_save_pkl + ): """Test TorchServe build with model object.""" mock_model = Mock(spec=Model) mock_create.return_value = mock_model mock_get_base.return_value = mock_model_object() mock_detect_fw.return_value = ("pytorch", "1.13.0") - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, model_path=self.temp_dir, mode=Mode.IN_PROCESS, # Use IN_PROCESS to skip prepare_for_torchserve - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TORCHSERVE - + result = builder._build_for_torchserve() - + self.assertEqual(result, mock_model) mock_save_pkl.assert_called() mock_create.assert_called_once() - @patch('sagemaker.serve.model_builder.save_pkl') - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') - def test_build_for_torchserve_with_hf_model_id(self, mock_is_js, mock_prepare, mock_create, mock_save_pkl): + @patch("sagemaker.serve.model_builder.save_pkl") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_build_for_torchserve_with_hf_model_id( + self, mock_is_js, mock_prepare, mock_create, mock_save_pkl + ): """Test TorchServe build with HuggingFace model ID.""" mock_is_js.return_value = False mock_model = Mock(spec=Model) mock_create.return_value = mock_model - + builder = ModelBuilder( model="gpt2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, model_path=self.temp_dir, - mode=Mode.IN_PROCESS # Use IN_PROCESS to skip prepare_for_torchserve + mode=Mode.IN_PROCESS, # Use IN_PROCESS to skip prepare_for_torchserve ) builder.model_server = ModelServer.TORCHSERVE builder.env_vars = {} builder.image_uri = MOCK_IMAGE_URI - + result = builder._build_for_torchserve() - + self.assertEqual(result, mock_model) self.assertEqual(builder.env_vars["HF_MODEL_ID"], "gpt2") self.assertIsNone(builder.s3_upload_path) - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder.ModelBuilder._save_model_inference_spec') - @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') - def test_build_for_torchserve_with_hf_token(self, mock_is_js, mock_save, mock_prepare, mock_create): + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + @patch("sagemaker.serve.model_builder.ModelBuilder._save_model_inference_spec") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_build_for_torchserve_with_hf_token( + self, mock_is_js, mock_save, mock_prepare, mock_create + ): """Test TorchServe build with HuggingFace token.""" mock_is_js.return_value = False mock_model = Mock(spec=Model) mock_create.return_value = mock_model - + builder = ModelBuilder( model="gpt2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, model_path=self.temp_dir, - mode=Mode.IN_PROCESS + mode=Mode.IN_PROCESS, ) builder.model_server = ModelServer.TORCHSERVE builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "hf_token_123"} - + result = builder._build_for_torchserve() - + self.assertEqual(result, mock_model) self.assertEqual(builder.env_vars["HF_TOKEN"], "hf_token_123") @@ -316,46 +321,84 @@ def setUp(self): """Set up test fixtures.""" self.mock_session = mock_sagemaker_session() - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") def test_build_for_jumpstart_unsupported_image_raises(self, mock_prepare, mock_get_kwargs): - """Test that unsupported JumpStart image raises error.""" + """Test that unsupported JumpStart image in LOCAL_CONTAINER mode raises error.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "unsupported-image:latest" mock_init_kwargs.env = {} mock_get_kwargs.return_value = mock_init_kwargs - + builder = ModelBuilder( model="some-model-id", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - mode=Mode.SAGEMAKER_ENDPOINT + mode=Mode.LOCAL_CONTAINER, ) builder._optimizing = False - + with self.assertRaises(ValueError) as context: builder._build_for_jumpstart() - - self.assertIn("Unsupported JumpStart image URI", str(context.exception)) - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - def test_build_for_jumpstart_routes_to_djl(self, mock_prepare, mock_build_djl, mock_get_kwargs): + self.assertIn("Local container mode is not yet supported", str(context.exception)) + self.assertIn("Mode.SAGEMAKER_ENDPOINT", str(context.exception)) + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_build_for_jumpstart_unknown_image_endpoint_uses_default( + self, mock_prepare, mock_create, mock_get_kwargs + ): + """Test that unknown image in endpoint mode passes through without error.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "vllm:0.17-gpu-py312-cu129-ubuntu22.04-sagemaker-v1" + ) + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache/models/vllm/model.tar.gz" + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock(spec=Model) + mock_create.return_value = mock_model + + builder = ModelBuilder( + model="huggingface-vlm-qwen3-5-4b", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT, + ) + builder._optimizing = False + + result = builder._build_for_jumpstart() + + self.assertEqual(result, mock_model) + self.assertIsNone(builder.model_server) + self.assertEqual(builder.s3_upload_path, "s3://jumpstart-cache/models/vllm/model.tar.gz") + mock_create.assert_called_once() + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_build_for_jumpstart_routes_to_djl(self, mock_prepare, mock_create, mock_get_kwargs): """Test JumpStart routing to DJL builder.""" mock_init_kwargs = Mock() - mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + ) mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache/models/djl/model.tar.gz" mock_get_kwargs.return_value = mock_init_kwargs mock_model = Mock(spec=Model) - mock_build_djl.return_value = mock_model + mock_create.return_value = mock_model builder = ModelBuilder( model="huggingface-llm-falcon-7b", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - mode=Mode.SAGEMAKER_ENDPOINT + mode=Mode.SAGEMAKER_ENDPOINT, ) builder._optimizing = False @@ -363,26 +406,31 @@ def test_build_for_jumpstart_routes_to_djl(self, mock_prepare, mock_build_djl, m self.assertEqual(result, mock_model) self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) - mock_build_djl.assert_called_once() + mock_create.assert_called_once() - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_djl, mock_get_kwargs): + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_build_for_jumpstart_passes_config_name( + self, mock_prepare, mock_create, mock_get_kwargs + ): """Test that config_name is forwarded to get_init_kwargs.""" mock_init_kwargs = Mock() - mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + ) mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache/models/djl/model.tar.gz" mock_get_kwargs.return_value = mock_init_kwargs mock_model = Mock(spec=Model) - mock_build_djl.return_value = mock_model + mock_create.return_value = mock_model builder = ModelBuilder( model="meta-textgeneration-llama-3-3-70b-instruct", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - mode=Mode.SAGEMAKER_ENDPOINT + mode=Mode.SAGEMAKER_ENDPOINT, ) builder._optimizing = False builder.config_name = "lmi-optimized" @@ -391,61 +439,66 @@ def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_d mock_get_kwargs.assert_called_once() call_kwargs = mock_get_kwargs.call_args - self.assertEqual(call_kwargs.kwargs.get("config_name") or call_kwargs[1].get("config_name"), "lmi-optimized") + self.assertEqual( + call_kwargs.kwargs.get("config_name") or call_kwargs[1].get("config_name"), + "lmi-optimized", + ) - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi_jumpstart') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - def test_build_for_jumpstart_routes_to_tgi(self, mock_prepare, mock_build_tgi, mock_get_kwargs): + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_build_for_jumpstart_routes_to_tgi(self, mock_prepare, mock_create, mock_get_kwargs): """Test JumpStart routing to TGI builder.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.0.1-tgi0.9.3-gpu-py39-cu118-ubuntu20.04" mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache/models/tgi/model.tar.gz" mock_get_kwargs.return_value = mock_init_kwargs - + mock_model = Mock(spec=Model) - mock_build_tgi.return_value = mock_model - + mock_create.return_value = mock_model + builder = ModelBuilder( model="huggingface-llm-mistral-7b", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - mode=Mode.SAGEMAKER_ENDPOINT + mode=Mode.SAGEMAKER_ENDPOINT, ) builder._optimizing = False - + result = builder._build_for_jumpstart() - + self.assertEqual(result, mock_model) self.assertEqual(builder.model_server, ModelServer.TGI) - mock_build_tgi.assert_called_once() + mock_create.assert_called_once() - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_mms_jumpstart') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - def test_build_for_jumpstart_routes_to_mms(self, mock_prepare, mock_build_mms, mock_get_kwargs): + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_build_for_jumpstart_routes_to_mms(self, mock_prepare, mock_create, mock_get_kwargs): """Test JumpStart routing to MMS builder.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache/models/mms/model.tar.gz" mock_get_kwargs.return_value = mock_init_kwargs - + mock_model = Mock(spec=Model) - mock_build_mms.return_value = mock_model - + mock_create.return_value = mock_model + builder = ModelBuilder( model="pytorch-ic-mobilenet-v2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - mode=Mode.SAGEMAKER_ENDPOINT + mode=Mode.SAGEMAKER_ENDPOINT, ) builder._optimizing = False - + result = builder._build_for_jumpstart() - + self.assertEqual(result, mock_model) self.assertEqual(builder.model_server, ModelServer.MMS) - mock_build_mms.assert_called_once() + mock_create.assert_called_once() class TestDeployWrappers(unittest.TestCase): @@ -455,114 +508,104 @@ def setUp(self): """Set up test fixtures.""" self.mock_session = mock_sagemaker_session() - @patch('sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint') + @patch("sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint") def test_djl_deploy_wrapper_sets_timeout(self, mock_deploy): """Test DJL deploy wrapper sets model data download timeout.""" mock_endpoint = Mock() mock_deploy.return_value = mock_endpoint - + builder = ModelBuilder( model="gpt2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.DJL_SERVING builder.built_model = Mock(spec=Model) - - result = builder._djl_model_builder_deploy_wrapper( - model_data_download_timeout=1800 - ) - + + result = builder._djl_model_builder_deploy_wrapper(model_data_download_timeout=1800) + self.assertEqual(result, mock_endpoint) mock_deploy.assert_called_once() call_kwargs = mock_deploy.call_args[1] - self.assertEqual(call_kwargs['model_data_download_timeout'], 1800) + self.assertEqual(call_kwargs["model_data_download_timeout"], 1800) - @patch('sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint') + @patch("sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint") def test_tgi_deploy_wrapper_calls_core_deploy(self, mock_deploy): """Test TGI deploy wrapper calls core deploy.""" mock_endpoint = Mock() mock_deploy.return_value = mock_endpoint - + builder = ModelBuilder( model="gpt2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TGI builder.built_model = Mock(spec=Model) - - result = builder._tgi_model_builder_deploy_wrapper( - endpoint_name="test-endpoint" - ) - + + result = builder._tgi_model_builder_deploy_wrapper(endpoint_name="test-endpoint") + self.assertEqual(result, mock_endpoint) mock_deploy.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint') + @patch("sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint") def test_tei_deploy_wrapper_calls_core_deploy(self, mock_deploy): """Test TEI deploy wrapper calls core deploy.""" mock_endpoint = Mock() mock_deploy.return_value = mock_endpoint - + builder = ModelBuilder( model="sentence-transformers/all-MiniLM-L6-v2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TEI builder.built_model = Mock(spec=Model) - - result = builder._tei_model_builder_deploy_wrapper( - endpoint_name="test-endpoint" - ) - + + result = builder._tei_model_builder_deploy_wrapper(endpoint_name="test-endpoint") + self.assertEqual(result, mock_endpoint) mock_deploy.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint') + @patch("sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint") def test_js_deploy_wrapper_calls_core_deploy(self, mock_deploy): """Test JumpStart deploy wrapper calls core deploy.""" mock_endpoint = Mock() mock_deploy.return_value = mock_endpoint - + builder = ModelBuilder( model="huggingface-llm-falcon-7b", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.built_model = Mock(spec=Model) - - result = builder._js_builder_deploy_wrapper( - endpoint_name="test-endpoint" - ) - + + result = builder._js_builder_deploy_wrapper(endpoint_name="test-endpoint") + self.assertEqual(result, mock_endpoint) mock_deploy.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint') + @patch("sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint") def test_transformers_deploy_wrapper_calls_core_deploy(self, mock_deploy): """Test Transformers deploy wrapper calls core deploy.""" mock_endpoint = Mock() mock_deploy.return_value = mock_endpoint - + builder = ModelBuilder( model="bert-base-uncased", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.MMS builder.built_model = Mock(spec=Model) - - result = builder._transformers_model_builder_deploy_wrapper( - endpoint_name="test-endpoint" - ) - + + result = builder._transformers_model_builder_deploy_wrapper(endpoint_name="test-endpoint") + self.assertEqual(result, mock_endpoint) mock_deploy.assert_called_once()