Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/ffi/src/udwf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ use stabby::vec::Vec as SVec;
mod partition_evaluator;
mod partition_evaluator_args;
mod range;
mod window_state;

use crate::arrow_wrappers::WrappedSchema;
use crate::util::{
Expand Down
175 changes: 172 additions & 3 deletions datafusion/ffi/src/udwf/partition_evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ use datafusion_expr::PartitionEvaluator;
use datafusion_expr::window_state::WindowAggState;
use prost::Message;

use stabby::string::String as SString;
use stabby::vec::Vec as SVec;

use super::range::FFI_Range;
use crate::arrow_wrappers::WrappedArray;
use crate::udwf::window_state::FFI_WindowAggState;
use crate::util::FFI_Result;
use crate::{df_result, sresult, sresult_return};

Expand All @@ -40,6 +42,11 @@ use crate::{df_result, sresult, sresult_return};
#[repr(C)]
#[derive(Debug)]
pub struct FFI_PartitionEvaluator {
pub memoize: unsafe extern "C" fn(
evaluator: &mut Self,
state: FFI_WindowAggState,
) -> FFI_Result<FFI_WindowAggState>,

pub evaluate_all: unsafe extern "C" fn(
evaluator: &mut Self,
values: SVec<WrappedArray>,
Expand Down Expand Up @@ -179,6 +186,27 @@ unsafe extern "C" fn evaluate_all_with_rank_fn_wrapper(
}
}

unsafe extern "C" fn memoize_fn_wrapper(
evaluator: &mut FFI_PartitionEvaluator,
state: FFI_WindowAggState,
) -> FFI_Result<FFI_WindowAggState> {
unsafe {
let inner = evaluator.inner_mut();
let mut native_state = sresult_return!(WindowAggState::try_from(state));

// Propagate errors from memoize
if let Err(e) = inner.memoize(&mut native_state) {
return FFI_Result::Err(SString::from(format!("{}", e)));
}

// Convert mutated state back to FFI
match FFI_WindowAggState::try_from(native_state) {
Ok(ffi_state) => FFI_Result::Ok(ffi_state),
Err(e) => FFI_Result::Err(SString::from(format!("{}", e))),
}
}
}

unsafe extern "C" fn get_range_fn_wrapper(
evaluator: &FFI_PartitionEvaluator,
idx: usize,
Expand Down Expand Up @@ -221,6 +249,7 @@ impl From<Box<dyn PartitionEvaluator>> for FFI_PartitionEvaluator {
let private_data = PartitionEvaluatorPrivateData { evaluator };

Self {
memoize: memoize_fn_wrapper,
evaluate: evaluate_fn_wrapper,
evaluate_all: evaluate_all_fn_wrapper,
evaluate_all_with_rank: evaluate_all_with_rank_fn_wrapper,
Expand Down Expand Up @@ -271,9 +300,20 @@ impl From<FFI_PartitionEvaluator> for Box<dyn PartitionEvaluator> {
}

impl PartitionEvaluator for ForeignPartitionEvaluator {
fn memoize(&mut self, _state: &mut WindowAggState) -> Result<()> {
// Exposing `memoize` increases the surface are of the FFI work
// so for now we dot support it.
fn memoize(&mut self, state: &mut WindowAggState) -> Result<()> {
let ffi_state = FFI_WindowAggState::try_from(state.clone())?;
// Temporarily store the context
// so we are preserving it AS IS for now
// TODO: possibly there is a better way of doing this
let saved_ctx = state.window_frame_ctx.take();

let result = unsafe { (self.evaluator.memoize)(&mut self.evaluator, ffi_state) };

let updated_ffi_state = df_result!(result)?;

*state = WindowAggState::try_from(updated_ffi_state)?;
state.window_frame_ctx = saved_ctx;

Ok(())
}

Expand Down Expand Up @@ -365,6 +405,15 @@ impl PartitionEvaluator for ForeignPartitionEvaluator {
mod tests {
use arrow::array::ArrayRef;
use datafusion::logical_expr::PartitionEvaluator;
use datafusion_common::scalar::ScalarValue;
use datafusion_expr::{
WindowFrame,
window_state::{WindowAggState, WindowFrameContext},
};
use std::ops::Range;

use arrow::array::Int32Array;
use std::sync::Arc;

use crate::udwf::partition_evaluator::{
FFI_PartitionEvaluator, ForeignPartitionEvaluator,
Expand Down Expand Up @@ -413,4 +462,124 @@ mod tests {

Ok(())
}

#[test]
fn test_memoize_state_preservation() -> datafusion_common::Result<()> {
// Create a test evaluator that actually modifies state
#[derive(Debug)]
struct StateModifyingEvaluator;

impl PartitionEvaluator for StateModifyingEvaluator {
fn memoize(
&mut self,
state: &mut WindowAggState,
) -> datafusion_common::Result<()> {
// Modify the window frame range
state.window_frame_range.start = state.window_frame_range.end - 1;
Ok(())
}

fn evaluate(
&mut self,
_values: &[ArrayRef],
_range: &Range<usize>,
) -> datafusion_common::Result<ScalarValue> {
Ok(ScalarValue::Int32(Some(42)))
}
}

let evaluator: Box<dyn PartitionEvaluator> = Box::new(StateModifyingEvaluator);
let mut ffi: FFI_PartitionEvaluator = evaluator.into();

// Make it act as foreign
ffi.library_marker_id = crate::mock_foreign_marker_id;
let mut foreign: Box<dyn PartitionEvaluator> =
Box::new(ForeignPartitionEvaluator { evaluator: ffi });

// Create state with a specific range
let mut state = WindowAggState {
window_frame_range: 0..10,
window_frame_ctx: None,
last_calculated_index: 0,
offset_pruned_rows: 0,
out_col: Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef,
n_row_result_missing: 0,
is_end: false,
};

// Call memoize
foreign.memoize(&mut state)?;

// Verify state was properly mutated
assert_eq!(
state.window_frame_range,
9..10,
"window_frame_range should have been pruned"
);

Ok(())
}

#[test]
fn test_memoize_preserves_window_frame_ctx() -> datafusion_common::Result<()> {
#[derive(Debug)]
struct CtxAwareEvaluator {
// Track whether memoize was called
memoize_called: std::cell::Cell<bool>,
}

impl PartitionEvaluator for CtxAwareEvaluator {
fn memoize(
&mut self,
state: &mut WindowAggState,
) -> datafusion_common::Result<()> {
self.memoize_called.set(true);
// Don't touch window_frame_ctx - just verify it's there
state.window_frame_range.start = state.window_frame_range.end - 1;
Ok(())
}
}

let evaluator: Box<dyn PartitionEvaluator> = Box::new(CtxAwareEvaluator {
memoize_called: std::cell::Cell::new(false),
});
let mut ffi: FFI_PartitionEvaluator = evaluator.into();

// Force foreign path
ffi.library_marker_id = crate::mock_foreign_marker_id;
let mut foreign: Box<dyn PartitionEvaluator> =
Box::new(ForeignPartitionEvaluator { evaluator: ffi });

// Create a real WindowFrameContext
let window_frame = Arc::new(WindowFrame::new(Some(true)));
let original_ctx = WindowFrameContext::new(window_frame, vec![]);

let mut state = WindowAggState {
window_frame_range: 0..10,
window_frame_ctx: Some(original_ctx),
last_calculated_index: 0,
offset_pruned_rows: 0,
out_col: Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])) as ArrayRef,
n_row_result_missing: 0,
is_end: false,
};

foreign.memoize(&mut state)?;

assert_eq!(state.window_frame_range, 9..10);

// Verify that window_frame_ctx isn't lost
assert!(
state.window_frame_ctx.is_some(),
"window_frame_ctx should be Some after memoize"
);

// Verify it's the same variant at least
match &state.window_frame_ctx.unwrap() {
WindowFrameContext::Rows(_) => {}
_ => panic!("Expected Rows variant, got something else"),
}

Ok(())
}
}
68 changes: 68 additions & 0 deletions datafusion/ffi/src/udwf/window_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::array::ArrayRef;
use datafusion_common::DataFusionError;
use datafusion_expr::window_state::WindowAggState;

use crate::{arrow_wrappers::WrappedArray, udwf::range::FFI_Range};

/// Holds the state of evaluating a window function
#[repr(C)]
#[derive(Debug)]
pub struct FFI_WindowAggState {
pub window_frame_range: FFI_Range,
pub last_calculated_index: usize,
pub offset_pruned_rows: usize,
/// The accumulated output column
pub out_col: WrappedArray,
pub n_row_result_missing: usize,
pub is_end: bool,
}

impl TryFrom<WindowAggState> for FFI_WindowAggState {
type Error = DataFusionError;

fn try_from(s: WindowAggState) -> Result<Self, Self::Error> {
Ok(Self {
window_frame_range: FFI_Range::from(s.window_frame_range.clone()),
last_calculated_index: s.last_calculated_index,
offset_pruned_rows: s.offset_pruned_rows,
out_col: WrappedArray::try_from(&s.out_col).map_err(DataFusionError::from)?,
n_row_result_missing: s.n_row_result_missing,
is_end: s.is_end,
})
}
}

impl TryFrom<FFI_WindowAggState> for WindowAggState {
type Error = DataFusionError;

fn try_from(s: FFI_WindowAggState) -> Result<Self, Self::Error> {
let out_col: ArrayRef = s.out_col.try_into().map_err(DataFusionError::from)?;

Ok(WindowAggState {
window_frame_range: s.window_frame_range.into(),
window_frame_ctx: None,
last_calculated_index: s.last_calculated_index,
offset_pruned_rows: s.offset_pruned_rows,
out_col,
n_row_result_missing: s.n_row_result_missing,
is_end: s.is_end,
})
}
}