Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! Adapter that makes [`GroupsAccumulator`] out of [`Accumulator`]
pub mod accumulate;
pub mod avg;
pub mod bool_op;
pub mod nulls;
pub mod prim_op;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Shared helpers for average group accumulator state handling.

use arrow::array::{ArrowNumericType, BooleanArray, PrimitiveArray};

use super::nulls::{filtered_null_mask, set_nulls};

/// Converts an AVG input value array into nullable per-row state arrays.
///
/// The returned arrays are `(sum_state, count_state)`. Callers keep control of
/// their aggregate-specific state field order when wrapping these arrays in the
/// final state vector.
///
/// Rows with NULL input values, `false` filters, or NULL filters are marked NULL
/// in both output arrays so later merge steps can ignore them consistently.
pub fn convert_to_avg_state<SumType, CountType>(
sums: PrimitiveArray<SumType>,
count_value: CountType::Native,
opt_filter: Option<&BooleanArray>,
) -> (PrimitiveArray<SumType>, PrimitiveArray<CountType>)
where
SumType: ArrowNumericType + Send,
CountType: ArrowNumericType + Send,
{
let counts = PrimitiveArray::<CountType>::from_value(count_value, sums.len());
let nulls = filtered_null_mask(opt_filter, &sums);
let counts = set_nulls(counts, nulls.clone());
let sums = set_nulls(sums, nulls);

(sums, counts)
}

#[cfg(test)]
mod tests {
use arrow::array::{Array, BooleanArray, Float64Array};
use arrow::datatypes::{DataType, Float64Type, Int64Type};

use super::convert_to_avg_state;

type CountType = Int64Type;

fn assert_validity(array: &dyn Array, expected: &[bool]) {
assert_eq!(array.len(), expected.len());
for (idx, expected_valid) in expected.iter().copied().enumerate() {
assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}");
}
}

#[test]
fn convert_to_avg_state_applies_input_nulls_to_sum_and_count() {
let sums = Float64Array::from(vec![Some(1.0), None, Some(3.0)]);

let (sums, counts) =
convert_to_avg_state::<Float64Type, CountType>(sums, 1, None);

assert_validity(&sums, &[true, false, true]);
assert_validity(&counts, &[true, false, true]);
assert_eq!(counts.values().as_ref(), &[1, 1, 1]);
}

#[test]
fn convert_to_avg_state_applies_filter_nulls_to_sum_and_count() {
let sums = Float64Array::from(vec![Some(1.0), Some(2.0), Some(3.0), Some(4.0)]);
let filter = BooleanArray::from(vec![Some(true), Some(false), None, Some(true)]);

let (sums, counts) =
convert_to_avg_state::<Float64Type, CountType>(sums, 1, Some(&filter));

assert_eq!(sums.null_count(), 2);
assert_validity(&sums, &[true, false, false, true]);

assert_eq!(counts.null_count(), 2);
assert_validity(&counts, &[true, false, false, true]);
assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]);
}

#[test]
fn convert_to_avg_state_preserves_sum_data_type() {
let sums = Float64Array::from(vec![1.0, 2.0]).with_data_type(DataType::Float64);

let (sums, _counts) =
convert_to_avg_state::<Float64Type, CountType>(sums, 1, None);

assert_eq!(sums.data_type(), &DataType::Float64);
}
}
126 changes: 116 additions & 10 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ use datafusion_functions_aggregate_common::aggregate::avg_distinct::{
DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{
filtered_null_mask, set_nulls,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::avg::convert_to_avg_state;
use datafusion_functions_aggregate_common::utils::DecimalAverager;
use datafusion_macros::user_doc;
use log::debug;
Expand Down Expand Up @@ -955,13 +953,7 @@ where
.as_primitive::<T>()
.clone()
.with_data_type(self.sum_data_type.clone());
let counts = UInt64Array::from_value(1, sums.len());

let nulls = filtered_null_mask(opt_filter, &sums);

// set nulls on the arrays
let counts = set_nulls(counts, nulls.clone());
let sums = set_nulls(sums, nulls);
let (sums, counts) = convert_to_avg_state::<T, UInt64Type>(sums, 1, opt_filter);

Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)])
}
Expand All @@ -974,3 +966,117 @@ where
self.counts.capacity() * size_of::<u64>() + self.sums.capacity() * size_of::<T>()
}
}

#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Array, Float64Array};
use arrow::datatypes::{Decimal128Type, DurationNanosecondType};

fn assert_validity(array: &dyn Array, expected: &[bool]) {
assert_eq!(array.len(), expected.len());
for (idx, expected_valid) in expected.iter().copied().enumerate() {
assert_eq!(!array.is_null(idx), expected_valid, "validity at row {idx}");
}
}

fn avg_state<T: ArrowPrimitiveType>(
state: &[ArrayRef],
) -> (&PrimitiveArray<UInt64Type>, &PrimitiveArray<T>) {
assert_eq!(state.len(), 2);
(
state[0].as_primitive::<UInt64Type>(),
state[1].as_primitive::<T>(),
)
}

fn float64_acc() -> AvgGroupsAccumulator<Float64Type, impl Fn(f64, u64) -> Result<f64>>
{
AvgGroupsAccumulator::<Float64Type, _>::new(
&DataType::Float64,
&DataType::Float64,
|sum, count| Ok(sum / count as f64),
)
}

fn decimal128_acc()
-> AvgGroupsAccumulator<Decimal128Type, impl Fn(i128, u64) -> Result<i128>> {
AvgGroupsAccumulator::<Decimal128Type, _>::new(
&DataType::Decimal128(10, 2),
&DataType::Decimal128(14, 6),
// convert_to_state does not evaluate averages, so avg_fn is unused here.
|sum, _count| Ok(sum),
)
}

fn duration_acc()
-> AvgGroupsAccumulator<DurationNanosecondType, impl Fn(i64, u64) -> Result<i64>>
{
AvgGroupsAccumulator::<DurationNanosecondType, _>::new(
&DataType::Duration(TimeUnit::Nanosecond),
&DataType::Duration(TimeUnit::Nanosecond),
|sum, count| Ok(sum / count as i64),
)
}

#[test]
fn float64_convert_to_state_uses_count_sum_order_and_null_filter() {
let acc = float64_acc();
let values: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
Some(10.0),
Some(20.0),
None,
Some(40.0),
]))];
let filter = BooleanArray::from(vec![Some(true), None, Some(true), Some(false)]);

let state = acc.convert_to_state(&values, Some(&filter)).unwrap();

let (counts, sums) = avg_state::<Float64Type>(&state);
assert_eq!(counts.values().as_ref(), &[1, 1, 1, 1]);
assert_validity(counts, &[true, false, false, false]);
assert_validity(sums, &[true, false, false, false]);
}

#[test]
fn decimal_convert_to_state_preserves_sum_type_and_nulls() {
let acc = decimal128_acc();
let values: Vec<ArrayRef> = vec![Arc::new(
PrimitiveArray::<Decimal128Type>::from(vec![
Some(100_i128),
None,
Some(300_i128),
])
.with_data_type(DataType::Decimal128(10, 2)),
)];

let state = acc.convert_to_state(&values, None).unwrap();

let (counts, sums) = avg_state::<Decimal128Type>(&state);
assert_eq!(sums.data_type(), &DataType::Decimal128(10, 2));
assert_eq!(counts.values().as_ref(), &[1, 1, 1]);
assert_validity(counts, &[true, false, true]);
assert_validity(sums, &[true, false, true]);
}

#[test]
fn duration_convert_to_state_preserves_sum_type_and_applies_filter() {
let acc = duration_acc();
let values: Vec<ArrayRef> = vec![Arc::new(
PrimitiveArray::<DurationNanosecondType>::from(vec![
Some(10_i64),
Some(20_i64),
])
.with_data_type(DataType::Duration(TimeUnit::Nanosecond)),
)];
let filter = BooleanArray::from(vec![Some(false), Some(true)]);

let state = acc.convert_to_state(&values, Some(&filter)).unwrap();

let (counts, sums) = avg_state::<DurationNanosecondType>(&state);
assert_eq!(sums.data_type(), &DataType::Duration(TimeUnit::Nanosecond));
assert_eq!(counts.values().as_ref(), &[1, 1]);
assert_validity(counts, &[false, true]);
assert_validity(sums, &[false, true]);
}
}
Loading
Loading