Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
ac5f7b7
feat(feature-processor): Add use_lake_formation_credentials parameter
BassemHalim Apr 28, 2026
18cd319
feat(feature-processor): Add signing key for stored function integrity
BassemHalim Apr 28, 2026
1aa4b3d
feat(feature-processor): Add Spark image resolver for dynamic image URI
BassemHalim Apr 28, 2026
8796003
feat(feature-processor): Use image resolver and pass signing key in s…
BassemHalim Apr 28, 2026
8d8847f
fix(feature-processor): Dynamic Hadoop version and always-on Feature …
BassemHalim Apr 28, 2026
544fd94
feat(remote-function): Support Python 3.12 and auto-detect Spark version
BassemHalim Apr 30, 2026
932c997
Add Integ test (WIP)
BassemHalim Apr 30, 2026
6d7e7bf
Merge branch 'master' into feat/feature-store-fp-lf
BassemHalim Apr 30, 2026
5eca24b
fix(sagemaker-core): Update Spark image error message to include Pyth…
BassemHalim May 1, 2026
6a9e051
chore(sagemaker-mlops): Add pyspark 3.5.1 to test and feature-process…
BassemHalim May 1, 2026
eb1eaf8
test(sagemaker-mlops): Skip Spark integ tests on unsupported Python v…
BassemHalim May 1, 2026
2df34bb
feat(sagemaker-mlops): Auto-install feature-store-pyspark in to_pipeline
BassemHalim May 1, 2026
6f08b06
feat(sagemaker-mlops): Add Python 3.12 Spark support and auto-install…
BassemHalim May 1, 2026
50c87f0
feat(feature-processor): Auto-install feature-store-pyspark for Spark…
BassemHalim May 1, 2026
08a884f
test(feature-processor): Update test_to_pipeline to match injected pr…
BassemHalim May 1, 2026
f8479c5
Merge branch 'master' into feat/feature-store-fp-lf
BassemHalim May 1, 2026
609c0e6
fix
BassemHalim May 2, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions sagemaker-core/src/sagemaker/core/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,24 @@ def __init__(
sagemaker_session=self.sagemaker_session,
)

# When using Spark, ensure sagemaker-feature-store-pyspark is installed
# and its version-matched JAR is on Spark's classpath before spark-submit
if spark_config:
install_cmd = (
"pip install --root-user-action=ignore"
" 'sagemaker-feature-store-pyspark>=2,<3'"
)
copy_jar_cmd = (
"python3 -c \"import feature_store_pyspark, shutil; "
"[shutil.copy(j, '/usr/lib/spark/jars/') "
"for j in feature_store_pyspark.classpath_jars()]\""
)
if self.pre_execution_commands is None:
self.pre_execution_commands = [install_cmd, copy_jar_cmd]
elif install_cmd not in self.pre_execution_commands:
self.pre_execution_commands.append(install_cmd)
self.pre_execution_commands.append(copy_jar_cmd)

self.pre_execution_script = resolve_value_from_config(
direct_input=pre_execution_script,
config_path=REMOTE_FUNCTION_PRE_EXECUTION_SCRIPT,
Expand Down Expand Up @@ -822,15 +840,23 @@ def _get_default_spark_image(session):

py_version = str(sys.version_info[0]) + str(sys.version_info[1])

if py_version not in ["39"]:
if py_version not in ["39", "312"]:
raise ValueError(
"The SageMaker Spark image for remote job only supports Python version 3.9. "
"The SageMaker Spark image for remote job only supports Python versions 3.9 and 3.12."
)

# Detect Spark version from installed pyspark, fall back to default
spark_version = DEFAULT_SPARK_VERSION
try:
import pyspark
spark_version = ".".join(pyspark.__version__.split(".")[:2])
except ImportError:
pass

image_uri = image_uris.retrieve(
framework=SPARK_NAME,
region=region,
version=DEFAULT_SPARK_VERSION,
version=spark_version,
instance_type=None,
py_version=f"py{py_version}",
container_version=DEFAULT_SPARK_CONTAINER_VERSION,
Expand Down
2 changes: 1 addition & 1 deletion sagemaker-core/tests/unit/remote_function/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_get_default_spark_image_unsupported_python_raises_error(self, mock_sess
with patch.object(sys, "version_info", (3, 8, 0)):
with pytest.raises(
ValueError,
match="SageMaker Spark image for remote job only supports Python version 3.9",
match="SageMaker Spark image for remote job only supports Python versions 3.9 and 3.12",
):
_JobSettings._get_default_spark_image(mock_session)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_get_default_spark_image_unsupported_python(self, mock_session):
with patch.object(sys, "version_info", (3, 8, 0)):
with pytest.raises(
ValueError,
match="SageMaker Spark image for remote job only supports Python version 3.9",
match="SageMaker Spark image for remote job only supports Python versions 3.9 and 3.12",
):
_JobSettings._get_default_spark_image(mock_session)

Expand Down
8 changes: 4 additions & 4 deletions sagemaker-mlops/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ dependencies = [

[project.optional-dependencies]
feature-processor = [
"pyspark==3.3.2",
"sagemaker-feature-store-pyspark-3.3",
"pyspark==3.5.1",
"sagemaker-feature-store-pyspark==2.0.0",
"setuptools<82",
]

Expand All @@ -45,8 +45,8 @@ test = [
"pytest-cov",
"mock",
"setuptools<82",
"pyspark==3.3.2",
"sagemaker-feature-store-pyspark-3.3",
"pyspark==3.5.1",
"sagemaker-feature-store-pyspark==2.0.0",
"pandas<3.0",
"numpy<3.0",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from typing import Callable, Dict, Optional, Tuple, List, Union

import attr
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives import serialization as crypto_serialization

from sagemaker.core.helper.session_helper import Session
from sagemaker.mlops.feature_store.feature_processor._constants import (
Expand Down Expand Up @@ -52,13 +54,13 @@ class ConfigUploader:

def prepare_step_input_channel_for_spark_mode(
self, func: Callable, s3_base_uri: str, sagemaker_session: Session
) -> Tuple[List[Channel], Dict]:
) -> Tuple[List[Channel], Dict, str]:
"""Prepares input channels for SageMaker Pipeline Step.

Returns:
Tuple of (List[Channel], spark_dependency_paths dict)
Tuple of (List[Channel], spark_dependency_paths dict, public_key_pem str)
"""
self._prepare_and_upload_callable(func, s3_base_uri, sagemaker_session)
public_key_pem = self._prepare_and_upload_callable(func, s3_base_uri, sagemaker_session)
bootstrap_scripts_s3uri = self._prepare_and_upload_runtime_scripts(
self.remote_decorator_config.spark_config,
s3_base_uri,
Expand Down Expand Up @@ -139,18 +141,33 @@ def prepare_step_input_channel_for_spark_mode(
SPARK_JAR_FILES_PATH: submit_jars_s3_paths,
SPARK_PY_FILES_PATH: submit_py_files_s3_paths,
SPARK_FILES_PATH: submit_files_s3_path,
}
}, public_key_pem

def _prepare_and_upload_callable(
self, func: Callable, s3_base_uri: str, sagemaker_session: Session
) -> None:
"""Prepares and uploads callable to S3"""
) -> str:
"""Prepares and uploads callable to S3.

Returns:
str: The public key PEM string for signature verification on the remote side.
"""
private_key = ec.generate_private_key(ec.SECP256R1())
public_key_pem = (
private_key.public_key()
.public_bytes(
crypto_serialization.Encoding.PEM,
crypto_serialization.PublicFormat.SubjectPublicKeyInfo,
)
.decode("utf-8")
)
stored_function = StoredFunction(
sagemaker_session=sagemaker_session,
s3_base_uri=s3_base_uri,
signing_key=private_key,
s3_kms_key=self.remote_decorator_config.s3_kms_key,
)
stored_function.save(func)
return public_key_pem

def _prepare_and_upload_workspace(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class FeatureProcessorConfig:
parameters: Optional[Dict[str, Union[str, Dict]]] = attr.ib()
enable_ingestion: bool = attr.ib()
spark_config: Dict[str, str] = attr.ib()
use_lake_formation_credentials: bool = attr.ib()

@staticmethod
def create(
Expand All @@ -59,6 +60,7 @@ def create(
parameters: Optional[Dict[str, Union[str, Dict]]],
enable_ingestion: bool,
spark_config: Dict[str, str],
use_lake_formation_credentials: bool = False,
) -> "FeatureProcessorConfig":
"""Static initializer."""
return FeatureProcessorConfig(
Expand All @@ -69,4 +71,5 @@ def create(
parameters=parameters,
enable_ingestion=enable_ingestion,
spark_config=spark_config,
use_lake_formation_credentials=use_lake_formation_credentials,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Resolves SageMaker Spark container image URIs based on installed PySpark and Python versions."""
from __future__ import absolute_import

import sys

import pyspark

from sagemaker.core import image_uris

SPARK_IMAGE_SUPPORT_MATRIX = {
"3.1": ["py37"],
"3.2": ["py39"],
"3.3": ["py39"],
"3.5": ["py39", "py312"],
}


def _get_spark_image_uri(session):
"""Resolve the SageMaker Spark container image URI for the installed PySpark and Python versions.

Args:
session: SageMaker Session with boto_region_name attribute.

Returns:
str: The ECR image URI for the matching Spark container.

Raises:
ValueError: If the Spark/Python version combination is not supported.
"""
spark_version = ".".join(pyspark.__version__.split(".")[:2])
py_version = f"py{sys.version_info[0]}{sys.version_info[1]}"

supported_py = SPARK_IMAGE_SUPPORT_MATRIX.get(spark_version)
if supported_py is None:
supported = ", ".join(sorted(SPARK_IMAGE_SUPPORT_MATRIX.keys()))
raise ValueError(
f"No SageMaker Spark container image available for Spark {spark_version}. "
f"Supported versions for remote execution: {supported}."
)

if py_version not in supported_py:
raise ValueError(
f"SageMaker Spark {spark_version} container images support "
f"{', '.join(supported_py)}. Current Python version: {py_version}."
)

return image_uris.retrieve(
framework="spark",
region=session.boto_region_name,
version=spark_version,
py_version=py_version,
container_version="v1",
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
"""Contains factory classes for instantiating Spark objects."""
from __future__ import absolute_import

import logging
from functools import lru_cache
from typing import List, Tuple, Dict

import feature_store_pyspark
import feature_store_pyspark.FeatureStoreManager as fsm
import pyspark
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql import SparkSession
Expand All @@ -26,6 +26,32 @@

SPARK_APP_NAME = "FeatureProcessor"

logger = logging.getLogger(__name__)

SPARK_TO_HADOOP_MAP = {
"3.1": "3.2.0",
"3.2": "3.3.1",
"3.3": "3.3.2",
"3.4": "3.3.4",
"3.5": "3.3.4",
}

_DEFAULT_HADOOP_VERSION = "3.3.4"

def _get_hadoop_version():
"""Resolve the Hadoop version for the installed PySpark version."""
spark_version = pyspark.__version__
major_minor = ".".join(spark_version.split(".")[:2])
hadoop_version = SPARK_TO_HADOOP_MAP.get(major_minor)
if hadoop_version is None:
hadoop_version = _DEFAULT_HADOOP_VERSION
logger.warning(
"Unknown Spark version %s. Falling back to Hadoop %s.",
spark_version,
hadoop_version,
)
return hadoop_version


class SparkSessionFactory:
"""Lazy loading, memoizing, instantiation of SparkSessions.
Expand Down Expand Up @@ -55,6 +81,10 @@ def spark_session(self) -> SparkSession:
is_training_job = self.environment_helper.is_training_job()
instance_count = self.environment_helper.get_instance_count()

# Copy version-matched Feature Store JAR to Spark's system classpath
# so it's available to the JVM even if SparkContext is already running.
self._install_feature_store_jars()

spark_configs = self._get_spark_configs(is_training_job)
spark_conf = SparkConf().setAll(spark_configs).setAppName(SPARK_APP_NAME)

Expand All @@ -69,6 +99,24 @@ def spark_session(self) -> SparkSession:

return SparkSession(sparkContext=sc)

@staticmethod
def _install_feature_store_jars():
"""Copy the Spark-version-matched Feature Store JAR to Spark's system classpath."""
import feature_store_pyspark
import shutil
import os

spark_version = ".".join(pyspark.__version__.split(".")[:2])
target_dir = "/usr/lib/spark/jars"
if not os.path.isdir(target_dir):
return
for jar in feature_store_pyspark.classpath_jars():
if spark_version in os.path.basename(jar):
dest = os.path.join(target_dir, os.path.basename(jar))
if not os.path.exists(dest):
shutil.copy(jar, dest)
logger.info("Copied %s to %s", jar, target_dir)

def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]:
"""Generate Spark Configurations optimized for feature_processing functionality.

Expand Down Expand Up @@ -115,28 +163,37 @@ def _get_spark_configs(self, is_training_job) -> List[Tuple[str, str]]:
spark_configs.extend(self.spark_config.items())

if not is_training_job:
fp_spark_jars = feature_store_pyspark.classpath_jars()
hadoop_version = _get_hadoop_version()
fp_spark_packages = [
"org.apache.hadoop:hadoop-aws:3.3.1",
"org.apache.hadoop:hadoop-common:3.3.1",
f"org.apache.hadoop:hadoop-aws:{hadoop_version}",
f"org.apache.hadoop:hadoop-common:{hadoop_version}",
]

if self.spark_config and "spark.jars" in self.spark_config:
fp_spark_jars.append(self.spark_config.get("spark.jars"))

if self.spark_config and "spark.jars.packages" in self.spark_config:
fp_spark_packages.append(self.spark_config.get("spark.jars.packages"))

spark_configs.extend(
(
("spark.jars", ",".join(fp_spark_jars)),
(
"spark.jars.packages",
",".join(fp_spark_packages),
),
)
spark_configs.append(
("spark.jars.packages", ",".join(fp_spark_packages))
)

# Always add Feature Store JARs so they are on the classpath
# regardless of whether we are in a training job or not.
import feature_store_pyspark
import os

spark_version = ".".join(pyspark.__version__.split(".")[:2])
fp_spark_jars = [
j for j in feature_store_pyspark.classpath_jars()
if spark_version in os.path.basename(j)
]
if not fp_spark_jars:
fp_spark_jars = feature_store_pyspark.classpath_jars()

if self.spark_config and "spark.jars" in self.spark_config:
fp_spark_jars.append(self.spark_config.get("spark.jars"))

spark_configs.append(("spark.jars", ",".join(fp_spark_jars)))

return spark_configs

def _get_jsc_hadoop_configs(self) -> List[Tuple[str, str]]:
Expand Down Expand Up @@ -197,6 +254,8 @@ class FeatureStoreManagerFactory:

@property
@lru_cache()
def feature_store_manager(self) -> fsm.FeatureStoreManager:
def feature_store_manager(self) -> "fsm.FeatureStoreManager":
"""Instansiate a new FeatureStoreManager."""
import feature_store_pyspark.FeatureStoreManager as fsm

return fsm.FeatureStoreManager()
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def ingest_udf_output(self, output: DataFrame, fp_config: FeatureProcessorConfig
input_data_frame=output,
feature_group_arn=fp_config.output,
target_stores=fp_config.target_stores,
use_lake_formation_credentials=fp_config.use_lake_formation_credentials
)
except Py4JJavaError as e:
if e.java_exception.getClass().getSimpleName() == "StreamIngestionFailureException":
Expand Down
Loading
Loading