From 2f9a5e2066fb0e7de436c818ba828d77b9f0a9c7 Mon Sep 17 00:00:00 2001 From: Joe Isaacs Date: Fri, 17 Apr 2026 18:56:38 +0100 Subject: [PATCH] Revert "Add extension constant pushdown rule and fix `InnerProduct` rule (#7507)" This reverts commit 869b2d11acc23adeaaeebc7c299f9083a60874c0. --- .../src/arrays/extension/compute/rules.rs | 61 -------------- vortex-tensor/src/scalar_fns/inner_product.rs | 82 ++++++------------- 2 files changed, 25 insertions(+), 118 deletions(-) diff --git a/vortex-array/src/arrays/extension/compute/rules.rs b/vortex-array/src/arrays/extension/compute/rules.rs index 7408488a0f1..6a58e4838be 100644 --- a/vortex-array/src/arrays/extension/compute/rules.rs +++ b/vortex-array/src/arrays/extension/compute/rules.rs @@ -6,23 +6,18 @@ use vortex_error::VortexResult; use crate::ArrayRef; use crate::IntoArray; use crate::array::ArrayView; -use crate::arrays::Constant; -use crate::arrays::ConstantArray; use crate::arrays::Extension; use crate::arrays::ExtensionArray; use crate::arrays::Filter; use crate::arrays::extension::ExtensionArrayExt; use crate::arrays::filter::FilterReduceAdaptor; use crate::arrays::slice::SliceReduceAdaptor; -use crate::matcher::AnyArray; use crate::optimizer::rules::ArrayParentReduceRule; use crate::optimizer::rules::ParentRuleSet; -use crate::scalar::Scalar; use crate::scalar_fn::fns::cast::CastReduceAdaptor; use crate::scalar_fn::fns::mask::MaskReduceAdaptor; pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ - ParentRuleSet::lift(&ExtensionConstantParentRule), ParentRuleSet::lift(&ExtensionFilterPushDownRule), ParentRuleSet::lift(&CastReduceAdaptor(Extension)), ParentRuleSet::lift(&FilterReduceAdaptor(Extension)), @@ -30,36 +25,6 @@ pub(crate) const PARENT_RULES: ParentRuleSet = ParentRuleSet::new(&[ ParentRuleSet::lift(&SliceReduceAdaptor(Extension)), ]); -/// Normalize `Extension(Constant(storage))` children to `Constant(Extension(storage))`. -#[derive(Debug)] -struct ExtensionConstantParentRule; - -impl ArrayParentReduceRule for ExtensionConstantParentRule { - type Parent = AnyArray; - - fn reduce_parent( - &self, - child: ArrayView<'_, Extension>, - parent: &ArrayRef, - child_idx: usize, - ) -> VortexResult> { - let Some(const_array) = child.storage_array().as_opt::() else { - return Ok(None); - }; - - let storage_scalar = const_array.scalar().clone(); - let ext_scalar = Scalar::extension_ref(child.ext_dtype().clone(), storage_scalar); - - let constant_with_extension_scalar = - ConstantArray::new(ext_scalar, child.len()).into_array(); - - parent - .clone() - .with_slot(child_idx, constant_with_extension_scalar) - .map(Some) - } -} - /// Push filter operations into the storage array of an extension array. #[derive(Debug)] struct ExtensionFilterPushDownRule; @@ -93,7 +58,6 @@ mod tests { use crate::IntoArray; #[expect(deprecated)] use crate::ToCanonical as _; - use crate::arrays::Constant; use crate::arrays::ConstantArray; use crate::arrays::Extension; use crate::arrays::ExtensionArray; @@ -213,31 +177,6 @@ mod tests { assert_eq!(canonical.len(), 3); } - #[test] - fn test_extension_constant_child_normalizes_under_scalar_fn() { - let ext_dtype = test_ext_dtype(); - - let constant_storage = ConstantArray::new(Scalar::from(10i64), 3).into_array(); - let constant_ext = ExtensionArray::new(ext_dtype.clone(), constant_storage).into_array(); - - let storage = buffer![15i64, 25, 35].into_array(); - let ext_array = ExtensionArray::new(ext_dtype, storage).into_array(); - - let scalar_fn_array = Binary - .try_new_array(3, Operator::Lt, [constant_ext, ext_array]) - .unwrap(); - - let optimized = scalar_fn_array.optimize().unwrap(); - let scalar_fn = optimized.as_opt::().unwrap(); - let children = scalar_fn.children(); - let constant = children[0] - .as_opt::() - .expect("constant extension child should be normalized"); - - assert!(constant.scalar().as_extension_opt().is_some()); - assert_eq!(constant.len(), 3); - } - #[test] fn test_scalar_fn_no_pushdown_different_ext_types() { #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index dd9c2a7381f..5928335ccf8 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -448,10 +448,18 @@ impl InnerProduct { return Ok(None); } - // The other side must be a constant tensor. - let Some(const_storage) = constant_tensor_storage(const_ref) else { + // The other side must be a constant-backed tensor-like extension whose scalar is + // non-null. + let Some(const_ext) = const_ref.as_opt::() else { return Ok(None); }; + let const_storage = const_ext.storage_array(); + let Some(const_backing) = const_storage.as_opt::() else { + return Ok(None); + }; + if const_backing.scalar().is_null() { + return Ok(None); + } let dim = sorf_view.options.dimension as usize; let num_rounds = sorf_view.options.num_rounds as usize; @@ -459,7 +467,7 @@ impl InnerProduct { let padded_dim = dim.next_power_of_two(); // Extract the single stored row of the constant via the stride-0 short-circuit. - let flat = extract_flat_elements(&const_storage, dim, ctx)?; + let flat = extract_flat_elements(const_storage, dim, ctx)?; if flat.ptype() != PType::F32 { // TODO(connor): as above, f16/f64 are not supported by this rewrite yet. The // standard path handles them correctly. @@ -474,9 +482,9 @@ impl InnerProduct { let mut rotated_query = vec![0.0f32; padded_dim]; rotation.rotate(&padded_query, &mut rotated_query); - // Build the rewritten constant as a `Vector` extension scalar. We reuse - // the original storage FSL nullability so the new extension dtype stays consistent with - // whatever the original tree expected. + // Build the rewritten constant as a `Vector` extension wrapping a + // `ConstantArray` of length `len`. We reuse the original storage FSL nullability so + // the new extension dtype stays consistent with whatever the original tree expected. let storage_fsl_nullability = const_storage.dtype().nullability(); let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); let children: Vec = rotated_query @@ -485,6 +493,7 @@ impl InnerProduct { .collect(); let fsl_scalar = Scalar::fixed_size_list(element_dtype.clone(), children, storage_fsl_nullability); + let new_storage = ConstantArray::new(fsl_scalar, len).into_array(); // Build a fresh `Vector` extension dtype. We cannot reuse the // original extension dtype because that one has `dim`, not `padded_dim`. @@ -495,8 +504,7 @@ impl InnerProduct { storage_fsl_nullability, ); let new_ext_dtype = ExtDType::::try_new(EmptyMetadata, new_fsl_dtype)?.erased(); - let new_constant = - ConstantArray::new(Scalar::extension_ref(new_ext_dtype, fsl_scalar), len).into_array(); + let new_constant = ExtensionArray::new(new_ext_dtype, new_storage).into_array(); // Extract the SorfTransform child (the already-padded Vector). let sorf_child = sorf_view @@ -564,9 +572,16 @@ impl InnerProduct { }; // Navigate the constant side and require its scalar be non-null. - let Some(const_storage) = constant_tensor_storage(const_candidate) else { + let Some(const_ext) = const_candidate.as_opt::() else { return Ok(None); }; + let const_storage = const_ext.storage_array(); + let Some(const_backing) = const_storage.as_opt::() else { + return Ok(None); + }; + if const_backing.scalar().is_null() { + return Ok(None); + } // Canonicalize codes and values. Codes may be e.g. BitPacked; executing is cheaper // than falling through to the standard path (which would also canonicalize). @@ -587,7 +602,7 @@ impl InnerProduct { let padded_dim = usize::try_from(fsl.list_size()).vortex_expect("fsl list_size fits usize"); - let flat = extract_flat_elements(&const_storage, padded_dim, ctx)?; + let flat = extract_flat_elements(const_storage, padded_dim, ctx)?; if flat.ptype() != PType::F32 { // TODO(connor): case 2 is f32-only. For f16/f64 we fall through to the standard // path, which computes the inner product with the correct element type. @@ -622,16 +637,6 @@ impl InnerProduct { } } -/// Return the storage constant for a canonical tensor-like constant query. -fn constant_tensor_storage(array: &ArrayRef) -> Option { - let constant = array.as_opt::()?; - if constant.scalar().is_null() { - return None; - } - let ext_scalar = constant.scalar().as_extension_opt()?; - Some(ConstantArray::new(ext_scalar.to_storage_scalar(), array.len()).into_array()) -} - /// Computes the inner product (dot product) of two equal-length float slices. /// /// Returns `sum(a_i * b_i)`. @@ -954,7 +959,6 @@ mod tests { use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; - use vortex_array::arrays::Constant; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; @@ -974,11 +978,9 @@ mod tests { use vortex_session::VortexSession; use crate::scalar_fns::inner_product::InnerProduct; - use crate::scalar_fns::inner_product::constant_tensor_storage; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; use crate::scalar_fns::sorf_transform::SorfTransform; - use crate::utils::extract_flat_elements; use crate::vector::Vector; static SESSION: LazyLock = @@ -1009,19 +1011,6 @@ mod tests { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - /// Expression-literal shape: a ConstantArray whose scalar itself is a Vector extension. - fn literal_vector_f32(elements: &[f32], len: usize) -> ArrayRef { - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let vector_scalar = Scalar::extension::(EmptyMetadata, storage_scalar); - ConstantArray::new(vector_scalar, len).into_array() - } - /// Build an `ExtensionArray>` whose storage is /// `FSL(DictArray(codes: u8, values: f32))`. This mirrors the shape that /// TurboQuant produces as the SorfTransform child. @@ -1126,27 +1115,6 @@ mod tests { // ---- Case 1: SorfTransform + Constant pull-through ---- - #[test] - fn constant_tensor_storage_accepts_extension_scalar_literal() -> VortexResult<()> { - let literal = literal_vector_f32(&[1.0, 2.0, 3.0], 5); - let storage = - constant_tensor_storage(&literal).expect("literal vector should be recognized"); - - assert_eq!(storage.len(), 5); - let const_storage = storage - .as_opt::() - .expect("storage should remain constant-backed"); - assert!(matches!( - const_storage.scalar().dtype(), - DType::FixedSizeList(_, 3, Nullability::NonNullable) - )); - - let mut ctx = SESSION.create_execution_ctx(); - let flat = extract_flat_elements(&storage, 3, &mut ctx)?; - assert_eq!(flat.row::(0), &[1.0, 2.0, 3.0]); - Ok(()) - } - /// Case 1: SorfTransform on LHS, constant query on RHS, with `dim < padded_dim` /// so the zero-padding branch is exercised. #[test]