Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 0 additions & 61 deletions vortex-array/src/arrays/extension/compute/rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,60 +6,25 @@ use vortex_error::VortexResult;
use crate::ArrayRef;
use crate::IntoArray;
use crate::array::ArrayView;
use crate::arrays::Constant;
use crate::arrays::ConstantArray;
use crate::arrays::Extension;
use crate::arrays::ExtensionArray;
use crate::arrays::Filter;
use crate::arrays::extension::ExtensionArrayExt;
use crate::arrays::filter::FilterReduceAdaptor;
use crate::arrays::slice::SliceReduceAdaptor;
use crate::matcher::AnyArray;
use crate::optimizer::rules::ArrayParentReduceRule;
use crate::optimizer::rules::ParentRuleSet;
use crate::scalar::Scalar;
use crate::scalar_fn::fns::cast::CastReduceAdaptor;
use crate::scalar_fn::fns::mask::MaskReduceAdaptor;

pub(crate) const PARENT_RULES: ParentRuleSet<Extension> = ParentRuleSet::new(&[
ParentRuleSet::lift(&ExtensionConstantParentRule),
ParentRuleSet::lift(&ExtensionFilterPushDownRule),
ParentRuleSet::lift(&CastReduceAdaptor(Extension)),
ParentRuleSet::lift(&FilterReduceAdaptor(Extension)),
ParentRuleSet::lift(&MaskReduceAdaptor(Extension)),
ParentRuleSet::lift(&SliceReduceAdaptor(Extension)),
]);

/// Normalize `Extension(Constant(storage))` children to `Constant(Extension(storage))`.
#[derive(Debug)]
struct ExtensionConstantParentRule;

impl ArrayParentReduceRule<Extension> for ExtensionConstantParentRule {
type Parent = AnyArray;

fn reduce_parent(
&self,
child: ArrayView<'_, Extension>,
parent: &ArrayRef,
child_idx: usize,
) -> VortexResult<Option<ArrayRef>> {
let Some(const_array) = child.storage_array().as_opt::<Constant>() else {
return Ok(None);
};

let storage_scalar = const_array.scalar().clone();
let ext_scalar = Scalar::extension_ref(child.ext_dtype().clone(), storage_scalar);

let constant_with_extension_scalar =
ConstantArray::new(ext_scalar, child.len()).into_array();

parent
.clone()
.with_slot(child_idx, constant_with_extension_scalar)
.map(Some)
}
}

/// Push filter operations into the storage array of an extension array.
#[derive(Debug)]
struct ExtensionFilterPushDownRule;
Expand Down Expand Up @@ -93,7 +58,6 @@ mod tests {
use crate::IntoArray;
#[expect(deprecated)]
use crate::ToCanonical as _;
use crate::arrays::Constant;
use crate::arrays::ConstantArray;
use crate::arrays::Extension;
use crate::arrays::ExtensionArray;
Expand Down Expand Up @@ -213,31 +177,6 @@ mod tests {
assert_eq!(canonical.len(), 3);
}

#[test]
fn test_extension_constant_child_normalizes_under_scalar_fn() {
let ext_dtype = test_ext_dtype();

let constant_storage = ConstantArray::new(Scalar::from(10i64), 3).into_array();
let constant_ext = ExtensionArray::new(ext_dtype.clone(), constant_storage).into_array();

let storage = buffer![15i64, 25, 35].into_array();
let ext_array = ExtensionArray::new(ext_dtype, storage).into_array();

let scalar_fn_array = Binary
.try_new_array(3, Operator::Lt, [constant_ext, ext_array])
.unwrap();

let optimized = scalar_fn_array.optimize().unwrap();
let scalar_fn = optimized.as_opt::<crate::arrays::ScalarFnVTable>().unwrap();
let children = scalar_fn.children();
let constant = children[0]
.as_opt::<Constant>()
.expect("constant extension child should be normalized");

assert!(constant.scalar().as_extension_opt().is_some());
assert_eq!(constant.len(), 3);
}

#[test]
fn test_scalar_fn_no_pushdown_different_ext_types() {
#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
Expand Down
82 changes: 25 additions & 57 deletions vortex-tensor/src/scalar_fns/inner_product.rs
Original file line number Diff line number Diff line change
Expand Up @@ -448,18 +448,26 @@ impl InnerProduct {
return Ok(None);
}

// The other side must be a constant tensor.
let Some(const_storage) = constant_tensor_storage(const_ref) else {
// The other side must be a constant-backed tensor-like extension whose scalar is
// non-null.
let Some(const_ext) = const_ref.as_opt::<Extension>() else {
return Ok(None);
};
let const_storage = const_ext.storage_array();
let Some(const_backing) = const_storage.as_opt::<Constant>() else {
return Ok(None);
};
if const_backing.scalar().is_null() {
return Ok(None);
}

let dim = sorf_view.options.dimension as usize;
let num_rounds = sorf_view.options.num_rounds as usize;
let seed = sorf_view.options.seed;
let padded_dim = dim.next_power_of_two();

// Extract the single stored row of the constant via the stride-0 short-circuit.
let flat = extract_flat_elements(&const_storage, dim, ctx)?;
let flat = extract_flat_elements(const_storage, dim, ctx)?;
if flat.ptype() != PType::F32 {
// TODO(connor): as above, f16/f64 are not supported by this rewrite yet. The
// standard path handles them correctly.
Expand All @@ -474,9 +482,9 @@ impl InnerProduct {
let mut rotated_query = vec![0.0f32; padded_dim];
rotation.rotate(&padded_query, &mut rotated_query);

// Build the rewritten constant as a `Vector<padded_dim, f32>` extension scalar. We reuse
// the original storage FSL nullability so the new extension dtype stays consistent with
// whatever the original tree expected.
// Build the rewritten constant as a `Vector<padded_dim, f32>` extension wrapping a
// `ConstantArray` of length `len`. We reuse the original storage FSL nullability so
// the new extension dtype stays consistent with whatever the original tree expected.
let storage_fsl_nullability = const_storage.dtype().nullability();
let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
let children: Vec<Scalar> = rotated_query
Expand All @@ -485,6 +493,7 @@ impl InnerProduct {
.collect();
let fsl_scalar =
Scalar::fixed_size_list(element_dtype.clone(), children, storage_fsl_nullability);
let new_storage = ConstantArray::new(fsl_scalar, len).into_array();

// Build a fresh `Vector<padded_dim, f32>` extension dtype. We cannot reuse the
// original extension dtype because that one has `dim`, not `padded_dim`.
Expand All @@ -495,8 +504,7 @@ impl InnerProduct {
storage_fsl_nullability,
);
let new_ext_dtype = ExtDType::<Vector>::try_new(EmptyMetadata, new_fsl_dtype)?.erased();
let new_constant =
ConstantArray::new(Scalar::extension_ref(new_ext_dtype, fsl_scalar), len).into_array();
let new_constant = ExtensionArray::new(new_ext_dtype, new_storage).into_array();

// Extract the SorfTransform child (the already-padded Vector<padded_dim, f32>).
let sorf_child = sorf_view
Expand Down Expand Up @@ -564,9 +572,16 @@ impl InnerProduct {
};

// Navigate the constant side and require its scalar be non-null.
let Some(const_storage) = constant_tensor_storage(const_candidate) else {
let Some(const_ext) = const_candidate.as_opt::<Extension>() else {
return Ok(None);
};
let const_storage = const_ext.storage_array();
let Some(const_backing) = const_storage.as_opt::<Constant>() else {
return Ok(None);
};
if const_backing.scalar().is_null() {
return Ok(None);
}

// Canonicalize codes and values. Codes may be e.g. BitPacked; executing is cheaper
// than falling through to the standard path (which would also canonicalize).
Expand All @@ -587,7 +602,7 @@ impl InnerProduct {

let padded_dim = usize::try_from(fsl.list_size()).vortex_expect("fsl list_size fits usize");

let flat = extract_flat_elements(&const_storage, padded_dim, ctx)?;
let flat = extract_flat_elements(const_storage, padded_dim, ctx)?;
if flat.ptype() != PType::F32 {
// TODO(connor): case 2 is f32-only. For f16/f64 we fall through to the standard
// path, which computes the inner product with the correct element type.
Expand Down Expand Up @@ -622,16 +637,6 @@ impl InnerProduct {
}
}

/// Return the storage constant for a canonical tensor-like constant query.
fn constant_tensor_storage(array: &ArrayRef) -> Option<ArrayRef> {
let constant = array.as_opt::<Constant>()?;
if constant.scalar().is_null() {
return None;
}
let ext_scalar = constant.scalar().as_extension_opt()?;
Some(ConstantArray::new(ext_scalar.to_storage_scalar(), array.len()).into_array())
}

/// Computes the inner product (dot product) of two equal-length float slices.
///
/// Returns `sum(a_i * b_i)`.
Expand Down Expand Up @@ -954,7 +959,6 @@ mod tests {
use vortex_array::ArrayRef;
use vortex_array::IntoArray;
use vortex_array::VortexSessionExecute;
use vortex_array::arrays::Constant;
use vortex_array::arrays::ConstantArray;
use vortex_array::arrays::ExtensionArray;
use vortex_array::arrays::FixedSizeListArray;
Expand All @@ -974,11 +978,9 @@ mod tests {
use vortex_session::VortexSession;

use crate::scalar_fns::inner_product::InnerProduct;
use crate::scalar_fns::inner_product::constant_tensor_storage;
use crate::scalar_fns::sorf_transform::SorfMatrix;
use crate::scalar_fns::sorf_transform::SorfOptions;
use crate::scalar_fns::sorf_transform::SorfTransform;
use crate::utils::extract_flat_elements;
use crate::vector::Vector;

static SESSION: LazyLock<VortexSession> =
Expand Down Expand Up @@ -1009,19 +1011,6 @@ mod tests {
Ok(ExtensionArray::new(ext_dtype, storage).into_array())
}

/// Expression-literal shape: a ConstantArray whose scalar itself is a Vector extension.
fn literal_vector_f32(elements: &[f32], len: usize) -> ArrayRef {
let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable);
let children: Vec<Scalar> = elements
.iter()
.map(|&v| Scalar::primitive(v, Nullability::NonNullable))
.collect();
let storage_scalar =
Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable);
let vector_scalar = Scalar::extension::<Vector>(EmptyMetadata, storage_scalar);
ConstantArray::new(vector_scalar, len).into_array()
}

/// Build an `ExtensionArray<Vector<list_size, f32>>` whose storage is
/// `FSL(DictArray(codes: u8, values: f32))`. This mirrors the shape that
/// TurboQuant produces as the SorfTransform child.
Expand Down Expand Up @@ -1126,27 +1115,6 @@ mod tests {

// ---- Case 1: SorfTransform + Constant pull-through ----

#[test]
fn constant_tensor_storage_accepts_extension_scalar_literal() -> VortexResult<()> {
let literal = literal_vector_f32(&[1.0, 2.0, 3.0], 5);
let storage =
constant_tensor_storage(&literal).expect("literal vector should be recognized");

assert_eq!(storage.len(), 5);
let const_storage = storage
.as_opt::<Constant>()
.expect("storage should remain constant-backed");
assert!(matches!(
const_storage.scalar().dtype(),
DType::FixedSizeList(_, 3, Nullability::NonNullable)
));

let mut ctx = SESSION.create_execution_ctx();
let flat = extract_flat_elements(&storage, 3, &mut ctx)?;
assert_eq!(flat.row::<f32>(0), &[1.0, 2.0, 3.0]);
Ok(())
}

/// Case 1: SorfTransform on LHS, constant query on RHS, with `dim < padded_dim`
/// so the zero-padding branch is exercised.
#[test]
Expand Down
Loading