diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c597ab308..2cd792ea9 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,6 +29,9 @@ env: jobs: test-matrix: runs-on: ubuntu-latest + # Backstop: a hung multiprocessing worker (e.g. during a pickle regression) + # should not block CI longer than this. + timeout-minutes: 30 strategy: fail-fast: false matrix: diff --git a/crates/core/src/codec.rs b/crates/core/src/codec.rs index 088532df2..7992a8bd6 100644 --- a/crates/core/src/codec.rs +++ b/crates/core/src/codec.rs @@ -19,11 +19,11 @@ //! //! Datafusion-python plans can carry references to Python-defined //! objects that the upstream protobuf codecs do not know how to -//! serialize: pure-Python scalar / aggregate / window UDFs, Python -//! query-planning extensions, and so on. Their state lives inside -//! `Py` callables and closures rather than being recoverable -//! from a name in the receiver's function registry. To ship a plan -//! across a process boundary (pickle, `multiprocessing`, Ray actor, +//! serialize: pure-Python scalar UDFs, Python query-planning +//! extensions, and so on. Their state lives inside `Py` +//! callables and closures rather than being recoverable from a name +//! in the receiver's function registry. To ship a plan across a +//! process boundary (pickle, `multiprocessing`, Ray actor, //! `datafusion-distributed`, etc.) those payloads have to be encoded //! into the proto wire format itself. //! @@ -48,52 +48,135 @@ //! plans to survive a serialization round-trip. Both codecs share //! the same payload framing for that reason. //! -//! Payloads emitted by these codecs are tagged with an 8-byte magic -//! prefix so the decoder can distinguish them from arbitrary bytes -//! (empty `fun_definition` from the default codec, user FFI payloads -//! that picked a non-colliding prefix). Dispatch precedence on -//! decode: **Python-inline payload (magic prefix match) → `inner` -//! codec → caller's `FunctionRegistry` fallback.** +//! Payloads emitted by these codecs are framed as +//! ` `. The +//! family magic identifies the UDF flavor; the version byte lets the +//! decoder reject too-new or too-old payloads with a clean error +//! instead of falling into an opaque `cloudpickle` tuple-unpack +//! failure when the tuple shape changes. Dispatch precedence on +//! decode: **family match + supported version → `inner` codec → +//! caller's `FunctionRegistry` fallback.** //! -//! ## Wire-format magic prefix registry +//! ## Wire-format family registry //! -//! | Layer + kind | Magic prefix | -//! | ----------------------------- | ------------ | -//! | `PythonLogicalCodec` scalar | `DFPYUDF1` | -//! | `PythonLogicalCodec` agg | `DFPYUDA1` | -//! | `PythonLogicalCodec` window | `DFPYUDW1` | -//! | `PythonPhysicalCodec` scalar | `DFPYUDF1` | -//! | `PythonPhysicalCodec` agg | `DFPYUDA1` | -//! | `PythonPhysicalCodec` window | `DFPYUDW1` | -//! | `PythonPhysicalCodec` expr | `DFPYPE1` | -//! | User FFI extension codec | user-chosen | -//! | Default codec | (none) | +//! | Layer + kind | Family prefix | +//! | ----------------------------- | ------------- | +//! | `PythonLogicalCodec` scalar | `DFPYUDF` | +//! | `PythonLogicalCodec` agg | `DFPYUDA` | +//! | `PythonLogicalCodec` window | `DFPYUDW` | +//! | `PythonPhysicalCodec` scalar | `DFPYUDF` | +//! | `PythonPhysicalCodec` agg | `DFPYUDA` | +//! | `PythonPhysicalCodec` window | `DFPYUDW` | +//! | User FFI extension codec | user-chosen | +//! | Default codec | (none) | //! -//! Downstream FFI codecs should pick non-colliding prefixes (use a -//! `DF` namespace plus a crate-specific suffix). The codec +//! Current wire-format version is [`WIRE_VERSION_CURRENT`]; supported +//! receive range is `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. +//! Bump [`WIRE_VERSION_CURRENT`] whenever the cloudpickle tuple shape +//! changes; raise [`WIRE_VERSION_MIN_SUPPORTED`] when dropping support +//! for an older shape. +//! +//! Downstream FFI codecs should pick non-colliding family prefixes +//! (use a `DF` namespace plus a crate-specific suffix). The codec //! implementations in this module currently delegate every method to //! `inner`; the encoder/decoder hooks for each kind are added as the //! corresponding Python-side type becomes serializable. use std::sync::Arc; -use arrow::datatypes::SchemaRef; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::ipc::reader::StreamReader; +use arrow::ipc::writer::StreamWriter; use datafusion::common::{Result, TableReference}; use datafusion::datasource::TableProvider; use datafusion::datasource::file_format::FileFormatFactory; use datafusion::execution::TaskContext; -use datafusion::logical_expr::{AggregateUDF, Extension, LogicalPlan, ScalarUDF, WindowUDF}; +use datafusion::logical_expr::{ + AggregateUDF, AggregateUDFImpl, Extension, LogicalPlan, ScalarUDF, ScalarUDFImpl, Signature, + TypeSignature, Volatility, WindowUDF, WindowUDFImpl, +}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::{DefaultLogicalExtensionCodec, LogicalExtensionCodec}; use datafusion_proto::physical_plan::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; +use pyo3::prelude::*; +use pyo3::sync::PyOnceLock; +use pyo3::types::{PyBytes, PyTuple}; + +use crate::udaf::PythonFunctionAggregateUDF; +use crate::udf::PythonFunctionScalarUDF; +use crate::udwf::PythonFunctionWindowUDF; + +// Wire-format framing for inlined Python UDF payloads. +// +// Layout: ` `. +// The family magic identifies the UDF flavor; the version byte lets +// the decoder reject too-new or too-old payloads with a clean error +// instead of falling into an opaque `cloudpickle` tuple-unpack failure +// when the tuple shape changes. Bump [`WIRE_VERSION_CURRENT`] whenever +// the tuple shape changes; raise [`WIRE_VERSION_MIN_SUPPORTED`] when +// dropping support for an older shape. + +/// Family prefix for an inlined Python scalar UDF +/// (cloudpickled tuple of name, callable, input schema, return field, +/// volatility). +pub(crate) const PY_SCALAR_UDF_FAMILY: &[u8] = b"DFPYUDF"; + +/// Family prefix for an inlined Python aggregate UDF +/// (cloudpickled tuple of name, accumulator factory, input schema, +/// return type, state types schema, volatility). +pub(crate) const PY_AGG_UDF_FAMILY: &[u8] = b"DFPYUDA"; + +/// Family prefix for an inlined Python window UDF +/// (cloudpickled tuple of name, evaluator factory, input schema, +/// return type, volatility). +pub(crate) const PY_WINDOW_UDF_FAMILY: &[u8] = b"DFPYUDW"; + +/// Wire-format version this build emits. +pub(crate) const WIRE_VERSION_CURRENT: u8 = 1; + +/// Oldest wire-format version this build still decodes. Bump when +/// retiring support for an older payload shape. +pub(crate) const WIRE_VERSION_MIN_SUPPORTED: u8 = 1; + +/// Tag `buf` with the framing header for `family` at the current +/// wire-format version. Append-only — the caller writes the +/// cloudpickle payload after. +fn write_wire_header(buf: &mut Vec, family: &[u8]) { + buf.extend_from_slice(family); + buf.push(WIRE_VERSION_CURRENT); +} -/// Wire-format prefix that tags a `fun_definition` payload as an -/// inlined Python scalar UDF (cloudpickled tuple of name, callable, -/// input schema, return field, volatility). Defined once here so -/// the encoder and decoder cannot drift. -#[allow(dead_code)] -pub(crate) const PY_SCALAR_UDF_MAGIC: &[u8] = b"DFPYUDF1"; +/// Inspect the framing on `buf`. +/// +/// * `Ok(None)` — `buf` does not carry `family`. The caller should +/// delegate to its `inner` codec. +/// * `Ok(Some(payload))` — `buf` carries `family` at a version this +/// build accepts; `payload` is the cloudpickle blob. +/// * `Err(_)` — `buf` carries `family` but at a version outside +/// `WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT`. The error +/// names the version and the supported range so an operator can +/// diagnose sender/receiver version drift instead of seeing an +/// opaque cloudpickle tuple-unpack failure. +fn strip_wire_header<'a>(buf: &'a [u8], family: &[u8], kind: &str) -> Result> { + if !buf.starts_with(family) { + return Ok(None); + } + let version_idx = family.len(); + let Some(&version) = buf.get(version_idx) else { + return Err(datafusion::error::DataFusionError::Execution(format!( + "Truncated inline Python {kind} payload: missing wire-format version byte" + ))); + }; + if !(WIRE_VERSION_MIN_SUPPORTED..=WIRE_VERSION_CURRENT).contains(&version) { + return Err(datafusion::error::DataFusionError::Execution(format!( + "Inline Python {kind} payload wire-format version v{version}; \ + this build supports v{WIRE_VERSION_MIN_SUPPORTED}..=v{WIRE_VERSION_CURRENT}. \ + Align datafusion-python versions on sender and receiver." + ))); + } + Ok(Some(&buf[version_idx + 1..])) +} /// `LogicalExtensionCodec` parked on every `SessionContext`. Holds /// the Python-aware encoding hooks for logical-layer types @@ -108,16 +191,44 @@ pub(crate) const PY_SCALAR_UDF_MAGIC: &[u8] = b"DFPYUDF1"; #[derive(Debug)] pub struct PythonLogicalCodec { inner: Arc, + python_udf_inlining: bool, } impl PythonLogicalCodec { pub fn new(inner: Arc) -> Self { - Self { inner } + Self { + inner, + python_udf_inlining: true, + } } pub fn inner(&self) -> &Arc { &self.inner } + + /// Whether Python-defined UDFs are encoded inline (and decoded + /// from cloudpickle blobs). Defaults to `true`. Set to `false` + /// when the codec sits on a session that must produce + /// cross-language wire bytes, or reject `cloudpickle.loads` on + /// untrusted `from_bytes` input. + /// + /// Security scope: strict mode (`false`) narrows only the codec + /// layer — it stops `Expr::from_bytes` from invoking + /// `cloudpickle.loads` on the inline `DFPY*` payload. It does + /// **not** make `pickle.loads(untrusted_bytes)` safe; treat every + /// `pickle.loads` on untrusted input as unsafe regardless of this + /// setting. See Python's [pickle module security warning][1] for + /// why `pickle.loads` is unsafe in general. + /// + /// [1]: https://docs.python.org/3/library/pickle.html#module-pickle + pub fn with_python_udf_inlining(mut self, enabled: bool) -> Self { + self.python_udf_inlining = enabled; + self + } + + pub fn python_udf_inlining(&self) -> bool { + self.python_udf_inlining + } } impl Default for PythonLogicalCodec { @@ -177,30 +288,76 @@ impl LogicalExtensionCodec for PythonLogicalCodec { } fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + if self.python_udf_inlining && try_encode_python_scalar_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udf(node, buf) } fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if self.python_udf_inlining { + if let Some(udf) = try_decode_python_scalar_udf(buf)? { + return Ok(udf); + } + } else if buf.starts_with(PY_SCALAR_UDF_FAMILY) { + return Err(refuse_inline_payload("scalar UDF", name)); + } self.inner.try_decode_udf(name, buf) } fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + if self.python_udf_inlining && try_encode_python_agg_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udaf(node, buf) } fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if self.python_udf_inlining { + if let Some(udaf) = try_decode_python_agg_udf(buf)? { + return Ok(udaf); + } + } else if buf.starts_with(PY_AGG_UDF_FAMILY) { + return Err(refuse_inline_payload("aggregate UDF", name)); + } self.inner.try_decode_udaf(name, buf) } fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + if self.python_udf_inlining && try_encode_python_window_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udwf(node, buf) } fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + if self.python_udf_inlining { + if let Some(udwf) = try_decode_python_window_udf(buf)? { + return Ok(udwf); + } + } else if buf.starts_with(PY_WINDOW_UDF_FAMILY) { + return Err(refuse_inline_payload("window UDF", name)); + } self.inner.try_decode_udwf(name, buf) } } +/// Build the error returned by a strict codec when it receives an +/// inline Python-UDF payload it has been told not to deserialize. +fn refuse_inline_payload(kind: &str, name: &str) -> datafusion::error::DataFusionError { + // `Execution`, not `Plan`: this is a wire-format decode refusal at + // codec time, not a planner-stage failure. Downstream error + // classification keys off the variant — surfacing this as a planner + // error would mis-route it into "fix your SQL" buckets. + datafusion::error::DataFusionError::Execution(format!( + "Refusing to deserialize inline Python {kind} '{name}': Python UDF \ + inlining is disabled on this session. Ask the sender to re-encode \ + with inlining disabled (so the UDF travels by name), or register \ + '{name}' on this receiver's session and enable inlining on both \ + sides — receivers cannot re-encode bytes they did not produce." + )) +} + /// `PhysicalExtensionCodec` mirror of [`PythonLogicalCodec`] parked /// on the same `SessionContext`. Carries the Python-aware encoding /// hooks for physical-layer types (`ExecutionPlan`, `PhysicalExpr`) @@ -212,20 +369,34 @@ impl LogicalExtensionCodec for PythonLogicalCodec { /// encoding on this layer too — otherwise a plan with a Python UDF /// would round-trip at the logical level but break at the physical /// level. Both layers reuse the shared payload framing -/// ([`PY_SCALAR_UDF_MAGIC`] et al.) so the wire format is identical. +/// ([`PY_SCALAR_UDF_FAMILY`]) so the wire format is identical. #[derive(Debug)] pub struct PythonPhysicalCodec { inner: Arc, + python_udf_inlining: bool, } impl PythonPhysicalCodec { pub fn new(inner: Arc) -> Self { - Self { inner } + Self { + inner, + python_udf_inlining: true, + } } pub fn inner(&self) -> &Arc { &self.inner } + + /// See [`PythonLogicalCodec::with_python_udf_inlining`]. + pub fn with_python_udf_inlining(mut self, enabled: bool) -> Self { + self.python_udf_inlining = enabled; + self + } + + pub fn python_udf_inlining(&self) -> bool { + self.python_udf_inlining + } } impl Default for PythonPhysicalCodec { @@ -249,10 +420,20 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { } fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + if self.python_udf_inlining && try_encode_python_scalar_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udf(node, buf) } fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if self.python_udf_inlining { + if let Some(udf) = try_decode_python_scalar_udf(buf)? { + return Ok(udf); + } + } else if buf.starts_with(PY_SCALAR_UDF_FAMILY) { + return Err(refuse_inline_payload("scalar UDF", name)); + } self.inner.try_decode_udf(name, buf) } @@ -269,18 +450,527 @@ impl PhysicalExtensionCodec for PythonPhysicalCodec { } fn try_encode_udaf(&self, node: &AggregateUDF, buf: &mut Vec) -> Result<()> { + if self.python_udf_inlining && try_encode_python_agg_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udaf(node, buf) } fn try_decode_udaf(&self, name: &str, buf: &[u8]) -> Result> { + if self.python_udf_inlining { + if let Some(udaf) = try_decode_python_agg_udf(buf)? { + return Ok(udaf); + } + } else if buf.starts_with(PY_AGG_UDF_FAMILY) { + return Err(refuse_inline_payload("aggregate UDF", name)); + } self.inner.try_decode_udaf(name, buf) } fn try_encode_udwf(&self, node: &WindowUDF, buf: &mut Vec) -> Result<()> { + if self.python_udf_inlining && try_encode_python_window_udf(node, buf)? { + return Ok(()); + } self.inner.try_encode_udwf(node, buf) } fn try_decode_udwf(&self, name: &str, buf: &[u8]) -> Result> { + if self.python_udf_inlining { + if let Some(udwf) = try_decode_python_window_udf(buf)? { + return Ok(udwf); + } + } else if buf.starts_with(PY_WINDOW_UDF_FAMILY) { + return Err(refuse_inline_payload("window UDF", name)); + } self.inner.try_decode_udwf(name, buf) } } + +// ============================================================================= +// Shared Python scalar UDF encode / decode helpers +// +// Both `PythonLogicalCodec` and `PythonPhysicalCodec` consult these on +// every `try_encode_udf` / `try_decode_udf` call. Same wire format on +// both layers — a Python `ScalarUDF` referenced inside a `LogicalPlan` +// or an `ExecutionPlan` round-trips identically. +// ============================================================================= + +/// Encode a Python scalar UDF inline if `node` is one. Returns +/// `Ok(true)` when the payload (`DFPYUDF` family prefix, version byte, +/// cloudpickled tuple) was written and the caller should skip its +/// inner codec. Returns `Ok(false)` for any non-Python UDF, signalling +/// the caller to delegate to its `inner`. +pub(crate) fn try_encode_python_scalar_udf(node: &ScalarUDF, buf: &mut Vec) -> Result { + let Some(py_udf) = node + .inner() + .as_any() + .downcast_ref::() + else { + return Ok(false); + }; + + Python::attach(|py| -> Result { + let bytes = encode_python_scalar_udf(py, py_udf) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + write_wire_header(buf, PY_SCALAR_UDF_FAMILY); + buf.extend_from_slice(&bytes); + Ok(true) + }) +} + +/// Decode an inline Python scalar UDF payload. Returns `Ok(None)` +/// when `buf` does not carry the `DFPYUDF` family prefix, signalling +/// the caller to delegate to its `inner` codec (and eventually the +/// `FunctionRegistry`). +pub(crate) fn try_decode_python_scalar_udf(buf: &[u8]) -> Result>> { + let Some(payload) = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF")? else { + return Ok(None); + }; + + Python::attach(|py| -> Result>> { + let udf = decode_python_scalar_udf(py, payload) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + Ok(Some(Arc::new(ScalarUDF::new_from_impl(udf)))) + }) +} + +/// Build the cloudpickle payload for a `PythonFunctionScalarUDF`. +/// +/// Layout: `cloudpickle.dumps((name, func, input_schema_bytes, +/// return_schema_bytes, volatility_str))`. Schema blobs are produced +/// by arrow-rs's native IPC stream writer (no pyarrow round-trip) and +/// decoded with the matching stream reader on the receiver. See +/// [`build_input_schema_bytes`] for what the input blob carries. +fn encode_python_scalar_udf(py: Python<'_>, udf: &PythonFunctionScalarUDF) -> PyResult> { + let signature = udf.signature(); + let input_dtypes = signature_input_dtypes(signature, "PythonFunctionScalarUDF")?; + let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; + let return_schema_bytes = build_single_field_schema_bytes(udf.return_field().as_ref())?; + let volatility = volatility_wire_str(signature.volatility); + + let payload = PyTuple::new( + py, + [ + udf.name().into_pyobject(py)?.into_any(), + udf.func().bind(py).clone().into_any(), + PyBytes::new(py, &input_schema_bytes).into_any(), + PyBytes::new(py, &return_schema_bytes).into_any(), + volatility.into_pyobject(py)?.into_any(), + ], + )?; + + cloudpickle(py)? + .call_method1("dumps", (payload,))? + .extract::>() +} + +/// Inverse of [`encode_python_scalar_udf`]. +fn decode_python_scalar_udf(py: Python<'_>, payload: &[u8]) -> PyResult { + let tuple = cloudpickle(py)? + .call_method1("loads", (PyBytes::new(py, payload),))? + .cast_into::()?; + + let name: String = tuple.get_item(0)?.extract()?; + let func: Py = tuple.get_item(1)?.unbind(); + let input_schema_bytes: Vec = tuple.get_item(2)?.extract()?; + let return_schema_bytes: Vec = tuple.get_item(3)?.extract()?; + let volatility_str: String = tuple.get_item(4)?.extract()?; + + let input_types = read_input_dtypes(&input_schema_bytes)?; + let return_field = read_single_return_field(&return_schema_bytes, "PythonFunctionScalarUDF")?; + let volatility = parse_volatility_str(&volatility_str)?; + + Ok(PythonFunctionScalarUDF::from_parts( + name, + func, + input_types, + return_field, + volatility, + )) +} + +/// Serialize a `Schema` to a self-contained IPC stream containing +/// only the schema message (no record batches). Inverse: +/// [`schema_from_ipc_bytes`]. +fn schema_to_ipc_bytes(schema: &Schema) -> arrow::error::Result> { + let mut buf: Vec = Vec::new(); + { + let mut writer = StreamWriter::try_new(&mut buf, schema)?; + writer.finish()?; + } + Ok(buf) +} + +/// Decode an IPC stream containing only a schema message back into a +/// `Schema`. Inverse: [`schema_to_ipc_bytes`]. +fn schema_from_ipc_bytes(bytes: &[u8]) -> arrow::error::Result { + let reader = StreamReader::try_new(std::io::Cursor::new(bytes), None)?; + Ok(reader.schema().as_ref().clone()) +} + +/// Extract the per-arg `DataType`s from a `Signature` known to be +/// `TypeSignature::Exact` (all Python-defined UDFs are constructed +/// with `Signature::exact`). Any other variant indicates the impl was +/// not built by this crate's UDF/UDAF/UDWF constructors. +fn signature_input_dtypes(signature: &Signature, kind: &str) -> PyResult> { + match &signature.type_signature { + TypeSignature::Exact(types) => Ok(types.clone()), + other => Err(pyo3::exceptions::PyValueError::new_err(format!( + "{kind} expected Signature::Exact, got {other:?}" + ))), + } +} + +/// Wrap per-arg `DataType`s in synthetic `arg_{i}` fields and emit +/// the IPC schema blob the encoder writes into the cloudpickle tuple. +/// +/// The names and `nullable: true` are arbitrary: the underlying +/// `TypeSignature::Exact` carries no per-input nullability or +/// metadata, and the receiver collapses these fields back to +/// `Vec` via [`read_input_dtypes`], so anything set here +/// beyond the data type is discarded on decode. +fn build_input_schema_bytes(dtypes: &[DataType]) -> PyResult> { + let fields: Vec = dtypes + .iter() + .enumerate() + .map(|(i, dt)| Field::new(format!("arg_{i}"), dt.clone(), true)) + .collect(); + schema_to_ipc_bytes(&Schema::new(fields)).map_err(arrow_to_py_err) +} + +/// Emit a single-field IPC schema blob. Used for return-type and +/// state-field payloads where the receiver needs to recover field +/// metadata (names, nullability, key/value attributes) verbatim. +fn build_single_field_schema_bytes(field: &Field) -> PyResult> { + schema_to_ipc_bytes(&Schema::new(vec![field.clone()])).map_err(arrow_to_py_err) +} + +/// Emit a multi-field IPC schema blob. +fn build_schema_bytes(fields: Vec) -> PyResult> { + schema_to_ipc_bytes(&Schema::new(fields)).map_err(arrow_to_py_err) +} + +/// Decode the per-arg `DataType`s the encoder wrote via +/// [`build_input_schema_bytes`]. +fn read_input_dtypes(bytes: &[u8]) -> PyResult> { + let schema = schema_from_ipc_bytes(bytes).map_err(arrow_to_py_err)?; + Ok(schema + .fields() + .iter() + .map(|f| f.data_type().clone()) + .collect()) +} + +/// Decode a single-field IPC schema blob and return that field by +/// value. `kind` names the UDF flavor in the error message produced +/// when the blob is empty (should be unreachable for sender-side +/// payloads built via [`build_single_field_schema_bytes`]). +fn read_single_return_field(bytes: &[u8], kind: &str) -> PyResult { + let schema = schema_from_ipc_bytes(bytes).map_err(arrow_to_py_err)?; + let field = schema.fields().first().ok_or_else(|| { + pyo3::exceptions::PyValueError::new_err(format!( + "{kind} return schema must contain exactly one field" + )) + })?; + Ok(field.as_ref().clone()) +} + +fn arrow_to_py_err(e: arrow::error::ArrowError) -> PyErr { + pyo3::exceptions::PyValueError::new_err(format!("{e}")) +} + +fn parse_volatility_str(s: &str) -> PyResult { + datafusion_python_util::parse_volatility(s) + .map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}"))) +} + +/// Stable wire-format string for a `Volatility`. Pinned to the three +/// tokens [`datafusion_python_util::parse_volatility`] accepts, so an +/// upstream change to `Volatility`'s `Debug` repr cannot silently +/// produce bytes the decoder rejects. +fn volatility_wire_str(v: Volatility) -> &'static str { + match v { + Volatility::Immutable => "immutable", + Volatility::Stable => "stable", + Volatility::Volatile => "volatile", + } +} + +/// Cached handle to the `cloudpickle` module. +/// +/// The encode/decode helpers above would otherwise re-resolve the +/// module on every call. `py.import` is backed by `sys.modules` and +/// therefore cheap, but each call still walks a dict and re-binds the +/// result; a plan with many Python UDFs pays that cost per UDF. +/// +/// `PyOnceLock` scopes the cached `Py` to the current +/// interpreter, so the slot drops cleanly on interpreter teardown +/// (relevant under CPython subinterpreters, PEP 684) instead of +/// resurrecting a `Py` rooted in a dead interpreter on the next call. +fn cloudpickle<'py>(py: Python<'py>) -> PyResult> { + static CLOUDPICKLE: PyOnceLock> = PyOnceLock::new(); + CLOUDPICKLE + .get_or_try_init(py, || Ok(py.import("cloudpickle")?.unbind().into_any())) + .map(|cached| cached.bind(py).clone()) +} + +// ============================================================================= +// Shared Python window UDF encode / decode helpers +// +// Cloudpickle tuple shape: `(name, evaluator_factory, input_schema_bytes, +// return_schema_bytes, volatility_str)`. The evaluator factory is the +// Python callable that produces a new evaluator instance per partition. +// ============================================================================= + +pub(crate) fn try_encode_python_window_udf(node: &WindowUDF, buf: &mut Vec) -> Result { + let Some(py_udf) = node + .inner() + .as_any() + .downcast_ref::() + else { + return Ok(false); + }; + + Python::attach(|py| -> Result { + let bytes = encode_python_window_udf(py, py_udf) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + write_wire_header(buf, PY_WINDOW_UDF_FAMILY); + buf.extend_from_slice(&bytes); + Ok(true) + }) +} + +pub(crate) fn try_decode_python_window_udf(buf: &[u8]) -> Result>> { + let Some(payload) = strip_wire_header(buf, PY_WINDOW_UDF_FAMILY, "window UDF")? else { + return Ok(None); + }; + + Python::attach(|py| -> Result>> { + let udf = decode_python_window_udf(py, payload) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + Ok(Some(Arc::new(WindowUDF::new_from_impl(udf)))) + }) +} + +fn encode_python_window_udf(py: Python<'_>, udf: &PythonFunctionWindowUDF) -> PyResult> { + let signature = WindowUDFImpl::signature(udf); + let input_dtypes = signature_input_dtypes(signature, "PythonFunctionWindowUDF")?; + let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; + let return_field = Field::new("result", udf.return_type().clone(), true); + let return_schema_bytes = build_single_field_schema_bytes(&return_field)?; + let volatility = volatility_wire_str(signature.volatility); + + let payload = PyTuple::new( + py, + [ + WindowUDFImpl::name(udf).into_pyobject(py)?.into_any(), + udf.evaluator().bind(py).clone().into_any(), + PyBytes::new(py, &input_schema_bytes).into_any(), + PyBytes::new(py, &return_schema_bytes).into_any(), + volatility.into_pyobject(py)?.into_any(), + ], + )?; + + cloudpickle(py)? + .call_method1("dumps", (payload,))? + .extract::>() +} + +fn decode_python_window_udf(py: Python<'_>, payload: &[u8]) -> PyResult { + let tuple = cloudpickle(py)? + .call_method1("loads", (PyBytes::new(py, payload),))? + .cast_into::()?; + + let name: String = tuple.get_item(0)?.extract()?; + let evaluator: Py = tuple.get_item(1)?.unbind(); + let input_schema_bytes: Vec = tuple.get_item(2)?.extract()?; + let return_schema_bytes: Vec = tuple.get_item(3)?.extract()?; + let volatility_str: String = tuple.get_item(4)?.extract()?; + + let input_types = read_input_dtypes(&input_schema_bytes)?; + let return_type = read_single_return_field(&return_schema_bytes, "PythonFunctionWindowUDF")? + .data_type() + .clone(); + let volatility = parse_volatility_str(&volatility_str)?; + + Ok(PythonFunctionWindowUDF::new( + name, + evaluator, + input_types, + return_type, + volatility, + )) +} + +// ============================================================================= +// Shared Python aggregate UDF encode / decode helpers +// +// Cloudpickle tuple shape: `(name, accumulator_factory, input_schema_bytes, +// return_type_bytes, state_schema_bytes, volatility_str)`. The accumulator +// factory is the Python callable that produces a new accumulator instance +// per partition. +// ============================================================================= + +pub(crate) fn try_encode_python_agg_udf(node: &AggregateUDF, buf: &mut Vec) -> Result { + let Some(py_udf) = node + .inner() + .as_any() + .downcast_ref::() + else { + return Ok(false); + }; + + Python::attach(|py| -> Result { + let bytes = encode_python_agg_udf(py, py_udf) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + write_wire_header(buf, PY_AGG_UDF_FAMILY); + buf.extend_from_slice(&bytes); + Ok(true) + }) +} + +pub(crate) fn try_decode_python_agg_udf(buf: &[u8]) -> Result>> { + let Some(payload) = strip_wire_header(buf, PY_AGG_UDF_FAMILY, "aggregate UDF")? else { + return Ok(None); + }; + + Python::attach(|py| -> Result>> { + let udf = decode_python_agg_udf(py, payload) + .map_err(|e| datafusion::error::DataFusionError::External(Box::new(e)))?; + Ok(Some(Arc::new(AggregateUDF::new_from_impl(udf)))) + }) +} + +fn encode_python_agg_udf(py: Python<'_>, udf: &PythonFunctionAggregateUDF) -> PyResult> { + let signature = AggregateUDFImpl::signature(udf); + let input_dtypes = signature_input_dtypes(signature, "PythonFunctionAggregateUDF")?; + let input_schema_bytes = build_input_schema_bytes(&input_dtypes)?; + let return_field = Field::new("result", udf.return_type().clone(), true); + let return_schema_bytes = build_single_field_schema_bytes(&return_field)?; + let state_fields: Vec = udf + .state_fields_ref() + .iter() + .map(|f| f.as_ref().clone()) + .collect(); + let state_schema_bytes = build_schema_bytes(state_fields)?; + let volatility = volatility_wire_str(signature.volatility); + + let payload = PyTuple::new( + py, + [ + AggregateUDFImpl::name(udf).into_pyobject(py)?.into_any(), + udf.accumulator().bind(py).clone().into_any(), + PyBytes::new(py, &input_schema_bytes).into_any(), + PyBytes::new(py, &return_schema_bytes).into_any(), + PyBytes::new(py, &state_schema_bytes).into_any(), + volatility.into_pyobject(py)?.into_any(), + ], + )?; + + cloudpickle(py)? + .call_method1("dumps", (payload,))? + .extract::>() +} + +fn decode_python_agg_udf(py: Python<'_>, payload: &[u8]) -> PyResult { + let tuple = cloudpickle(py)? + .call_method1("loads", (PyBytes::new(py, payload),))? + .cast_into::()?; + + let name: String = tuple.get_item(0)?.extract()?; + let accumulator: Py = tuple.get_item(1)?.unbind(); + let input_schema_bytes: Vec = tuple.get_item(2)?.extract()?; + let return_schema_bytes: Vec = tuple.get_item(3)?.extract()?; + let state_schema_bytes: Vec = tuple.get_item(4)?.extract()?; + let volatility_str: String = tuple.get_item(5)?.extract()?; + + let input_types = read_input_dtypes(&input_schema_bytes)?; + let return_type = read_single_return_field(&return_schema_bytes, "PythonFunctionAggregateUDF")? + .data_type() + .clone(); + // Preserve the encoded state field metadata (names, nullability, + // arbitrary key/value attributes) so the post-decode UDF reports + // the same state schema as the sender's instance — important for + // accumulators whose `StateFieldsArgs` consumers key off names or + // nullability rather than positional `DataType`. + let state_schema = schema_from_ipc_bytes(&state_schema_bytes).map_err(arrow_to_py_err)?; + let state_fields: Vec = + state_schema.fields().iter().cloned().collect(); + let volatility = parse_volatility_str(&volatility_str)?; + + Ok(PythonFunctionAggregateUDF::from_parts( + name, + accumulator, + input_types, + return_type, + state_fields, + volatility, + )) +} + +#[cfg(test)] +mod wire_header_tests { + use super::*; + + #[test] + fn strip_returns_none_when_family_absent() { + let buf = b"OTHER_PAYLOAD"; + assert!(matches!( + strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF"), + Ok(None) + )); + } + + #[test] + fn strip_errors_on_truncated_version_byte() { + let buf = PY_SCALAR_UDF_FAMILY; + let err = strip_wire_header(buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").unwrap_err(); + assert!(format!("{err}").contains("missing wire-format version byte")); + } + + #[test] + fn strip_errors_on_too_new_version() { + let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); + buf.push(WIRE_VERSION_CURRENT.saturating_add(1)); + buf.extend_from_slice(b"payload"); + let err = strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("wire-format version v")); + assert!(msg.contains("supports")); + assert!(msg.contains("Align datafusion-python versions")); + } + + #[test] + fn strip_errors_on_too_old_version() { + if WIRE_VERSION_MIN_SUPPORTED == 0 { + return; + } + let mut buf = PY_SCALAR_UDF_FAMILY.to_vec(); + buf.push(WIRE_VERSION_MIN_SUPPORTED - 1); + buf.extend_from_slice(b"payload"); + assert!(strip_wire_header(&buf, PY_SCALAR_UDF_FAMILY, "scalar UDF").is_err()); + } + + #[test] + fn write_then_strip_round_trips_payload() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_AGG_UDF_FAMILY); + buf.extend_from_slice(b"agg-payload"); + + let payload = strip_wire_header(&buf, PY_AGG_UDF_FAMILY, "aggregate UDF") + .unwrap() + .unwrap(); + assert_eq!(payload, b"agg-payload"); + } + + #[test] + fn strip_does_not_match_a_different_family() { + let mut buf = Vec::new(); + write_wire_header(&mut buf, PY_SCALAR_UDF_FAMILY); + buf.extend_from_slice(b"payload"); + assert!(matches!( + strip_wire_header(&buf, PY_WINDOW_UDF_FAMILY, "window UDF"), + Ok(None) + )); + } +} diff --git a/crates/core/src/context.rs b/crates/core/src/context.rs index 96de01889..1de8644ad 100644 --- a/crates/core/src/context.rs +++ b/crates/core/src/context.rs @@ -1407,6 +1407,22 @@ impl PySessionContext { physical_codec, }) } + + pub fn with_python_udf_inlining(&self, enabled: bool) -> Self { + let logical_codec = Arc::new( + PythonLogicalCodec::new(Arc::clone(self.logical_codec.inner())) + .with_python_udf_inlining(enabled), + ); + let physical_codec = Arc::new( + PythonPhysicalCodec::new(Arc::clone(self.physical_codec.inner())) + .with_python_udf_inlining(enabled), + ); + Self { + ctx: Arc::clone(&self.ctx), + logical_codec, + physical_codec, + } + } } impl PySessionContext { diff --git a/crates/core/src/udaf.rs b/crates/core/src/udaf.rs index 80ef51716..cb84fa375 100644 --- a/crates/core/src/udaf.rs +++ b/crates/core/src/udaf.rs @@ -15,16 +15,18 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::ptr::NonNull; use std::sync::Arc; use datafusion::arrow::array::ArrayRef; -use datafusion::arrow::datatypes::DataType; +use datafusion::arrow::datatypes::{DataType, Field, FieldRef}; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::common::ScalarValue; use datafusion::error::{DataFusionError, Result}; +use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs}; use datafusion::logical_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, AggregateUDFImpl, create_udaf, + Accumulator, AggregateUDF, AggregateUDFImpl, Signature, Volatility, }; use datafusion_ffi::udaf::FFI_AggregateUDF; use datafusion_python_util::parse_volatility; @@ -144,15 +146,161 @@ impl Accumulator for RustAccumulator { } } -pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { - Arc::new(move |_args| -> Result> { - let accum = Python::attach(|py| { - accum - .call0(py) - .map_err(|e| DataFusionError::Execution(format!("{e}"))) - })?; - Ok(Box::new(RustAccumulator::new(accum))) - }) +fn instantiate_accumulator(accum: &Py) -> Result> { + let instance = Python::attach(|py| { + accum + .call0(py) + .map_err(|e| DataFusionError::Execution(format!("{e}"))) + })?; + Ok(Box::new(RustAccumulator::new(instance))) +} + +/// Named-struct `AggregateUDFImpl` for Python-defined aggregate UDFs. +/// Holds the Python accumulator factory directly so the codec can +/// downcast and cloudpickle it across process boundaries. +#[derive(Debug)] +pub(crate) struct PythonFunctionAggregateUDF { + name: String, + accumulator: Py, + signature: Signature, + return_type: DataType, + state_fields: Vec, +} + +impl PythonFunctionAggregateUDF { + fn new( + name: String, + accumulator: Py, + input_types: Vec, + return_type: DataType, + state_types: Vec, + volatility: Volatility, + ) -> Self { + let signature = Signature::exact(input_types, volatility); + let state_fields = state_types + .into_iter() + .enumerate() + .map(|(i, t)| Arc::new(Field::new(format!("state_{i}"), t, true))) + .collect(); + Self { + name, + accumulator, + signature, + return_type, + state_fields, + } + } + + /// Stored Python callable that returns a fresh accumulator instance + /// per partition. Consumed by the codec to cloudpickle the factory + /// across process boundaries. + pub(crate) fn accumulator(&self) -> &Py { + &self.accumulator + } + + pub(crate) fn return_type(&self) -> &DataType { + &self.return_type + } + + pub(crate) fn state_fields_ref(&self) -> &[FieldRef] { + &self.state_fields + } + + /// Reconstruct a `PythonFunctionAggregateUDF` from the parts emitted + /// by the codec. `state_fields` carries the full state schema + /// (names, data types, nullability, metadata) — the codec extracts + /// it from the IPC payload, so the post-decode state schema is + /// identical to the pre-encode one. Use [`Self::new`] when only + /// `Vec` is available (e.g. the Python constructor path, + /// where field names are synthesized). + pub(crate) fn from_parts( + name: String, + accumulator: Py, + input_types: Vec, + return_type: DataType, + state_fields: Vec, + volatility: Volatility, + ) -> Self { + Self { + name, + accumulator, + signature: Signature::exact(input_types, volatility), + return_type, + state_fields, + } + } +} + +impl Eq for PythonFunctionAggregateUDF {} +impl PartialEq for PythonFunctionAggregateUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.signature == other.signature + && self.return_type == other.return_type + && self.state_fields == other.state_fields + // Pointer-identity fast path: `Arc`-shared clones of the + // same UDF skip the GIL roundtrip. Falls through to Python + // `__eq__` only for two distinct callables. + && (self.accumulator.as_ptr() == other.accumulator.as_ptr() + || Python::attach(|py| { + // See `PythonFunctionScalarUDF::eq` for the + // rationale on swallowing the exception as `false` + // and logging at `debug`. FIXME: revisit if + // upstream `AggregateUDFImpl` exposes a fallible + // `PartialEq`. + self.accumulator + .bind(py) + .eq(other.accumulator.bind(py)) + .unwrap_or_else(|e| { + log::debug!( + target: "datafusion_python::udaf", + "PythonFunctionAggregateUDF {:?} __eq__ raised; treating as unequal: {e}", + self.name, + ); + false + }) + })) + } +} + +impl std::hash::Hash for PythonFunctionAggregateUDF { + fn hash(&self, state: &mut H) { + // See `PythonFunctionScalarUDF`'s `Hash` impl for the + // rationale: hash the identifying header only and let + // `PartialEq` disambiguate callables. + self.name.hash(state); + self.signature.hash(state); + self.return_type.hash(state); + for f in &self.state_fields { + f.hash(state); + } + } +} + +impl AggregateUDFImpl for PythonFunctionAggregateUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(self.return_type.clone()) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + instantiate_accumulator(&self.accumulator) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(self.state_fields.clone()) + } } fn aggregate_udf_from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyDataFusionResult { @@ -190,14 +338,15 @@ impl PyAggregateUDF { state_type: PyArrowType>, volatility: &str, ) -> PyResult { - let function = create_udaf( - name, + let py_udf = PythonFunctionAggregateUDF::new( + name.to_string(), + accumulator, input_type.0, - Arc::new(return_type.0), + return_type.0, + state_type.0, parse_volatility(volatility)?, - to_rust_accumulator(accumulator), - Arc::new(state_type.0), ); + let function = AggregateUDF::new_from_impl(py_udf); Ok(Self { function }) } @@ -231,4 +380,9 @@ impl PyAggregateUDF { fn __repr__(&self) -> PyResult { Ok(format!("AggregateUDF({})", self.function.name())) } + + #[getter] + fn name(&self) -> &str { + self.function.name() + } } diff --git a/crates/core/src/udf.rs b/crates/core/src/udf.rs index c0a39cb47..72cdddba1 100644 --- a/crates/core/src/udf.rs +++ b/crates/core/src/udf.rs @@ -43,7 +43,7 @@ use crate::expr::PyExpr; /// This struct holds the Python written function that is a /// ScalarUDF. #[derive(Debug)] -struct PythonFunctionScalarUDF { +pub(crate) struct PythonFunctionScalarUDF { name: String, func: Py, signature: Signature, @@ -67,6 +67,37 @@ impl PythonFunctionScalarUDF { return_field: Arc::new(return_field), } } + + /// Stored Python callable. Consumed by the codec to cloudpickle + /// the function body across process boundaries. + pub(crate) fn func(&self) -> &Py { + &self.func + } + + pub(crate) fn return_field(&self) -> &FieldRef { + &self.return_field + } + + /// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted + /// by the codec. Inputs collapse to `Vec` because + /// `Signature::exact` cannot carry per-input nullability or + /// metadata — the encoder is free to discard that side of the + /// schema. `return_field` is kept as a `Field` so the post-decode + /// nullability and metadata match the sender's instance. + pub(crate) fn from_parts( + name: String, + func: Py, + input_types: Vec, + return_field: Field, + volatility: Volatility, + ) -> Self { + Self { + name, + func, + signature: Signature::exact(input_types, volatility), + return_field: Arc::new(return_field), + } + } } impl Eq for PythonFunctionScalarUDF {} @@ -75,21 +106,51 @@ impl PartialEq for PythonFunctionScalarUDF { self.name == other.name && self.signature == other.signature && self.return_field == other.return_field - && Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false)) + // Identical pointers ⇒ same Python object. Most equality + // checks compare `Arc`-shared clones of the same UDF + // (e.g. expression rewriting), so the pointer match short- + // circuits before touching the GIL. + && (self.func.as_ptr() == other.func.as_ptr() + || Python::attach(|py| { + // Rust's `PartialEq` cannot return `Result`, so we + // have to pick a side when Python `__eq__` raises. + // `false` is the conservative choice — better to + // report two UDFs as distinct than to wrongly + // merge them — but the silent miss can still + // surface as expression-dedup or cache-lookup + // anomalies. Log at `debug` so the failure is + // observable without flooding production logs. + // FIXME: revisit if upstream `ScalarUDFImpl` + // exposes a fallible `PartialEq`. + self.func + .bind(py) + .eq(other.func.bind(py)) + .unwrap_or_else(|e| { + log::debug!( + target: "datafusion_python::udf", + "PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}", + self.name, + ); + false + }) + })) } } impl Hash for PythonFunctionScalarUDF { fn hash(&self, state: &mut H) { + // Hash only the identifying header (name + signature + return + // field). Skipping `func` is intentional: the Rust `Hash` + // contract requires `a == b ⇒ hash(a) == hash(b)`, not the + // converse, so a coarser hash is sound — `PartialEq` still + // disambiguates two UDFs with the same header but distinct + // callables. Falling back to a sentinel on `py_hash` failure + // (as a prior revision did) silently mapped every unhashable + // closure to the same bucket; that is the worst case for a + // hashmap and is what this rewrite avoids. self.name.hash(state); self.signature.hash(state); self.return_field.hash(state); - - Python::attach(|py| { - let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects - - state.write_isize(py_hash); - }); } } @@ -220,4 +281,9 @@ impl PyScalarUDF { fn __repr__(&self) -> PyResult { Ok(format!("ScalarUDF({})", self.function.name())) } + + #[getter] + fn name(&self) -> &str { + self.function.name() + } } diff --git a/crates/core/src/udwf.rs b/crates/core/src/udwf.rs index 1d3608ada..5ce09e6d2 100644 --- a/crates/core/src/udwf.rs +++ b/crates/core/src/udwf.rs @@ -25,10 +25,9 @@ use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{FromPyArrow, PyArrowType, ToPyArrow}; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_expr::function::{PartitionEvaluatorArgs, WindowUDFFieldArgs}; -use datafusion::logical_expr::ptr_eq::PtrEq; use datafusion::logical_expr::window_state::WindowAggState; use datafusion::logical_expr::{ - PartitionEvaluator, PartitionEvaluatorFactory, Signature, Volatility, WindowUDF, WindowUDFImpl, + PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl, }; use datafusion::scalar::ScalarValue; use datafusion_ffi::udwf::FFI_WindowUDF; @@ -198,15 +197,13 @@ impl PartitionEvaluator for RustPartitionEvaluator { } } -pub fn to_rust_partition_evaluator(evaluator: Py) -> PartitionEvaluatorFactory { - Arc::new(move || -> Result> { - let evaluator = Python::attach(|py| { - evaluator - .call0(py) - .map_err(|e| DataFusionError::Execution(e.to_string())) - })?; - Ok(Box::new(RustPartitionEvaluator::new(evaluator))) - }) +fn instantiate_partition_evaluator(evaluator: &Py) -> Result> { + let instance = Python::attach(|py| { + evaluator + .call0(py) + .map_err(|e| DataFusionError::Execution(e.to_string())) + })?; + Ok(Box::new(RustPartitionEvaluator::new(instance))) } /// Represents an WindowUDF @@ -234,14 +231,14 @@ impl PyWindowUDF { volatility: &str, ) -> PyResult { let return_type = return_type.0; - let input_types = input_types.into_iter().map(|t| t.0).collect(); + let input_types: Vec = input_types.into_iter().map(|t| t.0).collect(); - let function = WindowUDF::from(MultiColumnWindowUDF::new( + let function = WindowUDF::from(PythonFunctionWindowUDF::new( name, + evaluator, input_types, return_type, parse_volatility(volatility)?, - to_rust_partition_evaluator(evaluator), )); Ok(Self { function }) } @@ -276,47 +273,94 @@ impl PyWindowUDF { fn __repr__(&self) -> PyResult { Ok(format!("WindowUDF({})", self.function.name())) } + + #[getter] + fn name(&self) -> &str { + self.function.name() + } } -#[derive(Hash, Eq, PartialEq)] -pub struct MultiColumnWindowUDF { +#[derive(Debug)] +pub(crate) struct PythonFunctionWindowUDF { name: String, + evaluator: Py, signature: Signature, return_type: DataType, - partition_evaluator_factory: PtrEq, } -impl std::fmt::Debug for MultiColumnWindowUDF { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - f.debug_struct("WindowUDF") - .field("name", &self.name) - .field("signature", &self.signature) - .field("return_type", &"") - .field("partition_evaluator_factory", &"") - .finish() - } -} - -impl MultiColumnWindowUDF { - pub fn new( +impl PythonFunctionWindowUDF { + pub(crate) fn new( name: impl Into, + evaluator: Py, input_types: Vec, return_type: DataType, volatility: Volatility, - partition_evaluator_factory: PartitionEvaluatorFactory, ) -> Self { let name = name.into(); let signature = Signature::exact(input_types, volatility); Self { name, + evaluator, signature, return_type, - partition_evaluator_factory: partition_evaluator_factory.into(), } } + + /// Stored Python callable that produces a fresh partition + /// evaluator instance per partition. Consumed by the codec to + /// cloudpickle the evaluator factory across process boundaries. + pub(crate) fn evaluator(&self) -> &Py { + &self.evaluator + } + + pub(crate) fn return_type(&self) -> &DataType { + &self.return_type + } +} + +impl Eq for PythonFunctionWindowUDF {} +impl PartialEq for PythonFunctionWindowUDF { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.signature == other.signature + && self.return_type == other.return_type + // Pointer-identity fast path: `Arc`-shared clones of the + // same UDF skip the GIL roundtrip. Falls through to Python + // `__eq__` only for two distinct callables. + && (self.evaluator.as_ptr() == other.evaluator.as_ptr() + || Python::attach(|py| { + // See `PythonFunctionScalarUDF::eq` for the + // rationale on swallowing the exception as `false` + // and logging at `debug`. FIXME: revisit if + // upstream `WindowUDFImpl` exposes a fallible + // `PartialEq`. + self.evaluator + .bind(py) + .eq(other.evaluator.bind(py)) + .unwrap_or_else(|e| { + log::debug!( + target: "datafusion_python::udwf", + "PythonFunctionWindowUDF {:?} __eq__ raised; treating as unequal: {e}", + self.name, + ); + false + }) + })) + } +} + +impl std::hash::Hash for PythonFunctionWindowUDF { + fn hash(&self, state: &mut H) { + // See `PythonFunctionScalarUDF`'s `Hash` impl for the + // rationale: hash the identifying header only and let + // `PartialEq` disambiguate evaluators. + self.name.hash(state); + self.signature.hash(state); + self.return_type.hash(state); + } } -impl WindowUDFImpl for MultiColumnWindowUDF { +impl WindowUDFImpl for PythonFunctionWindowUDF { fn as_any(&self) -> &dyn Any { self } @@ -339,7 +383,6 @@ impl WindowUDFImpl for MultiColumnWindowUDF { &self, _partition_evaluator_args: PartitionEvaluatorArgs, ) -> Result> { - let _ = _partition_evaluator_args; - (self.partition_evaluator_factory)() + instantiate_partition_evaluator(&self.evaluator) } } diff --git a/pyproject.toml b/pyproject.toml index 951f7adc3..a02f4608a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,13 @@ classifiers = [ "Programming Language :: Rust", ] dependencies = [ + # cloudpickle is invoked by the Rust-side PythonLogicalCodec / + # PythonPhysicalCodec via pyo3 to serialize Python UDF callables — + # scalar, aggregate, and window — into the proto wire format. + # Lazy-imported on the encode / decode hot paths (and cached after + # the first import), so users who never serialize a plan or + # expression incur no runtime cost beyond the install footprint. + "cloudpickle>=2.0", "pyarrow>=16.0.0;python_version<'3.14'", "pyarrow>=22.0.0;python_version>='3.14'", "typing-extensions;python_version<'3.13'", diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index f08b464bb..dfdeef07e 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -65,7 +65,7 @@ import importlib_metadata # type: ignore[import] # Public submodules -from . import functions, object_store, substrait, unparser +from . import functions, ipc, object_store, substrait, unparser # The following imports are okay to remain as opaque to the user. from ._internal import Config @@ -142,6 +142,7 @@ "configure_formatter", "expr", "functions", + "ipc", "lit", "literal", "object_store", diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 5c3501941..e3ceb3b05 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -1769,3 +1769,46 @@ def with_physical_extension_codec(self, codec: Any) -> SessionContext: new = SessionContext.__new__(SessionContext) new.ctx = new_internal return new + + def with_python_udf_inlining(self, *, enabled: bool) -> SessionContext: + """Toggle inline encoding of Python-defined UDFs on this session. + + ``enabled`` is keyword-only: + ``with_python_udf_inlining(enabled=False)`` reads at the call + site as the inverse of + ``with_python_udf_inlining(enabled=True)``, where a positional + ``True`` / ``False`` would not. + + When ``True`` (the default), Python scalar, aggregate, and window + UDFs travel inside the serialized expression and are + reconstructed on the receiver without pre-registration. + + Set ``False`` to: + + * Produce serialized bytes that round-trip through a non-Python + decoder (cross-language portability). UDFs are stored by name + only; the receiver must have matching registrations. + * Refuse to reconstruct Python UDFs from + :meth:`Expr.from_bytes` input that may come from an untrusted + source — ``cloudpickle.loads`` will not be invoked. + + The toggle applies directly to :meth:`Expr.to_bytes` / + :meth:`Expr.from_bytes` calls that pass this session as their + ``ctx`` argument. To make the toggle apply through + :func:`pickle.dumps` (which calls :meth:`Expr.to_bytes` with no + context), install this session as the driver's sender context + via :func:`datafusion.ipc.set_sender_ctx` — and install it as + the worker's context via + :func:`datafusion.ipc.set_worker_ctx` for the corresponding + :func:`pickle.loads`. + + For the full security model, see + :doc:`/user-guide/io/distributing_work` (Security section). In + short: this toggle narrows only the :meth:`Expr.from_bytes` + surface; :func:`pickle.loads` on untrusted bytes remains + unsafe regardless of the toggle. + """ + new_internal = self.ctx.with_python_udf_inlining(enabled) + new = SessionContext.__new__(SessionContext) + new.ctx = new_internal + return new diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index e0135e3ed..2ea060cc1 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -434,23 +434,70 @@ def variant_name(self) -> str: return self.expr.variant_name() def to_bytes(self, ctx: SessionContext | None = None) -> bytes: - """Serialize this expression to protobuf bytes. - - When ``ctx`` is supplied, encoding routes through the session's - installed :class:`LogicalExtensionCodec`. Without ``ctx`` a - default codec is used. + """Serialize this expression to bytes for shipping to another process. + + Use this — or :func:`pickle.dumps` — to send an expression to a + worker process for distributed evaluation. + + When ``ctx`` is supplied, encoding routes through that session's + installed :class:`LogicalExtensionCodec` (so settings like + :meth:`SessionContext.with_python_udf_inlining` take effect). + When ``ctx`` is ``None``, the default codec is used (Python UDF + inlining on, no user-installed extension codec). + + Built-in functions and Python UDFs (scalar, aggregate, window) + travel inside the returned bytes; the worker does not need to + pre-register them. UDFs imported via the FFI capsule protocol + travel by name only and must be registered on the worker. See + :doc:`/user-guide/io/distributing_work`. """ ctx_arg = ctx.ctx if ctx is not None else None return self.expr.to_bytes(ctx_arg) - @staticmethod - def from_bytes(ctx: SessionContext, data: bytes) -> Expr: - """Decode an expression from serialized protobuf bytes. - - ``ctx`` provides the function registry for resolving UDF - references and the logical codec for in-band Python payloads. + @classmethod + def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr: + """Reconstruct an expression from serialized bytes. + + Accepts output of :meth:`to_bytes` or :func:`pickle.dumps`. + ``ctx`` is the :class:`SessionContext` used to resolve any + function references that travel by name — aggregate UDFs, window + UDFs, FFI UDFs. When ``ctx`` is ``None`` the worker context + installed via :func:`datafusion.ipc.set_worker_ctx` is consulted; + if no worker context is installed, the global + :class:`SessionContext` is used (sufficient for built-ins and + Python scalar UDFs, plus any UDFs registered on the global + context). """ - return Expr(expr_internal.RawExpr.from_bytes(ctx.ctx, data)) + from datafusion.ipc import _resolve_ctx + + resolved = _resolve_ctx(ctx) + return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf)) + + def __reduce__(self) -> tuple: + """Pickle protocol hook. + + Lets expressions be shipped to worker processes via + :func:`pickle.dumps` / :func:`pickle.loads`. Built-in functions + and Python UDFs travel inside the pickle bytes; only FFI-capsule + UDFs require pre-registration on the worker. The worker's + :class:`SessionContext` for resolving those references is + looked up via :func:`datafusion.ipc.set_worker_ctx`, falling + back to the global :class:`SessionContext` if none has been + installed on the worker. + + The encoding side honors a driver-side sender context installed + via :func:`datafusion.ipc.set_sender_ctx` — that is how + :meth:`SessionContext.with_python_udf_inlining` propagates + through ``pickle.dumps``. + """ + from datafusion.ipc import get_sender_ctx + + return (Expr._reconstruct, (self.to_bytes(get_sender_ctx()),)) + + @classmethod + def _reconstruct(cls, proto_bytes: bytes) -> Expr: + """Internal entry point used by :meth:`__reduce__` on unpickle.""" + return cls.from_bytes(proto_bytes) def __richcmp__(self, other: Expr, op: int) -> Expr: """Comparison operator.""" diff --git a/python/datafusion/ipc.py b/python/datafusion/ipc.py new file mode 100644 index 000000000..c97e54ec4 --- /dev/null +++ b/python/datafusion/ipc.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Driver- and worker-side setup for distributing DataFusion expressions. + +When a :class:`Expr` is shipped to a worker process (e.g. through +:func:`multiprocessing.Pool` or a Ray actor), the worker reconstructs the +expression against a :class:`SessionContext`. If the expression references +UDFs imported via the FFI capsule protocol — or any UDF the worker would +otherwise resolve from its registered functions rather than from inside +the shipped expression — install a configured :class:`SessionContext` +once per worker: + +.. code-block:: python + + from datafusion import SessionContext + from datafusion.ipc import set_worker_ctx + + def init_worker(): + ctx = SessionContext() + ctx.register_udaf(my_ffi_aggregate) + set_worker_ctx(ctx) + +Built-in functions and Python UDFs (scalar, aggregate, window) travel +inside the shipped expression itself and do not need pre-registration +on the worker. + +On the driver side, call :func:`set_sender_ctx` to control how +:func:`pickle.dumps` encodes expressions — for example, to apply +:meth:`SessionContext.with_python_udf_inlining` to every pickled +expression on this thread: + +.. code-block:: python + + from datafusion import SessionContext + from datafusion.ipc import set_sender_ctx + + driver_ctx = SessionContext().with_python_udf_inlining(enabled=False) + set_sender_ctx(driver_ctx) + pickle.dumps(expr) # encoded with inlining disabled + +Without a sender context the default codec is used (Python UDF +inlining on). The sender context only affects pickle / ``to_bytes`` +encoding; explicit ``expr.to_bytes(ctx)`` calls still use the supplied +``ctx``. + +See :doc:`/user-guide/io/distributing_work` for the full pattern. +""" + +from __future__ import annotations + +import threading +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from datafusion.context import SessionContext + + +__all__ = [ + "clear_sender_ctx", + "clear_worker_ctx", + "get_sender_ctx", + "get_worker_ctx", + "set_sender_ctx", + "set_worker_ctx", +] + + +_local = threading.local() + + +def set_worker_ctx(ctx: SessionContext) -> None: + """Install this worker's :class:`SessionContext` for shipped expressions. + + Call once per worker — typically from a ``multiprocessing.Pool`` + initializer or a Ray actor ``__init__``. Idempotent: overwrites any + previous value. Stored in a thread-local slot, so each thread within a + worker may install its own context independently. + """ + _local.ctx = ctx + + +def clear_worker_ctx() -> None: + """Remove this worker's installed :class:`SessionContext`. + + After clearing, expressions reconstructed in this worker fall back to + the global :class:`SessionContext` — adequate for built-ins and Python + UDFs (scalar, aggregate, window), but anything imported via the FFI + capsule protocol must be registered on the global context to resolve. + """ + if hasattr(_local, "ctx"): + del _local.ctx + + +def get_worker_ctx() -> SessionContext | None: + """Return this worker's installed :class:`SessionContext`, or ``None``.""" + return getattr(_local, "ctx", None) + + +def set_sender_ctx(ctx: SessionContext) -> None: + """Install this driver's :class:`SessionContext` for outbound pickles. + + Controls how :func:`pickle.dumps` encodes :class:`Expr` instances on + this thread. The most useful application is propagating a session + configured with + :meth:`SessionContext.with_python_udf_inlining` so the toggle takes + effect through pickle (which otherwise calls + :meth:`Expr.to_bytes` with no context and uses the default codec). + + Idempotent: overwrites any previous value. Stored in a thread-local + slot, so worker threads on the driver may install their own contexts. + Does not affect :meth:`Expr.to_bytes` calls that pass an explicit + ``ctx`` — those continue to use the supplied context. + """ + _local.sender_ctx = ctx + + +def clear_sender_ctx() -> None: + """Remove this driver's installed sender :class:`SessionContext`. + + After clearing, pickled expressions fall back to the default codec + (Python UDF inlining on). + """ + if hasattr(_local, "sender_ctx"): + del _local.sender_ctx + + +def get_sender_ctx() -> SessionContext | None: + """Return this driver's installed sender :class:`SessionContext`, or ``None``.""" + return getattr(_local, "sender_ctx", None) + + +def _resolve_ctx( + explicit_ctx: SessionContext | None = None, +) -> SessionContext: + """Resolve a context for Expr reconstruction. + + Priority: explicit argument > worker context > global context. + Falling back to the global :class:`SessionContext` (instead of a + freshly constructed one) preserves any registrations the user has + installed on it. + """ + if explicit_ctx is not None: + return explicit_ctx + worker = get_worker_ctx() + if worker is not None: + return worker + # Lazy import: `datafusion/__init__.py` imports `datafusion.ipc` + # before `datafusion.context`, so a module-top import would force + # `datafusion.context` to load mid-init of `datafusion.ipc`. The + # cycle is benign today (context.py only pulls expr.py at module + # scope, neither pulls ipc.py back), but a single new import in + # context.py's transitive deps could turn it into a real cycle. + # Deferring keeps `datafusion.ipc` import-order-independent. + from datafusion.context import SessionContext # noqa: PLC0415 + + return SessionContext.global_ctx() diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 848ab4cee..da756473a 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -141,6 +141,16 @@ def __init__( name, func, input_fields, return_field, str(volatility) ) + @property + def name(self) -> str: + """Return the registered name of this UDF. + + For UDFs imported via the FFI capsule protocol, this is the + name the capsule itself reports — not the ``name`` argument + passed to the constructor (which is ignored on the FFI path). + """ + return self._udf.name + def __repr__(self) -> str: """Print a string representation of the Scalar UDF.""" return self._udf.__repr__() @@ -418,6 +428,16 @@ def __init__( str(volatility), ) + @property + def name(self) -> str: + """Return the registered name of this UDAF. + + For UDAFs imported via the FFI capsule protocol, this is the + name the capsule itself reports — not the ``name`` argument + passed to the constructor (which is ignored on the FFI path). + """ + return self._udaf.name + def __repr__(self) -> str: """Print a string representation of the Aggregate UDF.""" return self._udaf.__repr__() @@ -828,6 +848,16 @@ def __init__( name, func, input_types, return_type, str(volatility) ) + @property + def name(self) -> str: + """Return the registered name of this UDWF. + + For UDWFs imported via the FFI capsule protocol, this is the + name the capsule itself reports — not the ``name`` argument + passed to the constructor (which is ignored on the FFI path). + """ + return self._udwf.name + def __repr__(self) -> str: """Print a string representation of the Window UDF.""" return self._udwf.__repr__() diff --git a/python/tests/_pickle_multiprocessing_helpers.py b/python/tests/_pickle_multiprocessing_helpers.py new file mode 100644 index 000000000..4f04967f2 --- /dev/null +++ b/python/tests/_pickle_multiprocessing_helpers.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# The leading underscore is load-bearing: pytest with --import-mode=importlib +# (used in CI) assigns synthetic module names to test modules, which breaks +# subprocess imports during multiprocessing. An underscore-prefixed module is +# not collected as a test module, so it imports under its normal __name__ +# inside worker processes. + +from __future__ import annotations + +import pyarrow as pa +from datafusion import SessionContext, udf +from datafusion.ipc import clear_worker_ctx, set_worker_ctx + + +def make_double_udf(): + """Build the canonical UDF used in the multiprocessing tests.""" + return udf( + lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + [pa.int64()], + pa.int64(), + volatility="immutable", + name="double", + ) + + +def make_times_seven_udf(): + """Closure-capturing UDF — verifies cloudpickle preserves closed-over state.""" + multiplier = 7 + + def fn(arr): + return pa.array([(v.as_py() or 0) * multiplier for v in arr]) + + return udf( + fn, + [pa.int64()], + pa.int64(), + volatility="immutable", + name="times_seven", + ) + + +def init_worker_empty(): + """Pool initializer: install an empty SessionContext (no UDFs).""" + set_worker_ctx(SessionContext()) + + +def init_worker_clear(): + """Pool initializer: explicitly clear any prior worker context.""" + clear_worker_ctx() + + +def unpickle_and_describe(blob: bytes) -> str: + """Unpickle a proto-bytes blob and return its canonical name.""" + import pickle + + expr = pickle.loads(blob) # noqa: S301 + return expr.canonical_name() + + +def unpickle_and_evaluate(blob: bytes, batch: list[int]) -> list[int]: + """Unpickle an expression and evaluate it against an in-memory batch. + + Returns the result column as a Python list. Used to verify that + cloudpickled UDFs (including closure state) execute correctly in + a fresh worker process. + """ + import pickle + + expr = pickle.loads(blob) # noqa: S301 + ctx = SessionContext() + df = ctx.from_pydict({"a": batch}) + out = df.with_column("result", expr).select("result") + return out.to_pydict()["result"] diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 6a466f6f2..e1fdeab44 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -1186,7 +1186,7 @@ def test_expr_to_bytes_roundtrip(ctx: SessionContext) -> None: original = col("a") + lit(1) blob = original.to_bytes(ctx) - restored = Expr.from_bytes(ctx, blob) + restored = Expr.from_bytes(blob, ctx=ctx) # Canonical name preserves the structure of the expression even # though the underlying PyExpr instances are different. @@ -1201,6 +1201,6 @@ def test_expr_to_bytes_no_ctx_default_codec() -> None: fresh = SessionContext() original = col("a") * lit(2) blob = original.to_bytes() # encode side: default codec - restored = Expr.from_bytes(fresh, blob) + restored = Expr.from_bytes(blob, ctx=fresh) assert restored.canonical_name() == original.canonical_name() diff --git a/python/tests/test_pickle_expr.py b/python/tests/test_pickle_expr.py new file mode 100644 index 000000000..62692f8a6 --- /dev/null +++ b/python/tests/test_pickle_expr.py @@ -0,0 +1,479 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""In-process pickle round-trip tests for :class:`Expr`. + +Built-in functions and Python UDFs (scalar, aggregate, window) travel +with the pickled expression and do not need worker-side pre-registration. +The worker context (:mod:`datafusion.ipc`) is only consulted for UDFs +imported via the FFI capsule protocol. + +Cross-process tests live in ``test_pickle_multiprocessing.py``. +""" + +from __future__ import annotations + +import pickle +import threading + +import pyarrow as pa +import pytest +from datafusion import Expr, SessionContext, col, lit, udf +from datafusion.ipc import ( + clear_sender_ctx, + clear_worker_ctx, + get_sender_ctx, + get_worker_ctx, + set_sender_ctx, + set_worker_ctx, +) + + +@pytest.fixture(autouse=True) +def _reset_worker_ctx(): + """Ensure every test starts with no worker or sender context installed.""" + clear_worker_ctx() + clear_sender_ctx() + yield + clear_worker_ctx() + clear_sender_ctx() + + +def _double_udf(): + return udf( + lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + [pa.int64()], + pa.int64(), + volatility="immutable", + name="double", + ) + + +class TestProtoRoundTrip: + def test_builtin_round_trip(self): + e = col("a") + lit(1) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert decoded.canonical_name() == e.canonical_name() + + def test_to_bytes_from_bytes(self): + e = col("x") * lit(7) + blob = e.to_bytes() + assert isinstance(blob, bytes) + decoded = Expr.from_bytes(blob) + assert decoded.canonical_name() == e.canonical_name() + + def test_explicit_ctx_used(self, ctx): + e = col("a") + lit(1) + decoded = Expr.from_bytes(e.to_bytes(), ctx=ctx) + assert decoded.canonical_name() == e.canonical_name() + + +class TestUDFCodec: + """Python scalar UDFs ride inside the proto blob via the Rust codec. + + No worker context needed on the receiver — the cloudpickled callable is + embedded in ``fun_definition`` and reconstructed automatically. + """ + + def test_udf_self_contained_blob(self): + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + # The codec inlines the callable, so the blob is much bigger than a + # pure built-in blob but doesn't depend on receiver-side registration. + assert len(blob) > 200 + + def test_udf_decodes_into_fresh_ctx(self): + e = _double_udf()(col("a")) + blob = e.to_bytes() + fresh = SessionContext() + decoded = Expr.from_bytes(blob, ctx=fresh) + assert "double" in decoded.canonical_name() + + def test_udf_decodes_via_pickle_with_no_worker_ctx(self): + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "double" in decoded.canonical_name() + + def test_udf_decodes_via_pickle_with_worker_ctx(self): + set_worker_ctx(SessionContext()) + e = _double_udf()(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "double" in decoded.canonical_name() + + def test_closure_capturing_udf_names_match(self): + captured_multiplier = 7 + + def fn(arr): + return pa.array([(v.as_py() or 0) * captured_multiplier for v in arr]) + + u = udf( + fn, + [pa.int64()], + pa.int64(), + volatility="immutable", + name="times_seven", + ) + e = u(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + # Round-trip names match; functional verification of captured state + # happens in test_pickle_multiprocessing via an actual UDF call. + assert decoded.canonical_name() == e.canonical_name() + + +class TestAggregateUDFCodec: + """Python aggregate UDFs travel inline like scalar UDFs.""" + + def _build_aggregate_udf(self): + from datafusion import udaf + from datafusion.user_defined import Accumulator + + class CountAcc(Accumulator): + def __init__(self): + self._count = 0 + + def state(self): + return [pa.scalar(self._count, type=pa.int64())] + + def update(self, values): + self._count += len(values) + + def merge(self, states): + for s in states: + self._count += s[0].as_py() + + def evaluate(self): + return pa.scalar(self._count, type=pa.int64()) + + return udaf( + CountAcc, + [pa.int64()], + pa.int64(), + [pa.int64()], + "immutable", + name="count_all", + ) + + def test_agg_udf_self_contained_blob(self): + u = self._build_aggregate_udf() + e = u(col("a")) + blob = pickle.dumps(e) + assert len(blob) > 200 + + def test_agg_udf_decodes_into_fresh_ctx(self): + u = self._build_aggregate_udf() + e = u(col("a")) + blob = e.to_bytes() + fresh = SessionContext() + decoded = Expr.from_bytes(blob, ctx=fresh) + assert "count_all" in decoded.canonical_name() + + def test_agg_udf_decodes_via_pickle_with_no_worker_ctx(self): + u = self._build_aggregate_udf() + e = u(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "count_all" in decoded.canonical_name() + + def test_agg_udf_evaluates_after_roundtrip(self): + """End-to-end: the decoded aggregate UDF runs and merges across + partitions, exercising the round-tripped state-field schema.""" + u = self._build_aggregate_udf() + e = u(col("a")) + decoded = pickle.loads(pickle.dumps(e)) # noqa: S301 + + ctx = SessionContext() + df = ctx.from_pydict({"a": [1, 2, 3, 4, 5]}) + out = df.aggregate([], [decoded.alias("n")]).to_pydict() + assert out["n"] == [5] + + +class TestWindowUDFCodec: + """Python window UDFs travel inline like scalar UDFs.""" + + def _build_window_udf(self): + from datafusion import udwf + from datafusion.user_defined import WindowEvaluator + + class CountUpEvaluator(WindowEvaluator): + def evaluate_all(self, values, num_rows): + return pa.array(list(range(num_rows))) + + return udwf( + CountUpEvaluator, + [pa.int64()], + pa.int64(), + "immutable", + name="count_up", + ) + + def test_window_udf_self_contained_blob(self): + u = self._build_window_udf() + e = u(col("a")) + blob = pickle.dumps(e) + assert len(blob) > 200 + + def test_window_udf_decodes_into_fresh_ctx(self): + u = self._build_window_udf() + e = u(col("a")) + blob = e.to_bytes() + fresh = SessionContext() + decoded = Expr.from_bytes(blob, ctx=fresh) + assert "count_up" in decoded.canonical_name() + + def test_window_udf_decodes_via_pickle_with_no_worker_ctx(self): + u = self._build_window_udf() + e = u(col("a")) + blob = pickle.dumps(e) + decoded = pickle.loads(blob) # noqa: S301 + assert "count_up" in decoded.canonical_name() + + +class TestPythonUdfInliningToggle: + """`SessionContext.with_python_udf_inlining(enabled=False)` opts out of + inline Python UDF encoding for both encode and decode paths.""" + + def _build_double_udf(self): + return udf( + lambda arr: pa.array([(v.as_py() or 0) * 2 for v in arr]), + [pa.int64()], + pa.int64(), + volatility="immutable", + name="double", + ) + + def test_strict_encoder_emits_smaller_blob(self): + """Strict mode skips cloudpickle of the Python callable, so the + encoded bytes are dramatically smaller than the inline form.""" + ctx_inline = SessionContext() + ctx_strict = ctx_inline.with_python_udf_inlining(enabled=False) + u = self._build_double_udf() + e = u(col("a")) + + blob_inline = e.to_bytes(ctx_inline) + blob_strict = e.to_bytes(ctx_strict) + + assert len(blob_strict) < len(blob_inline) // 4 + + def test_toggle_off_then_on_restores_inline_encoding(self): + """`with_python_udf_inlining` is per-call clone semantics: + flipping off and then on must produce a context that emits the + same inline form as a fresh default context, byte-for-byte. + + Guards against a regression where the off→on transition leaves + the codec in a sticky strict state (e.g. by mutating shared + codec state instead of cloning). + """ + u = self._build_double_udf() + e = u(col("a")) + + baseline = SessionContext() + toggled = ( + SessionContext() + .with_python_udf_inlining(enabled=False) + .with_python_udf_inlining(enabled=True) + ) + + blob_baseline = e.to_bytes(baseline) + blob_toggled = e.to_bytes(toggled) + + assert blob_baseline == blob_toggled + + # Sanity check the decoded form against a fresh ctx — the + # toggled-back blob should be self-contained inline, not a + # strict by-name payload that needs registry resolution. + decoded = Expr.from_bytes(blob_toggled, ctx=SessionContext()) + assert "double" in decoded.canonical_name() + + def test_strict_roundtrip_via_registry(self): + """When both sender and receiver disable inlining, the UDF + travels by name only and the receiver resolves it from its + registered functions.""" + strict_sender = SessionContext().with_python_udf_inlining(enabled=False) + u = self._build_double_udf() + blob = u(col("a")).to_bytes(strict_sender) + + receiver = SessionContext().with_python_udf_inlining(enabled=False) + receiver.register_udf(u) + restored = Expr.from_bytes(blob, ctx=receiver) + assert "double" in restored.canonical_name() + + def test_strict_decoder_refuses_inline_payload(self): + """An inline-encoded blob fed to a strict receiver raises with a + clear error rather than silently invoking cloudpickle.loads. + + The receiver is intentionally *not* given a matching + registration: the codec refusal must trip before the registry + is ever consulted, so registering the UDF here would only mask + a regression that moved the check after registry lookup. + """ + sender = SessionContext() + u = self._build_double_udf() + blob = u(col("a")).to_bytes(sender) + + strict_receiver = SessionContext().with_python_udf_inlining(enabled=False) + # `RuntimeError` (not bare `Exception`): the codec refusal is + # surfaced through `parse_expr` → `PyRuntimeError`. Tightening + # the assertion catches a regression that swallows the refusal + # as a different error type. + with pytest.raises(RuntimeError, match="inlining is disabled"): + Expr.from_bytes(blob, ctx=strict_receiver) + + def test_sender_ctx_propagates_through_pickle(self): + """`set_sender_ctx` makes `pickle.dumps` use a strict codec. + + Without a sender context, pickle defaults to the inline codec and + the blob is large. With a strict sender context installed, the + blob shrinks because the Python callable is encoded by name + instead of cloudpickled. + """ + u = self._build_double_udf() + e = u(col("a")) + + blob_default = pickle.dumps(e) + + strict_sender = SessionContext().with_python_udf_inlining(enabled=False) + set_sender_ctx(strict_sender) + try: + blob_strict = pickle.dumps(e) + finally: + clear_sender_ctx() + + assert len(blob_strict) < len(blob_default) // 4 + + def test_sender_ctx_strict_roundtrip_via_pickle(self): + """End-to-end pickle round-trip with strict mode on both sides. + + Driver installs a strict sender context. Worker installs a + matching strict context with the UDF registered. The UDF + travels by name through `pickle.dumps` / `pickle.loads`. + """ + u = self._build_double_udf() + e = u(col("a")) + + strict_sender = SessionContext().with_python_udf_inlining(enabled=False) + set_sender_ctx(strict_sender) + try: + blob = pickle.dumps(e) + finally: + clear_sender_ctx() + + worker = SessionContext().with_python_udf_inlining(enabled=False) + worker.register_udf(u) + set_worker_ctx(worker) + try: + decoded = pickle.loads(blob) # noqa: S301 + finally: + clear_worker_ctx() + + assert "double" in decoded.canonical_name() + + def test_sender_ctx_strict_pickle_accepted_by_inline_worker_with_registry(self): + """A strict-encoded blob still decodes fine on an inline worker + because the wire format is the same default-codec by-name form. + Sanity check: cross-config works as long as the receiver can + resolve the name.""" + u = self._build_double_udf() + e = u(col("a")) + + strict_sender = SessionContext().with_python_udf_inlining(enabled=False) + set_sender_ctx(strict_sender) + try: + blob = pickle.dumps(e) + finally: + clear_sender_ctx() + + worker = SessionContext() + worker.register_udf(u) + set_worker_ctx(worker) + try: + decoded = pickle.loads(blob) # noqa: S301 + finally: + clear_worker_ctx() + + assert "double" in decoded.canonical_name() + + +class TestWorkerCtxLifecycle: + def test_set_and_clear(self): + assert get_worker_ctx() is None + ctx = SessionContext() + set_worker_ctx(ctx) + assert get_worker_ctx() is ctx + clear_worker_ctx() + assert get_worker_ctx() is None + + def test_clear_when_unset_is_noop(self): + clear_worker_ctx() # no error + assert get_worker_ctx() is None + + def test_thread_local_isolation(self): + main_ctx = SessionContext() + set_worker_ctx(main_ctx) + + seen_in_thread: list = [] + + def worker(): + seen_in_thread.append(get_worker_ctx()) + set_worker_ctx(SessionContext()) + seen_in_thread.append(get_worker_ctx()) + + t = threading.Thread(target=worker) + t.start() + t.join() + + # Thread saw no ctx initially (thread-local), then its own. + assert seen_in_thread[0] is None + assert seen_in_thread[1] is not main_ctx + # Main thread's ctx is unchanged by the thread's actions. + assert get_worker_ctx() is main_ctx + + +class TestSenderCtxLifecycle: + def test_set_and_clear(self): + assert get_sender_ctx() is None + ctx = SessionContext() + set_sender_ctx(ctx) + assert get_sender_ctx() is ctx + clear_sender_ctx() + assert get_sender_ctx() is None + + def test_clear_when_unset_is_noop(self): + clear_sender_ctx() # no error + assert get_sender_ctx() is None + + def test_thread_local_isolation(self): + main_ctx = SessionContext() + set_sender_ctx(main_ctx) + + seen_in_thread: list = [] + + def worker(): + seen_in_thread.append(get_sender_ctx()) + set_sender_ctx(SessionContext()) + seen_in_thread.append(get_sender_ctx()) + + t = threading.Thread(target=worker) + t.start() + t.join() + + assert seen_in_thread[0] is None + assert seen_in_thread[1] is not main_ctx + assert get_sender_ctx() is main_ctx diff --git a/python/tests/test_pickle_multiprocessing.py b/python/tests/test_pickle_multiprocessing.py new file mode 100644 index 000000000..6eabaff9e --- /dev/null +++ b/python/tests/test_pickle_multiprocessing.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Cross-process pickle tests for :class:`Expr`. + +Workers run with each :mod:`multiprocessing` start method (``fork``, +``forkserver``, ``spawn``). Python UDFs (scalar, aggregate, window) travel +with the pickled expression and need no worker-side pre-registration. +Worker-side helpers live in ``_pickle_multiprocessing_helpers`` — the +underscore prefix avoids pytest collection so the module imports under +its real name in worker subprocesses. +""" + +from __future__ import annotations + +import functools +import multiprocessing as mp +import pickle +import sys + +import pytest +from datafusion import col, lit + +from . import _pickle_multiprocessing_helpers as helpers + + +@functools.cache +def _multiprocessing_available() -> tuple[bool, str]: + """Return (available, reason). Some sandboxed environments deny semaphore + creation; without semaphores, ``multiprocessing.Pool`` cannot start. + + Cached so the probe Pool only spawns once per session, and only when a + test in this module is actually about to run — collection-only runs + (e.g. ``pytest --collect-only`` on the full suite) skip the probe. + """ + try: + ctx = mp.get_context("spawn") + with ctx.Pool(processes=1) as pool: + pool.map(int, [0]) + except (PermissionError, OSError) as exc: + return False, f"multiprocessing.Pool unavailable: {exc}" + return True, "" + + +@pytest.fixture(autouse=True) +def _skip_if_multiprocessing_unavailable(): + available, reason = _multiprocessing_available() + if not available: + pytest.skip(reason) + + +START_METHODS = [ + pytest.param( + "fork", + marks=pytest.mark.skipif( + sys.platform == "darwin", + reason="fork start method is unsafe with PyArrow/tokio on macOS", + ), + ), + "forkserver", + "spawn", +] + + +@pytest.mark.parametrize("start_method", START_METHODS) +@pytest.mark.timeout(120) +def test_builtin_pickle_via_pool(start_method): + """Built-in expressions round-trip in every start method.""" + expr = col("a") + lit(1) + blob = pickle.dumps(expr) + + ctx = mp.get_context(start_method) + with ctx.Pool(processes=2) as pool: + results = pool.map(helpers.unpickle_and_describe, [blob, blob, blob]) + + assert all(r == expr.canonical_name() for r in results) + + +@pytest.mark.parametrize("start_method", START_METHODS) +@pytest.mark.timeout(120) +def test_udf_pickle_self_contained(start_method): + """Scalar UDF travels inside the proto blob — no worker pre-registration. + + Workers start with no UDF registered. The Rust-side ``PythonUDFCodec`` + reconstructs the UDF from bytes embedded in the pickle blob. + """ + udf_obj = helpers.make_double_udf() + expr = udf_obj(col("a")) + blob = pickle.dumps(expr) + + ctx = mp.get_context(start_method) + with ctx.Pool(processes=2) as pool: + results = pool.starmap( + helpers.unpickle_and_evaluate, + [(blob, [1, 2, 3]), (blob, [10, 20, 30])], + ) + + assert results[0] == [2, 4, 6] + assert results[1] == [20, 40, 60] + + +@pytest.mark.parametrize("start_method", START_METHODS) +@pytest.mark.timeout(120) +def test_closure_capturing_udf_via_pool(start_method): + """Cloudpickle preserves closure state across the codec boundary.""" + udf_obj = helpers.make_times_seven_udf() + expr = udf_obj(col("a")) + blob = pickle.dumps(expr) + + ctx = mp.get_context(start_method) + with ctx.Pool(processes=2) as pool: + results = pool.starmap( + helpers.unpickle_and_evaluate, + [(blob, [1, 2, 3])], + ) + + assert results[0] == [7, 14, 21] diff --git a/uv.lock b/uv.lock index 3b7135e32..3fd3eec4b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -257,6 +257,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/f6/65ecc6878a89bb1c23a086ea335ad4bf21a588990c3f535a227b9eea9108/charset_normalizer-3.4.1-py3-none-any.whl", hash = "sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85", size = 49767, upload-time = "2024-12-24T18:12:32.852Z" }, ] +[[package]] +name = "cloudpickle" +version = "3.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/27/fb/576f067976d320f5f0114a8d9fa1215425441bb35627b1993e5afd8111e5/cloudpickle-3.1.2.tar.gz", hash = "sha256:7fda9eb655c9c230dab534f1983763de5835249750e85fbcef43aaa30a9a2414", size = 22330, upload-time = "2025-11-03T09:25:26.604Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/39/799be3f2f0f38cc727ee3b4f1445fe6d5e4133064ec2e4115069418a5bb6/cloudpickle-3.1.2-py3-none-any.whl", hash = "sha256:9acb47f6afd73f60dc1df93bb801b472f05ff42fa6c84167d25cb206be1fbf4a", size = 22228, upload-time = "2025-11-03T09:25:25.534Z" }, +] + [[package]] name = "codespell" version = "2.4.1" @@ -316,6 +325,7 @@ wheels = [ name = "datafusion" source = { editable = "." } dependencies = [ + { name = "cloudpickle" }, { name = "pyarrow" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] @@ -351,6 +361,7 @@ docs = [ [package.metadata] requires-dist = [ + { name = "cloudpickle", specifier = ">=2.0" }, { name = "pyarrow", marker = "python_full_version < '3.14'", specifier = ">=16.0.0" }, { name = "pyarrow", marker = "python_full_version >= '3.14'", specifier = ">=22.0.0" }, { name = "typing-extensions", marker = "python_full_version < '3.13'" },