diff --git a/datafusion/ffi/src/udwf/mod.rs b/datafusion/ffi/src/udwf/mod.rs index bff46386709f9..9871d88ffb8fe 100644 --- a/datafusion/ffi/src/udwf/mod.rs +++ b/datafusion/ffi/src/udwf/mod.rs @@ -40,6 +40,7 @@ use stabby::vec::Vec as SVec; mod partition_evaluator; mod partition_evaluator_args; mod range; +mod window_state; use crate::arrow_wrappers::WrappedSchema; use crate::util::{ diff --git a/datafusion/ffi/src/udwf/partition_evaluator.rs b/datafusion/ffi/src/udwf/partition_evaluator.rs index c4c43f00d81fa..c4fedf00473ad 100644 --- a/datafusion/ffi/src/udwf/partition_evaluator.rs +++ b/datafusion/ffi/src/udwf/partition_evaluator.rs @@ -27,10 +27,12 @@ use datafusion_expr::PartitionEvaluator; use datafusion_expr::window_state::WindowAggState; use prost::Message; +use stabby::string::String as SString; use stabby::vec::Vec as SVec; use super::range::FFI_Range; use crate::arrow_wrappers::WrappedArray; +use crate::udwf::window_state::FFI_WindowAggState; use crate::util::FFI_Result; use crate::{df_result, sresult, sresult_return}; @@ -40,6 +42,11 @@ use crate::{df_result, sresult, sresult_return}; #[repr(C)] #[derive(Debug)] pub struct FFI_PartitionEvaluator { + pub memoize: unsafe extern "C" fn( + evaluator: &mut Self, + state: FFI_WindowAggState, + ) -> FFI_Result, + pub evaluate_all: unsafe extern "C" fn( evaluator: &mut Self, values: SVec, @@ -179,6 +186,27 @@ unsafe extern "C" fn evaluate_all_with_rank_fn_wrapper( } } +unsafe extern "C" fn memoize_fn_wrapper( + evaluator: &mut FFI_PartitionEvaluator, + state: FFI_WindowAggState, +) -> FFI_Result { + unsafe { + let inner = evaluator.inner_mut(); + let mut native_state = sresult_return!(WindowAggState::try_from(state)); + + // Propagate errors from memoize + if let Err(e) = inner.memoize(&mut native_state) { + return FFI_Result::Err(SString::from(format!("{}", e))); + } + + // Convert mutated state back to FFI + match FFI_WindowAggState::try_from(native_state) { + Ok(ffi_state) => FFI_Result::Ok(ffi_state), + Err(e) => FFI_Result::Err(SString::from(format!("{}", e))), + } + } +} + unsafe extern "C" fn get_range_fn_wrapper( evaluator: &FFI_PartitionEvaluator, idx: usize, @@ -221,6 +249,7 @@ impl From> for FFI_PartitionEvaluator { let private_data = PartitionEvaluatorPrivateData { evaluator }; Self { + memoize: memoize_fn_wrapper, evaluate: evaluate_fn_wrapper, evaluate_all: evaluate_all_fn_wrapper, evaluate_all_with_rank: evaluate_all_with_rank_fn_wrapper, @@ -271,9 +300,20 @@ impl From for Box { } impl PartitionEvaluator for ForeignPartitionEvaluator { - fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> { - // Exposing `memoize` increases the surface are of the FFI work - // so for now we dot support it. + fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> { + let ffi_state = FFI_WindowAggState::try_from(state.clone())?; + // Temporarily store the context + // so we are preserving it AS IS for now + // TODO: possibly there is a better way of doing this + let saved_ctx = state.window_frame_ctx.take(); + + let result = unsafe { (self.evaluator.memoize)(&mut self.evaluator, ffi_state) }; + + let updated_ffi_state = df_result!(result)?; + + *state = WindowAggState::try_from(updated_ffi_state)?; + state.window_frame_ctx = saved_ctx; + Ok(()) } @@ -365,6 +405,15 @@ impl PartitionEvaluator for ForeignPartitionEvaluator { mod tests { use arrow::array::ArrayRef; use datafusion::logical_expr::PartitionEvaluator; + use datafusion_common::scalar::ScalarValue; + use datafusion_expr::{ + WindowFrame, + window_state::{WindowAggState, WindowFrameContext}, + }; + use std::ops::Range; + + use arrow::array::Int32Array; + use std::sync::Arc; use crate::udwf::partition_evaluator::{ FFI_PartitionEvaluator, ForeignPartitionEvaluator, @@ -413,4 +462,124 @@ mod tests { Ok(()) } + + #[test] + fn test_memoize_state_preservation() -> datafusion_common::Result<()> { + // Create a test evaluator that actually modifies state + #[derive(Debug)] + struct StateModifyingEvaluator; + + impl PartitionEvaluator for StateModifyingEvaluator { + fn memoize( + &mut self, + state: &mut WindowAggState, + ) -> datafusion_common::Result<()> { + // Modify the window frame range + state.window_frame_range.start = state.window_frame_range.end - 1; + Ok(()) + } + + fn evaluate( + &mut self, + _values: &[ArrayRef], + _range: &Range, + ) -> datafusion_common::Result { + Ok(ScalarValue::Int32(Some(42))) + } + } + + let evaluator: Box = Box::new(StateModifyingEvaluator); + let mut ffi: FFI_PartitionEvaluator = evaluator.into(); + + // Make it act as foreign + ffi.library_marker_id = crate::mock_foreign_marker_id; + let mut foreign: Box = + Box::new(ForeignPartitionEvaluator { evaluator: ffi }); + + // Create state with a specific range + let mut state = WindowAggState { + window_frame_range: 0..10, + window_frame_ctx: None, + last_calculated_index: 0, + offset_pruned_rows: 0, + out_col: Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef, + n_row_result_missing: 0, + is_end: false, + }; + + // Call memoize + foreign.memoize(&mut state)?; + + // Verify state was properly mutated + assert_eq!( + state.window_frame_range, + 9..10, + "window_frame_range should have been pruned" + ); + + Ok(()) + } + + #[test] + fn test_memoize_preserves_window_frame_ctx() -> datafusion_common::Result<()> { + #[derive(Debug)] + struct CtxAwareEvaluator { + // Track whether memoize was called + memoize_called: std::cell::Cell, + } + + impl PartitionEvaluator for CtxAwareEvaluator { + fn memoize( + &mut self, + state: &mut WindowAggState, + ) -> datafusion_common::Result<()> { + self.memoize_called.set(true); + // Don't touch window_frame_ctx - just verify it's there + state.window_frame_range.start = state.window_frame_range.end - 1; + Ok(()) + } + } + + let evaluator: Box = Box::new(CtxAwareEvaluator { + memoize_called: std::cell::Cell::new(false), + }); + let mut ffi: FFI_PartitionEvaluator = evaluator.into(); + + // Force foreign path + ffi.library_marker_id = crate::mock_foreign_marker_id; + let mut foreign: Box = + Box::new(ForeignPartitionEvaluator { evaluator: ffi }); + + // Create a real WindowFrameContext + let window_frame = Arc::new(WindowFrame::new(Some(true))); + let original_ctx = WindowFrameContext::new(window_frame, vec![]); + + let mut state = WindowAggState { + window_frame_range: 0..10, + window_frame_ctx: Some(original_ctx), + last_calculated_index: 0, + offset_pruned_rows: 0, + out_col: Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef, + n_row_result_missing: 0, + is_end: false, + }; + + foreign.memoize(&mut state)?; + + assert_eq!(state.window_frame_range, 9..10); + + // Verify that window_frame_ctx isn't lost + assert!( + state.window_frame_ctx.is_some(), + "window_frame_ctx should be Some after memoize" + ); + + // Verify it's the same variant at least + match &state.window_frame_ctx.unwrap() { + WindowFrameContext::Rows(_) => {} + _ => panic!("Expected Rows variant, got something else"), + } + + Ok(()) + } } diff --git a/datafusion/ffi/src/udwf/window_state.rs b/datafusion/ffi/src/udwf/window_state.rs new file mode 100644 index 0000000000000..40b359cd895ff --- /dev/null +++ b/datafusion/ffi/src/udwf/window_state.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::ArrayRef; +use datafusion_common::DataFusionError; +use datafusion_expr::window_state::WindowAggState; + +use crate::{arrow_wrappers::WrappedArray, udwf::range::FFI_Range}; + +/// Holds the state of evaluating a window function +#[repr(C)] +#[derive(Debug)] +pub struct FFI_WindowAggState { + pub window_frame_range: FFI_Range, + pub last_calculated_index: usize, + pub offset_pruned_rows: usize, + /// The accumulated output column + pub out_col: WrappedArray, + pub n_row_result_missing: usize, + pub is_end: bool, +} + +impl TryFrom for FFI_WindowAggState { + type Error = DataFusionError; + + fn try_from(s: WindowAggState) -> Result { + Ok(Self { + window_frame_range: FFI_Range::from(s.window_frame_range.clone()), + last_calculated_index: s.last_calculated_index, + offset_pruned_rows: s.offset_pruned_rows, + out_col: WrappedArray::try_from(&s.out_col).map_err(DataFusionError::from)?, + n_row_result_missing: s.n_row_result_missing, + is_end: s.is_end, + }) + } +} + +impl TryFrom for WindowAggState { + type Error = DataFusionError; + + fn try_from(s: FFI_WindowAggState) -> Result { + let out_col: ArrayRef = s.out_col.try_into().map_err(DataFusionError::from)?; + + Ok(WindowAggState { + window_frame_range: s.window_frame_range.into(), + window_frame_ctx: None, + last_calculated_index: s.last_calculated_index, + offset_pruned_rows: s.offset_pruned_rows, + out_col, + n_row_result_missing: s.n_row_result_missing, + is_end: s.is_end, + }) + } +}