From 7b4d1a6bc93a19438fb306450b21ad672109304d Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 30 May 2026 13:06:16 +0800 Subject: [PATCH 1/4] feat(aggregate): enhance avg functionality with shared helper and Spark integration - Added shared helper for average calculations in `avg.rs` with conversion to average state. - Exported the aggregate module in `groups_accumulator.rs`. - Updated Spark's average function to maintain state order and count type. - Added tests for common helper null/filter semantics and Spark null filter cases. --- .../src/aggregate/groups_accumulator.rs | 1 + .../src/aggregate/groups_accumulator/avg.rs | 110 ++++++++++++++++++ .../spark/src/function/aggregate/avg.rs | 30 +++-- 3 files changed, 133 insertions(+), 8 deletions(-) create mode 100644 datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index ad2a21bb4733c..936316e333b6b 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -19,6 +19,7 @@ //! Adapter that makes [`GroupsAccumulator`] out of [`Accumulator`] pub mod accumulate; +pub mod avg; pub mod bool_op; pub mod nulls; pub mod prim_op; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs new file mode 100644 index 0000000000000..47b7eea3cb23c --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared helpers for average group accumulator state handling. + +use arrow::array::{ArrowNumericType, BooleanArray, PrimitiveArray}; + +use super::nulls::{filtered_null_mask, set_nulls}; + +/// Converts an AVG input value array into nullable per-row state arrays. +/// +/// The returned arrays are `(sum_state, count_state)`. Callers keep control of +/// their aggregate-specific state field order when wrapping these arrays in the +/// final state vector. +/// +/// Rows with NULL input values, `false` filters, or NULL filters are marked NULL +/// in both output arrays so later merge steps can ignore them consistently. +pub fn convert_to_avg_state( + sums: PrimitiveArray, + count_value: CountType::Native, + opt_filter: Option<&BooleanArray>, +) -> (PrimitiveArray, PrimitiveArray) +where + SumType: ArrowNumericType + Send, + CountType: ArrowNumericType + Send, +{ + let counts = PrimitiveArray::::from_value(count_value, sums.len()); + let nulls = filtered_null_mask(opt_filter, &sums); + let counts = set_nulls(counts, nulls.clone()); + let sums = set_nulls(sums, nulls); + + (sums, counts) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, BooleanArray, Float64Array}; + use arrow::datatypes::Float64Type; + + use super::convert_to_avg_state; + + #[test] + fn convert_to_avg_state_applies_input_nulls_to_sum_and_count() { + let sums = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); + + let (sums, counts) = convert_to_avg_state::< + Float64Type, + arrow::datatypes::Int64Type, + >(sums, 1, None); + + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(!sums.is_null(2)); + assert!(!counts.is_null(0)); + assert!(counts.is_null(1)); + assert!(!counts.is_null(2)); + assert_eq!(counts.values().as_ref(), &[1, 1, 1]); + } + + #[test] + fn convert_to_avg_state_applies_filter_nulls_to_sum_and_count() { + let sums = Float64Array::from(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]); + let filter = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]); + + let (sums, counts) = convert_to_avg_state::< + Float64Type, + arrow::datatypes::Int64Type, + >(sums, 1, Some(&filter)); + + assert_eq!(sums.null_count(), 2); + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(sums.is_null(2)); + assert!(!sums.is_null(3)); + + assert_eq!(counts.null_count(), 2); + assert!(!counts.is_null(0)); + assert!(counts.is_null(1)); + assert!(counts.is_null(2)); + assert!(!counts.is_null(3)); + assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); + } + + #[test] + fn convert_to_avg_state_preserves_sum_data_type() { + let sums = Float64Array::from(vec![1.0, 2.0]) + .with_data_type(arrow::datatypes::DataType::Float64); + + let (sums, _counts) = convert_to_avg_state::< + Float64Type, + arrow::datatypes::Int64Type, + >(sums, 1, None); + + assert_eq!(sums.data_type(), &arrow::datatypes::DataType::Float64); + } +} diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 5f4d2c253a2dc..0fc541e1dbaef 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -32,9 +32,7 @@ use datafusion_expr::{ Accumulator, AggregateUDFImpl, Coercion, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, TypeSignatureClass, Volatility, }; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ - filtered_null_mask, set_nulls, -}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::avg::convert_to_avg_state; use std::sync::Arc; /// AVG aggregate expression @@ -356,11 +354,7 @@ where .as_primitive::() .clone() .with_data_type(self.return_data_type.clone()); - let counts = Int64Array::from_value(1, sums.len()); - - let nulls = filtered_null_mask(opt_filter, &sums); - let counts = set_nulls(counts, nulls.clone()); - let sums = set_nulls(sums, nulls); + let (sums, counts) = convert_to_avg_state::(sums, 1, opt_filter); // [sum, count] - must match state() and merge_batch() Ok(vec![ @@ -453,6 +447,26 @@ mod tests { assert_eq!(counts.value(2), 1); } + #[test] + fn convert_to_state_with_null_filter() { + let acc = make_acc(); + let values: Vec = + vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))]; + let filter = BooleanArray::from(vec![Some(true), None, Some(true)]); + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let sums = state[0].as_primitive::(); + let counts = state[1].as_primitive::(); + + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(!sums.is_null(2)); + + assert_eq!(counts.value(0), 1); + assert!(counts.is_null(1)); + assert_eq!(counts.value(2), 1); + } + #[test] fn convert_to_state_roundtrips_through_merge() { let mut acc = make_acc(); From 484fe010a077bd45903b51fe87c2d8d71627ca54 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 30 May 2026 13:10:22 +0800 Subject: [PATCH 2/4] feat(aggregate): refactor Avg function to use shared state and preserve data type integrity - Updated built-in Avg to utilize shared `convert_to_avg_state`. - Ensured the order of state is preserved as [count, sum]. - Maintained count type as UInt64. - Ensured sum data type consistency for Decimal and Duration. Added tests for: - Float64: validating count/sum order and null filter semantics. - Decimal128: checking sum type and input null semantics. - DurationNanosecond: verifying sum type and filter semantics. --- datafusion/functions-aggregate/src/average.rs | 123 ++++++++++++++++-- 1 file changed, 113 insertions(+), 10 deletions(-) diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index ddeb9b0870a16..08ed27f6fb330 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -44,9 +44,7 @@ use datafusion_functions_aggregate_common::aggregate::avg_distinct::{ DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ - filtered_null_mask, set_nulls, -}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::avg::convert_to_avg_state; use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; @@ -955,13 +953,7 @@ where .as_primitive::() .clone() .with_data_type(self.sum_data_type.clone()); - let counts = UInt64Array::from_value(1, sums.len()); - - let nulls = filtered_null_mask(opt_filter, &sums); - - // set nulls on the arrays - let counts = set_nulls(counts, nulls.clone()); - let sums = set_nulls(sums, nulls); + let (sums, counts) = convert_to_avg_state::(sums, 1, opt_filter); Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)]) } @@ -974,3 +966,114 @@ where self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Float64Array}; + use arrow::datatypes::{Decimal128Type, DurationNanosecondType}; + + fn float64_acc() -> AvgGroupsAccumulator Result> + { + AvgGroupsAccumulator::::new( + &DataType::Float64, + &DataType::Float64, + |sum, count| Ok(sum / count as f64), + ) + } + + fn decimal128_acc() + -> AvgGroupsAccumulator Result> { + AvgGroupsAccumulator::::new( + &DataType::Decimal128(10, 2), + &DataType::Decimal128(14, 6), + |sum, _count| Ok(sum), + ) + } + + fn duration_acc() + -> AvgGroupsAccumulator Result> + { + AvgGroupsAccumulator::::new( + &DataType::Duration(TimeUnit::Nanosecond), + &DataType::Duration(TimeUnit::Nanosecond), + |sum, count| Ok(sum / count as i64), + ) + } + + #[test] + fn float64_convert_to_state_uses_count_sum_order_and_null_filter() { + let acc = float64_acc(); + let values: Vec = vec![Arc::new(Float64Array::from(vec![ + Some(10.0), + Some(20.0), + None, + Some(40.0), + ]))]; + let filter = BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]); + + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let counts = state[0].as_primitive::(); + let sums = state[1].as_primitive::(); + assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); + assert!(!counts.is_null(0)); + assert!(counts.is_null(1)); + assert!(counts.is_null(2)); + assert!(counts.is_null(3)); + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(sums.is_null(2)); + assert!(sums.is_null(3)); + } + + #[test] + fn decimal_convert_to_state_preserves_sum_type_and_nulls() { + let acc = decimal128_acc(); + let values: Vec = vec![Arc::new( + PrimitiveArray::::from(vec![ + Some(100_i128), + None, + Some(300_i128), + ]) + .with_data_type(DataType::Decimal128(10, 2)), + )]; + + let state = acc.convert_to_state(&values, None).unwrap(); + + let counts = state[0].as_primitive::(); + let sums = state[1].as_primitive::(); + assert_eq!(sums.data_type(), &DataType::Decimal128(10, 2)); + assert_eq!(counts.values().as_ref(), &[1, 1, 1]); + assert!(!counts.is_null(0)); + assert!(counts.is_null(1)); + assert!(!counts.is_null(2)); + assert!(!sums.is_null(0)); + assert!(sums.is_null(1)); + assert!(!sums.is_null(2)); + } + + #[test] + fn duration_convert_to_state_preserves_sum_type_and_applies_filter() { + let acc = duration_acc(); + let values: Vec = vec![Arc::new( + PrimitiveArray::::from(vec![ + Some(10_i64), + Some(20_i64), + ]) + .with_data_type(DataType::Duration(TimeUnit::Nanosecond)), + )]; + let filter = BooleanArray::from(vec![Some(false), Some(true)]); + + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let counts = state[0].as_primitive::(); + let sums = state[1].as_primitive::(); + assert_eq!(sums.data_type(), &DataType::Duration(TimeUnit::Nanosecond)); + assert_eq!(counts.values().as_ref(), &[1, 1]); + assert!(counts.is_null(0)); + assert!(!counts.is_null(1)); + assert!(sums.is_null(0)); + assert!(!sums.is_null(1)); + } +} From 13edd95754d8936932ed19b4d936bd1dc56968c9 Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 30 May 2026 13:18:17 +0800 Subject: [PATCH 3/4] feat(tests): enhance test helpers and reduce redundancy - Added assert_validity test helpers for improved validation in tests - Reduced repeated null assertions to streamline code - Shortened common helper tests using local imports and type aliasing - Introduced built-in Avg avg_state test helper - Added comment for decimal test closure to clarify the unused avg_fn by convert_to_state --- .../src/aggregate/groups_accumulator/avg.rs | 54 ++++++++----------- datafusion/functions-aggregate/src/average.rs | 51 +++++++++--------- .../spark/src/function/aggregate/avg.rs | 25 ++++----- 3 files changed, 63 insertions(+), 67 deletions(-) diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs index 47b7eea3cb23c..cc6f889082d1d 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs @@ -49,25 +49,28 @@ where #[cfg(test)] mod tests { use arrow::array::{Array, BooleanArray, Float64Array}; - use arrow::datatypes::Float64Type; + use arrow::datatypes::{DataType, Float64Type, Int64Type}; use super::convert_to_avg_state; + type CountType = Int64Type; + + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); + } + } + #[test] fn convert_to_avg_state_applies_input_nulls_to_sum_and_count() { let sums = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); - let (sums, counts) = convert_to_avg_state::< - Float64Type, - arrow::datatypes::Int64Type, - >(sums, 1, None); - - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); - assert!(!counts.is_null(0)); - assert!(counts.is_null(1)); - assert!(!counts.is_null(2)); + let (sums, counts) = + convert_to_avg_state::(sums, 1, None); + + assert_validity(&sums, &[true, false, true]); + assert_validity(&counts, &[true, false, true]); assert_eq!(counts.values().as_ref(), &[1, 1, 1]); } @@ -76,35 +79,24 @@ mod tests { let sums = Float64Array::from(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]); let filter = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]); - let (sums, counts) = convert_to_avg_state::< - Float64Type, - arrow::datatypes::Int64Type, - >(sums, 1, Some(&filter)); + let (sums, counts) = + convert_to_avg_state::(sums, 1, Some(&filter)); assert_eq!(sums.null_count(), 2); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(sums.is_null(2)); - assert!(!sums.is_null(3)); + assert_validity(&sums, &[true, false, false, true]); assert_eq!(counts.null_count(), 2); - assert!(!counts.is_null(0)); - assert!(counts.is_null(1)); - assert!(counts.is_null(2)); - assert!(!counts.is_null(3)); + assert_validity(&counts, &[true, false, false, true]); assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); } #[test] fn convert_to_avg_state_preserves_sum_data_type() { - let sums = Float64Array::from(vec![1.0, 2.0]) - .with_data_type(arrow::datatypes::DataType::Float64); + let sums = Float64Array::from(vec![1.0, 2.0]).with_data_type(DataType::Float64); - let (sums, _counts) = convert_to_avg_state::< - Float64Type, - arrow::datatypes::Int64Type, - >(sums, 1, None); + let (sums, _counts) = + convert_to_avg_state::(sums, 1, None); - assert_eq!(sums.data_type(), &arrow::datatypes::DataType::Float64); + assert_eq!(sums.data_type(), &DataType::Float64); } } diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 08ed27f6fb330..a5fc4ec938654 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -973,6 +973,23 @@ mod tests { use arrow::array::{Array, Float64Array}; use arrow::datatypes::{Decimal128Type, DurationNanosecondType}; + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); + } + } + + fn avg_state( + state: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) { + assert_eq!(state.len(), 2); + ( + state[0].as_primitive::(), + state[1].as_primitive::(), + ) + } + fn float64_acc() -> AvgGroupsAccumulator Result> { AvgGroupsAccumulator::::new( @@ -987,6 +1004,7 @@ mod tests { AvgGroupsAccumulator::::new( &DataType::Decimal128(10, 2), &DataType::Decimal128(14, 6), + // convert_to_state does not evaluate averages, so avg_fn is unused here. |sum, _count| Ok(sum), ) } @@ -1014,17 +1032,10 @@ mod tests { let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let counts = state[0].as_primitive::(); - let sums = state[1].as_primitive::(); + let (counts, sums) = avg_state::(&state); assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); - assert!(!counts.is_null(0)); - assert!(counts.is_null(1)); - assert!(counts.is_null(2)); - assert!(counts.is_null(3)); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(sums.is_null(2)); - assert!(sums.is_null(3)); + assert_validity(counts, &[true, false, false, false]); + assert_validity(sums, &[true, false, false, false]); } #[test] @@ -1041,16 +1052,11 @@ mod tests { let state = acc.convert_to_state(&values, None).unwrap(); - let counts = state[0].as_primitive::(); - let sums = state[1].as_primitive::(); + let (counts, sums) = avg_state::(&state); assert_eq!(sums.data_type(), &DataType::Decimal128(10, 2)); assert_eq!(counts.values().as_ref(), &[1, 1, 1]); - assert!(!counts.is_null(0)); - assert!(counts.is_null(1)); - assert!(!counts.is_null(2)); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + assert_validity(counts, &[true, false, true]); + assert_validity(sums, &[true, false, true]); } #[test] @@ -1067,13 +1073,10 @@ mod tests { let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let counts = state[0].as_primitive::(); - let sums = state[1].as_primitive::(); + let (counts, sums) = avg_state::(&state); assert_eq!(sums.data_type(), &DataType::Duration(TimeUnit::Nanosecond)); assert_eq!(counts.values().as_ref(), &[1, 1]); - assert!(counts.is_null(0)); - assert!(!counts.is_null(1)); - assert!(sums.is_null(0)); - assert!(!sums.is_null(1)); + assert_validity(counts, &[false, true]); + assert_validity(sums, &[false, true]); } } diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 0fc541e1dbaef..90fcc71783eec 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -383,6 +383,13 @@ mod tests { }) } + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); + } + } + #[test] fn supports_convert_to_state() { assert!(make_acc().supports_convert_to_state()); @@ -418,12 +425,10 @@ mod tests { let sums = state[0].as_primitive::(); let counts = state[1].as_primitive::(); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + assert_validity(sums, &[true, false, true]); assert_eq!(counts.value(0), 1); - assert!(counts.is_null(1)); + assert_validity(counts, &[true, false, true]); assert_eq!(counts.value(2), 1); } @@ -438,12 +443,10 @@ mod tests { let sums = state[0].as_primitive::(); let counts = state[1].as_primitive::(); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + assert_validity(sums, &[true, false, true]); assert_eq!(counts.value(0), 1); - assert!(counts.is_null(1)); + assert_validity(counts, &[true, false, true]); assert_eq!(counts.value(2), 1); } @@ -458,12 +461,10 @@ mod tests { let sums = state[0].as_primitive::(); let counts = state[1].as_primitive::(); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + assert_validity(sums, &[true, false, true]); assert_eq!(counts.value(0), 1); - assert!(counts.is_null(1)); + assert_validity(counts, &[true, false, true]); assert_eq!(counts.value(2), 1); } From 7e8b4fe76f4a3c07fccd7814fe412eb543f5d02c Mon Sep 17 00:00:00 2001 From: Siew Kam Onn Date: Sat, 30 May 2026 14:31:04 +0800 Subject: [PATCH 4/4] feat(datafusion/spark): enhance Spark Avg merge_batch to honor opt_filter - Implemented logic to skip false and NULL values in merge_batch. - Maintained skipping of null converted state rows. - Added regression test: merge_batch_applies_filter. - Introduced spark_avg_state test helper for better testing. - Refactored code to eliminate repeated state[0]/state[1] decode boilerplate. --- .../spark/src/function/aggregate/avg.rs | 49 ++++++++++++++----- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 90fcc71783eec..d288c380a3cf6 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -287,7 +287,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - _opt_filter: Option<&BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); @@ -300,8 +300,12 @@ where for (idx, &group_index) in group_indices.iter().enumerate() { // Skip null state entries emitted by convert_to_state for - // filtered / null input rows. - if partial_counts.is_null(idx) || partial_sums.is_null(idx) { + // filtered / null input rows, and rows filtered during merge. + if partial_counts.is_null(idx) + || partial_sums.is_null(idx) + || opt_filter + .is_some_and(|filter| filter.is_null(idx) || !filter.value(idx)) + { continue; } self.counts[group_index] += partial_counts.value(idx); @@ -390,6 +394,16 @@ mod tests { } } + fn spark_avg_state( + state: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) { + assert_eq!(state.len(), 2); + ( + state[0].as_primitive::(), + state[1].as_primitive::(), + ) + } + #[test] fn supports_convert_to_state() { assert!(make_acc().supports_convert_to_state()); @@ -402,9 +416,7 @@ mod tests { vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))]; let state = acc.convert_to_state(&values, None).unwrap(); - assert_eq!(state.len(), 2); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); assert_eq!(sums.values().as_ref(), &[1.0, 2.0, 3.0]); assert_eq!(counts.values().as_ref(), &[1, 1, 1]); @@ -422,8 +434,7 @@ mod tests { ]))]; let state = acc.convert_to_state(&values, None).unwrap(); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); assert_validity(sums, &[true, false, true]); @@ -440,8 +451,7 @@ mod tests { let filter = BooleanArray::from(vec![true, false, true]); let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); assert_validity(sums, &[true, false, true]); @@ -458,8 +468,7 @@ mod tests { let filter = BooleanArray::from(vec![Some(true), None, Some(true)]); let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); assert_validity(sums, &[true, false, true]); @@ -468,6 +477,22 @@ mod tests { assert_eq!(counts.value(2), 1); } + #[test] + fn merge_batch_applies_filter() { + let mut acc = make_acc(); + let input: Vec = + vec![Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0]))]; + let state = acc.convert_to_state(&input, None).unwrap(); + let filter = BooleanArray::from(vec![Some(true), Some(false), None]); + + acc.merge_batch(&state, &[0, 0, 0], Some(&filter), 1) + .unwrap(); + + let result = acc.evaluate(EmitTo::All).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.value(0), 10.0); + } + #[test] fn convert_to_state_roundtrips_through_merge() { let mut acc = make_acc();