diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index ad2a21bb4733c..936316e333b6b 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -19,6 +19,7 @@ //! Adapter that makes [`GroupsAccumulator`] out of [`Accumulator`] pub mod accumulate; +pub mod avg; pub mod bool_op; pub mod nulls; pub mod prim_op; diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs new file mode 100644 index 0000000000000..cc6f889082d1d --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/avg.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared helpers for average group accumulator state handling. + +use arrow::array::{ArrowNumericType, BooleanArray, PrimitiveArray}; + +use super::nulls::{filtered_null_mask, set_nulls}; + +/// Converts an AVG input value array into nullable per-row state arrays. +/// +/// The returned arrays are `(sum_state, count_state)`. Callers keep control of +/// their aggregate-specific state field order when wrapping these arrays in the +/// final state vector. +/// +/// Rows with NULL input values, `false` filters, or NULL filters are marked NULL +/// in both output arrays so later merge steps can ignore them consistently. +pub fn convert_to_avg_state( + sums: PrimitiveArray, + count_value: CountType::Native, + opt_filter: Option<&BooleanArray>, +) -> (PrimitiveArray, PrimitiveArray) +where + SumType: ArrowNumericType + Send, + CountType: ArrowNumericType + Send, +{ + let counts = PrimitiveArray::::from_value(count_value, sums.len()); + let nulls = filtered_null_mask(opt_filter, &sums); + let counts = set_nulls(counts, nulls.clone()); + let sums = set_nulls(sums, nulls); + + (sums, counts) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, BooleanArray, Float64Array}; + use arrow::datatypes::{DataType, Float64Type, Int64Type}; + + use super::convert_to_avg_state; + + type CountType = Int64Type; + + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); + } + } + + #[test] + fn convert_to_avg_state_applies_input_nulls_to_sum_and_count() { + let sums = Float64Array::from(vec![Some(1.0), None, Some(3.0)]); + + let (sums, counts) = + convert_to_avg_state::(sums, 1, None); + + assert_validity(&sums, &[true, false, true]); + assert_validity(&counts, &[true, false, true]); + assert_eq!(counts.values().as_ref(), &[1, 1, 1]); + } + + #[test] + fn convert_to_avg_state_applies_filter_nulls_to_sum_and_count() { + let sums = Float64Array::from(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]); + let filter = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]); + + let (sums, counts) = + convert_to_avg_state::(sums, 1, Some(&filter)); + + assert_eq!(sums.null_count(), 2); + assert_validity(&sums, &[true, false, false, true]); + + assert_eq!(counts.null_count(), 2); + assert_validity(&counts, &[true, false, false, true]); + assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); + } + + #[test] + fn convert_to_avg_state_preserves_sum_data_type() { + let sums = Float64Array::from(vec![1.0, 2.0]).with_data_type(DataType::Float64); + + let (sums, _counts) = + convert_to_avg_state::(sums, 1, None); + + assert_eq!(sums.data_type(), &DataType::Float64); + } +} diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index ddeb9b0870a16..a5fc4ec938654 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -44,9 +44,7 @@ use datafusion_functions_aggregate_common::aggregate::avg_distinct::{ DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ - filtered_null_mask, set_nulls, -}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::avg::convert_to_avg_state; use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; @@ -955,13 +953,7 @@ where .as_primitive::() .clone() .with_data_type(self.sum_data_type.clone()); - let counts = UInt64Array::from_value(1, sums.len()); - - let nulls = filtered_null_mask(opt_filter, &sums); - - // set nulls on the arrays - let counts = set_nulls(counts, nulls.clone()); - let sums = set_nulls(sums, nulls); + let (sums, counts) = convert_to_avg_state::(sums, 1, opt_filter); Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)]) } @@ -974,3 +966,117 @@ where self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Array, Float64Array}; + use arrow::datatypes::{Decimal128Type, DurationNanosecondType}; + + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); + } + } + + fn avg_state( + state: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) { + assert_eq!(state.len(), 2); + ( + state[0].as_primitive::(), + state[1].as_primitive::(), + ) + } + + fn float64_acc() -> AvgGroupsAccumulator Result> + { + AvgGroupsAccumulator::::new( + &DataType::Float64, + &DataType::Float64, + |sum, count| Ok(sum / count as f64), + ) + } + + fn decimal128_acc() + -> AvgGroupsAccumulator Result> { + AvgGroupsAccumulator::::new( + &DataType::Decimal128(10, 2), + &DataType::Decimal128(14, 6), + // convert_to_state does not evaluate averages, so avg_fn is unused here. + |sum, _count| Ok(sum), + ) + } + + fn duration_acc() + -> AvgGroupsAccumulator Result> + { + AvgGroupsAccumulator::::new( + &DataType::Duration(TimeUnit::Nanosecond), + &DataType::Duration(TimeUnit::Nanosecond), + |sum, count| Ok(sum / count as i64), + ) + } + + #[test] + fn float64_convert_to_state_uses_count_sum_order_and_null_filter() { + let acc = float64_acc(); + let values: Vec = vec![Arc::new(Float64Array::from(vec![ + Some(10.0), + Some(20.0), + None, + Some(40.0), + ]))]; + let filter = BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]); + + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let (counts, sums) = avg_state::(&state); + assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]); + assert_validity(counts, &[true, false, false, false]); + assert_validity(sums, &[true, false, false, false]); + } + + #[test] + fn decimal_convert_to_state_preserves_sum_type_and_nulls() { + let acc = decimal128_acc(); + let values: Vec = vec![Arc::new( + PrimitiveArray::::from(vec![ + Some(100_i128), + None, + Some(300_i128), + ]) + .with_data_type(DataType::Decimal128(10, 2)), + )]; + + let state = acc.convert_to_state(&values, None).unwrap(); + + let (counts, sums) = avg_state::(&state); + assert_eq!(sums.data_type(), &DataType::Decimal128(10, 2)); + assert_eq!(counts.values().as_ref(), &[1, 1, 1]); + assert_validity(counts, &[true, false, true]); + assert_validity(sums, &[true, false, true]); + } + + #[test] + fn duration_convert_to_state_preserves_sum_type_and_applies_filter() { + let acc = duration_acc(); + let values: Vec = vec![Arc::new( + PrimitiveArray::::from(vec![ + Some(10_i64), + Some(20_i64), + ]) + .with_data_type(DataType::Duration(TimeUnit::Nanosecond)), + )]; + let filter = BooleanArray::from(vec![Some(false), Some(true)]); + + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let (counts, sums) = avg_state::(&state); + assert_eq!(sums.data_type(), &DataType::Duration(TimeUnit::Nanosecond)); + assert_eq!(counts.values().as_ref(), &[1, 1]); + assert_validity(counts, &[false, true]); + assert_validity(sums, &[false, true]); + } +} diff --git a/datafusion/spark/src/function/aggregate/avg.rs b/datafusion/spark/src/function/aggregate/avg.rs index 5f4d2c253a2dc..d288c380a3cf6 100644 --- a/datafusion/spark/src/function/aggregate/avg.rs +++ b/datafusion/spark/src/function/aggregate/avg.rs @@ -32,9 +32,7 @@ use datafusion_expr::{ Accumulator, AggregateUDFImpl, Coercion, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, TypeSignatureClass, Volatility, }; -use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ - filtered_null_mask, set_nulls, -}; +use datafusion_functions_aggregate_common::aggregate::groups_accumulator::avg::convert_to_avg_state; use std::sync::Arc; /// AVG aggregate expression @@ -289,7 +287,7 @@ where &mut self, values: &[ArrayRef], group_indices: &[usize], - _opt_filter: Option<&BooleanArray>, + opt_filter: Option<&BooleanArray>, total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 2, "two arguments to merge_batch"); @@ -302,8 +300,12 @@ where for (idx, &group_index) in group_indices.iter().enumerate() { // Skip null state entries emitted by convert_to_state for - // filtered / null input rows. - if partial_counts.is_null(idx) || partial_sums.is_null(idx) { + // filtered / null input rows, and rows filtered during merge. + if partial_counts.is_null(idx) + || partial_sums.is_null(idx) + || opt_filter + .is_some_and(|filter| filter.is_null(idx) || !filter.value(idx)) + { continue; } self.counts[group_index] += partial_counts.value(idx); @@ -356,11 +358,7 @@ where .as_primitive::() .clone() .with_data_type(self.return_data_type.clone()); - let counts = Int64Array::from_value(1, sums.len()); - - let nulls = filtered_null_mask(opt_filter, &sums); - let counts = set_nulls(counts, nulls.clone()); - let sums = set_nulls(sums, nulls); + let (sums, counts) = convert_to_avg_state::(sums, 1, opt_filter); // [sum, count] - must match state() and merge_batch() Ok(vec![ @@ -389,6 +387,23 @@ mod tests { }) } + fn assert_validity(array: &dyn Array, expected: &[bool]) { + assert_eq!(array.len(), expected.len()); + for (idx, expected_valid) in expected.iter().copied().enumerate() { + assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}"); + } + } + + fn spark_avg_state( + state: &[ArrayRef], + ) -> (&PrimitiveArray, &PrimitiveArray) { + assert_eq!(state.len(), 2); + ( + state[0].as_primitive::(), + state[1].as_primitive::(), + ) + } + #[test] fn supports_convert_to_state() { assert!(make_acc().supports_convert_to_state()); @@ -401,9 +416,7 @@ mod tests { vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))]; let state = acc.convert_to_state(&values, None).unwrap(); - assert_eq!(state.len(), 2); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); assert_eq!(sums.values().as_ref(), &[1.0, 2.0, 3.0]); assert_eq!(counts.values().as_ref(), &[1, 1, 1]); @@ -421,15 +434,12 @@ mod tests { ]))]; let state = acc.convert_to_state(&values, None).unwrap(); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + assert_validity(sums, &[true, false, true]); assert_eq!(counts.value(0), 1); - assert!(counts.is_null(1)); + assert_validity(counts, &[true, false, true]); assert_eq!(counts.value(2), 1); } @@ -441,18 +451,48 @@ mod tests { let filter = BooleanArray::from(vec![true, false, true]); let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); - let sums = state[0].as_primitive::(); - let counts = state[1].as_primitive::(); + let (sums, counts) = spark_avg_state(&state); - assert!(!sums.is_null(0)); - assert!(sums.is_null(1)); - assert!(!sums.is_null(2)); + assert_validity(sums, &[true, false, true]); assert_eq!(counts.value(0), 1); - assert!(counts.is_null(1)); + assert_validity(counts, &[true, false, true]); assert_eq!(counts.value(2), 1); } + #[test] + fn convert_to_state_with_null_filter() { + let acc = make_acc(); + let values: Vec = + vec![Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0]))]; + let filter = BooleanArray::from(vec![Some(true), None, Some(true)]); + let state = acc.convert_to_state(&values, Some(&filter)).unwrap(); + + let (sums, counts) = spark_avg_state(&state); + + assert_validity(sums, &[true, false, true]); + + assert_eq!(counts.value(0), 1); + assert_validity(counts, &[true, false, true]); + assert_eq!(counts.value(2), 1); + } + + #[test] + fn merge_batch_applies_filter() { + let mut acc = make_acc(); + let input: Vec = + vec![Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0]))]; + let state = acc.convert_to_state(&input, None).unwrap(); + let filter = BooleanArray::from(vec![Some(true), Some(false), None]); + + acc.merge_batch(&state, &[0, 0, 0], Some(&filter), 1) + .unwrap(); + + let result = acc.evaluate(EmitTo::All).unwrap(); + let result = result.as_primitive::(); + assert_eq!(result.value(0), 10.0); + } + #[test] fn convert_to_state_roundtrips_through_merge() { let mut acc = make_acc();