diff --git a/sagemaker-core/src/sagemaker/core/common_utils.py b/sagemaker-core/src/sagemaker/core/common_utils.py index 8a8134f5ea..0c906e6480 100644 --- a/sagemaker-core/src/sagemaker/core/common_utils.py +++ b/sagemaker-core/src/sagemaker/core/common_utils.py @@ -819,6 +819,9 @@ def sts_regional_endpoint(region): Returns: str: AWS STS regional endpoint """ + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("sts", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)} @@ -906,6 +909,9 @@ def aws_partition(region): Returns: str: partition corresponding to the region name passed in. Ex: "aws-cn" """ + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("sts", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/helper/session_helper.py b/sagemaker-core/src/sagemaker/core/helper/session_helper.py index 4d33c9c064..c737df2dfc 100644 --- a/sagemaker-core/src/sagemaker/core/helper/session_helper.py +++ b/sagemaker-core/src/sagemaker/core/helper/session_helper.py @@ -2121,6 +2121,9 @@ def sts_regional_endpoint(region): Returns: str: AWS STS regional endpoint """ + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = botocore_resolver().construct_endpoint("sts", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py index c4c2f5a45e..4b3572dbad 100644 --- a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py +++ b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever.py @@ -5,6 +5,7 @@ from sagemaker.core.inference_config import ServerlessInferenceConfig from sagemaker.core.training_compiler.config import TrainingCompilerConfig from sagemaker.core.common_utils import _botocore_resolver +from sagemaker.core.region_validation import validate_region from sagemaker.core.workflow import is_pipeline_variable from sagemaker.core.image_retriever.image_retriever_utils import ( _config_for_framework_and_scope, @@ -161,6 +162,7 @@ def retrieve_hugging_face_uri( ) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -359,6 +361,7 @@ def retrieve_pytorch_uri( py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -561,6 +564,7 @@ def retrieve( py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -623,6 +627,7 @@ def retrieve_base_python_image_uri(region: str, py_version: str = "310") -> str: framework = "sagemaker-base-python" version = "1.0" + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py index 6547ae0259..0ad3595924 100644 --- a/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py +++ b/sagemaker-core/src/sagemaker/core/image_retriever/image_retriever_utils.py @@ -483,6 +483,9 @@ def _retrieve_latest_pytorch_training_uri(region: str): version_config = config[image_scope]["versions"][latest_version] py_version = _validate_py_version_and_set_if_needed(None, version_config, None) + from sagemaker.core.region_validation import validate_region + + validate_region(region) endpoint_data = _botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/image_uris.py b/sagemaker-core/src/sagemaker/core/image_uris.py index 2f3ee0add5..4d4826b3dc 100644 --- a/sagemaker-core/src/sagemaker/core/image_uris.py +++ b/sagemaker-core/src/sagemaker/core/image_uris.py @@ -24,6 +24,7 @@ from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER from sagemaker.core.jumpstart.enums import JumpStartModelType from sagemaker.core.jumpstart.utils import is_jumpstart_model_input +from sagemaker.core.region_validation import validate_region from sagemaker.core.spark import defaults from sagemaker.core.jumpstart import artifacts from sagemaker.core.workflow import is_pipeline_variable @@ -213,6 +214,7 @@ def retrieve( py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework) version_config = version_config.get(py_version) or version_config registry = _registry_from_region(region, version_config["registries"]) + validate_region(region) endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} @@ -749,6 +751,7 @@ def get_base_python_image_uri(region, py_version="310") -> str: framework = "sagemaker-base-python" version = "1.0" + validate_region(region) endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region) if region == "il-central-1" and not endpoint_data: endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)} diff --git a/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py b/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py index 9193be568d..2d293f3d06 100644 --- a/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py +++ b/sagemaker-core/src/sagemaker/core/interactive_apps/detail_profiler_app.py @@ -79,6 +79,10 @@ def get_app_url(self, training_job_name: Optional[str] = None): Returns: str: An unsigned URL for DetailProfiler hosted on SageMaker. """ + from sagemaker.core.region_validation import validate_region + + validate_region(self.region) + if self._valid_domain_and_user: url = f"https://{self._domain_id}.studio.{self.region}.sagemaker.aws/profiler/default" if training_job_name is not None: diff --git a/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py b/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py index cc082f6d6f..0a8d866e07 100644 --- a/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py +++ b/sagemaker-core/src/sagemaker/core/interactive_apps/tensorboard.py @@ -84,9 +84,13 @@ def get_app_url( Returns: str: A URL for TensorBoard hosted on SageMaker. """ + from sagemaker.core.region_validation import validate_region + if training_job_name is not None: self._validate_job_name(training_job_name) + validate_region(self.region) + if ( self._in_studio_env and self._validate_domain_id(self._domain_id) diff --git a/sagemaker-core/src/sagemaker/core/jumpstart/utils.py b/sagemaker-core/src/sagemaker/core/jumpstart/utils.py index d46fa39df9..e2102be5d1 100644 --- a/sagemaker-core/src/sagemaker/core/jumpstart/utils.py +++ b/sagemaker-core/src/sagemaker/core/jumpstart/utils.py @@ -88,11 +88,14 @@ def get_eula_url(document: HubContentDocument, sagemaker_session: Optional[Sessi if sagemaker_session is None: sagemaker_session = Session() + from sagemaker.core.region_validation import validate_region + path_parts = document.HostingEulaUri.replace("s3://", "").split("/") bucket = path_parts[0] key = "/".join(path_parts[1:]) region = sagemaker_session.boto_region_name + validate_region(region) botocore_session = sagemaker_session.boto_session._session endpoint_resolver = botocore_session.get_component("endpoint_resolver") diff --git a/sagemaker-core/src/sagemaker/core/region_validation.py b/sagemaker-core/src/sagemaker/core/region_validation.py new file mode 100644 index 0000000000..76239eaf43 --- /dev/null +++ b/sagemaker-core/src/sagemaker/core/region_validation.py @@ -0,0 +1,90 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Region validation utilities to prevent SSRF via malicious region strings. + +This module provides validation for AWS region parameters before they are +interpolated into endpoint URLs. Without validation, a crafted region value +(e.g., ``x@attacker.com:443/#``) could redirect SDK API calls — including +SigV4-signed requests — to non-AWS hosts. + +See: CVE-2026-22611 (AWS SDK for .NET, same vulnerability class). +""" +from __future__ import absolute_import + +import re +from urllib.parse import urlparse + +# Regex for valid AWS region names (e.g., us-east-1, eu-west-2, cn-north-1, us-gov-west-1). +# Uses \A and \Z anchors to prevent newline injection bypass that $ allows. +_VALID_REGION_PATTERN = re.compile(r"\A[a-z]{2}(-[a-z]+)+-\d+\Z") + +# Trusted AWS domain suffixes for endpoint URL validation (defense-in-depth). +_AWS_DOMAINS = ( + ".amazonaws.com", + ".amazonaws.com.cn", + ".api.aws", + ".sagemaker.aws", +) + + +class InvalidRegionError(ValueError): + """Raised when an invalid AWS region string is provided. + + This prevents SSRF attacks where a crafted region value + (e.g., ``x@attacker.com:443/#``) could redirect SDK API calls + to non-AWS hosts. + """ + + +def validate_region(region: str) -> str: + """Validate that a region string is a well-formed AWS region name. + + Args: + region: The region string to validate. + + Returns: + The validated region string (unchanged). + + Raises: + InvalidRegionError: If the region does not match the expected pattern. + """ + if not isinstance(region, str) or not _VALID_REGION_PATTERN.match(region): + raise InvalidRegionError( + f"Invalid AWS region: {region!r}. " + "Region must match pattern like 'us-east-1', 'eu-west-2', 'cn-north-1'." + ) + return region + + +def validate_endpoint_url(url: str) -> str: + """Validate that a constructed endpoint URL resolves to an AWS host. + + This is a defense-in-depth check that catches URL manipulation even if + the region regex is somehow bypassed. + + Args: + url: The constructed endpoint URL. + + Returns: + The validated URL (unchanged). + + Raises: + InvalidRegionError: If the URL hostname does not end with a trusted AWS domain. + """ + parsed = urlparse(url) + hostname = parsed.hostname or "" + if not any(hostname.endswith(d) for d in _AWS_DOMAINS): + raise InvalidRegionError( + f"Constructed endpoint resolves to non-AWS host: {hostname!r}" + ) + return url diff --git a/sagemaker-core/src/sagemaker/core/spark/processing.py b/sagemaker-core/src/sagemaker/core/spark/processing.py index 82cdef954c..971a71f769 100644 --- a/sagemaker-core/src/sagemaker/core/spark/processing.py +++ b/sagemaker-core/src/sagemaker/core/spark/processing.py @@ -570,7 +570,10 @@ def _is_notebook_instance(self): def _get_notebook_instance_domain(self): """Get the instance's domain.""" + from sagemaker.core.region_validation import validate_region + region = self.sagemaker_session.boto_region_name + validate_region(region) with open("/opt/ml/metadata/resource-metadata.json") as file: data = json.load(file) notebook_name = data["ResourceName"] diff --git a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py index 738b47e309..8707aed22a 100644 --- a/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py +++ b/sagemaker-core/src/sagemaker/core/telemetry/telemetry_logging.py @@ -271,6 +271,9 @@ def _construct_url( ) -> str: """Construct the URL for the telemetry request""" + from sagemaker.core.region_validation import validate_region + + validate_region(region) base_url = ( f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" f"x-accountId={accountId}" diff --git a/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py b/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py index 0ce68b153a..9f2eee59be 100644 --- a/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py +++ b/sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py @@ -229,6 +229,9 @@ def _construct_url( ) -> str: """Placeholder docstring""" + from sagemaker.core.region_validation import validate_region + + validate_region(region) base_url = ( f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?" f"x-accountId={accountId}" diff --git a/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py b/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py index fe837a91fc..d4d41fcef5 100644 --- a/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py +++ b/sagemaker-train/src/sagemaker/train/common_utils/metrics_visualizer.py @@ -15,18 +15,25 @@ def _is_in_studio() -> bool: def _get_studio_base_url(region: str) -> str: """Get Studio base URL, or empty string if domain not resolvable.""" + from sagemaker.core.region_validation import validate_region from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata domain_id = _read_domain_id_from_metadata() if not domain_id or not region: return "" + validate_region(region) return f"https://studio-{domain_id}.studio.{region}.sagemaker.aws" def _parse_job_arn(job_arn: str): """Parse a SageMaker job ARN into (region, resource) or None.""" import re + from sagemaker.core.region_validation import validate_region m = re.match(r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:(\S+)', job_arn) - return (m.group(1), m.group(2)) if m else None + if not m: + return None + region = m.group(1) + validate_region(region) + return (region, m.group(2)) def get_console_job_url(job_arn: str) -> str: