Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
428 changes: 394 additions & 34 deletions crates/core/src/codec.rs

Large diffs are not rendered by default.

82 changes: 74 additions & 8 deletions crates/core/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ use crate::expr::PyExpr;
/// This struct holds the Python written function that is a
/// ScalarUDF.
#[derive(Debug)]
struct PythonFunctionScalarUDF {
pub(crate) struct PythonFunctionScalarUDF {
name: String,
func: Py<PyAny>,
signature: Signature,
Expand All @@ -67,6 +67,37 @@ impl PythonFunctionScalarUDF {
return_field: Arc::new(return_field),
}
}

/// Stored Python callable. Consumed by the codec to cloudpickle
/// the function body across process boundaries.
pub(crate) fn func(&self) -> &Py<PyAny> {
&self.func
}

pub(crate) fn return_field(&self) -> &FieldRef {
&self.return_field
}

/// Reconstruct a `PythonFunctionScalarUDF` from the parts emitted
/// by the codec. Inputs collapse to `Vec<DataType>` because
/// `Signature::exact` cannot carry per-input nullability or
/// metadata — the encoder is free to discard that side of the
/// schema. `return_field` is kept as a `Field` so the post-decode
/// nullability and metadata match the sender's instance.
pub(crate) fn from_parts(
name: String,
func: Py<PyAny>,
input_types: Vec<DataType>,
return_field: Field,
volatility: Volatility,
) -> Self {
Self {
name,
func,
signature: Signature::exact(input_types, volatility),
return_field: Arc::new(return_field),
}
}
}

impl Eq for PythonFunctionScalarUDF {}
Expand All @@ -75,21 +106,51 @@ impl PartialEq for PythonFunctionScalarUDF {
self.name == other.name
&& self.signature == other.signature
&& self.return_field == other.return_field
&& Python::attach(|py| self.func.bind(py).eq(other.func.bind(py)).unwrap_or(false))
// Identical pointers ⇒ same Python object. Most equality
// checks compare `Arc`-shared clones of the same UDF
// (e.g. expression rewriting), so the pointer match short-
// circuits before touching the GIL.
&& (self.func.as_ptr() == other.func.as_ptr()
|| Python::attach(|py| {
// Rust's `PartialEq` cannot return `Result`, so we
// have to pick a side when Python `__eq__` raises.
// `false` is the conservative choice — better to
// report two UDFs as distinct than to wrongly
// merge them — but the silent miss can still
// surface as expression-dedup or cache-lookup
// anomalies. Log at `debug` so the failure is
// observable without flooding production logs.
// FIXME: revisit if upstream `ScalarUDFImpl`
// exposes a fallible `PartialEq`.
self.func
.bind(py)
.eq(other.func.bind(py))
.unwrap_or_else(|e| {
log::debug!(
target: "datafusion_python::udf",
"PythonFunctionScalarUDF {:?} __eq__ raised; treating as unequal: {e}",
self.name,
);
false
})
}))
}
}

impl Hash for PythonFunctionScalarUDF {
fn hash<H: Hasher>(&self, state: &mut H) {
// Hash only the identifying header (name + signature + return
// field). Skipping `func` is intentional: the Rust `Hash`
// contract requires `a == b ⇒ hash(a) == hash(b)`, not the
// converse, so a coarser hash is sound — `PartialEq` still
// disambiguates two UDFs with the same header but distinct
// callables. Falling back to a sentinel on `py_hash` failure
// (as a prior revision did) silently mapped every unhashable
// closure to the same bucket; that is the worst case for a
// hashmap and is what this rewrite avoids.
self.name.hash(state);
self.signature.hash(state);
self.return_field.hash(state);

Python::attach(|py| {
let py_hash = self.func.bind(py).hash().unwrap_or(0); // Handle unhashable objects

state.write_isize(py_hash);
});
}
}

Expand Down Expand Up @@ -220,4 +281,9 @@ impl PyScalarUDF {
fn __repr__(&self) -> PyResult<String> {
Ok(format!("ScalarUDF({})", self.function.name()))
}

#[getter]
fn name(&self) -> &str {
self.function.name()
}
}
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ classifiers = [
"Programming Language :: Rust",
]
dependencies = [
# cloudpickle is invoked by the Rust-side PythonLogicalCodec /
# PythonPhysicalCodec via pyo3 to serialize Python UDF callables —
# scalar, aggregate, and window — into the proto wire format.
# Lazy-imported on the encode / decode hot paths (and cached after
# the first import), so users who never serialize a plan or
# expression incur no runtime cost beyond the install footprint.
"cloudpickle>=2.0",
"pyarrow>=16.0.0;python_version<'3.14'",
"pyarrow>=22.0.0;python_version>='3.14'",
"typing-extensions;python_version<'3.13'",
Expand Down
3 changes: 2 additions & 1 deletion python/datafusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
import importlib_metadata # type: ignore[import]

# Public submodules
from . import functions, object_store, substrait, unparser
from . import functions, ipc, object_store, substrait, unparser

# The following imports are okay to remain as opaque to the user.
from ._internal import Config
Expand Down Expand Up @@ -142,6 +142,7 @@
"configure_formatter",
"expr",
"functions",
"ipc",
"lit",
"literal",
"object_store",
Expand Down
103 changes: 92 additions & 11 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

from __future__ import annotations

from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any, ClassVar

import pyarrow as pa
Expand Down Expand Up @@ -434,23 +434,104 @@ def variant_name(self) -> str:
return self.expr.variant_name()

def to_bytes(self, ctx: SessionContext | None = None) -> bytes:
"""Serialize this expression to protobuf bytes.
"""Serialize this expression to bytes for shipping to another process.

When ``ctx`` is supplied, encoding routes through the session's
installed :class:`LogicalExtensionCodec`. Without ``ctx`` a
default codec is used.
Use this — or :func:`pickle.dumps` — to send an expression to a
worker process for distributed evaluation.

When ``ctx`` is supplied, encoding routes through that session's
installed :class:`LogicalExtensionCodec`. When ``ctx`` is
``None``, the default codec is used.

Built-in functions and Python scalar UDFs travel inside the
returned bytes; the worker does not need to pre-register them.
UDFs imported via the FFI capsule protocol travel by name only
and must be registered on the worker.

.. warning::
Bytes returned here may embed a cloudpickled Python
callable (when the expression carries a Python scalar UDF).
Reconstructing them via :meth:`from_bytes` or
:func:`pickle.loads` executes arbitrary Python on the
receiver. Only accept payloads from trusted sources.

Examples:
>>> from datafusion import col, lit
>>> blob = (col("a") + lit(1)).to_bytes()
>>> isinstance(blob, bytes)
True
"""
ctx_arg = ctx.ctx if ctx is not None else None
return self.expr.to_bytes(ctx_arg)

@staticmethod
def from_bytes(ctx: SessionContext, data: bytes) -> Expr:
"""Decode an expression from serialized protobuf bytes.
@classmethod
def from_bytes(cls, buf: bytes, ctx: SessionContext | None = None) -> Expr:
"""Reconstruct an expression from serialized bytes.

Accepts output of :meth:`to_bytes` or :func:`pickle.dumps`.
``ctx`` is the :class:`SessionContext` used to resolve any
function references that travel by name (e.g. FFI UDFs). When
``ctx`` is ``None`` the worker context installed via
:func:`datafusion.ipc.set_worker_ctx` is consulted; if no worker
context is installed, the global :class:`SessionContext` is used
(sufficient for built-ins and Python scalar UDFs, plus any UDFs
registered on the global context).

.. warning::
Decoding may invoke ``cloudpickle.loads`` on bytes embedded
in the payload, which executes arbitrary Python code. Treat
``buf`` as code, not data — only decode bytes you produced
yourself or received from a trusted sender.

Examples:
>>> from datafusion import Expr, col, lit
>>> blob = (col("a") + lit(1)).to_bytes()
>>> Expr.from_bytes(blob).canonical_name()
'a + Int64(1)'
"""
from datafusion.ipc import _resolve_ctx

``ctx`` provides the function registry for resolving UDF
references and the logical codec for in-band Python payloads.
resolved = _resolve_ctx(ctx)
return cls(expr_internal.RawExpr.from_bytes(resolved.ctx, buf))

def __reduce__(self) -> tuple[Callable[[bytes], Expr], tuple[bytes]]:
"""Pickle protocol hook.

Lets expressions be shipped to worker processes via
:func:`pickle.dumps` / :func:`pickle.loads`. Built-in functions
and Python scalar UDFs travel inside the pickle bytes; only
FFI-capsule UDFs require pre-registration on the worker. The
worker's :class:`SessionContext` for resolving those references
is looked up via :func:`datafusion.ipc.set_worker_ctx`, falling
back to the global :class:`SessionContext` if none has been
installed on the worker.

.. warning::
:func:`pickle.loads` on the returned tuple executes
arbitrary Python on the receiver, including any
cloudpickled UDF callable embedded in the payload. Only
unpickle expressions from trusted sources.

Examples:
>>> import pickle
>>> from datafusion import col, lit
>>> e = col("a") * lit(2)
>>> pickle.loads(pickle.dumps(e)).canonical_name()
'a * Int64(2)'
"""
return (Expr._reconstruct, (self.to_bytes(),))

@classmethod
def _reconstruct(cls, proto_bytes: bytes) -> Expr:
"""Internal entry point used by :meth:`__reduce__` on unpickle.

Examples:
>>> from datafusion import Expr, col, lit
>>> blob = (col("a") + lit(1)).to_bytes()
>>> Expr._reconstruct(blob).canonical_name()
'a + Int64(1)'
"""
return Expr(expr_internal.RawExpr.from_bytes(ctx.ctx, data))
return cls.from_bytes(proto_bytes)

def __richcmp__(self, other: Expr, op: int) -> Expr:
"""Comparison operator."""
Expand Down
Loading
Loading