Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/sagemaker/local/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from sagemaker.local.image import _SageMakerContainer
from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host
from sagemaker.utils import DeferredError, get_config_value, format_tags
from sagemaker.utils import DeferredError, get_config_value, format_tags, validate_path_within_directory
from sagemaker.local.exceptions import StepExecutionException

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -506,14 +506,20 @@ def _perform_batch_inference(self, input_data, output_data, **kwargs):

working_dir = self._get_working_directory()
dataset_dir = data_source.get_root_dir()
working_dir_real = os.path.realpath(working_dir)

for fn in data_source.get_file_list():

relative_path = os.path.dirname(os.path.relpath(fn, dataset_dir))
filename = os.path.basename(fn)
copy_directory_structure(working_dir, relative_path)
destination_path = os.path.join(working_dir, relative_path, filename + ".out")

validate_path_within_directory(
destination_path, working_dir, source_description=fn
)

copy_directory_structure(working_dir, relative_path)

with open(destination_path, "wb") as f:
for item in batch_provider.pad(fn, max_payload):
# call the container and add the result to inference.
Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/serve/model_format/mlflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os

from sagemaker import Session, image_uris
from sagemaker.utils import validate_path_within_directory
from sagemaker.serve.utils.types import ModelServer
from sagemaker.serve.detector.image_detector import _cast_to_compatible_version
from sagemaker.serve.model_format.mlflow.constants import (
Expand Down Expand Up @@ -243,6 +244,7 @@ def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> Non
s3 = session.boto_session.client("s3")

os.makedirs(dst_path, exist_ok=True)
dst_path_real = os.path.realpath(dst_path)

# Spot check: enforce ownership only when downloading from the session's default
# bucket. Cross-account reads are left untouched.
Expand All @@ -260,6 +262,10 @@ def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> Non
rel_path = os.path.relpath(key, s3_key)
local_file_path = os.path.join(dst_path, rel_path)

validate_path_within_directory(
local_file_path, dst_path, source_description=key
)

if not key.endswith("/"):
local_file_dir = os.path.dirname(local_file_path)
os.makedirs(local_file_dir, exist_ok=True)
Expand Down
9 changes: 8 additions & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import botocore.config
from botocore.exceptions import ClientError
import six
from sagemaker.utils import instance_supports_kms, create_paginator_config
from sagemaker.utils import instance_supports_kms, create_paginator_config, validate_path_within_directory

import sagemaker.logs
from sagemaker import vpc_utils, s3_utils
Expand Down Expand Up @@ -550,13 +550,20 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
if expected_owner:
download_extra_args["ExpectedBucketOwner"] = expected_owner
downloaded_paths = []
path_real = os.path.realpath(path)
for dir_path in directories:
validate_path_within_directory(dir_path, path)
os.makedirs(os.path.dirname(dir_path), exist_ok=True)
for key in keys:
tail_s3_uri_path = os.path.basename(key)
if not os.path.splitext(key_prefix)[1]:
tail_s3_uri_path = os.path.relpath(key, key_prefix)
destination_path = os.path.join(path, tail_s3_uri_path)

validate_path_within_directory(
destination_path, path, source_description=key
)

if not os.path.exists(os.path.dirname(destination_path)):
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
s3.download_file(
Expand Down
28 changes: 28 additions & 0 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3, extra_args=Non
extra_args (dict): Optional extra arguments passed to each download_file call.
Used to carry ExpectedBucketOwner when the bucket is the session's default.
"""
target_real = os.path.realpath(target)
bucket = s3.Bucket(bucket_name)
for obj_sum in bucket.objects.filter(Prefix=prefix):
# if obj_sum is a folder object skip it.
Expand All @@ -465,6 +466,8 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3, extra_args=Non
s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/")
file_path = os.path.join(target, s3_relative_path)

validate_path_within_directory(file_path, target, source_description=obj_sum.key)

try:
os.makedirs(os.path.dirname(file_path))
except OSError as exc:
Expand Down Expand Up @@ -1673,6 +1676,31 @@ def _get_resolved_path(path):
return normpath(realpath(abspath(path)))


def validate_path_within_directory(file_path, target_directory, source_description=""):
"""Validate that file_path resolves to a location within target_directory.

Prevents path traversal attacks (CWE-22) by resolving both paths to their
canonical forms and checking containment.

Args:
file_path (str): The file path to validate.
target_directory (str): The directory that file_path must stay within.
source_description (str): Optional description of the source (e.g. S3 key)
included in the error message for debugging.

Raises:
ValueError: If file_path resolves to a location outside target_directory.
"""
target_real = os.path.realpath(target_directory)
file_real = os.path.realpath(file_path)
if not file_real.startswith(target_real + os.sep) and file_real != target_real:
source_info = f"'{source_description}' resolves to " if source_description else ""
raise ValueError(
f"Path traversal detected: {source_info}"
f"'{file_real}' which is outside the target directory '{target_real}'"
)


def _is_bad_path(path, base):
"""Checks if the joined path (base directory + file path) is rooted under the base directory

Expand Down
93 changes: 93 additions & 0 deletions tests/unit/sagemaker/local/test_local_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,99 @@ def test_local_transform_job_perform_batch_inference(
assert "file2.out" in output_files


@patch("sagemaker.local.data.get_batch_strategy_instance")
@patch("sagemaker.local.data.get_data_source_instance")
@patch("sagemaker.local.entities.get_config_value")
def test_perform_batch_inference_path_traversal_in_file_list(
get_config_value,
get_data_source_instance,
get_batch_strategy_instance,
local_transform_job,
):
"""Test that file paths resolving outside working_dir are blocked."""
import tempfile

with tempfile.TemporaryDirectory() as working_dir:
with tempfile.TemporaryDirectory() as dataset_dir:
get_config_value.return_value = working_dir

data_source = Mock()
data_source.get_root_dir.return_value = dataset_dir
traversal_file = os.path.join(dataset_dir, "..", "..", "..", "etc", "passwd")
data_source.get_file_list.return_value = [traversal_file]
get_data_source_instance.return_value = data_source

batch_strategy = Mock()
get_batch_strategy_instance.return_value = batch_strategy

local_transform_job._get_working_directory = Mock(return_value=working_dir)

input_data = {
"DataSource": {"S3DataSource": {"S3Uri": "s3://some_bucket/data"}},
"ContentType": "text/csv",
}
output_data = {"S3OutputPath": "s3://bucket/output", "Accept": "text/csv"}

with pytest.raises(ValueError, match="Path traversal detected"):
local_transform_job._perform_batch_inference(
input_data, output_data, BatchStrategy="MultiRecord", MaxPayloadInMB="6"
)


@patch("sagemaker.local.entities.move_to_destination")
@patch("sagemaker.local.data.get_batch_strategy_instance")
@patch("sagemaker.local.data.get_data_source_instance")
@patch("sagemaker.local.entities.get_config_value")
def test_perform_batch_inference_safe_file_paths_are_allowed(
get_config_value,
get_data_source_instance,
get_batch_strategy_instance,
move_to_destination,
local_transform_job,
):
"""Test that normal file paths within dataset_dir are allowed."""
import tempfile

with tempfile.TemporaryDirectory() as working_dir:
with tempfile.TemporaryDirectory() as dataset_dir:
test_file = os.path.join(dataset_dir, "input.csv")
with open(test_file, "w") as f:
f.write("test data")

get_config_value.return_value = working_dir

data_source = Mock()
data_source.get_root_dir.return_value = dataset_dir
data_source.get_file_list.return_value = [test_file]
get_data_source_instance.return_value = data_source

batch_strategy = Mock()
batch_strategy.pad.return_value = [b"test data"]
get_batch_strategy_instance.return_value = batch_strategy

local_transform_job._get_working_directory = Mock(return_value=working_dir)
local_transform_job.container = Mock()

runtime_client = Mock()
response_object = Mock()
response_object.read.return_value = b"result"
runtime_client.invoke_endpoint.return_value = {"Body": response_object}
local_transform_job.local_session.sagemaker_runtime_client = runtime_client

input_data = {
"DataSource": {"S3DataSource": {"S3Uri": "s3://some_bucket/data"}},
"ContentType": "text/csv",
}
output_data = {"S3OutputPath": "s3://bucket/output", "Accept": "text/csv"}

local_transform_job._perform_batch_inference(
input_data, output_data, BatchStrategy="MultiRecord", MaxPayloadInMB="6"
)

expected_output = os.path.join(working_dir, "input.csv.out")
assert os.path.exists(expected_output)


@patch("sagemaker.local.entities._SageMakerContainer", Mock())
@patch("sagemaker.local.entities.get_docker_host")
@patch("sagemaker.local.entities._perform_request")
Expand Down
107 changes: 107 additions & 0 deletions tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,113 @@ def test_download_s3_artifacts_valid_s3_path(mock_os_makedirs, mock_session):
)


@patch("sagemaker.serve.model_format.mlflow.utils.os.makedirs")
def test_download_s3_artifacts_path_traversal_via_dotdot_in_key(mock_makedirs):
"""Test that S3 keys with '..' traversal sequences are blocked."""
import tempfile

mock_session = MagicMock()
mock_s3_client = MagicMock()
mock_session.boto_session.client.return_value = mock_s3_client

mock_paginator = MagicMock()
mock_s3_client.get_paginator.return_value = mock_paginator
mock_paginator.paginate.return_value = [
{
"Contents": [
{"Key": "model/../../../../etc/passwd"},
]
}
]

with tempfile.TemporaryDirectory() as tmpdir:
with pytest.raises(ValueError, match="Path traversal detected"):
_download_s3_artifacts("s3://my-bucket/model", tmpdir, mock_session)

mock_s3_client.download_file.assert_not_called()


@patch("sagemaker.serve.model_format.mlflow.utils.os.makedirs")
def test_download_s3_artifacts_path_traversal_overwrite_ssh_keys(mock_makedirs):
"""Test the attack scenario targeting SSH keys."""
import tempfile

mock_session = MagicMock()
mock_s3_client = MagicMock()
mock_session.boto_session.client.return_value = mock_s3_client

mock_paginator = MagicMock()
mock_s3_client.get_paginator.return_value = mock_paginator
mock_paginator.paginate.return_value = [
{
"Contents": [
{"Key": "mlruns/exp1/model/../../../../Users/alice/.ssh/authorized_keys"},
]
}
]

with tempfile.TemporaryDirectory() as tmpdir:
with pytest.raises(ValueError, match="Path traversal detected"):
_download_s3_artifacts(
"s3://shared-bucket/mlruns/exp1/model", tmpdir, mock_session
)

mock_s3_client.download_file.assert_not_called()


@patch("sagemaker.serve.model_format.mlflow.utils.os.makedirs")
def test_download_s3_artifacts_safe_keys_are_allowed(mock_makedirs):
"""Test that normal S3 keys within the target directory are allowed."""
import tempfile

mock_session = MagicMock()
mock_s3_client = MagicMock()
mock_session.boto_session.client.return_value = mock_s3_client

mock_paginator = MagicMock()
mock_s3_client.get_paginator.return_value = mock_paginator
mock_paginator.paginate.return_value = [
{
"Contents": [
{"Key": "model/MLmodel"},
{"Key": "model/subdir/model.pkl"},
]
}
]

with tempfile.TemporaryDirectory() as tmpdir:
_download_s3_artifacts("s3://my-bucket/model", tmpdir, mock_session)

assert mock_s3_client.download_file.call_count == 2


@patch("sagemaker.serve.model_format.mlflow.utils.os.makedirs")
def test_download_s3_artifacts_folder_keys_are_skipped(mock_makedirs):
"""Test that S3 folder objects (keys ending with /) are not downloaded."""
import tempfile

mock_session = MagicMock()
mock_s3_client = MagicMock()
mock_session.boto_session.client.return_value = mock_s3_client

mock_paginator = MagicMock()
mock_s3_client.get_paginator.return_value = mock_paginator
mock_paginator.paginate.return_value = [
{
"Contents": [
{"Key": "model/subdir/"},
{"Key": "model/subdir/file.txt"},
]
}
]

with tempfile.TemporaryDirectory() as tmpdir:
_download_s3_artifacts("s3://my-bucket/model", tmpdir, mock_session)

# Only the file should be downloaded, not the folder
assert mock_s3_client.download_file.call_count == 1


@patch("sagemaker.image_uris.retrieve")
@patch("sagemaker.serve.model_format.mlflow.utils._cast_to_compatible_version")
@patch("sagemaker.serve.model_format.mlflow.utils._get_framework_version_from_requirements")
Expand Down
Loading
Loading