diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 0cf6c6d55a..9e16d918f7 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -28,7 +28,7 @@ from sagemaker.local.image import _SageMakerContainer from sagemaker.local.utils import copy_directory_structure, move_to_destination, get_docker_host -from sagemaker.utils import DeferredError, get_config_value, format_tags +from sagemaker.utils import DeferredError, get_config_value, format_tags, validate_path_within_directory from sagemaker.local.exceptions import StepExecutionException logger = logging.getLogger(__name__) @@ -506,14 +506,20 @@ def _perform_batch_inference(self, input_data, output_data, **kwargs): working_dir = self._get_working_directory() dataset_dir = data_source.get_root_dir() + working_dir_real = os.path.realpath(working_dir) for fn in data_source.get_file_list(): relative_path = os.path.dirname(os.path.relpath(fn, dataset_dir)) filename = os.path.basename(fn) - copy_directory_structure(working_dir, relative_path) destination_path = os.path.join(working_dir, relative_path, filename + ".out") + validate_path_within_directory( + destination_path, working_dir, source_description=fn + ) + + copy_directory_structure(working_dir, relative_path) + with open(destination_path, "wb") as f: for item in batch_provider.pad(fn, max_payload): # call the container and add the result to inference. diff --git a/src/sagemaker/serve/model_format/mlflow/utils.py b/src/sagemaker/serve/model_format/mlflow/utils.py index 7ce3df6710..7e8132beda 100644 --- a/src/sagemaker/serve/model_format/mlflow/utils.py +++ b/src/sagemaker/serve/model_format/mlflow/utils.py @@ -21,6 +21,7 @@ import os from sagemaker import Session, image_uris +from sagemaker.utils import validate_path_within_directory from sagemaker.serve.utils.types import ModelServer from sagemaker.serve.detector.image_detector import _cast_to_compatible_version from sagemaker.serve.model_format.mlflow.constants import ( @@ -243,6 +244,7 @@ def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> Non s3 = session.boto_session.client("s3") os.makedirs(dst_path, exist_ok=True) + dst_path_real = os.path.realpath(dst_path) # Spot check: enforce ownership only when downloading from the session's default # bucket. Cross-account reads are left untouched. @@ -260,6 +262,10 @@ def _download_s3_artifacts(s3_path: str, dst_path: str, session: Session) -> Non rel_path = os.path.relpath(key, s3_key) local_file_path = os.path.join(dst_path, rel_path) + validate_path_within_directory( + local_file_path, dst_path, source_description=key + ) + if not key.endswith("/"): local_file_dir = os.path.dirname(local_file_path) os.makedirs(local_file_dir, exist_ok=True) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index ec17ad0d29..3b404fa65a 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -32,7 +32,7 @@ import botocore.config from botocore.exceptions import ClientError import six -from sagemaker.utils import instance_supports_kms, create_paginator_config +from sagemaker.utils import instance_supports_kms, create_paginator_config, validate_path_within_directory import sagemaker.logs from sagemaker import vpc_utils, s3_utils @@ -550,13 +550,20 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): if expected_owner: download_extra_args["ExpectedBucketOwner"] = expected_owner downloaded_paths = [] + path_real = os.path.realpath(path) for dir_path in directories: + validate_path_within_directory(dir_path, path) os.makedirs(os.path.dirname(dir_path), exist_ok=True) for key in keys: tail_s3_uri_path = os.path.basename(key) if not os.path.splitext(key_prefix)[1]: tail_s3_uri_path = os.path.relpath(key, key_prefix) destination_path = os.path.join(path, tail_s3_uri_path) + + validate_path_within_directory( + destination_path, path, source_description=key + ) + if not os.path.exists(os.path.dirname(destination_path)): os.makedirs(os.path.dirname(destination_path), exist_ok=True) s3.download_file( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 704632c620..48654bb36d 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -456,6 +456,7 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3, extra_args=Non extra_args (dict): Optional extra arguments passed to each download_file call. Used to carry ExpectedBucketOwner when the bucket is the session's default. """ + target_real = os.path.realpath(target) bucket = s3.Bucket(bucket_name) for obj_sum in bucket.objects.filter(Prefix=prefix): # if obj_sum is a folder object skip it. @@ -465,6 +466,8 @@ def _download_files_under_prefix(bucket_name, prefix, target, s3, extra_args=Non s3_relative_path = obj_sum.key[len(prefix) :].lstrip("/") file_path = os.path.join(target, s3_relative_path) + validate_path_within_directory(file_path, target, source_description=obj_sum.key) + try: os.makedirs(os.path.dirname(file_path)) except OSError as exc: @@ -1673,6 +1676,31 @@ def _get_resolved_path(path): return normpath(realpath(abspath(path))) +def validate_path_within_directory(file_path, target_directory, source_description=""): + """Validate that file_path resolves to a location within target_directory. + + Prevents path traversal attacks (CWE-22) by resolving both paths to their + canonical forms and checking containment. + + Args: + file_path (str): The file path to validate. + target_directory (str): The directory that file_path must stay within. + source_description (str): Optional description of the source (e.g. S3 key) + included in the error message for debugging. + + Raises: + ValueError: If file_path resolves to a location outside target_directory. + """ + target_real = os.path.realpath(target_directory) + file_real = os.path.realpath(file_path) + if not file_real.startswith(target_real + os.sep) and file_real != target_real: + source_info = f"'{source_description}' resolves to " if source_description else "" + raise ValueError( + f"Path traversal detected: {source_info}" + f"'{file_real}' which is outside the target directory '{target_real}'" + ) + + def _is_bad_path(path, base): """Checks if the joined path (base directory + file path) is rooted under the base directory diff --git a/tests/unit/sagemaker/local/test_local_entities.py b/tests/unit/sagemaker/local/test_local_entities.py index 74a361cf73..acd1a72a33 100644 --- a/tests/unit/sagemaker/local/test_local_entities.py +++ b/tests/unit/sagemaker/local/test_local_entities.py @@ -173,6 +173,99 @@ def test_local_transform_job_perform_batch_inference( assert "file2.out" in output_files +@patch("sagemaker.local.data.get_batch_strategy_instance") +@patch("sagemaker.local.data.get_data_source_instance") +@patch("sagemaker.local.entities.get_config_value") +def test_perform_batch_inference_path_traversal_in_file_list( + get_config_value, + get_data_source_instance, + get_batch_strategy_instance, + local_transform_job, +): + """Test that file paths resolving outside working_dir are blocked.""" + import tempfile + + with tempfile.TemporaryDirectory() as working_dir: + with tempfile.TemporaryDirectory() as dataset_dir: + get_config_value.return_value = working_dir + + data_source = Mock() + data_source.get_root_dir.return_value = dataset_dir + traversal_file = os.path.join(dataset_dir, "..", "..", "..", "etc", "passwd") + data_source.get_file_list.return_value = [traversal_file] + get_data_source_instance.return_value = data_source + + batch_strategy = Mock() + get_batch_strategy_instance.return_value = batch_strategy + + local_transform_job._get_working_directory = Mock(return_value=working_dir) + + input_data = { + "DataSource": {"S3DataSource": {"S3Uri": "s3://some_bucket/data"}}, + "ContentType": "text/csv", + } + output_data = {"S3OutputPath": "s3://bucket/output", "Accept": "text/csv"} + + with pytest.raises(ValueError, match="Path traversal detected"): + local_transform_job._perform_batch_inference( + input_data, output_data, BatchStrategy="MultiRecord", MaxPayloadInMB="6" + ) + + +@patch("sagemaker.local.entities.move_to_destination") +@patch("sagemaker.local.data.get_batch_strategy_instance") +@patch("sagemaker.local.data.get_data_source_instance") +@patch("sagemaker.local.entities.get_config_value") +def test_perform_batch_inference_safe_file_paths_are_allowed( + get_config_value, + get_data_source_instance, + get_batch_strategy_instance, + move_to_destination, + local_transform_job, +): + """Test that normal file paths within dataset_dir are allowed.""" + import tempfile + + with tempfile.TemporaryDirectory() as working_dir: + with tempfile.TemporaryDirectory() as dataset_dir: + test_file = os.path.join(dataset_dir, "input.csv") + with open(test_file, "w") as f: + f.write("test data") + + get_config_value.return_value = working_dir + + data_source = Mock() + data_source.get_root_dir.return_value = dataset_dir + data_source.get_file_list.return_value = [test_file] + get_data_source_instance.return_value = data_source + + batch_strategy = Mock() + batch_strategy.pad.return_value = [b"test data"] + get_batch_strategy_instance.return_value = batch_strategy + + local_transform_job._get_working_directory = Mock(return_value=working_dir) + local_transform_job.container = Mock() + + runtime_client = Mock() + response_object = Mock() + response_object.read.return_value = b"result" + runtime_client.invoke_endpoint.return_value = {"Body": response_object} + local_transform_job.local_session.sagemaker_runtime_client = runtime_client + + input_data = { + "DataSource": {"S3DataSource": {"S3Uri": "s3://some_bucket/data"}}, + "ContentType": "text/csv", + } + output_data = {"S3OutputPath": "s3://bucket/output", "Accept": "text/csv"} + + local_transform_job._perform_batch_inference( + input_data, output_data, BatchStrategy="MultiRecord", MaxPayloadInMB="6" + ) + + expected_output = os.path.join(working_dir, "input.csv.out") + assert os.path.exists(expected_output) + + @patch("sagemaker.local.entities._SageMakerContainer", Mock()) @patch("sagemaker.local.entities.get_docker_host") @patch("sagemaker.local.entities._perform_request") diff --git a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py index be55ba71c3..43c21caaea 100644 --- a/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py +++ b/tests/unit/sagemaker/serve/model_format/mlflow/test_mlflow_utils.py @@ -230,6 +230,113 @@ def test_download_s3_artifacts_valid_s3_path(mock_os_makedirs, mock_session): ) +@patch("sagemaker.serve.model_format.mlflow.utils.os.makedirs") +def test_download_s3_artifacts_path_traversal_via_dotdot_in_key(mock_makedirs): + """Test that S3 keys with '..' traversal sequences are blocked.""" + import tempfile + + mock_session = MagicMock() + mock_s3_client = MagicMock() + mock_session.boto_session.client.return_value = mock_s3_client + + mock_paginator = MagicMock() + mock_s3_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "Contents": [ + {"Key": "model/../../../../etc/passwd"}, + ] + } + ] + + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(ValueError, match="Path traversal detected"): + _download_s3_artifacts("s3://my-bucket/model", tmpdir, mock_session) + + mock_s3_client.download_file.assert_not_called() + + +@patch("sagemaker.serve.model_format.mlflow.utils.os.makedirs") +def test_download_s3_artifacts_path_traversal_overwrite_ssh_keys(mock_makedirs): + """Test the attack scenario targeting SSH keys.""" + import tempfile + + mock_session = MagicMock() + mock_s3_client = MagicMock() + mock_session.boto_session.client.return_value = mock_s3_client + + mock_paginator = MagicMock() + mock_s3_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "Contents": [ + {"Key": "mlruns/exp1/model/../../../../Users/alice/.ssh/authorized_keys"}, + ] + } + ] + + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(ValueError, match="Path traversal detected"): + _download_s3_artifacts( + "s3://shared-bucket/mlruns/exp1/model", tmpdir, mock_session + ) + + mock_s3_client.download_file.assert_not_called() + + +@patch("sagemaker.serve.model_format.mlflow.utils.os.makedirs") +def test_download_s3_artifacts_safe_keys_are_allowed(mock_makedirs): + """Test that normal S3 keys within the target directory are allowed.""" + import tempfile + + mock_session = MagicMock() + mock_s3_client = MagicMock() + mock_session.boto_session.client.return_value = mock_s3_client + + mock_paginator = MagicMock() + mock_s3_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "Contents": [ + {"Key": "model/MLmodel"}, + {"Key": "model/subdir/model.pkl"}, + ] + } + ] + + with tempfile.TemporaryDirectory() as tmpdir: + _download_s3_artifacts("s3://my-bucket/model", tmpdir, mock_session) + + assert mock_s3_client.download_file.call_count == 2 + + +@patch("sagemaker.serve.model_format.mlflow.utils.os.makedirs") +def test_download_s3_artifacts_folder_keys_are_skipped(mock_makedirs): + """Test that S3 folder objects (keys ending with /) are not downloaded.""" + import tempfile + + mock_session = MagicMock() + mock_s3_client = MagicMock() + mock_session.boto_session.client.return_value = mock_s3_client + + mock_paginator = MagicMock() + mock_s3_client.get_paginator.return_value = mock_paginator + mock_paginator.paginate.return_value = [ + { + "Contents": [ + {"Key": "model/subdir/"}, + {"Key": "model/subdir/file.txt"}, + ] + } + ] + + with tempfile.TemporaryDirectory() as tmpdir: + _download_s3_artifacts("s3://my-bucket/model", tmpdir, mock_session) + + # Only the file should be downloaded, not the folder + assert mock_s3_client.download_file.call_count == 1 + + @patch("sagemaker.image_uris.retrieve") @patch("sagemaker.serve.model_format.mlflow.utils._cast_to_compatible_version") @patch("sagemaker.serve.model_format.mlflow.utils._get_framework_version_from_requirements") diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 079bdea8eb..a4e2e8e6dc 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -7052,6 +7052,83 @@ def test_download_data_with_file_and_directory(makedirs, sagemaker_session): ) +@patch("os.makedirs") +def test_download_data_path_traversal_in_file_key(makedirs, sagemaker_session, tmp_path): + sagemaker_session.s3_client = Mock() + sagemaker_session.s3_client.list_objects_v2 = Mock( + return_value={ + "Contents": [ + {"Key": "data/../../../../etc/passwd", "Size": 100}, + ] + } + ) + + with pytest.raises(ValueError, match="Path traversal detected"): + sagemaker_session.download_data( + path=str(tmp_path), bucket="test-bucket", key_prefix="data/" + ) + + sagemaker_session.s3_client.download_file.assert_not_called() + + +@patch("os.makedirs") +def test_download_data_path_traversal_in_directory_key(makedirs, sagemaker_session, tmp_path): + sagemaker_session.s3_client = Mock() + sagemaker_session.s3_client.list_objects_v2 = Mock( + return_value={ + "Contents": [ + {"Key": "data/../../../etc/cron.d/", "Size": 0}, + ] + } + ) + + with pytest.raises(ValueError, match="Path traversal detected"): + sagemaker_session.download_data( + path=str(tmp_path), bucket="test-bucket", key_prefix="data/" + ) + + +@patch("os.makedirs") +def test_download_data_path_traversal_overwrite_aws_credentials( + makedirs, sagemaker_session, tmp_path +): + sagemaker_session.s3_client = Mock() + sagemaker_session.s3_client.list_objects_v2 = Mock( + return_value={ + "Contents": [ + {"Key": "data/../../../../Users/alice/.aws/credentials", "Size": 200}, + ] + } + ) + + with pytest.raises(ValueError, match="Path traversal detected"): + sagemaker_session.download_data( + path=str(tmp_path), bucket="shared-bucket", key_prefix="data/" + ) + + sagemaker_session.s3_client.download_file.assert_not_called() + + +@patch("os.makedirs") +def test_download_data_safe_keys_are_allowed(makedirs, sagemaker_session, tmp_path): + sagemaker_session.s3_client = Mock() + sagemaker_session.s3_client.list_objects_v2 = Mock( + return_value={ + "Contents": [ + {"Key": "data/train.csv", "Size": 100}, + {"Key": "data/models/model.pkl", "Size": 500}, + ] + } + ) + + result = sagemaker_session.download_data( + path=str(tmp_path), bucket="test-bucket", key_prefix="data/" + ) + + assert len(result) == 2 + assert sagemaker_session.s3_client.download_file.call_count == 2 + + def test_create_hub(sagemaker_session): sagemaker_session.create_hub( hub_name="mock-hub-name", diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index a9eeb6ef50..cf726a1a7a 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -18,6 +18,7 @@ import logging import shutil import tarfile +import tempfile from datetime import datetime import os import re @@ -729,6 +730,92 @@ def test_download_folder_points_to_single_file(makedirs): obj_mock.reset_mock() +def test_download_files_under_prefix_path_traversal_via_dotdot_in_key(): + """Test that S3 keys with '..' traversal sequences are blocked.""" + from sagemaker.utils import _download_files_under_prefix + + mock_s3 = Mock() + mock_bucket = Mock() + mock_s3.Bucket.return_value = mock_bucket + + mock_obj_summary = Mock() + mock_obj_summary.key = "data/../../../../etc/passwd" + mock_obj_summary.bucket_name = "bucket" + mock_bucket.objects.filter.return_value = [mock_obj_summary] + + mock_obj = Mock() + mock_s3.Object.return_value = mock_obj + + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(ValueError, match="Path traversal detected"): + _download_files_under_prefix("bucket", "data/", tmpdir, mock_s3) + + mock_obj.download_file.assert_not_called() + + +def test_download_files_under_prefix_path_traversal_via_relative_escape(): + """Test that keys resolving outside target via relpath are blocked.""" + from sagemaker.utils import _download_files_under_prefix + + mock_s3 = Mock() + mock_bucket = Mock() + mock_s3.Bucket.return_value = mock_bucket + + mock_obj_summary = Mock() + mock_obj_summary.key = "prefix/../../../etc/cron.d/backdoor" + mock_obj_summary.bucket_name = "bucket" + mock_bucket.objects.filter.return_value = [mock_obj_summary] + + mock_obj = Mock() + mock_s3.Object.return_value = mock_obj + + with tempfile.TemporaryDirectory() as tmpdir: + with pytest.raises(ValueError, match="Path traversal detected"): + _download_files_under_prefix("bucket", "prefix/", tmpdir, mock_s3) + + mock_obj.download_file.assert_not_called() + + +def test_download_files_under_prefix_safe_keys_are_allowed(): + """Test that normal S3 keys within the target directory are allowed.""" + from sagemaker.utils import _download_files_under_prefix + + mock_s3 = Mock() + mock_bucket = Mock() + mock_s3.Bucket.return_value = mock_bucket + + mock_obj_summary = Mock() + mock_obj_summary.key = "data/subdir/file.txt" + mock_obj_summary.bucket_name = "bucket" + mock_bucket.objects.filter.return_value = [mock_obj_summary] + + mock_obj = Mock() + mock_s3.Object.return_value = mock_obj + + with tempfile.TemporaryDirectory() as tmpdir: + _download_files_under_prefix("bucket", "data/", tmpdir, mock_s3) + + mock_obj.download_file.assert_called_once() + + +def test_download_files_under_prefix_folder_objects_are_skipped(): + """Test that S3 folder objects (keys ending with /) are skipped.""" + from sagemaker.utils import _download_files_under_prefix + + mock_s3 = Mock() + mock_bucket = Mock() + mock_s3.Bucket.return_value = mock_bucket + + mock_obj_summary = Mock() + mock_obj_summary.key = "data/subdir/" + mock_bucket.objects.filter.return_value = [mock_obj_summary] + + with tempfile.TemporaryDirectory() as tmpdir: + _download_files_under_prefix("bucket", "data/", tmpdir, mock_s3) + + mock_s3.Object.assert_not_called() + + def test_download_file(): boto_mock = MagicMock(name="boto_session") boto_mock.client("sts").get_caller_identity.return_value = {"Account": "123"}