From 059248ade26b9849c4505e4863cd5324b1c3253a Mon Sep 17 00:00:00 2001 From: Connor Tsui Date: Fri, 17 Apr 2026 14:25:10 -0400 Subject: [PATCH] fix reduction rule Signed-off-by: Connor Tsui --- .../src/arrays/extension/compute/rules.rs | 55 +++++++-------- .../src/arrays/extension/vtable/mod.rs | 67 ++++++++++--------- 2 files changed, 60 insertions(+), 62 deletions(-) diff --git a/vortex-array/src/arrays/extension/compute/rules.rs b/vortex-array/src/arrays/extension/compute/rules.rs index 7408488a0f1..568fe1c061c 100644 --- a/vortex-array/src/arrays/extension/compute/rules.rs +++ b/vortex-array/src/arrays/extension/compute/rules.rs @@ -14,52 +14,44 @@ use crate::arrays::Filter; use crate::arrays::extension::ExtensionArrayExt; use crate::arrays::filter::FilterReduceAdaptor; use crate::arrays::slice::SliceReduceAdaptor; -use crate::matcher::AnyArray; use crate::optimizer::rules::ArrayParentReduceRule; +use crate::optimizer::rules::ArrayReduceRule; use crate::optimizer::rules::ParentRuleSet; +use crate::optimizer::rules::ReduceRuleSet; use crate::scalar::Scalar; use crate::scalar_fn::fns::cast::CastReduceAdaptor; use crate::scalar_fn::fns::mask::MaskReduceAdaptor; -pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ - ParentRuleSet::lift(&ExtensionConstantParentRule), - ParentRuleSet::lift(&ExtensionFilterPushDownRule), - ParentRuleSet::lift(&CastReduceAdaptor(Extension)), - ParentRuleSet::lift(&FilterReduceAdaptor(Extension)), - ParentRuleSet::lift(&MaskReduceAdaptor(Extension)), - ParentRuleSet::lift(&SliceReduceAdaptor(Extension)), -]); +pub(crate) const RULES: ReduceRuleSet = ReduceRuleSet::new(&[&ExtensionConstantRule]); /// Normalize `Extension(Constant(storage))` children to `Constant(Extension(storage))`. #[derive(Debug)] -struct ExtensionConstantParentRule; - -impl ArrayParentReduceRule for ExtensionConstantParentRule { - type Parent = AnyArray; +struct ExtensionConstantRule; - fn reduce_parent( - &self, - child: ArrayView<'_, Extension>, - parent: &ArrayRef, - child_idx: usize, - ) -> VortexResult> { - let Some(const_array) = child.storage_array().as_opt::() else { +impl ArrayReduceRule for ExtensionConstantRule { + fn reduce(&self, array: ArrayView<'_, Extension>) -> VortexResult> { + let Some(const_array) = array.storage_array().as_opt::() else { return Ok(None); }; let storage_scalar = const_array.scalar().clone(); - let ext_scalar = Scalar::extension_ref(child.ext_dtype().clone(), storage_scalar); + let ext_scalar = Scalar::extension_ref(array.ext_dtype().clone(), storage_scalar); let constant_with_extension_scalar = - ConstantArray::new(ext_scalar, child.len()).into_array(); + ConstantArray::new(ext_scalar, array.len()).into_array(); - parent - .clone() - .with_slot(child_idx, constant_with_extension_scalar) - .map(Some) + Ok(Some(constant_with_extension_scalar.into_array())) } } +pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ + ParentRuleSet::lift(&ExtensionFilterPushDownRule), + ParentRuleSet::lift(&CastReduceAdaptor(Extension)), + ParentRuleSet::lift(&FilterReduceAdaptor(Extension)), + ParentRuleSet::lift(&MaskReduceAdaptor(Extension)), + ParentRuleSet::lift(&SliceReduceAdaptor(Extension)), +]); + /// Push filter operations into the storage array of an extension array. #[derive(Debug)] struct ExtensionFilterPushDownRule; @@ -99,6 +91,7 @@ mod tests { use crate::arrays::ExtensionArray; use crate::arrays::FilterArray; use crate::arrays::PrimitiveArray; + use crate::arrays::ScalarFnVTable; use crate::arrays::extension::ExtensionArrayExt; use crate::arrays::scalar_fn::ScalarFnArrayExt; use crate::arrays::scalar_fn::ScalarFnFactoryExt; @@ -227,8 +220,8 @@ mod tests { .try_new_array(3, Operator::Lt, [constant_ext, ext_array]) .unwrap(); - let optimized = scalar_fn_array.optimize().unwrap(); - let scalar_fn = optimized.as_opt::().unwrap(); + let optimized = scalar_fn_array.optimize_recursive().unwrap(); + let scalar_fn = optimized.as_opt::().unwrap(); let children = scalar_fn.children(); let constant = children[0] .as_opt::() @@ -291,7 +284,7 @@ mod tests { let optimized = scalar_fn_array.optimize().unwrap(); // The first child should still be an ExtensionArray (no pushdown happened) - let scalar_fn = optimized.as_opt::().unwrap(); + let scalar_fn = optimized.as_opt::().unwrap(); assert!( scalar_fn.children()[0].as_opt::().is_some(), "Expected first child to remain ExtensionArray when ext types differ" @@ -316,7 +309,7 @@ mod tests { let optimized = scalar_fn_array.optimize().unwrap(); // No pushdown should happen because sibling is not a constant - let scalar_fn = optimized.as_opt::().unwrap(); + let scalar_fn = optimized.as_opt::().unwrap(); assert!( scalar_fn.children()[0].as_opt::().is_some(), "Expected first child to remain ExtensionArray when sibling is not constant" @@ -339,7 +332,7 @@ mod tests { let optimized = scalar_fn_array.optimize().unwrap(); // No pushdown should happen because constant is not an extension scalar - let scalar_fn = optimized.as_opt::().unwrap(); + let scalar_fn = optimized.as_opt::().unwrap(); assert!( scalar_fn.children()[0].as_opt::().is_some(), "Expected first child to remain ExtensionArray when constant is not extension" diff --git a/vortex-array/src/arrays/extension/vtable/mod.rs b/vortex-array/src/arrays/extension/vtable/mod.rs index b82d4b8e0df..a3fde75f79b 100644 --- a/vortex-array/src/arrays/extension/vtable/mod.rs +++ b/vortex-array/src/arrays/extension/vtable/mod.rs @@ -31,10 +31,14 @@ use crate::arrays::extension::ExtensionData; use crate::arrays::extension::array::SLOT_NAMES; use crate::arrays::extension::array::STORAGE_SLOT; use crate::arrays::extension::compute::rules::PARENT_RULES; +use crate::arrays::extension::compute::rules::RULES; use crate::buffer::BufferHandle; use crate::dtype::DType; use crate::serde::ArrayChildren; +#[derive(Clone, Debug)] +pub struct Extension; + /// A [`Extension`]-encoded Vortex array. pub type ExtensionArray = Array; @@ -59,29 +63,6 @@ impl VTable for Extension { *ID } - fn nbuffers(_array: ArrayView<'_, Self>) -> usize { - 0 - } - - fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle { - vortex_panic!("ExtensionArray buffer index {idx} out of bounds") - } - - fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option { - None - } - - fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String { - SLOT_NAMES[idx].to_string() - } - - fn serialize( - _array: ArrayView<'_, Self>, - _session: &VortexSession, - ) -> VortexResult>> { - Ok(Some(vec![])) - } - fn validate( &self, data: &ExtensionData, @@ -111,6 +92,25 @@ impl VTable for Extension { Ok(()) } + fn nbuffers(_array: ArrayView<'_, Self>) -> usize { + 0 + } + + fn buffer(_array: ArrayView<'_, Self>, idx: usize) -> BufferHandle { + vortex_panic!("ExtensionArray buffer index {idx} out of bounds") + } + + fn buffer_name(_array: ArrayView<'_, Self>, _idx: usize) -> Option { + None + } + + fn serialize( + _array: ArrayView<'_, Self>, + _session: &VortexSession, + ) -> VortexResult>> { + Ok(Some(vec![])) + } + fn deserialize( &self, dtype: &DType, @@ -143,27 +143,32 @@ impl VTable for Extension { .with_slots(vec![Some(storage)])) } + fn slot_name(_array: ArrayView<'_, Self>, idx: usize) -> String { + SLOT_NAMES[idx].to_string() + } + fn execute(array: Array, _ctx: &mut ExecutionCtx) -> VortexResult { Ok(ExecutionResult::done(array)) } - fn reduce_parent( + fn execute_parent( array: ArrayView<'_, Self>, parent: &ArrayRef, child_idx: usize, + ctx: &mut ExecutionCtx, ) -> VortexResult> { - PARENT_RULES.evaluate(array, parent, child_idx) + PARENT_KERNELS.execute(array, parent, child_idx, ctx) } - fn execute_parent( + fn reduce(array: ArrayView<'_, Self>) -> VortexResult> { + RULES.evaluate(array) + } + + fn reduce_parent( array: ArrayView<'_, Self>, parent: &ArrayRef, child_idx: usize, - ctx: &mut ExecutionCtx, ) -> VortexResult> { - PARENT_KERNELS.execute(array, parent, child_idx, ctx) + PARENT_RULES.evaluate(array, parent, child_idx) } } - -#[derive(Clone, Debug)] -pub struct Extension;