diff --git a/datafusion/physical-plan/Cargo.toml b/datafusion/physical-plan/Cargo.toml index 7acb21b8f3b93..3a453f851f4e4 100644 --- a/datafusion/physical-plan/Cargo.toml +++ b/datafusion/physical-plan/Cargo.toml @@ -106,3 +106,11 @@ required-features = ["test_utils"] harness = false name = "aggregate_vectorized" required-features = ["test_utils"] + +[[bench]] +name = "single_column_aggr" +harness = false + +[profile.profiling] +inherits = "release" +debug = true diff --git a/datafusion/physical-plan/benches/single_column_aggr.rs b/datafusion/physical-plan/benches/single_column_aggr.rs new file mode 100644 index 0000000000000..d7a80902f5a06 --- /dev/null +++ b/datafusion/physical-plan/benches/single_column_aggr.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, StringDictionaryBuilder}; +use arrow::datatypes::{DataType, Field, Schema, UInt8Type}; +use criterion::{Criterion, criterion_group, criterion_main}; +use datafusion_expr::EmitTo; +use datafusion_physical_plan::aggregates::group_values::single_group_by::dictionary::GroupValuesDictionary; +use datafusion_physical_plan::aggregates::group_values::{GroupValues, new_group_values}; +use datafusion_physical_plan::aggregates::order::GroupOrdering; +use std::sync::Arc; +#[derive(Debug)] +enum Cardinality { + Xsmall, // 1 + Small, // 10 + Medium, // 50 + Large, // 200 +} +#[derive(Debug)] +enum BatchSize { + Small, // 8192 + Medium, // 32768 + Large, // 65536 +} +#[derive(Debug)] +enum NullRate { + Zero, // 0% + Low, // 1% + Medium, // 5% + High, // 20% +} +#[derive(Debug, Clone)] +enum GroupType { + Dictionary, + GroupValueRows, +} +fn create_string_values(cardinality: &Cardinality) -> Vec { + let num_values = match cardinality { + Cardinality::Xsmall => 3, + Cardinality::Small => 10, + Cardinality::Medium => 50, + Cardinality::Large => 200, + }; + (0..num_values) + .map(|i| format!("group_value_{i:06}")) + .collect() +} +fn create_batch(batch_size: &BatchSize, cardinality: &Cardinality) -> Vec { + let size = match batch_size { + BatchSize::Small => 8192, + BatchSize::Medium => 32768, + BatchSize::Large => 65536, + }; + let unique_strings = create_string_values(cardinality); + if unique_strings.is_empty() { + return Vec::new(); + } + + unique_strings.iter().cycle().take(size).cloned().collect() +} +fn strings_to_dict_array(values: Vec>) -> ArrayRef { + let mut builder = StringDictionaryBuilder::::new(); + for v in values { + match v { + Some(v) => builder.append_value(v), + None => builder.append_null(), + } + } + Arc::new(builder.finish()) +} +fn introduce_nulls(values: Vec, null_rate: &NullRate) -> Vec> { + let rate = match null_rate { + NullRate::Zero => 0.0, + NullRate::Low => 0.01, + NullRate::Medium => 0.05, + NullRate::High => 0.20, + }; + values + .into_iter() + .map(|v| { + if rand::random::() < rate { + None + } else { + Some(v) + } + }) + .collect() +} + +fn generate_group_values(kind: &GroupType) -> Box { + match kind { + GroupType::GroupValueRows => { + // we know this is going to hit the fallback path I.E GroupValueRows, but for the sake of avoiding making private items public call the public api + let schema = Arc::new(Schema::new(vec![Field::new( + "group_col", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + )])); + new_group_values(schema, &GroupOrdering::None).unwrap() + } + GroupType::Dictionary => { + // call custom path directly + Box::new(GroupValuesDictionary::::new(&DataType::Utf8)) + } + } +} + +fn bench_single_column_group_values(c: &mut Criterion) { + let group_types = [GroupType::GroupValueRows, GroupType::Dictionary]; + let cardinalities = [ + Cardinality::Xsmall, + Cardinality::Small, + Cardinality::Medium, + Cardinality::Large, + ]; + let batch_sizes = [BatchSize::Small, BatchSize::Medium, BatchSize::Large]; + let null_rates = [ + NullRate::Zero, + NullRate::Low, + NullRate::Medium, + NullRate::High, + ]; + + for cardinality in &cardinalities { + for batch_size in &batch_sizes { + for null_rate in &null_rates { + for group_type in &group_types { + let group_name = format!( + "t1_{group_type:?}_cardinality_{cardinality:?}_batch_{batch_size:?}_null_rate_{null_rate:?}" + ); + + let string_vec = create_batch(batch_size, cardinality); + let nullable_values = introduce_nulls(string_vec, null_rate); + let col_ref = match group_type { + GroupType::Dictionary | GroupType::GroupValueRows => { + strings_to_dict_array(nullable_values.clone()) + } + }; + c.bench_function(&group_name, |b| { + b.iter_batched( + || { + //create fresh group values for each iteration + let gv = generate_group_values(group_type); + let col = col_ref.clone(); + (gv, col) + }, + |(mut group_values, col)| { + let mut groups = Vec::new(); + group_values.intern(&[col], &mut groups).unwrap(); + //group_values.emit(EmitTo::All).unwrap(); + }, + criterion::BatchSize::SmallInput, + ); + }); + + // Second benchmark that alternates between intern and emit to simulate more realistic usage patterns where the same group values is used across multiple batches of the same grouping column + let multi_batch_name = format!( + "multi_batch/{group_type:?}_cardinality_{cardinality:?}_batch_{batch_size:?}_null_rate_{null_rate:?}" + ); + c.bench_function(&multi_batch_name, |b| { + b.iter_batched( + || { + // setup - create 3 batches to simulate multiple record batches + let gv = generate_group_values(group_type); + let batch1 = col_ref.clone(); + let batch2 = col_ref.clone(); + let batch3 = col_ref.clone(); + (gv, batch1, batch2, batch3) + }, + |(mut group_values, batch1, batch2, batch3)| { + // simulate realistic aggregation flow: + // multiple intern calls (one per record batch) + // followed by emit + let mut groups = Vec::new(); + + group_values.intern(&[batch1], &mut groups).unwrap(); + groups.clear(); + group_values.intern(&[batch2], &mut groups).unwrap(); + groups.clear(); + group_values.intern(&[batch3], &mut groups).unwrap(); + + // emit once at the end like the real aggregation flow + group_values.emit(EmitTo::All).unwrap(); + }, + criterion::BatchSize::SmallInput, + ); + }); + } + } + } + } +} + +fn bench_repeated_intern_prefab_cols(c: &mut Criterion) { + let cardinality = Cardinality::Small; + let batch_size = BatchSize::Large; + let null_rate = NullRate::Low; + let group_types = [GroupType::GroupValueRows, GroupType::Dictionary]; + + for group_type in &group_types { + let group_type = group_type.clone(); + let string_vec = create_batch(&batch_size, &cardinality); + let nullable_values = introduce_nulls(string_vec, &null_rate); + let col_ref = match group_type { + GroupType::Dictionary | GroupType::GroupValueRows => { + strings_to_dict_array(nullable_values.clone()) + } + }; + + // Build once outside the benchmark iteration and reuse in intern calls. + let arr1 = col_ref.clone(); + let arr2 = col_ref.clone(); + let arr3 = col_ref.clone(); + let arr4 = col_ref.clone(); + + let group_name = format!( + "repeated_intern/{group_type:?}_cardinality_{cardinality:?}_batch_{batch_size:?}_null_rate_{null_rate:?}" + ); + c.bench_function(&group_name, |b| { + b.iter_batched( + || generate_group_values(&group_type), + |mut group_values| { + let mut groups = Vec::new(); + + group_values + .intern(std::slice::from_ref(&arr1), &mut groups) + .unwrap(); + groups.clear(); + group_values + .intern(std::slice::from_ref(&arr2), &mut groups) + .unwrap(); + groups.clear(); + group_values + .intern(std::slice::from_ref(&arr3), &mut groups) + .unwrap(); + groups.clear(); + group_values + .intern(std::slice::from_ref(&arr4), &mut groups) + .unwrap(); + }, + criterion::BatchSize::SmallInput, + ); + }); + } +} + +criterion_group!( + benches, + bench_single_column_group_values, + bench_repeated_intern_prefab_cols +); +criterion_main!(benches); diff --git a/datafusion/physical-plan/profile.json.gz b/datafusion/physical-plan/profile.json.gz new file mode 100644 index 0000000000000..6b0a0bc551b6a Binary files /dev/null and b/datafusion/physical-plan/profile.json.gz differ diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index 2f3b1a19e7d73..041e4cdbb4c38 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -30,8 +30,8 @@ use datafusion_expr::EmitTo; pub mod multi_group_by; -mod row; -mod single_group_by; +pub mod row; +pub mod single_group_by; use datafusion_physical_expr::binary_map::OutputType; use multi_group_by::GroupValuesColumn; use row::GroupValuesRows; @@ -41,7 +41,8 @@ pub(crate) use single_group_by::primitive::HashValue; use crate::aggregates::{ group_values::single_group_by::{ boolean::GroupValuesBoolean, bytes::GroupValuesBytes, - bytes_view::GroupValuesBytesView, primitive::GroupValuesPrimitive, + bytes_view::GroupValuesBytesView, dictionary::GroupValuesDictionary, + primitive::GroupValuesPrimitive, }, order::GroupOrdering, }; @@ -196,6 +197,56 @@ pub fn new_group_values( DataType::Boolean => { return Ok(Box::new(GroupValuesBoolean::new())); } + DataType::Dictionary(key_type, value_type) => { + if supported_single_dictionary_value(value_type) { + return match key_type.as_ref() { + // TODO: turn this into a macro + DataType::Int8 => { + Ok(Box::new(GroupValuesDictionary::< + arrow::datatypes::Int8Type, + >::new(value_type))) + } + DataType::Int16 => { + Ok(Box::new(GroupValuesDictionary::< + arrow::datatypes::Int16Type, + >::new(value_type))) + } + DataType::Int32 => { + Ok(Box::new(GroupValuesDictionary::< + arrow::datatypes::Int32Type, + >::new(value_type))) + } + DataType::Int64 => { + Ok(Box::new(GroupValuesDictionary::< + arrow::datatypes::Int64Type, + >::new(value_type))) + } + DataType::UInt8 => { + Ok(Box::new(GroupValuesDictionary::< + arrow::datatypes::UInt8Type, + >::new(value_type))) + } + DataType::UInt16 => { + Ok(Box::new(GroupValuesDictionary::< + arrow::datatypes::UInt16Type, + >::new(value_type))) + } + DataType::UInt32 => { + Ok(Box::new(GroupValuesDictionary::< + arrow::datatypes::UInt32Type, + >::new(value_type))) + } + DataType::UInt64 => { + Ok(Box::new(GroupValuesDictionary::< + arrow::datatypes::UInt64Type, + >::new(value_type))) + } + _ => Err(datafusion_common::DataFusionError::NotImplemented( + format!("Unsupported dictionary key type: {key_type:?}",), + )), + }; + } + } _ => {} } } @@ -207,6 +258,19 @@ pub fn new_group_values( Ok(Box::new(GroupValuesColumn::::try_new(schema)?)) } } else { + // TODO: add specialized implementation for dictionary encoding columns for 2+ group by columns case Ok(Box::new(GroupValuesRows::try_new(schema)?)) } } + +fn supported_single_dictionary_value(t: &DataType) -> bool { + matches!( + t, + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::Utf8View + | DataType::BinaryView + ) +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/dictionary.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/dictionary.rs new file mode 100644 index 0000000000000..fa66540908b0d --- /dev/null +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/dictionary.rs @@ -0,0 +1,1488 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::aggregates::group_values::GroupValues; +use crate::hash_utils::RandomState; +use arrow::array::{ + Array, ArrayRef, BinaryArray, BinaryBuilder, BinaryViewArray, BinaryViewBuilder, + DictionaryArray, LargeBinaryArray, LargeBinaryBuilder, LargeStringArray, + LargeStringBuilder, PrimitiveArray, PrimitiveBuilder, StringArray, StringBuilder, + StringViewArray, StringViewBuilder, UInt64Array, +}; +use arrow::datatypes::{ArrowDictionaryKeyType, ArrowNativeType, DataType}; +use datafusion_common::Result; +use datafusion_common::hash_utils::create_hashes; +use datafusion_expr::EmitTo; +use std::collections::HashMap; +use std::marker::PhantomData; +use std::sync::Arc; + +pub struct GroupValuesDictionary { + /* + We know that every single &[ArrayRef] that is passed in is a dictionary array + + self.inter() will be called across record batches, this means that + we cannot rely on a trivial approach where we just store the dictionary mapping as it is + + + + Possible soluitions: + 1A. store a hashmap that last across .intern() calls + | cast cols:&[ArrayRef] to generic Dictionary array, check if weve already stored its values (unique values) before + | if we have check the current mapping internally and update the groups array with the initial mapping for this value + | if it does not exist already (hashmap.size) is the group_id for this element + 1B. how do we retrieve the dictionary encoded array this function expects? + | NOTE: emit returns one value per group not one value per row. The group values are the distinct values in the order they were first seen — not the full expanded key array [one per group index] + | keep a value_order array that stores unique elements the first time their seen, this maintains order for self.emit() + | the return type of the array self.emit() returns is based on the value type of the dictionary, may be smart to have an internal Group values that handles that logic + | + + Possible optimizations (Ignore for now) + 2A. dont rely directly in a hashmap we could hash all of the values at once and then as we iterate the keys array refer to them as the values are assumed to be smaller than the keys + | at the start of self.intern hash every value in the dictionary + | iterate through the keys section of dict_array + | for each key check its corresponding value and if it exist + + + */ + // stores the order new unique elements are seen for self.emit() + seen_elements: Vec>, // Box doesnt provide the flexibility of building partition arrays that wed need to support emit::First(N) + value_dt: DataType, + _phantom: PhantomData, + // keeps track of which values weve already seen. stored as -> + unique_dict_value_mapping: HashMap)>>, + // fixed seeds ensure consistent hashing across GroupValuesDictionary instances + // this is critical for correct behavior in multi-partition aggregation where + // partial phase emits are re-interned by the final phase + random_state: RandomState, + null_group_id: Option, // cache the group id for nulls since they all map to the same group +} + +impl GroupValuesDictionary { + pub fn new(data_type: &DataType) -> Self { + Self { + seen_elements: Vec::new(), + unique_dict_value_mapping: HashMap::new(), + value_dt: data_type.clone(), + _phantom: PhantomData, + random_state: RandomState::with_seed(0), + null_group_id: None, + } + } + fn compute_value_hashes(&mut self, values: &ArrayRef) -> Result> { + let mut hashes = vec![0u64; values.len()]; + create_hashes([Arc::clone(values)], &self.random_state, &mut hashes)?; + Ok(hashes) + } + /*fn keys_to_usize(keys: &PrimitiveArray) -> Vec> { + keys.iter() + .map(|k| k.map(|v| v.to_usize().unwrap())) + .collect() + }*/ + + fn get_raw_bytes(values: &ArrayRef, index: usize) -> &[u8] { + match values.data_type() { + DataType::Utf8 => values + .as_any() + .downcast_ref::() + .expect("Expected StringArray") + .value(index) + .as_bytes(), + DataType::LargeUtf8 => values + .as_any() + .downcast_ref::() + .expect("Expected LargeStringArray") + .value(index) + .as_bytes(), + DataType::Utf8View => values + .as_any() + .downcast_ref::() + .expect("Expected StringViewArray") + .value(index) + .as_bytes(), + DataType::Binary => values + .as_any() + .downcast_ref::() + .expect("Expected BinaryArray") + .value(index), + DataType::LargeBinary => values + .as_any() + .downcast_ref::() + .expect("Expected LargeBinaryArray") + .value(index), + DataType::BinaryView => values + .as_any() + .downcast_ref::() + .expect("Expected BinaryViewArray") + .value(index), + other => unimplemented!("get_raw_bytes not implemented for {other:?}"), + } + } + + fn sentinel_repr(dt: &DataType) -> Vec { + match dt { + // 0xFF bytes cannot appear in valid UTF8 so no risk of collision with real values + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { + vec![0xFF, 0xFF, 0xFF, 0xFF] + } + // TODO: binary types need a better sentinel + DataType::Binary | DataType::LargeBinary | DataType::BinaryView => { + vec![0xFF, 0xFF, 0xFF, 0xFF] + } + // for primitives use a byte sequence that is a different length than the native type + // a real i8 is always exactly 1 byte so 2 bytes can never match a real value + other => unimplemented!("sentinel_repr not implemented for {other:?}"), + } + } + + #[inline] + fn get_null_group_id(&mut self) -> usize { + if let Some(group_id) = self.null_group_id { + group_id + } else { + if let Some(entries) = self + .unique_dict_value_mapping + .get(&((usize::MAX - 1) as u64)) + { + entries[0].0 + } else { + // first time we've seen a null + let new_group_id = self.seen_elements.len(); + let raw_bytes = Self::sentinel_repr(&self.value_dt); + self.seen_elements.push(raw_bytes.clone()); + self.unique_dict_value_mapping + .insert((usize::MAX - 1) as u64, vec![(new_group_id, raw_bytes)]); + self.null_group_id = Some(new_group_id); // cache it + new_group_id + } + } + } + fn transform_into_array(&self, raw: &[Vec]) -> Result { + let sentinel = Self::sentinel_repr(&self.value_dt); + match &self.value_dt { + DataType::Utf8 => { + let mut builder = StringBuilder::new(); + for raw_bytes in raw { + if raw_bytes == &sentinel { + builder.append_null(); + } else { + let s = std::str::from_utf8(raw_bytes).map_err(|e| { + datafusion_common::DataFusionError::Internal(format!( + "Invalid utf8 in seen_elements: {e}" + )) + })?; + builder.append_value(s); + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + DataType::LargeUtf8 => { + let mut builder = LargeStringBuilder::new(); + for raw_bytes in raw { + if raw_bytes == &sentinel { + builder.append_null(); + } else { + let s = std::str::from_utf8(raw_bytes).map_err(|e| { + datafusion_common::DataFusionError::Internal(format!( + "Invalid utf8 in seen_elements: {e}" + )) + })?; + builder.append_value(s); + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + DataType::Utf8View => { + let mut builder = StringViewBuilder::new(); + for raw_bytes in raw { + if raw_bytes == &sentinel { + builder.append_null(); + } else { + let s = std::str::from_utf8(raw_bytes).map_err(|e| { + datafusion_common::DataFusionError::Internal(format!( + "Invalid utf8 in seen_elements: {e}" + )) + })?; + builder.append_value(s); + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + DataType::Binary => { + let mut builder = BinaryBuilder::new(); + for raw_bytes in raw { + if raw_bytes == &sentinel { + builder.append_null(); + } else { + builder.append_value(raw_bytes); + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + DataType::LargeBinary => { + let mut builder = LargeBinaryBuilder::new(); + for raw_bytes in raw { + if raw_bytes == &sentinel { + builder.append_null(); + } else { + builder.append_value(raw_bytes); + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + DataType::BinaryView => { + let mut builder = BinaryViewBuilder::new(); + for raw_bytes in raw { + if raw_bytes == &sentinel { + builder.append_null(); + } else { + builder.append_value(raw_bytes); + } + } + Ok(Arc::new(builder.finish()) as ArrayRef) + } + other => Err(datafusion_common::DataFusionError::NotImplemented(format!( + "transform_into_array not implemented for {other:?}" + ))), + } + } + fn normalize_dict_array( + values: &ArrayRef, + key_array: &PrimitiveArray, + ) -> (ArrayRef, Vec>) { + // maps old value index -> new canonical index + let mut old_to_new: Vec> = vec![None; values.len()]; + let mut canonical_indices: Vec = Vec::new(); + + for (i, slot) in old_to_new.iter_mut().enumerate() { + if values.is_null(i) { + continue; + } + let raw = Self::get_raw_bytes(values, i); + let canonical = canonical_indices + .iter() + .position(|&j| Self::get_raw_bytes(values, j) == raw); + if let Some(idx) = canonical { + *slot = Some(idx); + } else { + *slot = Some(canonical_indices.len()); + canonical_indices.push(i); + } + } + // build new deduplicated values array using take + let indices = UInt64Array::from( + canonical_indices + .iter() + .map(|&i| i as u64) + .collect::>(), + ); + let new_values = arrow::compute::take(values.as_ref(), &indices, None).unwrap(); + + // remap keys + let new_keys: Vec> = (0..key_array.len()) + .map(|i| { + if key_array.is_null(i) { + None + } else { + let old_key = key_array.value(i).to_usize().unwrap(); + old_to_new[old_key] + } + }) + .collect(); + + (new_values, new_keys) + } +} + +impl GroupValues for GroupValuesDictionary { + // not really sure how to return the size of strings and binary values so this is a best effort approach + fn size(&self) -> usize { + size_of::() + + self + .seen_elements + .iter() + .map(|b| b.capacity()) + .sum::() + + self.unique_dict_value_mapping.capacity() + * size_of::<(u64, Vec<(usize, Vec)>)>() + } + fn len(&self) -> usize { + self.seen_elements.len() + } + fn is_empty(&self) -> bool { + self.seen_elements.is_empty() + } + fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()> { + if cols.len() != 1 { + return Err(datafusion_common::DataFusionError::Internal( + "GroupValuesDictionary only supports single column group by".to_string(), + )); + } + let array = Arc::clone(&cols[0]); + groups.clear(); // zero out buffer + let dict_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + datafusion_common::DataFusionError::Internal(format!( + "GroupValuesDictionary expected DictionaryArray but got {:?}", + array.data_type() + )) + })?; + + // pre-allocate space for seen_elements using occupancy + // occupancy count gives us the number of truly distinct non-null values in this batch + let occupied = dict_array.occupancy().count_set_bits(); + self.seen_elements.reserve(occupied); + + let values = dict_array.values(); + let key_array = dict_array.keys(); + if key_array.is_empty() { + return Ok(()); // nothing to intern, just return early + } + let (values, keys_as_usize) = Self::normalize_dict_array(values, key_array); + let values = &values; + // compute hashes for all values in the values array upfront + // value_hashes[i] corresponds to values[i] + let value_hashes = self.compute_value_hashes(values)?; + + // convert key array to Vec for cheap indexed access + // avoids repeated .value(i).to_usize() calls in the hot loop + //let keys_as_usize = Self::keys_to_usize(key_array); + + // Pass 1: iterate values array (d iterations) - build a mapping of value hash -> group id for all unique values in the dictionary + // this allows us to do a single hashmap lookup per key in the hot loop instead of + let mut key_to_group: Vec> = vec![None; values.len()]; + for value_idx in 0..values.len() { + if values.is_null(value_idx) { + // this will be handled in phase 2 + continue; + } + let hash = value_hashes[value_idx]; + if let Some(entries) = self.unique_dict_value_mapping.get(&hash) { + let raw = Self::get_raw_bytes(values, value_idx); + if let Some((group_id, _)) = entries + .iter() + .find(|(_, stored_bytes)| raw == stored_bytes.as_slice()) + { + key_to_group[value_idx] = Some(*group_id); + continue; + } + } + } + // Pass 2: iterate keys array (n iterations) - + // only d insertions at most, repeated work is cached + for key_opt in &keys_as_usize { + let group_id = match key_opt { + None => self.get_null_group_id(), + Some(key) => { + if let Some(group_id) = key_to_group[*key] { + group_id + } else if values.is_null(*key) { + let gid = self.get_null_group_id(); + key_to_group[*key] = Some(gid); // cache it for future keys that point to null values + gid + } else { + // new unique value we havent seen before, assign a new group id and store it in the map + let new_group_id = self.seen_elements.len(); + let raw_bytes = Self::get_raw_bytes(values, *key).to_vec(); + self.seen_elements.push(raw_bytes.clone()); + if let Some(entries) = + self.unique_dict_value_mapping.get_mut(&value_hashes[*key]) + { + entries.push((new_group_id, raw_bytes)); + } else { + self.unique_dict_value_mapping.insert( + value_hashes[*key], + vec![(new_group_id, raw_bytes)], + ); + } + key_to_group[*key] = Some(new_group_id); + new_group_id + } + } + }; + groups.push(group_id); + } + Ok(()) + } + // This needs to return a dictionary encoded array + fn emit(&mut self, emit_to: EmitTo) -> Result> { + let elements_to_emit = match emit_to { + EmitTo::All => { + self.null_group_id = None; + self.unique_dict_value_mapping.clear(); + std::mem::take(&mut self.seen_elements) + } + EmitTo::First(n) => { + let first_n = self.seen_elements.drain(..n).collect::>(); + // update null_group_id if the null group was in the first n + if let Some(null_id) = self.null_group_id { + if null_id < n { + self.null_group_id = None; + } else { + self.null_group_id = Some(null_id - n); + } + } + // shift all remaining group indices down by n in the map + self.unique_dict_value_mapping.retain(|_, entries| { + entries.retain_mut(|(group_id, _)| { + if *group_id < n { + false + } else { + *group_id -= n; + true + } + }); + !entries.is_empty() + }); + first_n + } + }; + + let n = elements_to_emit.len(); + let values_array = self.transform_into_array(&elements_to_emit)?; + + // reconstruct dictionary keys 0..n + let mut keys_builder = PrimitiveBuilder::::with_capacity(n); + for i in 0..n { + if Some(i) == self.null_group_id { + keys_builder.append_null(); + } else { + keys_builder.append_value(K::Native::usize_as(i)); + } + } + let dict_array = + DictionaryArray::::try_new(keys_builder.finish(), values_array)?; + Ok(vec![Arc::new(dict_array)]) + } + fn clear_shrink(&mut self, num_rows: usize) { + self.seen_elements.clear(); + self.seen_elements.shrink_to(num_rows); + self.null_group_id = None; + self.unique_dict_value_mapping.clear(); + self.unique_dict_value_mapping.shrink_to(num_rows); + } +} + +#[cfg(test)] +mod group_values_trait_test { + use super::*; + use arrow::array::{DictionaryArray, Int32Array, StringArray, UInt8Array}; + use arrow::datatypes::{Int32Type, UInt8Type}; + use std::sync::Arc; + + fn create_dict_array(keys: Vec, values: Vec<&str>) -> ArrayRef { + let values = StringArray::from(values); + let keys = UInt8Array::from(keys); + Arc::new(DictionaryArray::::try_new(keys, Arc::new(values)).unwrap()) + } + + // Helper function to validate that emitted arrays are DictionaryArrays with the correct type + fn assert_emitted_is_dict_array(result: &[ArrayRef]) { + assert_eq!(result.len(), 1, "Expected exactly one array in emit result"); + let array = &result[0]; + + match array.data_type() { + DataType::Dictionary(key_type, value_type) => { + // Verify it's the expected key type (UInt8 in our tests) + match key_type.as_ref() { + DataType::UInt8 => {} + other => panic!("Expected UInt8 key type, got {other:?}"), + } + + // Verify it's the expected value type (Utf8 in our tests) + match value_type.as_ref() { + DataType::Utf8 => {} + other => panic!("Expected Utf8 value type, got {other:?}"), + } + } + other => panic!("Expected DictionaryArray, got {other:?}"), + } + + // Now verify we can actually downcast to the expected types + let dict_array = array + .as_any() + .downcast_ref::>() + .expect("Failed to downcast to DictionaryArray"); + + let _values = dict_array + .values() + .as_any() + .downcast_ref::() + .expect("Dictionary values should be StringArray"); + } + + mod basic_functionality { + use super::*; + + #[test] + fn test_single_group_all_same_values() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = create_dict_array(vec![0, 0, 0], vec!["red"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + + assert_eq!(groups_vector.len(), 3); + assert_eq!(group_values_trait_obj.len(), 1); + assert!(!group_values_trait_obj.is_empty()); + } + + #[test] + fn test_multiple_groups() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = + create_dict_array(vec![0, 1, 0, 2, 1], vec!["red", "blue", "green"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 3); + assert_eq!(groups_vector.len(), 5); + } + + #[test] + fn test_multiple_groups_with_nulls() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let keys = UInt8Array::from(vec![Some(0), None, Some(1), None, Some(0)]); + let values = StringArray::from(vec!["red", "blue"]); + let dict_array = Arc::new( + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(), + ) as ArrayRef; + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + + assert_eq!(groups_vector.len(), 5); + assert_eq!(group_values_trait_obj.len(), 3); + assert_eq!(groups_vector[1], groups_vector[3]); + assert_eq!(groups_vector[0], groups_vector[4]); + assert_ne!(groups_vector[0], groups_vector[1]); + assert_ne!(groups_vector[2], groups_vector[1]); + } + + #[test] + fn test_all_different_values() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = create_dict_array( + vec![0, 1, 2, 3, 4], + vec!["apple", "banana", "cherry", "date", "elderberry"], + ); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + + assert_eq!(group_values_trait_obj.len(), 5); + assert_eq!(groups_vector.len(), 5); + } + } + + mod edge_cases { + use super::*; + + #[test] + fn test_empty_batch() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = create_dict_array(vec![], vec!["red"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + + assert_eq!(group_values_trait_obj.len(), 0); + assert_eq!(groups_vector.len(), 0); + assert!(group_values_trait_obj.is_empty()); + } + + #[test] + fn test_single_row() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = create_dict_array(vec![0], vec!["apple"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 1); + assert_eq!(groups_vector.len(), 1); + assert_eq!(groups_vector[0], 0); + } + + #[test] + fn test_repeated_pattern() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = + create_dict_array(vec![0, 1, 2, 0, 1, 2, 0, 1, 2], vec!["a", "b", "c"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + + assert_eq!(group_values_trait_obj.len(), 3); + assert_eq!(groups_vector.len(), 9); + } + + #[test] + fn test_null_heavy_mixed_values() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let keys = UInt8Array::from(vec![ + None, + None, + Some(0u8), + None, + Some(1u8), + None, + Some(0u8), + Some(1u8), + None, + Some(2u8), + None, + ]); + let values = StringArray::from(vec!["red", "blue", "green"]); + let dict_array = Arc::new( + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(), + ) as ArrayRef; + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + + // groups are: null + red + blue + green + assert_eq!(group_values_trait_obj.len(), 4); + assert_eq!(groups_vector.len(), 11); + + // all null rows should map to one group + let null_group = groups_vector[0]; + assert_eq!(groups_vector[1], null_group); + assert_eq!(groups_vector[3], null_group); + assert_eq!(groups_vector[5], null_group); + assert_eq!(groups_vector[8], null_group); + assert_eq!(groups_vector[10], null_group); + + // repeated non-null values should map consistently + assert_eq!(groups_vector[2], groups_vector[6]); // red + assert_eq!(groups_vector[4], groups_vector[7]); // blue + + // null and non-null groups should remain distinct + assert_ne!(groups_vector[2], null_group); + assert_ne!(groups_vector[4], null_group); + assert_ne!(groups_vector[9], null_group); + } + + #[test] + fn test_null_group_stable_across_batches_with_reordered_dict() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let batch1_keys = UInt8Array::from(vec![None, Some(0u8), None, Some(1u8)]); + let batch1_values = StringArray::from(vec!["a", "b"]); + let batch1 = Arc::new( + DictionaryArray::::try_new( + batch1_keys, + Arc::new(batch1_values), + ) + .unwrap(), + ) as ArrayRef; + + let mut groups_vector1 = Vec::new(); + group_values_trait_obj + .intern(&[batch1], &mut groups_vector1) + .unwrap(); + + assert_eq!(group_values_trait_obj.len(), 3); // null + a + b + let null_group = groups_vector1[0]; + let a_group = groups_vector1[1]; + let b_group = groups_vector1[3]; + assert_eq!(groups_vector1[2], null_group); + + // Same logical values, but dictionary value ordering changed: ["a", "c", "b"] + let batch2_keys = + UInt8Array::from(vec![Some(0u8), None, Some(2u8), None, Some(1u8)]); + let batch2_values = StringArray::from(vec!["a", "c", "b"]); + let batch2 = Arc::new( + DictionaryArray::::try_new( + batch2_keys, + Arc::new(batch2_values), + ) + .unwrap(), + ) as ArrayRef; + + let mut groups_vector2 = Vec::new(); + group_values_trait_obj + .intern(&[batch2], &mut groups_vector2) + .unwrap(); + + assert_eq!(group_values_trait_obj.len(), 4); // adds only new value "c" + assert_eq!(groups_vector2[0], a_group); // "a" should reuse prior group + assert_eq!(groups_vector2[1], null_group); + assert_eq!(groups_vector2[3], null_group); + assert_eq!(groups_vector2[2], b_group); // "b" should reuse prior group + assert_ne!(groups_vector2[4], null_group); // "c" is not null + assert_ne!(groups_vector2[4], a_group); + assert_ne!(groups_vector2[4], b_group); + } + + #[test] + fn test_null_values_in_values_array() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + // Reproduce Sql::aggregates::basic::count_distinct_dictionary_mixed_values + let keys = UInt8Array::from(vec![0, 1, 2, 0, 1, 3]); + let values = StringArray::from(vec![None, Some("abc"), Some("def"), None]); + let dict = Arc::new( + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(), + ) as ArrayRef; + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict], &mut groups_vector) + .unwrap(); + + assert_eq!(group_values_trait_obj.len(), 3); + assert_eq!(groups_vector.len(), 6); + assert_eq!(groups_vector[0], groups_vector[3]); // both null + assert_eq!(groups_vector[0], groups_vector[5]); // both null + assert_eq!(groups_vector[1], groups_vector[4]); // both "abc" + assert_ne!(groups_vector[1], groups_vector[2]); // "abc" != "def" + assert_ne!(groups_vector[0], groups_vector[1]); // null != "abc" + + // emit and verify output + let result = group_values_trait_obj.emit(EmitTo::All).unwrap(); + assert_eq!(result.len(), 1); // single column + + let emitted = result[0] + .as_any() + .downcast_ref::>() + .expect("Expected DictionaryArray"); + // should have 3 entries - null, "abc", "def" + assert_eq!(emitted.values().len(), 3); + + // verify the values array has correct nulls + assert!(emitted.values().is_null(groups_vector[0])); // null group should be null + assert!(!emitted.values().is_null(groups_vector[1])); // "abc" should not be null + assert!(!emitted.values().is_null(groups_vector[2])); // "def" should not be null + + // verify string values + let string_values = emitted + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(string_values.value(groups_vector[1]), "abc"); + assert_eq!(string_values.value(groups_vector[2]), "def"); + + // group_values should be empty after EmitTo::All + assert!(group_values_trait_obj.is_empty()); + } + } + + mod multi_column { + use super::*; + + #[test] + fn test_multiple_columns_passed() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array1 = create_dict_array(vec![0, 1, 0], vec!["red", "blue"]); + + let dict_array2 = create_dict_array(vec![0, 0, 1], vec!["x", "y"]); + + let mut groups_vector = Vec::new(); + let result = group_values_trait_obj + .intern(&[dict_array1, dict_array2], &mut groups_vector); + assert!( + result.is_err(), + "Should error when multiple columns are passed (only single column supported)" + ); + } + } + + mod consecutive_batches { + use super::*; + + #[test] + fn test_consecutive_batches_then_emit() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let batch1 = create_dict_array(vec![0, 1, 0], vec!["red", "blue"]); + + let mut groups_vector1 = Vec::new(); + group_values_trait_obj + .intern(&[batch1], &mut groups_vector1) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 2); + assert_eq!(groups_vector1.len(), 3); + + let batch2 = create_dict_array(vec![0, 1, 2], vec!["green", "red", "blue"]); + + let mut groups_vector2 = Vec::new(); + group_values_trait_obj + .intern(&[batch2], &mut groups_vector2) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 3); + assert_eq!(groups_vector2.len(), 3); + + let result = group_values_trait_obj.emit(EmitTo::All).unwrap(); + assert_emitted_is_dict_array(&result); + assert!(group_values_trait_obj.is_empty()); + } + + #[test] + fn test_three_consecutive_batches_with_partial_emit() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let batch1 = create_dict_array(vec![0, 1], vec!["a", "b"]); + let mut groups_vector1 = Vec::new(); + group_values_trait_obj + .intern(&[batch1], &mut groups_vector1) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 2); + + let batch2 = create_dict_array(vec![0, 1, 2], vec!["a", "b", "c"]); + let mut groups_vector2 = Vec::new(); + group_values_trait_obj + .intern(&[batch2], &mut groups_vector2) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 3); + + let batch3 = create_dict_array( + vec![0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 2, 1, 2], + vec!["c", "d", "e"], + ); + let mut groups_vector3 = Vec::new(); + group_values_trait_obj + .intern(&[batch3], &mut groups_vector3) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 5); + + let result = group_values_trait_obj.emit(EmitTo::All).unwrap(); + assert_emitted_is_dict_array(&result); + assert!(group_values_trait_obj.is_empty()); + result.iter().for_each(|array| { + let dict_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + let values = dict_array.values(); + let string_array = values.as_any().downcast_ref::().unwrap(); + let value_strings: Vec = (0..string_array.len()) + .map(|i| string_array.value(i).to_string()) + .collect(); + let unexpected_values: Vec<&String> = value_strings + .iter() + .filter(|v| { + **v != "a" && **v != "b" && **v != "c" && **v != "d" && **v != "e" + }) + .collect(); + assert!( + unexpected_values.is_empty(), + "Emitted unexpected values: {unexpected_values:#?}" + ); + }); + } + } + + mod state_management { + use super::*; + + #[test] + fn test_initial_state_is_empty() { + let group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + assert!(group_values_trait_obj.is_empty()); + assert_eq!(group_values_trait_obj.len(), 0); + } + + #[test] + fn test_size_grows_after_intern() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let initial_size = group_values_trait_obj.size(); + + let dict_array1 = + create_dict_array(vec![0, 1, 0, 1, 2], vec!["red", "blue", "green"]); + + let mut groups_vector1 = Vec::new(); + group_values_trait_obj + .intern(&[dict_array1], &mut groups_vector1) + .unwrap(); + + let size_after_first_intern = group_values_trait_obj.size(); + assert!( + size_after_first_intern > initial_size, + "Size should grow after first intern" + ); + + let dict_array2 = create_dict_array( + vec![0, 1, 2, 3, 4], + vec!["yellow", "orange", "purple", "pink", "brown"], + ); + + let mut groups_vector2 = Vec::new(); + group_values_trait_obj + .intern(&[dict_array2], &mut groups_vector2) + .unwrap(); + + let size_after_second_intern = group_values_trait_obj.size(); + assert!( + size_after_second_intern > size_after_first_intern, + "Size should grow after second intern with new items" + ); + + let dict_array3 = + create_dict_array(vec![0, 1, 2], vec!["red", "blue", "green"]); + + let mut groups_vector3 = Vec::new(); + group_values_trait_obj + .intern(&[dict_array3], &mut groups_vector3) + .unwrap(); + + let size_after_third_intern = group_values_trait_obj.size(); + assert_eq!( + size_after_third_intern, size_after_second_intern, + "Size should not grow when interning previously seen values" + ); + + let result = group_values_trait_obj.emit(EmitTo::All).unwrap(); + assert_emitted_is_dict_array(&result); + assert!( + group_values_trait_obj.is_empty(), + "Should be empty after emit all" + ); + } + + #[test] + fn test_clear_shrink_resets_state() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = create_dict_array(vec![0, 1, 0], vec!["red", "blue"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 2); + + group_values_trait_obj.clear_shrink(100); + assert_eq!(group_values_trait_obj.len(), 0); + assert!(group_values_trait_obj.is_empty()); + } + + #[test] + fn test_clear_shrink_with_zero() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = + create_dict_array(vec![0, 1, 2, 1, 0], vec!["red", "blue", "green"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + + group_values_trait_obj.clear_shrink(0); + assert!(group_values_trait_obj.is_empty()); + assert_eq!(group_values_trait_obj.len(), 0); + } + + #[test] + fn test_emit_all_clears_state() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = create_dict_array(vec![0, 1, 0], vec!["red", "blue"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 2); + + let result = group_values_trait_obj.emit(EmitTo::All).unwrap(); + assert_emitted_is_dict_array(&result); + + assert!(group_values_trait_obj.is_empty()); + assert_eq!(group_values_trait_obj.len(), 0); + } + + #[test] + fn test_emit_first_n() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = + create_dict_array(vec![0, 1, 2], vec!["apple", "banana", "cherry"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 3); + + let result = group_values_trait_obj.emit(EmitTo::First(1)).unwrap(); + assert_emitted_is_dict_array(&result); + assert_eq!(group_values_trait_obj.len(), 2); + + let result = group_values_trait_obj.emit(EmitTo::First(2)).unwrap(); + assert_emitted_is_dict_array(&result); + assert!(group_values_trait_obj.is_empty()); + } + + #[test] + fn test_complex_emit_flow_with_multiple_intern() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let batch1 = create_dict_array(vec![0, 1, 2, 3], vec!["a", "b", "c", "d"]); + let mut groups_vector1 = Vec::new(); + group_values_trait_obj + .intern(&[batch1], &mut groups_vector1) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 4); + + let result = group_values_trait_obj.emit(EmitTo::First(2)).unwrap(); + assert_emitted_is_dict_array(&result); + assert_eq!( + group_values_trait_obj.len(), + 2, + "After emitting 2, should have 2 left (c, d)" + ); + + let batch2 = create_dict_array(vec![0, 1, 2], vec!["a", "b", "e"]); + let mut groups_vector2 = Vec::new(); + group_values_trait_obj + .intern(&[batch2], &mut groups_vector2) + .unwrap(); + assert_eq!( + group_values_trait_obj.len(), + 5, + "After second intern: 2 remaining (c,d) + 3 new from batch2 (a,b,e) = 5 groups" + ); + + let result = group_values_trait_obj.emit(EmitTo::First(1)).unwrap(); + assert_emitted_is_dict_array(&result); + assert_eq!( + group_values_trait_obj.len(), + 4, + "After emitting 1 more (c), should have 4 left (d,a,b,e)" + ); + + let batch3 = create_dict_array(vec![0, 1, 2], vec!["a", "f", "g"]); + let mut groups_vector3 = Vec::new(); + group_values_trait_obj + .intern(&[batch3], &mut groups_vector3) + .unwrap(); + assert_eq!( + group_values_trait_obj.len(), + 6, + "After third intern: 4 remaining (d,a,b,e) + 2 new from batch3 (f,g) = 6 groups (a already exists)" + ); + + let result = group_values_trait_obj.emit(EmitTo::All).unwrap(); + assert_emitted_is_dict_array(&result); + assert!( + group_values_trait_obj.is_empty(), + "After emitting all, should be empty" + ); + assert_eq!(group_values_trait_obj.len(), 0); + } + } + + mod data_correctness { + use super::*; + + #[test] + fn test_group_assignment_order() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = + create_dict_array(vec![0, 1, 0, 2, 1], vec!["red", "blue", "green"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + + assert_eq!(groups_vector.len(), 5); + assert_eq!(groups_vector[0], groups_vector[2]); + assert_eq!(groups_vector[1], groups_vector[4]); + } + + #[test] + fn test_groups_vector_correctness_first_appearance() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = + create_dict_array(vec![0, 1, 2, 0, 1, 2], vec!["x", "y", "z"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + + assert_eq!(groups_vector.len(), 6); + let group_x = groups_vector[0]; + let group_y = groups_vector[1]; + let group_z = groups_vector[2]; + + assert_eq!( + groups_vector[3], group_x, + "Fourth row should match first row group" + ); + assert_eq!( + groups_vector[4], group_y, + "Fifth row should match second row group" + ); + assert_eq!( + groups_vector[5], group_z, + "Sixth row should match third row group" + ); + } + + #[test] + fn test_groups_vector_sequential_assignment() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = + create_dict_array(vec![2, 0, 1], vec!["first", "second", "third"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + + assert_eq!(groups_vector.len(), 3); + assert_eq!( + group_values_trait_obj.len(), + 3, + "Should have exactly 3 unique groups" + ); + let all_different = groups_vector[0] != groups_vector[1] + && groups_vector[1] != groups_vector[2] + && groups_vector[0] != groups_vector[2]; + assert!( + all_different, + "All rows should have different group assignments" + ); + } + + #[test] + fn test_emit_partial_preserves_state() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let dict_array = + create_dict_array(vec![0, 1, 2, 3], vec!["a", "b", "c", "d"]); + + let mut groups_vector = Vec::new(); + group_values_trait_obj + .intern(&[dict_array], &mut groups_vector) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 4); + + let emitted = group_values_trait_obj.emit(EmitTo::First(2)).unwrap(); + assert_emitted_is_dict_array(&emitted); + assert_eq!( + group_values_trait_obj.len(), + 2, + "Should have 2 groups remaining after partial emit" + ); + + let emitted_remaining = group_values_trait_obj.emit(EmitTo::All).unwrap(); + assert_emitted_is_dict_array(&emitted_remaining); + assert!( + group_values_trait_obj.is_empty(), + "Should be empty after final emit" + ); + } + + #[test] + fn test_emit_restores_intern_ability() { + let mut group_values_trait_obj = + GroupValuesDictionary::::new(&DataType::Utf8); + let batch1 = create_dict_array(vec![0, 1], vec!["alpha", "beta"]); + + let mut groups_vector1 = Vec::new(); + group_values_trait_obj + .intern(&[batch1], &mut groups_vector1) + .unwrap(); + assert_eq!(group_values_trait_obj.len(), 2); + + let result = group_values_trait_obj.emit(EmitTo::All).unwrap(); + assert_emitted_is_dict_array(&result); + assert!(group_values_trait_obj.is_empty()); + + let batch2 = + create_dict_array(vec![0, 1, 2], vec!["gamma", "delta", "epsilon"]); + + let mut groups_vector2 = Vec::new(); + group_values_trait_obj + .intern(&[batch2], &mut groups_vector2) + .unwrap(); + assert_eq!( + group_values_trait_obj.len(), + 3, + "Should be able to intern new groups after emit" + ); + + let result = group_values_trait_obj.emit(EmitTo::All).unwrap(); + assert_emitted_is_dict_array(&result); + assert!( + group_values_trait_obj.is_empty(), + "Should be empty after second emit" + ); + } + + #[test] + fn test_null_keys_form_single_group() { + let mut group_values = + GroupValuesDictionary::::new(&DataType::Utf8); + // keys: [0, null, 1, null, 0] + // values: ["a", "b"] + // null keys should all map to the same group + let keys = Int32Array::from(vec![Some(0), None, Some(1), None, Some(0)]); + let values = StringArray::from(vec!["a", "b"]); + let dict = Arc::new(DictionaryArray::new(keys, Arc::new(values))) as ArrayRef; + + let mut groups = Vec::new(); + group_values.intern(&[dict], &mut groups).unwrap(); + + // should have 3 groups: "a", "b", null + assert_eq!(group_values.len(), 3); + // null rows (index 1 and 3) should map to same group + assert_eq!(groups[1], groups[3]); + // non null rows should map to correct groups + assert_eq!(groups[0], groups[4]); // both "a" + assert_ne!(groups[0], groups[2]); // "a" != "b" + } + + #[test] + fn test_null_values_in_dictionary_form_single_group() { + let mut group_values = + GroupValuesDictionary::::new(&DataType::Utf8); + // keys: [0, 1, 2, 1, 0] + // values: ["a", null, "b"] + // keys pointing to null value should all map to same group + let keys = Int32Array::from(vec![0, 1, 2, 1, 0]); + let values = StringArray::from(vec![Some("a"), None, Some("b")]); + let dict = Arc::new(DictionaryArray::new(keys, Arc::new(values))) as ArrayRef; + + let mut groups = Vec::new(); + group_values.intern(&[dict], &mut groups).unwrap(); + + // should have 3 groups: "a", null, "b" + assert_eq!(group_values.len(), 3); + // rows pointing to null value (index 1 and 3) should map to same group + assert_eq!(groups[1], groups[3]); + // non null rows should map correctly + assert_eq!(groups[0], groups[4]); // both "a" + assert_ne!(groups[0], groups[2]); // "a" != "b" + } + } + #[cfg(test)] + mod null_value_edge_cases { + use super::*; + + /// Regression test for COUNT DISTINCT with mixed null and non-null dictionary values + /// Expected: only non-null values "abc" and "def" are counted = 2 + #[test] + fn test_count_distinct_mixed_nulls() { + let mut group_values = + GroupValuesDictionary::::new(&DataType::Utf8); + // keys: [0, 1, 2, 0, 1, 3] + // values: [None, "abc", "def", None] + // rows 0, 3, 5 point to null values → should all map to null group + // rows 1, 4 point to "abc" → same group + // row 2 points to "def" → own group + let keys = UInt8Array::from(vec![0, 1, 2, 0, 1, 3]); + let values = StringArray::from(vec![None, Some("abc"), Some("def"), None]); + let dict = Arc::new( + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(), + ) as ArrayRef; + + let mut groups = Vec::new(); + group_values.intern(&[dict], &mut groups).unwrap(); + // 3 groups: null, "abc", "def" + assert_eq!(group_values.len(), 3); + assert_eq!(groups.len(), 6); + + // null group - rows 0, 3, 5 all map to same group + assert_eq!(groups[0], groups[3]); + assert_eq!(groups[0], groups[5]); + // "abc" group - rows 1 and 4 + assert_eq!(groups[1], groups[4]); + // all three groups are distinct + assert_ne!(groups[0], groups[1]); + assert_ne!(groups[1], groups[2]); + assert_ne!(groups[0], groups[2]); + + // emit and verify null is correctly represented + let result = group_values.emit(EmitTo::All).unwrap(); + assert_eq!(result.len(), 1); + + let emitted = result[0] + .as_any() + .downcast_ref::>() + .expect("Expected DictionaryArray"); + + // null key should be null in the emitted array + let null_key = emitted.keys().value(groups[0]); + assert!(emitted.values().is_null(null_key as usize)); + + // check that non-null groups point to non-null values + let abc_key = emitted.keys().value(groups[1]); + assert!(!emitted.values().is_null(abc_key as usize)); + + let def_key = emitted.keys().value(groups[2]); + assert!(!emitted.values().is_null(def_key as usize)); + + assert!(group_values.is_empty()); + } + + /// Regression test for GROUP BY with null keys in dictionary + /// Expected: null keys form a single group, non-null keys form their own groups + #[test] + fn test_group_by_null_keys() { + let mut group_values = + GroupValuesDictionary::::new(&DataType::Utf8); + // keys: [Some(0), None, Some(1), None, Some(0)] + // values: ["group_a", "group_b"] + // null key rows 1 and 3 should map to same null group + let keys = UInt8Array::from(vec![Some(0), None, Some(1), None, Some(0)]); + let values = StringArray::from(vec![Some("group_a"), Some("group_b")]); + let dict = Arc::new( + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(), + ) as ArrayRef; + + let mut groups = Vec::new(); + group_values.intern(&[dict], &mut groups).unwrap(); + + // 3 groups: "group_a", "group_b", null + assert_eq!(group_values.len(), 3); + assert_eq!(groups.len(), 5); + + // null keys map to same group + assert_eq!(groups[1], groups[3]); + // "group_a" rows map to same group + assert_eq!(groups[0], groups[4]); + // all three groups are distinct + assert_ne!(groups[0], groups[1]); + assert_ne!(groups[0], groups[2]); + assert_ne!(groups[1], groups[2]); + + // emit and verify + let result = group_values.emit(EmitTo::All).unwrap(); + assert_eq!(result.len(), 1); + + let emitted = result[0] + .as_any() + .downcast_ref::>() + .expect("Expected DictionaryArray"); + let string_values = emitted + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + // null group key should be null in emitted array + let null_key = emitted.keys().value(groups[1]); + assert!(string_values.is_null(null_key as usize)); + + // non-null groups point to non-null values + let group_a_key = emitted.keys().value(groups[0]); + assert!(!string_values.is_null(group_a_key as usize)); + assert_eq!(string_values.value(group_a_key as usize), "group_a"); + + let group_b_key = emitted.keys().value(groups[2]); + assert!(!string_values.is_null(group_b_key as usize)); + assert_eq!(string_values.value(group_b_key as usize), "group_b"); + + assert!(group_values.is_empty()); + } + + /// Regression test for GROUP BY with null values in dictionary values array + /// Expected: keys pointing to null values form a single null group + #[test] + fn test_group_by_null_values_in_dict() { + let mut group_values = + GroupValuesDictionary::::new(&DataType::Utf8); + // keys: [0, 1, 2, 1, 0] + // values: ["val_x", None, "val_y"] + // key 1 points to null value - rows 1 and 3 should map to null group + let keys = UInt8Array::from(vec![0u8, 1, 2, 1, 0]); + let values = StringArray::from(vec![Some("val_x"), None, Some("val_y")]); + let dict = Arc::new( + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(), + ) as ArrayRef; + + let mut groups = Vec::new(); + group_values.intern(&[dict], &mut groups).unwrap(); + + // 3 groups: "val_x", null, "val_y" + assert_eq!(group_values.len(), 3); + assert_eq!(groups.len(), 5); + + // rows pointing to null value map to same group + assert_eq!(groups[1], groups[3]); + // "val_x" rows map to same group + assert_eq!(groups[0], groups[4]); + // all three groups are distinct + assert_ne!(groups[0], groups[1]); + assert_ne!(groups[1], groups[2]); + assert_ne!(groups[0], groups[2]); + + // emit and verify + let result = group_values.emit(EmitTo::All).unwrap(); + assert_eq!(result.len(), 1); + + let emitted = result[0] + .as_any() + .downcast_ref::>() + .expect("Expected DictionaryArray"); + + // null group should be null in emitted array + let null_key = emitted.keys().value(groups[1]); + let string_values = emitted + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(string_values.is_null(null_key as usize)); + + let val_x_key = emitted.keys().value(groups[0]); + assert_eq!(string_values.value(val_x_key as usize), "val_x"); + + let val_y_key = emitted.keys().value(groups[2]); + assert_eq!(string_values.value(val_y_key as usize), "val_y"); + assert!(group_values.is_empty()); + } + } +} diff --git a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs index 89c6b624e8e0a..0dac3e72d9e45 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/single_group_by/mod.rs @@ -20,4 +20,5 @@ pub(crate) mod boolean; pub(crate) mod bytes; pub(crate) mod bytes_view; +pub mod dictionary; pub(crate) mod primitive; diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index 5b41a47406797..eefae699787e1 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -1700,3 +1700,91 @@ mod tests { Ok(()) } } + +#[cfg(test)] +mod dictionary_aggregation { + use super::*; + use crate::aggregates::{ArrayRef, DataType, Field, RecordBatch, Schema}; + use crate::expressions::col; + use crate::test::TestMemoryExec; + use arrow::datatypes::UInt8Type; + use datafusion_functions_aggregate::count::count_udaf; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + + /// Equivalent SQL: + /// SELECT region, COUNT(*) + /// FROM events + /// GROUP BY region + /// + /// Smoke test to verify that aggregation over a dictionary-encoded + /// GROUP BY column produces output without panicking or erroring. + /// region is a low cardinality dictionary-encoded string column + /// with 3 distinct values across 6 rows, mirroring a realistic + /// events table where region is always present. + #[tokio::test] + async fn test_count_group_by_dictionary_column() -> Result<()> { + // dictionary encoded region column + // 3 distinct values across 6 rows + let keys = UInt8Array::from(vec![0, 1, 0, 2, 1, 0]); + let values = StringArray::from(vec!["us-east", "us-west", "eu-central"]); + let region_col: ArrayRef = Arc::new( + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(), + ); + + // event_id column to count + let event_id_col: ArrayRef = + Arc::new(Int64Array::from(vec![1001, 1002, 1003, 1004, 1005, 1006])); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "region", + DataType::Dictionary(Box::new(DataType::UInt8), Box::new(DataType::Utf8)), + false, + ), + Field::new("event_id", DataType::Int64, false), + ])); + + let batch = + RecordBatch::try_new(Arc::clone(&schema), vec![region_col, event_id_col])?; + + let exec = Arc::new(TestMemoryExec::try_new( + &[vec![batch]], + Arc::clone(&schema), + None, + )?); + + let aggregate_exec = AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![( + col("region", &schema)?, + "region".to_string(), + )]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("event_id", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("count") + .build()?, + )], + vec![None], + exec, + Arc::clone(&schema), + )?; + + let task_ctx = Arc::new(TaskContext::default()); + let mut stream = GroupedHashAggregateStream::new(&aggregate_exec, &task_ctx, 0)?; + + let mut batches = vec![]; + while let Some(result) = stream.next().await { + batches.push(result?); + } + + // verify we got output + assert!(!batches.is_empty()); + // verify we got 3 groups - one per distinct region + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + assert_eq!(total_rows, 3); + dbg!("record batches: {batches:#?}"); + + Ok(()) + } +}