diff --git a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py index dfcfa10bc7..64c76f59cf 100644 --- a/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py +++ b/sagemaker-serve/src/sagemaker/serve/model_builder_servers.py @@ -36,9 +36,7 @@ _get_gpu_info_fallback, ) from sagemaker.serve.utils.hf_utils import _get_model_config_properties_from_hf -from sagemaker.serve.detector.image_detector import ( - _get_model_base, _detect_framework_and_version -) +from sagemaker.serve.detector.image_detector import _get_model_base, _detect_framework_and_version from sagemaker.serve.detector.pickler import save_pkl from sagemaker.serve.utils.types import ModelServer @@ -75,14 +73,14 @@ class _ModelBuilderServers(object): def _build_for_model_server(self) -> Model: """Build model using explicit model server configuration. - + This method routes to the appropriate model server builder based on the specified model_server parameter. It validates that the model server is supported and that required parameters are provided. - + Returns: Model: A deployable model object configured for the specified model server. - + Raises: ValueError: If the model server is not supported or required parameters are missing. """ @@ -118,18 +116,17 @@ def _build_for_model_server(self) -> Model: else: raise ValueError(f"Unsupported model server: {self.model_server}") - def _build_for_torchserve(self) -> Model: """Build model for TorchServe deployment. - + Configures the model for deployment using TorchServe model server. Handles both local model objects and HuggingFace model IDs. - + Returns: Model: Configured model ready for TorchServe deployment. """ self.secret_key = "" - + # Save inference spec if we have local artifacts self._save_model_inference_spec() @@ -137,11 +134,11 @@ def _build_for_torchserve(self) -> Model: # Configure HuggingFace model support if not self._is_jumpstart_model_id(): self.env_vars.setdefault("HF_MODEL_ID", self.model) - + # Add HuggingFace token if available if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + # HF models download directly from hub self.s3_upload_path = None @@ -176,19 +173,18 @@ def _build_for_torchserve(self) -> Model: return self._create_model() - def _build_for_tgi(self) -> Model: """Build model for Text Generation Inference (TGI) deployment. - + Configures the model for deployment using Hugging Face's TGI server, optimized for large language model inference with features like tensor parallelism and continuous batching. - + Returns: Model: Configured model ready for TGI deployment. """ self.secret_key = "" - + # Initialize TGI-specific configuration if self.model_server != ModelServer.TGI: messaging = ( @@ -200,7 +196,7 @@ def _build_for_tgi(self) -> Model: self.model_server = ModelServer.TGI self._validate_tgi_serving_sample_data() - + # Use notebook instance type if available nb_instance = _get_nb_instance() if nb_instance and not self._user_provided_instance_type: @@ -208,49 +204,51 @@ def _build_for_tgi(self) -> Model: logger.debug(f"Using detected notebook instance type: {nb_instance}") from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for TGI self.env_vars.setdefault("HF_MODEL_ID", self.model) - + self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + # Apply TGI-specific configurations default_tgi_configurations, _default_max_new_tokens = _get_default_tgi_configurations( self.model, self.hf_model_config, self.schema_builder ) # Filter out None values to avoid pydantic validation errors - filtered_tgi_config = {k: v for k, v in default_tgi_configurations.items() if v is not None} + filtered_tgi_config = { + k: v for k, v in default_tgi_configurations.items() if v is not None + } self.env_vars.update(filtered_tgi_config) - + # Configure schema builder for text generation if "parameters" not in self.schema_builder.sample_input: self.schema_builder.sample_input["parameters"] = {} - self.schema_builder.sample_input["parameters"]["max_new_tokens"] = _default_max_new_tokens - + self.schema_builder.sample_input["parameters"][ + "max_new_tokens" + ] = _default_max_new_tokens + # Set TGI sharding defaults self.env_vars.setdefault("SHARDED", "false") self.env_vars.setdefault("NUM_SHARD", "1") - + # Configure HuggingFace cache and authentication - tgi_env_vars = { - "HF_HOME": "/tmp", - "HUGGINGFACE_HUB_CACHE": "/tmp" - } - + tgi_env_vars = {"HF_HOME": "/tmp", "HUGGINGFACE_HUB_CACHE": "/tmp"} + if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): tgi_env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + self.env_vars.update(tgi_env_vars) - + # TGI downloads models directly from HuggingFace Hub self.s3_upload_path = None self._auto_detect_image_uri() - + if not self._optimizing: if self.mode in LOCAL_MODES: self._prepare_for_mode(should_upload_artifacts=True) @@ -270,48 +268,51 @@ def _build_for_tgi(self) -> Model: raise ValueError( "Instance type must be provided when building for SageMaker Endpoint mode." ) - + try: tot_gpus = _get_gpu_info(self.instance_type, self.sagemaker_session) except Exception: # pylint: disable=W0703 tot_gpus = _get_gpu_info_fallback(self.instance_type) - + default_num_shard = _get_default_tensor_parallel_degree(self.hf_model_config, tot_gpus) - self.env_vars.update({ - "NUM_SHARD": str(default_num_shard), - "SHARDED": "true" if default_num_shard > 1 else "false", - }) - + self.env_vars.update( + { + "NUM_SHARD": str(default_num_shard), + "SHARDED": "true" if default_num_shard > 1 else "false", + } + ) + model = self._create_model() - + if "HF_HUB_OFFLINE" in self.env_vars: self.env_vars.update({"HF_HUB_OFFLINE": "0"}) - + return model def _build_for_djl(self) -> Model: """Build model for DJL Serving deployment. - + Configures the model for deployment using Amazon's Deep Java Library (DJL) Serving, which provides high-performance inference with support for tensor parallelism and model partitioning. - + Returns: Model: Configured model ready for DJL Serving deployment. """ self.secret_key = "" self.model_server = ModelServer.DJL_SERVING - + # Set MODEL_LOADING_TIMEOUT from instance variable if self.model_data_download_timeout: self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(self.model_data_download_timeout)}) - + self._validate_djl_serving_sample_data() # Create DJL directory structure for local artifacts from sagemaker.serve.model_server.djl_serving.prepare import _create_dir_structure + _create_dir_structure(self.model_path) - + # Use notebook instance type if available nb_instance = _get_nb_instance() if nb_instance and not self._user_provided_instance_type: @@ -321,7 +322,7 @@ def _build_for_djl(self) -> Model: if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for DJL self.env_vars.setdefault("HF_MODEL_ID", self.model) - + # Get model configuration for DJL optimization self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") @@ -336,7 +337,9 @@ def _build_for_djl(self) -> Model: # Configure schema builder for text generation if "parameters" not in self.schema_builder.sample_input: self.schema_builder.sample_input["parameters"] = {} - self.schema_builder.sample_input["parameters"]["max_new_tokens"] = _default_max_new_tokens + self.schema_builder.sample_input["parameters"][ + "max_new_tokens" + ] = _default_max_new_tokens # Set DJL serving defaults (only if not already set by user) djl_env_vars = { @@ -360,7 +363,7 @@ def _build_for_djl(self) -> Model: self.s3_upload_path = None self._auto_detect_image_uri() - + if not self._optimizing: if self.mode in LOCAL_MODES: self._prepare_for_mode(should_upload_artifacts=True) @@ -383,34 +386,31 @@ def _build_for_djl(self) -> Model: raise ValueError( "Instance type must be provided when building for SageMaker Endpoint mode." ) - + try: tot_gpus = _get_gpu_info(self.instance_type, self.sagemaker_session) except Exception: # pylint: disable=W0703 tot_gpus = _get_gpu_info_fallback(self.instance_type) - + default_tensor_parallel_degree = _get_default_tensor_parallel_degree( self.hf_model_config, tot_gpus ) - self.env_vars.update({ - "TENSOR_PARALLEL_DEGREE": str(default_tensor_parallel_degree) - }) - + self.env_vars.update({"TENSOR_PARALLEL_DEGREE": str(default_tensor_parallel_degree)}) + model = self._create_model() - + if "HF_HUB_OFFLINE" in self.env_vars: self.env_vars.update({"TRANSFORMERS_OFFLINE": "0"}) - - return model + return model def _build_for_triton(self) -> Model: """Build model for NVIDIA Triton Inference Server deployment. - + Configures the model for deployment using NVIDIA Triton Inference Server, which provides high-performance inference with support for multiple frameworks and advanced features like dynamic batching. - + Returns: Model: Configured model ready for Triton deployment. """ @@ -420,7 +420,7 @@ def _build_for_triton(self) -> Model: if isinstance(self.model, str): self.framework = None self.framework_version = None - + # Configure HuggingFace model for Triton if not self._is_jumpstart_model_id(): # Get model metadata for task detection @@ -430,53 +430,53 @@ def _build_for_triton(self) -> Model: model_task = hf_model_md.get("pipeline_tag") if model_task: self.env_vars.update({"HF_TASK": model_task}) - + # Configure HuggingFace authentication self.env_vars.setdefault("HF_MODEL_ID", self.model) if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + # Triton downloads models directly from HuggingFace Hub self.s3_upload_path = None elif self.model: - fw, self.framework_version = _detect_framework_and_version(str(_get_model_base(self.model))) + fw, self.framework_version = _detect_framework_and_version( + str(_get_model_base(self.model)) + ) self.framework = self._normalize_framework_to_enum(fw) - + # Auto-detect Triton image if not provided if not self.image_uri: self._auto_detect_image_for_triton() - + # Prepare Triton-specific artifacts self._save_inference_spec() self._prepare_for_triton() - + # Prepare deployment artifacts if self.mode in LOCAL_MODES: self._prepare_for_mode() else: # self._prepare_for_mode(should_upload_artifacts=True) self.s3_model_data_url, _ = self._prepare_for_mode(should_upload_artifacts=True) - - return self._create_model() - + return self._create_model() def _build_for_tensorflow_serving(self) -> Model: """Build model for TensorFlow Serving deployment. - + Configures the model for deployment using TensorFlow Serving, Google's high-performance serving system for TensorFlow models. - + Returns: Model: Configured model ready for TensorFlow Serving deployment. - + Raises: ValueError: If image_uri is not provided for TensorFlow Serving. """ self.secret_key = "" if not getattr(self, "_is_mlflow_model", False): raise ValueError("Tensorflow Serving is currently only supported for mlflow models.") - + # Save Schema Builder if not os.path.exists(self.model_path): os.makedirs(self.model_path) @@ -501,23 +501,22 @@ def _build_for_tensorflow_serving(self) -> Model: return self._create_model() - def _build_for_tei(self) -> Model: """Build model for Text Embeddings Inference (TEI) deployment. - + Configures the model for deployment using Hugging Face's TEI server, optimized for embedding model inference with features like pooling strategies and efficient batching. - + Returns: Model: Configured model ready for TEI deployment. """ self.secret_key = "" - + # Set MODEL_LOADING_TIMEOUT from instance variable if self.model_data_download_timeout: self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(self.model_data_download_timeout)}) - + if self.model_server != ModelServer.TEI: messaging = ( "HuggingFace Model ID support on model server: " @@ -534,33 +533,34 @@ def _build_for_tei(self) -> Model: logger.debug(f"Using detected notebook instance type: {nb_instance}") from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if isinstance(self.model, str) and not self._is_jumpstart_model_id(): # Configure HuggingFace model for TEI self.env_vars.setdefault("HF_MODEL_ID", self.model) - + self.hf_model_config = _get_model_config_properties_from_hf( self.model, self.env_vars.get("HUGGING_FACE_HUB_TOKEN") ) - + # Configure TEI-specific environment variables tei_env_vars = { "HF_HOME": "/tmp", "HUGGINGFACE_HUB_CACHE": "/tmp", "MODEL_LOADING_TIMEOUT": "240", } - + if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): tei_env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + self.env_vars.update(tei_env_vars) - + # TEI downloads models directly from HuggingFace Hub self.s3_upload_path = None self._auto_detect_image_uri() - + if not self._optimizing: if self.mode in LOCAL_MODES: self._prepare_for_mode(should_upload_artifacts=True) @@ -579,27 +579,26 @@ def _build_for_tei(self) -> Model: raise ValueError( "Instance type must be provided when building for SageMaker Endpoint mode." ) - + model = self._create_model() - + if "HF_HUB_OFFLINE" in self.env_vars: self.env_vars.update({"HF_HUB_OFFLINE": "0"}) - - return model + return model def _build_for_smd(self) -> Model: """Build model for SageMaker Distribution (SMD) deployment. - + Configures the model for deployment using SageMaker's distribution container, which provides a comprehensive ML environment with pre-installed frameworks and optimizations. - + Returns: Model: Configured model ready for SMD deployment. """ self.secret_key = "" - + self._save_model_inference_spec() if self.mode != Mode.IN_PROCESS: @@ -625,16 +624,16 @@ def _build_for_smd(self) -> Model: def _build_for_transformers(self) -> Model: """Build model for HuggingFace Transformers deployment. - + Configures the model for deployment using Multi-Model Server (MMS) with HuggingFace Transformers backend for general-purpose model inference. - + Returns: Model: Configured model ready for Transformers deployment. """ self.secret_key = "" self.model_server = ModelServer.MMS - + # Set MODEL_LOADING_TIMEOUT from instance variable if self.model_data_download_timeout: self.env_vars.update({"MODEL_LOADING_TIMEOUT": str(self.model_data_download_timeout)}) @@ -676,6 +675,7 @@ def _build_for_transformers(self) -> Model: logger.debug(f"Using detected notebook instance type: {nb_instance}") from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if not isinstance(self.model, str) or not self._is_jumpstart_model_id(): @@ -700,13 +700,13 @@ def _build_for_transformers(self) -> Model: model_task = hf_model_md.get("pipeline_tag") if model_task: self.env_vars.update({"HF_TASK": model_task}) - + self.env_vars.setdefault("HF_MODEL_ID", self.model) - + # Add HuggingFace token if available if self.env_vars.get("HUGGING_FACE_HUB_TOKEN"): self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN") - + # Get HF config for string model IDs if hasattr(self.env_vars, "HF_API_TOKEN"): self.hf_model_config = _get_model_config_properties_from_hf( @@ -723,7 +723,6 @@ def _build_for_transformers(self) -> Model: else: self.s3_model_data_url, _ = self._prepare_for_mode() - # Clean up empty secret key if ( "SAGEMAKER_SERVE_SECRET_KEY" in self.env_vars @@ -739,14 +738,12 @@ def _build_for_transformers(self) -> Model: return self._create_model() - - def _build_for_djl_jumpstart(self, init_kwargs) -> Model: """Build DJL model using JumpStart artifacts. - + Args: init_kwargs: JumpStart initialization parameters. - + Returns: Model: Configured DJL model for JumpStart deployment. """ @@ -754,14 +751,12 @@ def _build_for_djl_jumpstart(self, init_kwargs) -> Model: self.model_server = ModelServer.DJL_SERVING from sagemaker.serve.model_server.djl_serving.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if self.mode in LOCAL_MODES: # Prepare DJL resources for local deployment - ( - self.js_model_config, - self.prepared_for_djl - ) = prepare_djl_js_resources( + (self.js_model_config, self.prepared_for_djl) = prepare_djl_js_resources( model_path=self.model_path, js_id=self.model, dependencies=self.dependencies, @@ -776,13 +771,12 @@ def _build_for_djl_jumpstart(self, init_kwargs) -> Model: self.prepared_for_djl = True return self._create_model() - def _build_for_tgi_jumpstart(self, init_kwargs) -> Model: """Build TGI model using JumpStart artifacts. - + Args: init_kwargs: JumpStart initialization parameters. - + Returns: Model: Configured TGI model for JumpStart deployment. """ @@ -790,6 +784,7 @@ def _build_for_tgi_jumpstart(self, init_kwargs) -> Model: self.model_server = ModelServer.TGI from sagemaker.serve.model_server.tgi.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if self.mode in LOCAL_MODES: @@ -809,16 +804,15 @@ def _build_for_tgi_jumpstart(self, init_kwargs) -> Model: if not hasattr(self, "s3_upload_path") or not self.s3_upload_path: self.s3_upload_path = init_kwargs.model_data - self.prepared_for_tgi = True + self.prepared_for_tgi = True return self._create_model() - def _build_for_mms_jumpstart(self, init_kwargs) -> Model: """Build MMS model using JumpStart artifacts. - + Args: init_kwargs: JumpStart initialization parameters. - + Returns: Model: Configured MMS model for JumpStart deployment. """ @@ -826,6 +820,7 @@ def _build_for_mms_jumpstart(self, init_kwargs) -> Model: self.model_server = ModelServer.MMS from sagemaker.serve.model_server.multi_model_server.prepare import _create_dir_structure + _create_dir_structure(self.model_path) if self.mode in LOCAL_MODES: @@ -848,49 +843,50 @@ def _build_for_mms_jumpstart(self, init_kwargs) -> Model: self.prepared_for_mms = True return self._create_model() - - def _build_for_jumpstart(self) -> Model: """Build model for JumpStart deployment. - + Configures the model for deployment using SageMaker JumpStart, which provides pre-trained models with optimized configurations and deployment settings. - + Returns: Model: Configured model ready for JumpStart deployment. - + Raises: ValueError: If the JumpStart image URI is not supported. """ from sagemaker.core.jumpstart.factory.utils import get_init_kwargs self.secret_key = "" - + # Get JumpStart model configuration init_kwargs = get_init_kwargs( model_id=self.model, model_version=self.model_version or "*", region=self.region, instance_type=self.instance_type, - tolerate_vulnerable_model=getattr(self, 'tolerate_vulnerable_model', None), - tolerate_deprecated_model=getattr(self, 'tolerate_deprecated_model', None), - config_name=getattr(self, 'config_name', None), + tolerate_vulnerable_model=getattr(self, "tolerate_vulnerable_model", None), + tolerate_deprecated_model=getattr(self, "tolerate_deprecated_model", None), + config_name=getattr(self, "config_name", None), ) - + # Configure image URI and environment variables self.image_uri = self.image_uri or init_kwargs.image_uri - - if hasattr(init_kwargs, 'env') and init_kwargs.env: + + if hasattr(init_kwargs, "env") and init_kwargs.env: self.env_vars.update(init_kwargs.env) - + # Handle model artifacts for fine-tuned models - if hasattr(init_kwargs, 'model_data') and init_kwargs.model_data: - if isinstance(init_kwargs.model_data, dict) and 'S3DataSource' in init_kwargs.model_data: - self.s3_model_data_url = init_kwargs.model_data['S3DataSource']['S3Uri'] + if hasattr(init_kwargs, "model_data") and init_kwargs.model_data: + if ( + isinstance(init_kwargs.model_data, dict) + and "S3DataSource" in init_kwargs.model_data + ): + self.s3_model_data_url = init_kwargs.model_data["S3DataSource"]["S3Uri"] elif isinstance(init_kwargs.model_data, str): self.s3_model_data_url = init_kwargs.model_data - + # Prepare resources based on mode and model server if self.mode == Mode.LOCAL_CONTAINER: # Route to appropriate model server and prepare resources @@ -907,7 +903,7 @@ def _build_for_jumpstart(self) -> Model: model_data=self.s3_model_data_url, ) return self._build_for_djl_jumpstart(init_kwargs) - + elif "tgi-inference" in self.image_uri: self.model_server = ModelServer.TGI if not hasattr(self, "prepared_for_tgi"): @@ -918,7 +914,7 @@ def _build_for_jumpstart(self) -> Model: model_data=self.s3_model_data_url, ) return self._build_for_tgi_jumpstart(init_kwargs) - + elif "huggingface-pytorch-inference" in self.image_uri: self.model_server = ModelServer.MMS if not hasattr(self, "prepared_for_mms"): @@ -930,27 +926,28 @@ def _build_for_jumpstart(self) -> Model: ) return self._build_for_mms_jumpstart(init_kwargs) else: - raise ValueError(f"Unsupported JumpStart image URI: {self.image_uri}") - + raise ValueError( + f"Local container mode is not yet supported for JumpStart image: {self.image_uri}. " + f"Use Mode.SAGEMAKER_ENDPOINT for deployment." + ) + else: - # SAGEMAKER_ENDPOINT mode - prepare artifacts if needed + # SAGEMAKER_ENDPOINT mode — all JumpStart containers follow the same + # pattern: JumpStart provides the full image URI, env vars, and S3 + # model artifacts. No framework-specific prep is needed. if not self._optimizing: self._prepare_for_mode() - - # Route to appropriate model server based on image URI + if "djl-inference" in self.image_uri: self.model_server = ModelServer.DJL_SERVING - return self._build_for_djl_jumpstart(init_kwargs) elif "tgi-inference" in self.image_uri: self.model_server = ModelServer.TGI - return self._build_for_tgi_jumpstart(init_kwargs) elif "huggingface-pytorch-inference" in self.image_uri: self.model_server = ModelServer.MMS - return self._build_for_mms_jumpstart(init_kwargs) - else: - raise ValueError(f"Unsupported JumpStart image URI: {self.image_uri}") - + if not hasattr(self, "s3_upload_path") or not self.s3_upload_path: + self.s3_upload_path = init_kwargs.model_data + return self._create_model() def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, LocalEndpoint]: """Returns predictor depending on local mode or endpoint mode""" @@ -961,122 +958,120 @@ def _djl_model_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, # Handle local deployment modes if self.mode == Mode.IN_PROCESS: return self._deploy_local_endpoint(**kwargs) - + if self.mode == Mode.LOCAL_CONTAINER: return self._deploy_local_endpoint(**kwargs) - + # Remove mode/role from kwargs if present kwargs.pop("mode", None) if "role" in kwargs: self.role_arn = kwargs.pop("role") - + # Set default values if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - + if "initial_instance_count" not in kwargs: kwargs["initial_instance_count"] = 1 - + # Deploy to SageMaker endpoint return self._deploy_core_endpoint(*args, **kwargs) - def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, LocalEndpoint]: """Simplified TGI deploy wrapper - env vars already set during build.""" - + # Handle mode overrides for local deployment if self.mode == Mode.IN_PROCESS: return self._deploy_local_endpoint(**kwargs) - + if self.mode == Mode.LOCAL_CONTAINER: return self._deploy_local_endpoint(**kwargs) - + # Remove mode/role from kwargs if present kwargs.pop("mode", None) if "role" in kwargs: self.role_arn = kwargs.pop("role") - + # Set default values if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - + if "initial_instance_count" not in kwargs: kwargs["initial_instance_count"] = 1 - + # Deploy to SageMaker endpoint return self._deploy_core_endpoint(*args, **kwargs) - def _tei_model_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, LocalEndpoint]: """Simplified TEI deploy wrapper - env vars already set during build.""" - + # Handle local deployment modes if self.mode == Mode.IN_PROCESS: return self._deploy_local_endpoint(**kwargs) - + if self.mode == Mode.LOCAL_CONTAINER: return self._deploy_local_endpoint(**kwargs) - + # Remove mode/role from kwargs if present kwargs.pop("mode", None) if "role" in kwargs: self.role_arn = kwargs.pop("role") - + # Set default values if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - + if "initial_instance_count" not in kwargs: kwargs["initial_instance_count"] = 1 - + # Deploy to SageMaker endpoint return self._deploy_core_endpoint(**kwargs) - - def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, LocalEndpoint]: """Simplified JumpStart deploy wrapper - resource prep already done during build.""" - + # Handle local deployment if self.mode == Mode.LOCAL_CONTAINER: return self._deploy_local_endpoint(**kwargs) - + # Remove mode/role from kwargs if present kwargs.pop("mode", None) if "role" in kwargs: self.role_arn = kwargs.pop("role") - + # Set default values if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - + if hasattr(self, "instance_type"): kwargs.update({"instance_type": self.instance_type}) - + # Deploy to SageMaker endpoint return self._deploy_core_endpoint(**kwargs) - def _transformers_model_builder_deploy_wrapper(self, *args, **kwargs) -> Union[Endpoint, LocalEndpoint]: + def _transformers_model_builder_deploy_wrapper( + self, *args, **kwargs + ) -> Union[Endpoint, LocalEndpoint]: """Simplified Transformers deploy wrapper - env vars already set during build.""" - + # Handle local deployment modes if self.mode == Mode.LOCAL_CONTAINER: return self._deploy_local_endpoint(**kwargs) - + if self.mode == Mode.IN_PROCESS: return self._deploy_local_endpoint(**kwargs) - + # Remove mode/role from kwargs if present kwargs.pop("mode", None) if "role" in kwargs: self.role_arn = kwargs.pop("role") - + # Set default values if "endpoint_logging" not in kwargs: kwargs["endpoint_logging"] = True - + if "initial_instance_count" not in kwargs: kwargs["initial_instance_count"] = 1 - + # Deploy to SageMaker endpoint return self._deploy_core_endpoint(**kwargs) diff --git a/sagemaker-serve/tests/integ/test_jumpstart_vllm_integration.py b/sagemaker-serve/tests/integ/test_jumpstart_vllm_integration.py new file mode 100644 index 0000000000..861d9326d9 --- /dev/null +++ b/sagemaker-serve/tests/integ/test_jumpstart_vllm_integration.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import uuid +import pytest +import logging + +from sagemaker.serve.model_builder import ModelBuilder +from sagemaker.core.jumpstart.configs import JumpStartConfig +from sagemaker.train.configs import Compute + +logger = logging.getLogger(__name__) + +MODEL_ID = "huggingface-vlm-qwen3-5-4b" +INSTANCE_TYPE = "ml.g6.12xlarge" +MODEL_NAME_PREFIX = "js-vllm-test-model" + + +@pytest.mark.slow_test +def test_jumpstart_vllm_build(): + """Integration test for JumpStart model using vLLM container image. + + Validates that ModelBuilder correctly handles vLLM container images + which do not match the legacy djl-inference/tgi-inference/huggingface-pytorch-inference + patterns in the endpoint mode routing. + """ + logger.info("Starting JumpStart vLLM build test...") + + compute = Compute(instance_type=INSTANCE_TYPE) + jumpstart_config = JumpStartConfig(model_id=MODEL_ID) + model_builder = ModelBuilder.from_jumpstart_config( + jumpstart_config=jumpstart_config, compute=compute + ) + unique_id = str(uuid.uuid4())[:8] + + core_model = model_builder.build(model_name=f"{MODEL_NAME_PREFIX}-{unique_id}") + logger.info(f"Model Successfully Created: {core_model.model_name}") + + try: + assert ( + "vllm" in model_builder.image_uri + ), f"Expected vLLM image URI, got: {model_builder.image_uri}" + assert model_builder.s3_upload_path is not None, "s3_upload_path should be set" + logger.info("JumpStart vLLM build test completed successfully") + finally: + core_model.delete() + logger.info("Model Successfully Deleted!") diff --git a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py index 1a6f3b2442..d905b7decf 100644 --- a/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py +++ b/sagemaker-serve/tests/unit/servers/test_model_builder_servers.py @@ -919,7 +919,7 @@ def test_build_unsupported_image_uri(self, mock_init): with self.assertRaises(ValueError) as ctx: self.builder._build_for_jumpstart() - self.assertIn("Unsupported", str(ctx.exception)) + self.assertIn("Local container mode is not yet supported", str(ctx.exception)) @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") @patch("sagemaker.serve.model_builder_servers.prepare_djl_js_resources") @@ -985,21 +985,21 @@ def test_build_passes_none_config_name_when_not_set( @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") @patch.object(MockModelBuilderServers, "_prepare_for_mode") - @patch.object(MockModelBuilderServers, "_build_for_djl_jumpstart") - def test_build_sagemaker_endpoint_djl(self, mock_djl_build, mock_prepare, mock_init): + @patch.object(MockModelBuilderServers, "_create_model") + def test_build_sagemaker_endpoint_djl(self, mock_create, mock_prepare, mock_init): """Test building DJL JumpStart for SAGEMAKER_ENDPOINT.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "djl-inference:0.21.0" mock_init_kwargs.env = {} mock_init_kwargs.model_data = "s3://bucket/model.tar.gz" mock_init.return_value = mock_init_kwargs - mock_djl_build.return_value = Mock() + mock_create.return_value = Mock() self.builder.mode = Mode.SAGEMAKER_ENDPOINT self.builder.image_uri = None result = self.builder._build_for_jumpstart() - mock_djl_build.assert_called_once() + mock_create.assert_called_once() class TestDeployWrappers(unittest.TestCase): diff --git a/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py b/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py index 02b0962feb..3b00264ffa 100644 --- a/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py +++ b/sagemaker-serve/tests/unit/test_model_builder_servers_coverage.py @@ -18,7 +18,7 @@ mock_model_object, MOCK_ROLE_ARN, MOCK_IMAGE_URI, - MOCK_S3_URI + MOCK_S3_URI, ) @@ -35,181 +35,179 @@ def test_build_for_model_server_unsupported_raises(self): model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = "UNSUPPORTED_SERVER" - + with self.assertRaises(ValueError) as context: builder._build_for_model_server() - + self.assertIn("is not supported yet", str(context.exception)) def test_build_for_model_server_without_model_raises(self): """Test that missing model/inference_spec raises error.""" builder = ModelBuilder( - role_arn=MOCK_ROLE_ARN, - sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, image_uri=MOCK_IMAGE_URI ) builder.model = None builder.inference_spec = None builder.model_metadata = None builder.model_server = ModelServer.TORCHSERVE - + with self.assertRaises(ValueError) as context: builder._build_for_model_server() - + self.assertIn("Missing required parameter", str(context.exception)) - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_torchserve') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_torchserve") def test_build_for_model_server_routes_to_torchserve(self, mock_build): """Test routing to TorchServe builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TORCHSERVE - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_djl") def test_build_for_model_server_routes_to_djl(self, mock_build): """Test routing to DJL builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.DJL_SERVING - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_tgi") def test_build_for_model_server_routes_to_tgi(self, mock_build): """Test routing to TGI builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model="gpt2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TGI - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tei') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_tei") def test_build_for_model_server_routes_to_tei(self, mock_build): """Test routing to TEI builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model="sentence-transformers/all-MiniLM-L6-v2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TEI - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_triton') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_triton") def test_build_for_model_server_routes_to_triton(self, mock_build): """Test routing to Triton builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TRITON - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tensorflow_serving') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_tensorflow_serving") def test_build_for_model_server_routes_to_tensorflow(self, mock_build): """Test routing to TensorFlow Serving builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TENSORFLOW_SERVING - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_transformers') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_transformers") def test_build_for_model_server_routes_to_mms(self, mock_build): """Test routing to MMS/Transformers builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.MMS - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_smd') + @patch("sagemaker.serve.model_builder.ModelBuilder._build_for_smd") def test_build_for_model_server_routes_to_smd(self, mock_build): """Test routing to SMD builder.""" mock_model = Mock(spec=Model) mock_build.return_value = mock_model - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.SMD - + result = builder._build_for_model_server() - + self.assertEqual(result, mock_model) mock_build.assert_called_once() @@ -225,86 +223,93 @@ def setUp(self): def tearDown(self): """Clean up temp directory.""" import shutil + if os.path.exists(self.temp_dir): shutil.rmtree(self.temp_dir) - @patch('sagemaker.serve.model_builder.save_pkl') - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder._detect_framework_and_version') - @patch('sagemaker.serve.model_builder._get_model_base') - def test_build_for_torchserve_with_model_object(self, mock_get_base, mock_detect_fw, mock_prepare, mock_create, mock_save_pkl): + @patch("sagemaker.serve.model_builder.save_pkl") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + @patch("sagemaker.serve.model_builder._detect_framework_and_version") + @patch("sagemaker.serve.model_builder._get_model_base") + def test_build_for_torchserve_with_model_object( + self, mock_get_base, mock_detect_fw, mock_prepare, mock_create, mock_save_pkl + ): """Test TorchServe build with model object.""" mock_model = Mock(spec=Model) mock_create.return_value = mock_model mock_get_base.return_value = mock_model_object() mock_detect_fw.return_value = ("pytorch", "1.13.0") - + builder = ModelBuilder( model=mock_model_object(), role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, model_path=self.temp_dir, mode=Mode.IN_PROCESS, # Use IN_PROCESS to skip prepare_for_torchserve - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TORCHSERVE - + result = builder._build_for_torchserve() - + self.assertEqual(result, mock_model) mock_save_pkl.assert_called() mock_create.assert_called_once() - @patch('sagemaker.serve.model_builder.save_pkl') - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') - def test_build_for_torchserve_with_hf_model_id(self, mock_is_js, mock_prepare, mock_create, mock_save_pkl): + @patch("sagemaker.serve.model_builder.save_pkl") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_build_for_torchserve_with_hf_model_id( + self, mock_is_js, mock_prepare, mock_create, mock_save_pkl + ): """Test TorchServe build with HuggingFace model ID.""" mock_is_js.return_value = False mock_model = Mock(spec=Model) mock_create.return_value = mock_model - + builder = ModelBuilder( model="gpt2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, model_path=self.temp_dir, - mode=Mode.IN_PROCESS # Use IN_PROCESS to skip prepare_for_torchserve + mode=Mode.IN_PROCESS, # Use IN_PROCESS to skip prepare_for_torchserve ) builder.model_server = ModelServer.TORCHSERVE builder.env_vars = {} builder.image_uri = MOCK_IMAGE_URI - + result = builder._build_for_torchserve() - + self.assertEqual(result, mock_model) self.assertEqual(builder.env_vars["HF_MODEL_ID"], "gpt2") self.assertIsNone(builder.s3_upload_path) - @patch('sagemaker.serve.model_builder.ModelBuilder._create_model') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - @patch('sagemaker.serve.model_builder.ModelBuilder._save_model_inference_spec') - @patch('sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id') - def test_build_for_torchserve_with_hf_token(self, mock_is_js, mock_save, mock_prepare, mock_create): + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + @patch("sagemaker.serve.model_builder.ModelBuilder._save_model_inference_spec") + @patch("sagemaker.serve.model_builder.ModelBuilder._is_jumpstart_model_id") + def test_build_for_torchserve_with_hf_token( + self, mock_is_js, mock_save, mock_prepare, mock_create + ): """Test TorchServe build with HuggingFace token.""" mock_is_js.return_value = False mock_model = Mock(spec=Model) mock_create.return_value = mock_model - + builder = ModelBuilder( model="gpt2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, model_path=self.temp_dir, - mode=Mode.IN_PROCESS + mode=Mode.IN_PROCESS, ) builder.model_server = ModelServer.TORCHSERVE builder.env_vars = {"HUGGING_FACE_HUB_TOKEN": "hf_token_123"} - + result = builder._build_for_torchserve() - + self.assertEqual(result, mock_model) self.assertEqual(builder.env_vars["HF_TOKEN"], "hf_token_123") @@ -316,46 +321,84 @@ def setUp(self): """Set up test fixtures.""" self.mock_session = mock_sagemaker_session() - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") def test_build_for_jumpstart_unsupported_image_raises(self, mock_prepare, mock_get_kwargs): - """Test that unsupported JumpStart image raises error.""" + """Test that unsupported JumpStart image in LOCAL_CONTAINER mode raises error.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "unsupported-image:latest" mock_init_kwargs.env = {} mock_get_kwargs.return_value = mock_init_kwargs - + builder = ModelBuilder( model="some-model-id", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - mode=Mode.SAGEMAKER_ENDPOINT + mode=Mode.LOCAL_CONTAINER, ) builder._optimizing = False - + with self.assertRaises(ValueError) as context: builder._build_for_jumpstart() - - self.assertIn("Unsupported JumpStart image URI", str(context.exception)) - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - def test_build_for_jumpstart_routes_to_djl(self, mock_prepare, mock_build_djl, mock_get_kwargs): + self.assertIn("Local container mode is not yet supported", str(context.exception)) + self.assertIn("Mode.SAGEMAKER_ENDPOINT", str(context.exception)) + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_build_for_jumpstart_unknown_image_endpoint_uses_default( + self, mock_prepare, mock_create, mock_get_kwargs + ): + """Test that unknown image in endpoint mode passes through without error.""" + mock_init_kwargs = Mock() + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-east-1.amazonaws.com/" + "vllm:0.17-gpu-py312-cu129-ubuntu22.04-sagemaker-v1" + ) + mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache/models/vllm/model.tar.gz" + mock_get_kwargs.return_value = mock_init_kwargs + + mock_model = Mock(spec=Model) + mock_create.return_value = mock_model + + builder = ModelBuilder( + model="huggingface-vlm-qwen3-5-4b", + role_arn=MOCK_ROLE_ARN, + sagemaker_session=self.mock_session, + mode=Mode.SAGEMAKER_ENDPOINT, + ) + builder._optimizing = False + + result = builder._build_for_jumpstart() + + self.assertEqual(result, mock_model) + self.assertIsNone(builder.model_server) + self.assertEqual(builder.s3_upload_path, "s3://jumpstart-cache/models/vllm/model.tar.gz") + mock_create.assert_called_once() + + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_build_for_jumpstart_routes_to_djl(self, mock_prepare, mock_create, mock_get_kwargs): """Test JumpStart routing to DJL builder.""" mock_init_kwargs = Mock() - mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + ) mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache/models/djl/model.tar.gz" mock_get_kwargs.return_value = mock_init_kwargs mock_model = Mock(spec=Model) - mock_build_djl.return_value = mock_model + mock_create.return_value = mock_model builder = ModelBuilder( model="huggingface-llm-falcon-7b", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - mode=Mode.SAGEMAKER_ENDPOINT + mode=Mode.SAGEMAKER_ENDPOINT, ) builder._optimizing = False @@ -363,26 +406,31 @@ def test_build_for_jumpstart_routes_to_djl(self, mock_prepare, mock_build_djl, m self.assertEqual(result, mock_model) self.assertEqual(builder.model_server, ModelServer.DJL_SERVING) - mock_build_djl.assert_called_once() + mock_create.assert_called_once() - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_djl_jumpstart') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_djl, mock_get_kwargs): + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_build_for_jumpstart_passes_config_name( + self, mock_prepare, mock_create, mock_get_kwargs + ): """Test that config_name is forwarded to get_init_kwargs.""" mock_init_kwargs = Mock() - mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + mock_init_kwargs.image_uri = ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/djl-inference:0.21.0-deepspeed0.8.0-cu117" + ) mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache/models/djl/model.tar.gz" mock_get_kwargs.return_value = mock_init_kwargs mock_model = Mock(spec=Model) - mock_build_djl.return_value = mock_model + mock_create.return_value = mock_model builder = ModelBuilder( model="meta-textgeneration-llama-3-3-70b-instruct", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - mode=Mode.SAGEMAKER_ENDPOINT + mode=Mode.SAGEMAKER_ENDPOINT, ) builder._optimizing = False builder.config_name = "lmi-optimized" @@ -391,61 +439,66 @@ def test_build_for_jumpstart_passes_config_name(self, mock_prepare, mock_build_d mock_get_kwargs.assert_called_once() call_kwargs = mock_get_kwargs.call_args - self.assertEqual(call_kwargs.kwargs.get("config_name") or call_kwargs[1].get("config_name"), "lmi-optimized") + self.assertEqual( + call_kwargs.kwargs.get("config_name") or call_kwargs[1].get("config_name"), + "lmi-optimized", + ) - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_tgi_jumpstart') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - def test_build_for_jumpstart_routes_to_tgi(self, mock_prepare, mock_build_tgi, mock_get_kwargs): + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_build_for_jumpstart_routes_to_tgi(self, mock_prepare, mock_create, mock_get_kwargs): """Test JumpStart routing to TGI builder.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-tgi-inference:2.0.1-tgi0.9.3-gpu-py39-cu118-ubuntu20.04" mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache/models/tgi/model.tar.gz" mock_get_kwargs.return_value = mock_init_kwargs - + mock_model = Mock(spec=Model) - mock_build_tgi.return_value = mock_model - + mock_create.return_value = mock_model + builder = ModelBuilder( model="huggingface-llm-mistral-7b", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - mode=Mode.SAGEMAKER_ENDPOINT + mode=Mode.SAGEMAKER_ENDPOINT, ) builder._optimizing = False - + result = builder._build_for_jumpstart() - + self.assertEqual(result, mock_model) self.assertEqual(builder.model_server, ModelServer.TGI) - mock_build_tgi.assert_called_once() + mock_create.assert_called_once() - @patch('sagemaker.core.jumpstart.factory.utils.get_init_kwargs') - @patch('sagemaker.serve.model_builder.ModelBuilder._build_for_mms_jumpstart') - @patch('sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode') - def test_build_for_jumpstart_routes_to_mms(self, mock_prepare, mock_build_mms, mock_get_kwargs): + @patch("sagemaker.core.jumpstart.factory.utils.get_init_kwargs") + @patch("sagemaker.serve.model_builder.ModelBuilder._create_model") + @patch("sagemaker.serve.model_builder.ModelBuilder._prepare_for_mode") + def test_build_for_jumpstart_routes_to_mms(self, mock_prepare, mock_create, mock_get_kwargs): """Test JumpStart routing to MMS builder.""" mock_init_kwargs = Mock() mock_init_kwargs.image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04" mock_init_kwargs.env = {} + mock_init_kwargs.model_data = "s3://jumpstart-cache/models/mms/model.tar.gz" mock_get_kwargs.return_value = mock_init_kwargs - + mock_model = Mock(spec=Model) - mock_build_mms.return_value = mock_model - + mock_create.return_value = mock_model + builder = ModelBuilder( model="pytorch-ic-mobilenet-v2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - mode=Mode.SAGEMAKER_ENDPOINT + mode=Mode.SAGEMAKER_ENDPOINT, ) builder._optimizing = False - + result = builder._build_for_jumpstart() - + self.assertEqual(result, mock_model) self.assertEqual(builder.model_server, ModelServer.MMS) - mock_build_mms.assert_called_once() + mock_create.assert_called_once() class TestDeployWrappers(unittest.TestCase): @@ -455,114 +508,104 @@ def setUp(self): """Set up test fixtures.""" self.mock_session = mock_sagemaker_session() - @patch('sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint') + @patch("sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint") def test_djl_deploy_wrapper_sets_timeout(self, mock_deploy): """Test DJL deploy wrapper sets model data download timeout.""" mock_endpoint = Mock() mock_deploy.return_value = mock_endpoint - + builder = ModelBuilder( model="gpt2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.DJL_SERVING builder.built_model = Mock(spec=Model) - - result = builder._djl_model_builder_deploy_wrapper( - model_data_download_timeout=1800 - ) - + + result = builder._djl_model_builder_deploy_wrapper(model_data_download_timeout=1800) + self.assertEqual(result, mock_endpoint) mock_deploy.assert_called_once() call_kwargs = mock_deploy.call_args[1] - self.assertEqual(call_kwargs['model_data_download_timeout'], 1800) + self.assertEqual(call_kwargs["model_data_download_timeout"], 1800) - @patch('sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint') + @patch("sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint") def test_tgi_deploy_wrapper_calls_core_deploy(self, mock_deploy): """Test TGI deploy wrapper calls core deploy.""" mock_endpoint = Mock() mock_deploy.return_value = mock_endpoint - + builder = ModelBuilder( model="gpt2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TGI builder.built_model = Mock(spec=Model) - - result = builder._tgi_model_builder_deploy_wrapper( - endpoint_name="test-endpoint" - ) - + + result = builder._tgi_model_builder_deploy_wrapper(endpoint_name="test-endpoint") + self.assertEqual(result, mock_endpoint) mock_deploy.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint') + @patch("sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint") def test_tei_deploy_wrapper_calls_core_deploy(self, mock_deploy): """Test TEI deploy wrapper calls core deploy.""" mock_endpoint = Mock() mock_deploy.return_value = mock_endpoint - + builder = ModelBuilder( model="sentence-transformers/all-MiniLM-L6-v2", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.TEI builder.built_model = Mock(spec=Model) - - result = builder._tei_model_builder_deploy_wrapper( - endpoint_name="test-endpoint" - ) - + + result = builder._tei_model_builder_deploy_wrapper(endpoint_name="test-endpoint") + self.assertEqual(result, mock_endpoint) mock_deploy.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint') + @patch("sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint") def test_js_deploy_wrapper_calls_core_deploy(self, mock_deploy): """Test JumpStart deploy wrapper calls core deploy.""" mock_endpoint = Mock() mock_deploy.return_value = mock_endpoint - + builder = ModelBuilder( model="huggingface-llm-falcon-7b", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.built_model = Mock(spec=Model) - - result = builder._js_builder_deploy_wrapper( - endpoint_name="test-endpoint" - ) - + + result = builder._js_builder_deploy_wrapper(endpoint_name="test-endpoint") + self.assertEqual(result, mock_endpoint) mock_deploy.assert_called_once() - @patch('sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint') + @patch("sagemaker.serve.model_builder.ModelBuilder._deploy_core_endpoint") def test_transformers_deploy_wrapper_calls_core_deploy(self, mock_deploy): """Test Transformers deploy wrapper calls core deploy.""" mock_endpoint = Mock() mock_deploy.return_value = mock_endpoint - + builder = ModelBuilder( model="bert-base-uncased", role_arn=MOCK_ROLE_ARN, sagemaker_session=self.mock_session, - image_uri=MOCK_IMAGE_URI + image_uri=MOCK_IMAGE_URI, ) builder.model_server = ModelServer.MMS builder.built_model = Mock(spec=Model) - - result = builder._transformers_model_builder_deploy_wrapper( - endpoint_name="test-endpoint" - ) - + + result = builder._transformers_model_builder_deploy_wrapper(endpoint_name="test-endpoint") + self.assertEqual(result, mock_endpoint) mock_deploy.assert_called_once()