Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions sagemaker-core/src/sagemaker/core/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,9 @@ def sts_regional_endpoint(region):
Returns:
str: AWS STS regional endpoint
"""
from sagemaker.core.region_validation import validate_region

validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
Expand Down Expand Up @@ -906,6 +909,9 @@ def aws_partition(region):
Returns:
str: partition corresponding to the region name passed in. Ex: "aws-cn"
"""
from sagemaker.core.region_validation import validate_region

validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("sts", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
Expand Down
3 changes: 3 additions & 0 deletions sagemaker-core/src/sagemaker/core/helper/session_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,9 @@ def sts_regional_endpoint(region):
Returns:
str: AWS STS regional endpoint
"""
from sagemaker.core.region_validation import validate_region

validate_region(region)
endpoint_data = botocore_resolver().construct_endpoint("sts", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "sts.{}.amazonaws.com".format(region)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sagemaker.core.inference_config import ServerlessInferenceConfig
from sagemaker.core.training_compiler.config import TrainingCompilerConfig
from sagemaker.core.common_utils import _botocore_resolver
from sagemaker.core.region_validation import validate_region
from sagemaker.core.workflow import is_pipeline_variable
from sagemaker.core.image_retriever.image_retriever_utils import (
_config_for_framework_and_scope,
Expand Down Expand Up @@ -161,6 +162,7 @@ def retrieve_hugging_face_uri(
)
version_config = version_config.get(py_version) or version_config
registry = _registry_from_region(region, version_config["registries"])
validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down Expand Up @@ -359,6 +361,7 @@ def retrieve_pytorch_uri(
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
version_config = version_config.get(py_version) or version_config
registry = _registry_from_region(region, version_config["registries"])
validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down Expand Up @@ -561,6 +564,7 @@ def retrieve(
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
version_config = version_config.get(py_version) or version_config
registry = _registry_from_region(region, version_config["registries"])
validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down Expand Up @@ -623,6 +627,7 @@ def retrieve_base_python_image_uri(region: str, py_version: str = "310") -> str:

framework = "sagemaker-base-python"
version = "1.0"
validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,9 @@ def _retrieve_latest_pytorch_training_uri(region: str):
version_config = config[image_scope]["versions"][latest_version]
py_version = _validate_py_version_and_set_if_needed(None, version_config, None)

from sagemaker.core.region_validation import validate_region

validate_region(region)
endpoint_data = _botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down
3 changes: 3 additions & 0 deletions sagemaker-core/src/sagemaker/core/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from sagemaker.core.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION, JUMPSTART_LOGGER
from sagemaker.core.jumpstart.enums import JumpStartModelType
from sagemaker.core.jumpstart.utils import is_jumpstart_model_input
from sagemaker.core.region_validation import validate_region
from sagemaker.core.spark import defaults
from sagemaker.core.jumpstart import artifacts
from sagemaker.core.workflow import is_pipeline_variable
Expand Down Expand Up @@ -213,6 +214,7 @@ def retrieve(
py_version = _validate_py_version_and_set_if_needed(py_version, version_config, framework)
version_config = version_config.get(py_version) or version_config
registry = _registry_from_region(region, version_config["registries"])
validate_region(region)
endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down Expand Up @@ -749,6 +751,7 @@ def get_base_python_image_uri(region, py_version="310") -> str:

framework = "sagemaker-base-python"
version = "1.0"
validate_region(region)
endpoint_data = utils._botocore_resolver().construct_endpoint("ecr", region)
if region == "il-central-1" and not endpoint_data:
endpoint_data = {"hostname": "ecr.{}.amazonaws.com".format(region)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def get_app_url(self, training_job_name: Optional[str] = None):
Returns:
str: An unsigned URL for DetailProfiler hosted on SageMaker.
"""
from sagemaker.core.region_validation import validate_region

validate_region(self.region)

if self._valid_domain_and_user:
url = f"https://{self._domain_id}.studio.{self.region}.sagemaker.aws/profiler/default"
if training_job_name is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,13 @@ def get_app_url(
Returns:
str: A URL for TensorBoard hosted on SageMaker.
"""
from sagemaker.core.region_validation import validate_region

if training_job_name is not None:
self._validate_job_name(training_job_name)

validate_region(self.region)

if (
self._in_studio_env
and self._validate_domain_id(self._domain_id)
Expand Down
3 changes: 3 additions & 0 deletions sagemaker-core/src/sagemaker/core/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,14 @@ def get_eula_url(document: HubContentDocument, sagemaker_session: Optional[Sessi
if sagemaker_session is None:
sagemaker_session = Session()

from sagemaker.core.region_validation import validate_region

path_parts = document.HostingEulaUri.replace("s3://", "").split("/")

bucket = path_parts[0]
key = "/".join(path_parts[1:])
region = sagemaker_session.boto_region_name
validate_region(region)

botocore_session = sagemaker_session.boto_session._session
endpoint_resolver = botocore_session.get_component("endpoint_resolver")
Expand Down
90 changes: 90 additions & 0 deletions sagemaker-core/src/sagemaker/core/region_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Region validation utilities to prevent SSRF via malicious region strings.

This module provides validation for AWS region parameters before they are
interpolated into endpoint URLs. Without validation, a crafted region value
(e.g., ``x@attacker.com:443/#``) could redirect SDK API calls — including
SigV4-signed requests — to non-AWS hosts.

See: CVE-2026-22611 (AWS SDK for .NET, same vulnerability class).
"""
from __future__ import absolute_import

import re
from urllib.parse import urlparse

# Regex for valid AWS region names (e.g., us-east-1, eu-west-2, cn-north-1, us-gov-west-1).
# Uses \A and \Z anchors to prevent newline injection bypass that $ allows.
_VALID_REGION_PATTERN = re.compile(r"\A[a-z]{2}(-[a-z]+)+-\d+\Z")

# Trusted AWS domain suffixes for endpoint URL validation (defense-in-depth).
_AWS_DOMAINS = (
".amazonaws.com",
".amazonaws.com.cn",
".api.aws",
".sagemaker.aws",
)


class InvalidRegionError(ValueError):
"""Raised when an invalid AWS region string is provided.

This prevents SSRF attacks where a crafted region value
(e.g., ``x@attacker.com:443/#``) could redirect SDK API calls
to non-AWS hosts.
"""


def validate_region(region: str) -> str:
"""Validate that a region string is a well-formed AWS region name.

Args:
region: The region string to validate.

Returns:
The validated region string (unchanged).

Raises:
InvalidRegionError: If the region does not match the expected pattern.
"""
if not isinstance(region, str) or not _VALID_REGION_PATTERN.match(region):
raise InvalidRegionError(
f"Invalid AWS region: {region!r}. "
"Region must match pattern like 'us-east-1', 'eu-west-2', 'cn-north-1'."
)
return region


def validate_endpoint_url(url: str) -> str:
"""Validate that a constructed endpoint URL resolves to an AWS host.

This is a defense-in-depth check that catches URL manipulation even if
the region regex is somehow bypassed.

Args:
url: The constructed endpoint URL.

Returns:
The validated URL (unchanged).

Raises:
InvalidRegionError: If the URL hostname does not end with a trusted AWS domain.
"""
parsed = urlparse(url)
hostname = parsed.hostname or ""
if not any(hostname.endswith(d) for d in _AWS_DOMAINS):
raise InvalidRegionError(
f"Constructed endpoint resolves to non-AWS host: {hostname!r}"
)
return url
3 changes: 3 additions & 0 deletions sagemaker-core/src/sagemaker/core/spark/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,10 @@ def _is_notebook_instance(self):

def _get_notebook_instance_domain(self):
"""Get the instance's domain."""
from sagemaker.core.region_validation import validate_region

region = self.sagemaker_session.boto_region_name
validate_region(region)
with open("/opt/ml/metadata/resource-metadata.json") as file:
data = json.load(file)
notebook_name = data["ResourceName"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ def _construct_url(
) -> str:
"""Construct the URL for the telemetry request"""

from sagemaker.core.region_validation import validate_region

validate_region(region)
base_url = (
f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?"
f"x-accountId={accountId}"
Expand Down
3 changes: 3 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/utils/telemetry_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ def _construct_url(
) -> str:
"""Placeholder docstring"""

from sagemaker.core.region_validation import validate_region

validate_region(region)
base_url = (
f"https://sm-pysdk-t-{region}.s3.{region}.amazonaws.com/telemetry?"
f"x-accountId={accountId}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,25 @@ def _is_in_studio() -> bool:

def _get_studio_base_url(region: str) -> str:
"""Get Studio base URL, or empty string if domain not resolvable."""
from sagemaker.core.region_validation import validate_region
from sagemaker.train.common_utils.finetune_utils import _read_domain_id_from_metadata
domain_id = _read_domain_id_from_metadata()
if not domain_id or not region:
return ""
validate_region(region)
return f"https://studio-{domain_id}.studio.{region}.sagemaker.aws"


def _parse_job_arn(job_arn: str):
"""Parse a SageMaker job ARN into (region, resource) or None."""
import re
from sagemaker.core.region_validation import validate_region
m = re.match(r'arn:aws(?:-[a-z]+)?:sagemaker:([a-z0-9-]+):\d+:(\S+)', job_arn)
return (m.group(1), m.group(2)) if m else None
if not m:
return None
region = m.group(1)
validate_region(region)
return (region, m.group(2))


def get_console_job_url(job_arn: str) -> str:
Expand Down
Loading