diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs index e407be5e390dc..350ec13712652 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/bytes.rs @@ -16,7 +16,7 @@ // under the License. use crate::aggregates::group_values::multi_group_by::{ - GroupColumn, Nulls, nulls_equal_to, + GroupColumn, Nulls, nulls_equal_to, split_vec_min_alloc, }; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::array::{ @@ -380,11 +380,10 @@ where // Given offsets like [0, 2, 4, 5] and n = 1, we expect to get // offsets [0, 2, 3]. We first create two offsets for first_n as [0, 2] and the remaining as [2, 4, 5]. // And we shift the offset starting from 0 for the remaining one, [2, 4, 5] -> [0, 2, 3]. - let mut first_n_offsets = self.offsets.drain(0..n).collect::>(); - let offset_n = *self.offsets.first().unwrap(); - self.offsets - .iter_mut() - .for_each(|offset| *offset = offset.sub(offset_n)); + let offset_n = self.offsets[n]; + let mut first_n_offsets = split_vec_min_alloc(&mut self.offsets, n); + // After the split, self.offsets[0] == offset_n in both branches; normalize in-place. + self.offsets.iter_mut().for_each(|o| *o = o.sub(offset_n)); first_n_offsets.push(offset_n); // SAFETY: the offsets were constructed correctly in `insert_if_new` -- diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs index 2115e9a34da64..e742fab0812a4 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/mod.rs @@ -107,6 +107,19 @@ pub trait GroupColumn: Send + Sync { fn take_n(&mut self, n: usize) -> ArrayRef; } +/// Splits `vec` at `n`, returning the first `n` elements and leaving the +/// remainder in `vec`. Allocates for whichever portion is smaller to minimize +/// peak memory: `drain+collect` when `n <= remaining`, `split_off+replace` +/// when `remaining < n`. +pub(super) fn split_vec_min_alloc(vec: &mut Vec, n: usize) -> Vec { + if n * 2 <= vec.len() { + vec.drain(0..n).collect() + } else { + let remaining = vec.split_off(n); + mem::replace(vec, remaining) + } +} + /// Determines if the nullability of the existing and new input array can be used /// to short-circuit the comparison of the two values. /// @@ -1273,7 +1286,50 @@ mod tests { GroupValues, multi_group_by::GroupValuesColumn, }; - use super::GroupIndexView; + use super::{GroupIndexView, split_vec_min_alloc}; + + #[test] + fn test_split_vec_min_alloc_drain_branch() { + // n * 2 <= len → drain+collect branch (allocates n elements) + let mut v = vec![1, 2, 3, 4, 5, 6]; + let first = split_vec_min_alloc(&mut v, 2); + assert_eq!(first, vec![1, 2]); + assert_eq!(v, vec![3, 4, 5, 6]); + } + + #[test] + fn test_split_vec_min_alloc_split_off_branch() { + // remaining < n → split_off+replace branch (allocates remaining elements) + let mut v = vec![1, 2, 3, 4, 5, 6]; + let first = split_vec_min_alloc(&mut v, 4); + assert_eq!(first, vec![1, 2, 3, 4]); + assert_eq!(v, vec![5, 6]); + } + + #[test] + fn test_split_vec_min_alloc_exactly_half() { + // n * 2 == len → drain branch (boundary condition) + let mut v = vec![1, 2, 3, 4]; + let first = split_vec_min_alloc(&mut v, 2); + assert_eq!(first, vec![1, 2]); + assert_eq!(v, vec![3, 4]); + } + + #[test] + fn test_split_vec_min_alloc_take_all() { + let mut v = vec![1, 2, 3]; + let first = split_vec_min_alloc(&mut v, 3); + assert_eq!(first, vec![1, 2, 3]); + assert!(v.is_empty()); + } + + #[test] + fn test_split_vec_min_alloc_take_none() { + let mut v = vec![1, 2, 3]; + let first = split_vec_min_alloc(&mut v, 0); + assert!(first.is_empty()); + assert_eq!(v, vec![1, 2, 3]); + } #[test] fn test_intern_for_vectorized_group_values() { diff --git a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs index bdc06fa553de5..4aae996f6811d 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/multi_group_by/primitive.rs @@ -16,7 +16,7 @@ // under the License. use crate::aggregates::group_values::multi_group_by::{ - GroupColumn, Nulls, nulls_equal_to, + GroupColumn, Nulls, nulls_equal_to, split_vec_min_alloc, }; use crate::aggregates::group_values::null_builder::MaybeNullBufferBuilder; use arrow::array::ArrowNativeTypeOp; @@ -267,8 +267,7 @@ impl GroupColumn } fn take_n(&mut self, n: usize) -> ArrayRef { - let first_n = self.group_values.drain(0..n).collect::>(); - + let first_n = split_vec_min_alloc(&mut self.group_values, n); let first_n_nulls = if NULLABLE { self.nulls.take_n(n) } else { None }; Arc::new( @@ -584,4 +583,40 @@ mod tests { assert!(results[3]); assert!(results[4]); } + + #[test] + fn test_primitive_take_n() { + // drain branch: n * 2 <= len + let mut builder = + PrimitiveGroupValueBuilder::::new(DataType::Int64); + let array = Arc::new(Int64Array::from(vec![ + Some(10), + None, + Some(30), + Some(40), + Some(50), + ])) as ArrayRef; + for i in 0..5 { + builder.append_val(&array, i).unwrap(); + } + // len=5, n=2, n*2=4 <= 5 → drain branch + let out = builder.take_n(2); + let expected = Arc::new(Int64Array::from(vec![Some(10), None])) as ArrayRef; + assert_eq!(&out, &expected); + // remaining: [30, 40, 50] + assert_eq!(builder.len(), 3); + + // split_off branch: remaining < n (len=3, n=2, n*2=4 > 3) + let out2 = builder.take_n(2); + let expected2 = Arc::new(Int64Array::from(vec![Some(30), Some(40)])) as ArrayRef; + assert_eq!(&out2, &expected2); + // remaining: [50] + assert_eq!(builder.len(), 1); + + // take the last element + let out3 = builder.take_n(1); + let expected3 = Arc::new(Int64Array::from(vec![Some(50)])) as ArrayRef; + assert_eq!(&out3, &expected3); + assert_eq!(builder.len(), 0); + } }