From 37cdf9bab3f2cb6b3a11ff76c431f9f0141e2161 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 02:03:44 +0000 Subject: [PATCH 01/15] Fix FixedShapeTensorMetadata Display for permuted tensors The `Display` impl emitted `]` as a closing bracket without a matching opener and without a separator from the shape list, producing output like `Tensor(2, 3, 41, 0, 2])` instead of `Tensor(2, 3, 4, [1, 0, 2])`. Emit the opening `, [` so the permutation is printed as a bracketed list after the shape, and add regression tests covering all four display cases (shape-only, 0d scalar, dim names, permutation, and the combination of dim names + permutation). Signed-off-by: Claude --- vortex-tensor/src/fixed_shape/metadata.rs | 39 +++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/vortex-tensor/src/fixed_shape/metadata.rs b/vortex-tensor/src/fixed_shape/metadata.rs index 264d18453c4..757138d3e50 100644 --- a/vortex-tensor/src/fixed_shape/metadata.rs +++ b/vortex-tensor/src/fixed_shape/metadata.rs @@ -215,6 +215,7 @@ impl fmt::Display for FixedShapeTensorMetadata { } if let Some(perm) = &self.permutation { + write!(f, ", [")?; for (i, p) in perm.iter().enumerate() { if i > 0 { write!(f, ", ")?; @@ -353,6 +354,44 @@ mod tests { Ok(()) } + // -- Display -- + + #[test] + fn display_shape_only() { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]); + assert_eq!(m.to_string(), "Tensor(2, 3, 4)"); + } + + #[test] + fn display_scalar_0d() { + let m = FixedShapeTensorMetadata::new(vec![]); + assert_eq!(m.to_string(), "Tensor()"); + } + + #[test] + fn display_with_dim_names() -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(vec![3, 4]) + .with_dim_names(vec!["rows".into(), "cols".into()])?; + assert_eq!(m.to_string(), "Tensor(rows: 3, cols: 4)"); + Ok(()) + } + + #[test] + fn display_with_permutation() -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![1, 0, 2])?; + assert_eq!(m.to_string(), "Tensor(2, 3, 4, [1, 0, 2])"); + Ok(()) + } + + #[test] + fn display_with_dim_names_and_permutation() -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()])? + .with_permutation(vec![1, 2, 0])?; + assert_eq!(m.to_string(), "Tensor(x: 2, y: 3, z: 4, [1, 2, 0])"); + Ok(()) + } + #[test] fn dim_names_wrong_length() { let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_dim_names(vec!["x".into()]); From 599487b3125cf53970d3328166ac21093ca7fcd3 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 02:04:27 +0000 Subject: [PATCH 02/15] Tidy scalar_fn vtable boilerplate in vortex-tensor - Standardize `ScalarFnId::from("...")` on `ScalarFnId::new("...")` so all five tensor scalar functions construct their id the same way (l2_denorm and sorf_transform already used `::new`). - Fix a typo in the `L2Norm` readthrough expect message that said "L2Denom must have at 2 children". No behavior change; the two methods are equivalent aside from wording. Signed-off-by: Claude --- vortex-tensor/src/scalar_fns/cosine_similarity.rs | 2 +- vortex-tensor/src/scalar_fns/inner_product.rs | 2 +- vortex-tensor/src/scalar_fns/l2_norm.rs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 9f8eff0b361..764bc23b815 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -84,7 +84,7 @@ impl ScalarFnVTable for CosineSimilarity { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::from("vortex.tensor.cosine_similarity") + ScalarFnId::new("vortex.tensor.cosine_similarity") } fn arity(&self, _options: &Self::Options) -> Arity { diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index dd9c2a7381f..905c2d3fcac 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -99,7 +99,7 @@ impl ScalarFnVTable for InnerProduct { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::from("vortex.tensor.inner_product") + ScalarFnId::new("vortex.tensor.inner_product") } fn arity(&self, _options: &Self::Options) -> Arity { diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 59a2e5a45b1..94c0b41f3d2 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -84,7 +84,7 @@ impl ScalarFnVTable for L2Norm { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::from("vortex.tensor.l2_norm") + ScalarFnId::new("vortex.tensor.l2_norm") } fn arity(&self, _options: &Self::Options) -> Arity { @@ -142,7 +142,7 @@ impl ScalarFnVTable for L2Norm { if let Some(sfn) = input_ref.as_opt::>() { let norms = sfn .nth_child(1) - .vortex_expect("L2Denom must have at 2 children"); + .vortex_expect("L2Denorm must have 2 children"); vortex_ensure_eq!(norms.dtype(), &norm_dtype); From ebd85b06da0eefc366134f13d8ff04b1de3e0df9 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 02:04:37 +0000 Subject: [PATCH 03/15] Use `Buffer::copy_from` in TurboQuant centroids buffer build `build_quantized_fsl` was constructing a `BufferMut`, copying the centroid slice into it via `extend_from_slice`, and then freezing it into a `Buffer`. `Buffer::copy_from` does the same thing in one call. Signed-off-by: Claude --- vortex-tensor/src/encodings/turboquant/compress.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index b0173bbe36c..35e51cc88fe 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -25,6 +25,7 @@ use vortex_array::dtype::Nullability; use vortex_array::dtype::extension::ExtDType; use vortex_array::extension::EmptyMetadata; use vortex_array::validity::Validity; +use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -136,10 +137,8 @@ fn build_quantized_fsl( padded_dim: usize, ) -> VortexResult { let codes = PrimitiveArray::new::(all_indices.freeze(), Validity::NonNullable); - - let mut centroids_buf = BufferMut::::with_capacity(centroids.len()); - centroids_buf.extend_from_slice(centroids); - let centroids_array = PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable); + let centroids_array = + PrimitiveArray::new::(Buffer::copy_from(centroids), Validity::NonNullable); let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?; From 860321a0360b29c5719e132924dd801bf84256f3 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 02:30:17 +0000 Subject: [PATCH 04/15] Dedupe L2Denorm-orientation pattern and share l2_denorm_array test helper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Introduce `DenormOrientation::classify`, a small enum helper that normalises a `(lhs, rhs)` operand pair into `Both`, `One { denorm, plain }`, or `Neither`. Symmetric binary tensor operators (`CosineSimilarity`, `InnerProduct`) use it instead of hand-rolling "swap lhs and rhs if only the rhs is denormalised" at each call site. - Move the `l2_denorm_array` test helper — previously duplicated byte-for-byte between `cosine_similarity` and `inner_product` tests — into `utils::test_helpers` so both suites share one implementation and threading an `ExecutionCtx` through is explicit. - Fix "have be" -> "have been" in two comments and replace the stale "four independent accumulators" comment in `execute_dict_constant_inner_product` with a reference to the `PARTIAL_SUMS` constant (currently 8). Signed-off-by: Claude --- .../src/scalar_fns/cosine_similarity.rs | 77 ++++++++----------- vortex-tensor/src/scalar_fns/inner_product.rs | 64 ++++++--------- vortex-tensor/src/scalar_fns/l2_denorm.rs | 46 +++++++++++ vortex-tensor/src/utils.rs | 17 ++++ 4 files changed, 119 insertions(+), 85 deletions(-) diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 764bc23b815..9ea5c185827 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -11,7 +11,6 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; -use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; @@ -37,7 +36,7 @@ use vortex_session::VortexSession; use crate::scalar_fns::inner_product::BinaryTensorOpMetadata; use crate::scalar_fns::inner_product::InnerProduct; -use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::DenormOrientation; use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm; use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::extract_l2_denorm_children; @@ -59,6 +58,7 @@ use crate::utils::validate_tensor_float_input; /// /// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor /// [`Vector`]: crate::vector::Vector +/// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm #[derive(Clone)] pub struct CosineSimilarity; @@ -141,32 +141,24 @@ impl ScalarFnVTable for CosineSimilarity { let len = args.row_count(); // If either side is a constant tensor-like extension array, eagerly normalize the single - // stored row and re-wrap it as an `L2Denorm` whose children are both [`ConstantArray`]s. + // stored row and re-wrap it as an `L2Denorm` whose children are both `ConstantArray`s. // The L2Denorm fast path below then picks it up. - if let Some(lhs_constant) = - try_build_constant_l2_denorm(&lhs_ref, len, ctx)?.map(|sfn| sfn.into_array()) - { - lhs_ref = lhs_constant; + if let Some(sfn) = try_build_constant_l2_denorm(&lhs_ref, len, ctx)? { + lhs_ref = sfn.into_array(); } - if let Some(rhs_constant) = - try_build_constant_l2_denorm(&rhs_ref, len, ctx)?.map(|sfn| sfn.into_array()) - { - rhs_ref = rhs_constant; + if let Some(sfn) = try_build_constant_l2_denorm(&rhs_ref, len, ctx)? { + rhs_ref = sfn.into_array(); } - // Check if any of our children have be already normalized. - { - let lhs_is_denorm = lhs_ref.is::>(); - let rhs_is_denorm = rhs_ref.is::>(); - - if lhs_is_denorm && rhs_is_denorm { - return self.execute_both_denorm(&lhs_ref, &rhs_ref, len, ctx); - } else if lhs_is_denorm || rhs_is_denorm { - if rhs_is_denorm { - (lhs_ref, rhs_ref) = (rhs_ref, lhs_ref); - } - return self.execute_one_denorm(&lhs_ref, &rhs_ref, len, ctx); + // Take any L2Denorm-wrapped fast path that applies. + match DenormOrientation::classify(&lhs_ref, &rhs_ref) { + DenormOrientation::Both { lhs, rhs } => { + return self.execute_both_denorm(lhs, rhs, len, ctx); } + DenormOrientation::One { denorm, plain } => { + return self.execute_one_denorm(denorm, plain, len, ctx); + } + DenormOrientation::Neither => {} } // Compute combined validity. @@ -348,6 +340,7 @@ mod tests { use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; use crate::utils::test_helpers::constant_vector_array; + use crate::utils::test_helpers::l2_denorm_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -536,25 +529,13 @@ mod tests { Ok(()) } - /// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms. - fn l2_denorm_array( - shape: &[usize], - normalized_elements: &[f64], - norms: &[f64], - ) -> VortexResult { - let len = norms.len(); - let normalized = tensor_array(shape, normalized_elements)?; - let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); - let mut ctx = SESSION.create_execution_ctx(); - Ok(L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?.into_array()) - } - #[test] fn both_denorm_self_similarity() -> VortexResult<()> { // [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8]. // [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0]. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Self-similarity should always be 1.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]); @@ -565,8 +546,9 @@ mod tests { fn both_denorm_orthogonal() -> VortexResult<()> { // [3.0, 0.0] normalized [1.0, 0.0], norm 3.0. // [0.0, 4.0] normalized [0.0, 1.0], norm 4.0. - let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0])?; - let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0], &mut ctx)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]); Ok(()) @@ -575,8 +557,9 @@ mod tests { #[test] fn both_denorm_zero_norm() -> VortexResult<()> { // Zero-norm row: normalized is [0.0, 0.0], norm is 0.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); @@ -588,7 +571,8 @@ mod tests { // LHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // RHS is plain [3.0, 4.0]. // cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; let rhs = tensor_array(&[2], &[3.0, 4.0])?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); @@ -599,8 +583,9 @@ mod tests { fn one_side_denorm_rhs() -> VortexResult<()> { // LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6. + let mut ctx = SESSION.create_execution_ctx(); let lhs = tensor_array(&[2], &[1.0, 0.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]); Ok(()) @@ -609,11 +594,11 @@ mod tests { #[test] fn both_denorm_null_norms() -> VortexResult<()> { // Row 0: valid, row 1: null (via nullable norms on rhs). - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); - let mut ctx = SESSION.create_execution_ctx(); let rhs = L2Denorm::try_new_array(normalized_r, norms_r, 2, &mut ctx)?.into_array(); let scalar_fn = CosineSimilarity::new().erased(); diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 905c2d3fcac..e2f92d91a62 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -56,7 +56,7 @@ use vortex_error::vortex_err; use vortex_session::VortexSession; use crate::matcher::AnyTensor; -use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::DenormOrientation; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfTransform; use crate::utils::extract_flat_elements; @@ -167,23 +167,19 @@ impl ScalarFnVTable for InnerProduct { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let mut lhs_ref = args.get(0)?; - let mut rhs_ref = args.get(1)?; + let lhs_ref = args.get(0)?; + let rhs_ref = args.get(1)?; let len = args.row_count(); - // Check if any of our children have be already normalized. - { - let lhs_is_denorm = lhs_ref.is::>(); - let rhs_is_denorm = rhs_ref.is::>(); - - if lhs_is_denorm && rhs_is_denorm { - return self.execute_both_denorm(&lhs_ref, &rhs_ref, len, ctx); - } else if lhs_is_denorm || rhs_is_denorm { - if rhs_is_denorm { - (lhs_ref, rhs_ref) = (rhs_ref, lhs_ref); - } - return self.execute_one_denorm(&lhs_ref, &rhs_ref, len, ctx); + // Take any L2Denorm-wrapped fast path that applies. + match DenormOrientation::classify(&lhs_ref, &rhs_ref) { + DenormOrientation::Both { lhs, rhs } => { + return self.execute_both_denorm(lhs, rhs, len, ctx); } + DenormOrientation::One { denorm, plain } => { + return self.execute_one_denorm(denorm, plain, len, ctx); + } + DenormOrientation::Neither => {} } // Reduction case 1: `InnerProduct(SorfTransform(x), const)` rewrites to @@ -647,8 +643,8 @@ fn inner_product_row(a: &[T], b: &[T]) -> T { /// For each row, computes `sum(q[j] * values[codes[row * dim + j]])` using the codebook /// `values` directly instead of decoding the dictionary into dense vectors. /// -/// The inner loop uses four independent accumulators so the CPU can pipeline FP additions -/// instead of waiting for each `fadd` to retire before starting the next. +/// The inner loop uses `PARTIAL_SUMS` independent accumulators so the CPU can pipeline FP +/// additions instead of waiting for each `fadd` to retire before starting the next. fn execute_dict_constant_inner_product( q: &[f32], values: &[f32], @@ -704,6 +700,7 @@ mod tests { use crate::scalar_fns::l2_denorm::L2Denorm; use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::l2_denorm_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -818,28 +815,14 @@ mod tests { Ok(()) } - /// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms. - fn l2_denorm_array( - shape: &[usize], - normalized_elements: &[f64], - norms: &[f64], - ) -> VortexResult { - use vortex_array::IntoArray; - - let len = norms.len(); - let normalized = tensor_array(shape, normalized_elements)?; - let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); - let mut ctx = SESSION.create_execution_ctx(); - Ok(L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?.into_array()) - } - #[test] fn both_denorm() -> VortexResult<()> { // LHS: [3.0, 4.0] = L2Denorm([0.6, 0.8], 5.0). // RHS: [1.0, 0.0] = L2Denorm([1.0, 0.0], 1.0). // dot([3.0, 4.0], [1.0, 0.0]) = 3.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; - let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0], &mut ctx)?; // Expected: 5.0 * 1.0 * dot([0.6, 0.8], [1.0, 0.0]) = 5.0 * 0.6 = 3.0. assert_close(&eval_inner_product(lhs, rhs, 1)?, &[3.0]); @@ -850,8 +833,9 @@ mod tests { fn both_denorm_multiple_rows() -> VortexResult<()> { // Row 0: [3.0, 4.0] dot [3.0, 4.0] = 25.0. // Row 1: [1.0, 0.0] dot [0.0, 1.0] = 0.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0], &mut ctx)?; assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); Ok(()) @@ -862,7 +846,8 @@ mod tests { // LHS: L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // RHS: plain [1.0, 2.0]. // dot([3.0, 4.0], [1.0, 2.0]) = 3.0 + 8.0 = 11.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; let rhs = tensor_array(&[2], &[1.0, 2.0])?; assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); @@ -874,8 +859,9 @@ mod tests { // LHS: plain [1.0, 2.0]. // RHS: L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // dot([1.0, 2.0], [3.0, 4.0]) = 3.0 + 8.0 = 11.0. + let mut ctx = SESSION.create_execution_ctx(); let lhs = tensor_array(&[2], &[1.0, 2.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); Ok(()) @@ -889,7 +875,7 @@ mod tests { let mut ctx = SESSION.create_execution_ctx(); let lhs = L2Denorm::try_new_array(normalized_l, norms_l, 2, &mut ctx)?.into_array(); - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let scalar_fn = InnerProduct::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 673c454b56b..323570076d0 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -22,6 +22,7 @@ use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::ScalarFnVTable as ScalarFnArrayEncoding; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; @@ -697,6 +698,51 @@ fn validate_l2_normalized_rows_impl( Ok(()) } +/// Classification of a binary operand pair by which side (if any) is wrapped in [`L2Denorm`]. +/// +/// Symmetric binary tensor operators (e.g. [`CosineSimilarity`], [`InnerProduct`]) have identical +/// fast paths for "only the lhs is denormalized" and "only the rhs is denormalized", and a separate +/// fast path for "both are denormalized". Rather than hand-rolling the commutative swap at every +/// call site, callers classify their operands with [`Self::classify`] and pattern-match on the +/// returned variant. +/// +/// [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity +/// [`InnerProduct`]: crate::scalar_fns::inner_product::InnerProduct +pub(crate) enum DenormOrientation<'a> { + /// Both operands are [`ExactScalarFn`] arrays. + Both { + lhs: &'a ArrayRef, + rhs: &'a ArrayRef, + }, + /// Exactly one operand is an [`ExactScalarFn`]; the other is plain. + One { + denorm: &'a ArrayRef, + plain: &'a ArrayRef, + }, + /// Neither operand is an [`ExactScalarFn`]. + Neither, +} + +impl<'a> DenormOrientation<'a> { + /// Classify `(lhs, rhs)` by which side (if any) is wrapped in [`L2Denorm`]. + pub(crate) fn classify(lhs: &'a ArrayRef, rhs: &'a ArrayRef) -> Self { + let lhs_denorm = lhs.is::>(); + let rhs_denorm = rhs.is::>(); + match (lhs_denorm, rhs_denorm) { + (true, true) => Self::Both { lhs, rhs }, + (true, false) => Self::One { + denorm: lhs, + plain: rhs, + }, + (false, true) => Self::One { + denorm: rhs, + plain: lhs, + }, + (false, false) => Self::Neither, + } + } +} + #[cfg(test)] mod tests { diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 76d65ce9eef..d01ee632c27 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -154,10 +154,12 @@ pub fn extract_l2_denorm_children(array: &ArrayRef) -> (ArrayRef, ArrayRef) { #[cfg(test)] pub mod test_helpers { use vortex_array::ArrayRef; + use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; @@ -170,6 +172,7 @@ pub mod test_helpers { use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; + use crate::scalar_fns::l2_denorm::L2Denorm; use crate::vector::Vector; /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. @@ -247,6 +250,20 @@ pub mod test_helpers { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } + /// Creates an [`L2Denorm`] scalar function array from pre-normalized f64 tensor elements and + /// f64 norms. The caller must ensure every row of `normalized_elements` is unit-norm or zero. + pub fn l2_denorm_array( + shape: &[usize], + normalized_elements: &[f64], + norms: &[f64], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let len = norms.len(); + let normalized = tensor_array(shape, normalized_elements)?; + let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); + Ok(L2Denorm::try_new_array(normalized, norms, len, ctx)?.into_array()) + } + /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` /// value, with support for NaN (NaN == NaN is considered equal). #[track_caller] From e096f2ff4c948ae6ec46509b530a39339bc90041 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 02:30:50 +0000 Subject: [PATCH 05/15] Replace FlatElements `stride` field with an explicit `is_constant` flag MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `FlatElements` used a `stride: usize` field that was either `0` (single stored row, used for `ConstantArray` storage) or `list_size` (full row per index). `row(i)` then multiplied the index by `stride`, which works but hides the real invariant — "stride is 0 xor stride == list_size". Store the invariant directly: - `is_constant: bool` replaces the tri-value `stride` field. - `row(i)` selects `0` or `i` via a single branch, so the two cases compile to the same code as before but the struct's semantics are self-describing. - `extract_flat_elements` collapses the duplicated `FixedSizeListArray::execute` + `PrimitiveArray::execute` path into a single tail block, rewriting the source once up front based on whether the storage is constant-backed. Signed-off-by: Claude --- vortex-tensor/src/utils.rs | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index d01ee632c27..9919944f8c8 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -80,12 +80,12 @@ pub fn cast_to_f32(prim: PrimitiveArray) -> VortexResult> { /// The flat primitive elements of a tensor storage array, with typed row access. /// /// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a -/// constant input materializes only a single row (stride=0), while a full array uses -/// stride=list_size. +/// constant-backed input materializes only a single row that every index reads (`is_constant = +/// true`), while a full array stores one row per index. pub struct FlatElements { elems: PrimitiveArray, - stride: usize, list_size: usize, + is_constant: bool, } impl FlatElements { @@ -96,10 +96,14 @@ impl FlatElements { } /// Returns the `i`-th row as a typed slice of length `list_size`. + /// + /// When the source was a constant-backed storage, all indices resolve to the single stored + /// row. #[must_use] pub fn row(&self, i: usize) -> &[T] { + let row_idx = if self.is_constant { 0 } else { i }; let slice = self.elems.as_slice::(); - &slice[i * self.stride..][..self.list_size] + &slice[row_idx * self.list_size..][..self.list_size] } } @@ -114,26 +118,21 @@ pub fn extract_flat_elements( list_size: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - if let Some(constant) = storage.as_opt::() { - // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge - // amount of data. + // Constant-backed storage: materialize just the single stored row so canonicalization does + // not expand the array to the full column length. + let (source, is_constant) = if let Some(constant) = storage.as_opt::() { let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); - let fsl: FixedSizeListArray = single.execute(ctx)?; - let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - return Ok(FlatElements { - elems, - stride: 0, - list_size, - }); - } + (single, true) + } else { + (storage.clone(), false) + }; - // Otherwise we have to fully expand all of the data. - let fsl: FixedSizeListArray = storage.clone().execute(ctx)?; + let fsl: FixedSizeListArray = source.execute(ctx)?; let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; Ok(FlatElements { elems, - stride: list_size, list_size, + is_constant, }) } From a98b0033ef09faa5f58458907eaa36ae4f563abb Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 02:31:04 +0000 Subject: [PATCH 06/15] Tidy SorfTransform vtable: round-trippable metadata and option-aware SQL - `fmt_sql` now prints the [`SorfOptions`] next to the child expression (e.g. `sorf_transform(col, SorfOptions(seed=42, rounds=3, dim=100, ptype=f32))`) so the SQL form uniquely identifies the transform instead of collapsing distinct configurations to the same string. - Factor `SorfTransformMetadata <-> SorfOptions` conversion into an `impl From<&SorfOptions>` and a `to_options()` method. `serialize` and `deserialize` become thin wrappers and the `num_rounds: u32 -> u8` range-check plus `validate_sorf_options` call live in exactly one place. Signed-off-by: Claude --- .../src/scalar_fns/sorf_transform/vtable.rs | 62 ++++++++++++------- 1 file changed, 38 insertions(+), 24 deletions(-) diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index 64f92da384f..26a56c245dc 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -70,13 +70,13 @@ impl ScalarFnVTable for SorfTransform { fn fmt_sql( &self, - _options: &Self::Options, + options: &Self::Options, expr: &Expression, f: &mut Formatter<'_>, ) -> fmt::Result { write!(f, "sorf_transform(")?; expr.child(0).fmt_sql(f)?; - write!(f, ")") + write!(f, ", {options})") } fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { @@ -214,21 +214,46 @@ pub(super) struct SorfTransformMetadata { element_ptype: i32, } +impl From<&SorfOptions> for SorfTransformMetadata { + fn from(options: &SorfOptions) -> Self { + Self { + seed: options.seed, + num_rounds: u32::from(options.num_rounds), + dimension: options.dimension, + element_ptype: options.element_ptype as i32, + } + } +} + +impl SorfTransformMetadata { + /// Rebuild the [`SorfOptions`] this metadata was serialized from, validating that the wire + /// values are in range. + fn to_options(&self) -> VortexResult { + let num_rounds = u8::try_from(self.num_rounds).map_err(|_| { + vortex_err!( + "SorfTransform num_rounds {} does not fit in u8", + self.num_rounds + ) + })?; + let options = SorfOptions { + seed: self.seed, + num_rounds, + dimension: self.dimension, + element_ptype: self.element_ptype(), + }; + validate_sorf_options(&options)?; + Ok(options) + } +} + impl ScalarFnArrayVTable for SorfTransform { fn serialize( &self, view: &ScalarFnArrayView, _session: &VortexSession, ) -> VortexResult>> { - let options = view.options; Ok(Some( - SorfTransformMetadata { - seed: options.seed, - num_rounds: u32::from(options.num_rounds), - dimension: options.dimension, - element_ptype: options.element_ptype as i32, - } - .encode_to_vec(), + SorfTransformMetadata::from(view.options).encode_to_vec(), )) } @@ -240,20 +265,9 @@ impl ScalarFnArrayVTable for SorfTransform { children: &dyn ArrayChildren, _session: &VortexSession, ) -> VortexResult> { - let metadata = SorfTransformMetadata::decode(metadata) - .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))?; - let options = SorfOptions { - seed: metadata.seed, - num_rounds: u8::try_from(metadata.num_rounds).map_err(|_| { - vortex_err!( - "SorfTransform num_rounds {} does not fit in u8", - metadata.num_rounds - ) - })?, - dimension: metadata.dimension, - element_ptype: metadata.element_ptype(), - }; - validate_sorf_options(&options)?; + let options = SorfTransformMetadata::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))? + .to_options()?; // `return_dtype` sets the output FSL's nullability to the child's nullability (see // `return_dtype` above), so we read the child nullability back from the parent dtype. From 1d040d2e2e249850e5033a1e3cec5b2ed3555cd5 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 02:31:21 +0000 Subject: [PATCH 07/15] Add `turboquant_compress` to consolidate the full compression pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The full TurboQuant compression pipeline — normalize a tensor via `normalize_as_l2_denorm`, quantize the normalized child via `turboquant_encode_unchecked`, and reattach the stored norms as an outer `L2Denorm` wrapper — was spelled out three times: in `TurboQuantScheme::compress`, in `vector_search::compress_turboquant`, and in the `normalize_and_encode` test helper. Add `turboquant_compress(input, config, ctx)` alongside the existing encode helpers and use it in all three call sites. Each caller collapses to a single line that threads its own config through (the scheme and `compress_turboquant` keep the `TurboQuantConfig::default()` choice, and the test helper now takes an explicit config). Signed-off-by: Claude --- vortex-tensor/public-api.lock | 4 +- .../src/encodings/turboquant/compress.rs | 43 +++++++++++++++++++ vortex-tensor/src/encodings/turboquant/mod.rs | 1 + .../src/encodings/turboquant/scheme.rs | 34 +-------------- .../src/encodings/turboquant/tests/mod.rs | 15 +------ vortex-tensor/src/vector_search.rs | 23 +--------- 6 files changed, 53 insertions(+), 67 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 96b95e1e91e..11778465928 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -80,6 +80,8 @@ pub const vortex_tensor::encodings::turboquant::MIN_DIMENSION: u32 pub fn vortex_tensor::encodings::turboquant::tq_validate_vector_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::turboquant::turboquant_compress(input: vortex_array::array::erased::ArrayRef, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub unsafe fn vortex_tensor::encodings::turboquant::turboquant_encode_unchecked(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult @@ -502,7 +504,7 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::child_name(&sel pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::fmt_sql(&self, options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::id(&self) -> vortex_array::scalar_fn::ScalarFnId diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index 35e51cc88fe..d521cb5d86a 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -21,6 +21,7 @@ use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::DictArray; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::dtype::Nullability; use vortex_array::dtype::extension::ExtDType; use vortex_array::extension::EmptyMetadata; @@ -36,6 +37,8 @@ use crate::encodings::turboquant::MIN_DIMENSION; use crate::encodings::turboquant::centroids::compute_centroid_boundaries; use crate::encodings::turboquant::centroids::find_nearest_centroid; use crate::encodings::turboquant::centroids::get_centroids; +use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; @@ -272,3 +275,43 @@ fn wrap_padded_as_vector(fsl: ArrayRef) -> VortexResult { let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); Ok(ExtensionArray::new(ext_dtype, fsl).into_array()) } + +/// Apply the full TurboQuant compression pipeline to a [`Vector`](crate::vector::Vector) +/// extension array: normalize the rows via [`normalize_as_l2_denorm`], quantize the normalized +/// child via [`turboquant_encode_unchecked`], and reattach the stored norms as the outer +/// [`L2Denorm`] wrapper. +/// +/// The returned array has the canonical TurboQuant shape: +/// +/// ```text +/// ScalarFnArray(L2Denorm, [ +/// ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]), +/// norms, +/// ]) +/// ``` +/// +/// # Errors +/// +/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or +/// if [`turboquant_encode_unchecked`] rejects the input shape. +pub fn turboquant_compress( + input: ArrayRef, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let l2_denorm = normalize_as_l2_denorm(input, ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + let norms = l2_denorm.child_at(1).clone(); + let num_rows = l2_denorm.len(); + + let normalized_ext = normalized + .as_opt::() + .vortex_expect("normalize_as_l2_denorm always produces an Extension array child"); + + // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero for null rows). + let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?; + + // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally + // bypass the strict normalized-row validation when reattaching the stored norms. + Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) +} diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 49e53effd0d..09704f0052c 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -144,6 +144,7 @@ pub(crate) mod compress; mod scheme; pub use compress::TurboQuantConfig; +pub use compress::turboquant_compress; pub use compress::turboquant_encode; pub use compress::turboquant_encode_unchecked; pub use scheme::TurboQuantScheme; diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index b603f12e16e..ee908378da8 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -23,9 +23,6 @@ use vortex_array::ArrayRef; use vortex_array::Canonical; -use vortex_array::IntoArray; -use vortex_array::arrays::Extension; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_compressor::CascadingCompressor; use vortex_compressor::ctx::CompressorContext; use vortex_compressor::estimate::CompressionEstimate; @@ -38,9 +35,7 @@ use vortex_error::VortexResult; use crate::encodings::turboquant::MAX_CENTROIDS; use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::tq_validate_vector_dtype; -use crate::encodings::turboquant::turboquant_encode_unchecked; -use crate::scalar_fns::l2_denorm::L2Denorm; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; +use crate::encodings::turboquant::turboquant_compress; /// TurboQuant compression scheme for [`Vector`] extension types. /// @@ -105,33 +100,8 @@ impl Scheme for TurboQuantScheme { data: &mut ArrayAndStats, _ctx: CompressorContext, ) -> VortexResult { - let ext_array = data - .array() - .as_opt::() - .vortex_expect("expected an extension array"); - let mut ctx = compressor.execution_ctx(); - - // 1. Normalize: produces L2Denorm(normalized_vectors, norms). - let l2_denorm = normalize_as_l2_denorm(ext_array.as_ref().clone(), &mut ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - // 2. Quantize the normalized child: SorfTransform(FSL(Dict)). - let normalized_ext = normalized - .as_opt::() - .vortex_expect("normalized child should be an Extension array"); - - let config = TurboQuantConfig::default(); - // SAFETY: We just normalized the input via `normalize_as_l2_denorm`, so all rows are - // guaranteed to be unit-norm (or zero for originally-null rows). - let sorf_dict = unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx)? }; - - // 3. Wrap back in L2Denorm: the SorfTransform is the "normalized" child. - // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally - // bypass the strict normalized-row validation when reattaching the stored norms. - Ok(unsafe { L2Denorm::new_array_unchecked(sorf_dict, norms, num_rows) }?.into_array()) + turboquant_compress(data.array().clone(), &TurboQuantConfig::default(), &mut ctx) } } diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs index b111c6e28ba..58de354561b 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/mod.rs @@ -33,8 +33,8 @@ use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuantConfig; +use crate::encodings::turboquant::turboquant_compress; use crate::encodings::turboquant::turboquant_encode_unchecked; -use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::tests::SESSION; use crate::vector::Vector; @@ -84,18 +84,7 @@ fn normalize_and_encode( config: &TurboQuantConfig, ctx: &mut vortex_array::ExecutionCtx, ) -> VortexResult { - let l2_denorm = normalize_as_l2_denorm(ext.as_ref().clone(), ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - let normalized_ext = normalized - .as_opt::() - .vortex_expect("normalized child should be an Extension array"); - // SAFETY: We just normalized the input via `normalize_as_l2_denorm`. - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx)? }; - - Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) + turboquant_compress(ext.as_ref().clone(), config, ctx) } /// Unwrap an L2Denorm ScalarFnArray into (sorf_child, norms_child). diff --git a/vortex-tensor/src/vector_search.rs b/vortex-tensor/src/vector_search.rs index 81a379683db..e7c717d87b5 100644 --- a/vortex-tensor/src/vector_search.rs +++ b/vortex-tensor/src/vector_search.rs @@ -46,9 +46,7 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; @@ -59,13 +57,10 @@ use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::fns::operators::Operator; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::turboquant_encode_unchecked; +use crate::encodings::turboquant::turboquant_compress; use crate::scalar_fns::cosine_similarity::CosineSimilarity; -use crate::scalar_fns::l2_denorm::L2Denorm; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::vector::Vector; /// Apply the canonical TurboQuant encoding pipeline to a `Vector` array. @@ -86,21 +81,7 @@ use crate::vector::Vector; /// Returns an error if `data` is not a [`Vector`] extension array, if normalization fails, or /// if the underlying TurboQuant encoder rejects the input shape. pub fn compress_turboquant(data: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - let l2_denorm = normalize_as_l2_denorm(data, ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - let Some(normalized_ext) = normalized.as_opt::() else { - vortex_bail!("normalize_as_l2_denorm must produce an Extension array child"); - }; - - let config = TurboQuantConfig::default(); - // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero), which is - // the invariant `turboquant_encode_unchecked` expects. - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, ctx) }?; - - Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) + turboquant_compress(data, &TurboQuantConfig::default(), ctx) } /// Build a [`Vector`] extension array whose storage is a [`ConstantArray`] broadcasting a single From dd5b665dd956c7c06407609a62c8a51efbf7242d Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 02:49:43 +0000 Subject: [PATCH 08/15] Unify binary tensor validation and fix assorted matcher polish MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Extract `validate_binary_tensor_float_inputs(op_name, lhs, rhs)` for the common "two operands must share the same float tensor dtype" precondition. `CosineSimilarity::return_dtype` and `InnerProduct::return_dtype` now both call it instead of hand-rolling the `eq_ignore_nullability` check, the extension-type downcast, the `AnyTensor` metadata lookup, and the float element-type check. The old `InnerProduct` path also had a redundant `is::()` plus `metadata_opt::()` pair — both collapse to the one `metadata_opt` call inside `validate_tensor_float_input`. - Add `VectorMatcherMetadata::list_size() -> usize` so it mirrors `FixedShapeTensorMatcherMetadata::list_size`. `TensorMatch::list_size` no longer needs the inline `as usize` cast for the `Vector` arm. - Drop the unused `_ctx` parameter from `CosineSimilarity::execute_both_denorm`; it never used the context since the both-denorm path is pure array composition. - Fix the assert message in `vector::matcher` that said "element dtype must be primitive" — the check is `element_dtype.is_float()`, so the message now says "must be float". Signed-off-by: Claude --- vortex-tensor/public-api.lock | 2 ++ vortex-tensor/src/matcher.rs | 2 +- .../src/scalar_fns/cosine_similarity.rs | 16 ++--------- vortex-tensor/src/scalar_fns/inner_product.rs | 28 ++----------------- vortex-tensor/src/utils.rs | 17 +++++++++++ vortex-tensor/src/vector/matcher.rs | 9 +++++- 6 files changed, 34 insertions(+), 40 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 11778465928..33c4fb671d9 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -576,6 +576,8 @@ pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32 pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType +pub fn vortex_tensor::vector::VectorMatcherMetadata::list_size(&self) -> usize + pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata diff --git a/vortex-tensor/src/matcher.rs b/vortex-tensor/src/matcher.rs index 973562be3da..258a464b57f 100644 --- a/vortex-tensor/src/matcher.rs +++ b/vortex-tensor/src/matcher.rs @@ -45,7 +45,7 @@ impl TensorMatch<'_> { pub fn list_size(self) -> usize { match self { Self::FixedShapeTensor(metadata) => metadata.list_size(), - Self::Vector(metadata) => metadata.dimensions() as usize, + Self::Vector(metadata) => metadata.list_size(), } } } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 9ea5c185827..a8db1cb79c8 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -31,7 +31,6 @@ use vortex_array::serde::ArrayChildren; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; -use vortex_error::vortex_ensure; use vortex_session::VortexSession; use crate::scalar_fns::inner_product::BinaryTensorOpMetadata; @@ -40,7 +39,7 @@ use crate::scalar_fns::l2_denorm::DenormOrientation; use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm; use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::extract_l2_denorm_children; -use crate::utils::validate_tensor_float_input; +use crate::utils::validate_binary_tensor_float_inputs; /// Cosine similarity between two columns. /// @@ -116,16 +115,8 @@ impl ScalarFnVTable for CosineSimilarity { let lhs = &arg_dtypes[0]; let rhs = &arg_dtypes[1]; - // Both must have the same dtype (ignoring top-level nullability). - vortex_ensure!( - lhs.eq_ignore_nullability(rhs), - "CosineSimilarity requires both inputs to have the same dtype, got {lhs} and {rhs}" - ); - - // We don't need to look at rhs anymore since we know lhs and rhs are equal. - let tensor_match = validate_tensor_float_input(lhs)?; + let tensor_match = validate_binary_tensor_float_inputs("CosineSimilarity", lhs, rhs)?; let ptype = tensor_match.element_ptype(); - let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); Ok(DType::Primitive(ptype, nullability)) } @@ -153,7 +144,7 @@ impl ScalarFnVTable for CosineSimilarity { // Take any L2Denorm-wrapped fast path that applies. match DenormOrientation::classify(&lhs_ref, &rhs_ref) { DenormOrientation::Both { lhs, rhs } => { - return self.execute_both_denorm(lhs, rhs, len, ctx); + return self.execute_both_denorm(lhs, rhs, len); } DenormOrientation::One { denorm, plain } => { return self.execute_one_denorm(denorm, plain, len, ctx); @@ -258,7 +249,6 @@ impl CosineSimilarity { lhs_ref: &ArrayRef, rhs_ref: &ArrayRef, len: usize, - _ctx: &mut ExecutionCtx, ) -> VortexResult { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index e2f92d91a62..8fc71e4fd86 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -61,6 +61,7 @@ use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfTransform; use crate::utils::extract_flat_elements; use crate::utils::extract_l2_denorm_children; +use crate::utils::validate_binary_tensor_float_inputs; use crate::vector::Vector; /// Inner product (dot product) between two columns. @@ -131,32 +132,9 @@ impl ScalarFnVTable for InnerProduct { let lhs = &arg_dtypes[0]; let rhs = &arg_dtypes[1]; - // Both must have the same dtype (ignoring top-level nullability). - vortex_ensure!( - lhs.eq_ignore_nullability(rhs), - "InnerProduct requires both inputs to have the same dtype, got {lhs} and {rhs}" - ); - - // Both inputs must be tensor-like extension types. - let lhs_ext = lhs - .as_extension_opt() - .ok_or_else(|| vortex_err!("InnerProduct lhs must be an extension type, got {lhs}"))?; - - vortex_ensure!( - lhs_ext.is::(), - "InnerProduct inputs must be an `AnyTensor`, got {lhs}" - ); - - let tensor_match = lhs_ext - .metadata_opt::() - .ok_or_else(|| vortex_err!("InnerProduct inputs must be an `AnyTensor`, got {lhs}"))?; + // TODO(connor): relax the float-only gate once integer tensors are supported. + let tensor_match = validate_binary_tensor_float_inputs("InnerProduct", lhs, rhs)?; let ptype = tensor_match.element_ptype(); - // TODO(connor): This should support integer tensors! - vortex_ensure!( - ptype.is_float(), - "InnerProduct element dtype must be a float primitive, got {ptype}" - ); - let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); Ok(DType::Primitive(ptype, nullability)) } diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 9919944f8c8..866d7d2325b 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -43,6 +43,23 @@ pub fn validate_tensor_float_input(input_dtype: &DType) -> VortexResult( + op_name: &str, + lhs: &'a DType, + rhs: &DType, +) -> VortexResult> { + vortex_ensure!( + lhs.eq_ignore_nullability(rhs), + "{op_name} requires both inputs to have the same dtype, got {lhs} and {rhs}" + ); + validate_tensor_float_input(lhs) +} + /// Cast a float [`PrimitiveArray`] to a `Buffer`. /// /// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively diff --git a/vortex-tensor/src/vector/matcher.rs b/vortex-tensor/src/vector/matcher.rs index 0da5384f303..4e0781c225e 100644 --- a/vortex-tensor/src/vector/matcher.rs +++ b/vortex-tensor/src/vector/matcher.rs @@ -49,7 +49,7 @@ impl Matcher for AnyVector { let dimensions = *list_size; - assert!(element_dtype.is_float(), "element dtype must be primitive"); + assert!(element_dtype.is_float(), "element dtype must be float"); assert!( !element_dtype.is_nullable(), "element dtype must be non-nullable" @@ -87,6 +87,13 @@ impl VectorMatcherMetadata { pub fn dimensions(&self) -> u32 { self.dimensions } + + /// Returns the flattened element count per vector row. Always equal to + /// [`dimensions`](Self::dimensions); exists as a `usize`-typed alias that mirrors + /// [`FixedShapeTensorMatcherMetadata::list_size`](crate::fixed_shape::FixedShapeTensorMatcherMetadata::list_size). + pub fn list_size(&self) -> usize { + self.dimensions as usize + } } #[cfg(test)] From 0e296ad887ecf9df59b83b2c1b9148e66999bf3b Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 02:53:46 +0000 Subject: [PATCH 09/15] Extract `Vector::wrap_storage` and use it at four call sites The two-line `ExtDType::::try_new(EmptyMetadata, storage.dtype().clone()) + ExtensionArray::new(ext_dtype, storage).into_array()` incantation for wrapping a storage array in a [`Vector`] extension appeared verbatim in: - `compress::wrap_padded_as_vector` (private, used twice) - `sorf_transform::vtable` (empty-array branch + `inverse_rotate_typed`) - `vector_search::build_constant_query_vector` Promote it to `Vector::wrap_storage(storage)`, an associated function on the [`Vector`] vtable struct that is the natural home for the operation. Each call site drops to a single line and the `ExtDType`/`EmptyMetadata` imports go away where they were only pulled in for this pattern. The old private helper is deleted. Signed-off-by: Claude --- vortex-tensor/public-api.lock | 4 +++ .../src/encodings/turboquant/compress.rs | 14 ++-------- .../src/scalar_fns/sorf_transform/vtable.rs | 8 ++---- vortex-tensor/src/vector/mod.rs | 26 +++++++++++++++++++ vortex-tensor/src/vector_search.rs | 8 +----- 5 files changed, 35 insertions(+), 25 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 33c4fb671d9..72f089c13a4 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -528,6 +528,10 @@ pub fn vortex_tensor::vector::AnyVector::try_match<'a>(ext_dtype: &'a vortex_arr pub struct vortex_tensor::vector::Vector +impl vortex_tensor::vector::Vector + +pub fn vortex_tensor::vector::Vector::wrap_storage(storage: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + impl core::clone::Clone for vortex_tensor::vector::Vector pub fn vortex_tensor::vector::Vector::clone(&self) -> vortex_tensor::vector::Vector diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index d521cb5d86a..bb194a7796a 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -15,7 +15,6 @@ use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Extension; -use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::DictArray; @@ -23,8 +22,6 @@ use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::dtype::Nullability; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::extension::EmptyMetadata; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_buffer::BufferMut; @@ -242,7 +239,7 @@ pub unsafe fn turboquant_encode_unchecked( Validity::NonNullable, 0, )?; - let empty_padded_vector = wrap_padded_as_vector(empty_fsl.into_array())?; + let empty_padded_vector = Vector::wrap_storage(empty_fsl.into_array())?; let sorf_options = SorfOptions { seed, @@ -258,7 +255,7 @@ pub unsafe fn turboquant_encode_unchecked( let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?; let quantized_fsl = build_quantized_fsl(num_rows, core.all_indices, &core.centroids, core.padded_dim)?; - let padded_vector = wrap_padded_as_vector(quantized_fsl)?; + let padded_vector = Vector::wrap_storage(quantized_fsl)?; let sorf_options = SorfOptions { seed, @@ -269,13 +266,6 @@ pub unsafe fn turboquant_encode_unchecked( Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) } -/// Wrap an `FSL` in a [`Vector`](crate::vector::Vector) extension so it can be -/// passed as the child of [`SorfTransform`], which expects a `Vector` input. -fn wrap_padded_as_vector(fsl: ArrayRef) -> VortexResult { - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl).into_array()) -} - /// Apply the full TurboQuant compression pipeline to a [`Vector`](crate::vector::Vector) /// extension array: normalize the rows via [`normalize_as_l2_denorm`], quantize the normalized /// child via [`turboquant_encode_unchecked`], and reattach the stored norms as the outer diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index 26a56c245dc..2c0ca420d8d 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -143,9 +143,7 @@ impl ScalarFnVTable for SorfTransform { validity, 0, )?; - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + Vector::wrap_storage(fsl.into_array()) }); } @@ -330,7 +328,5 @@ fn inverse_rotate_typed( let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new(elements.into_array(), dim_u32, validity, num_rows)?; - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + Vector::wrap_storage(fsl.into_array()) } diff --git a/vortex-tensor/src/vector/mod.rs b/vortex-tensor/src/vector/mod.rs index 3c6a8a8c8cc..14f3b9249ce 100644 --- a/vortex-tensor/src/vector/mod.rs +++ b/vortex-tensor/src/vector/mod.rs @@ -3,10 +3,36 @@ //! Vector extension type for fixed-length float vectors (e.g., embeddings). +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::extension::EmptyMetadata; +use vortex_error::VortexResult; + /// The Vector extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct Vector; +impl Vector { + /// Wrap a `FixedSizeList`-valued `storage` array in a [`Vector`] extension array. + /// + /// The storage's dtype is reused verbatim for the extension's storage dtype, so the caller + /// is responsible for having already constructed an FSL with the float element ptype and + /// non-nullable elements that [`Vector::validate_dtype`](ExtVTable::validate_dtype) requires. + /// + /// [`ExtVTable::validate_dtype`]: vortex_array::dtype::extension::ExtVTable::validate_dtype + /// + /// # Errors + /// + /// Returns an error if `storage` does not satisfy [`Vector`]'s storage-dtype contract (e.g. + /// it is not a `FixedSizeList` of non-nullable floats). + pub fn wrap_storage(storage: ArrayRef) -> VortexResult { + let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + } +} + mod matcher; pub use matcher::AnyVector; diff --git a/vortex-tensor/src/vector_search.rs b/vortex-tensor/src/vector_search.rs index e7c717d87b5..8df955e476e 100644 --- a/vortex-tensor/src/vector_search.rs +++ b/vortex-tensor/src/vector_search.rs @@ -46,13 +46,10 @@ use vortex_array::ArrayRef; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::ExtensionArray; use vortex_array::builtins::ArrayBuiltins; use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::extension::EmptyMetadata; use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::fns::operators::Operator; @@ -106,11 +103,8 @@ pub fn build_constant_query_vector>( .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) .collect(); let storage_scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let storage = ConstantArray::new(storage_scalar, num_rows).into_array(); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + Vector::wrap_storage(storage) } /// Build the lazy similarity-search expression tree for a prepared database array and a From 532e1da79ec03ca9436ce76bafc59919fc7215a4 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 13:22:38 +0000 Subject: [PATCH 10/15] Drop normalize_and_encode test wrapper The `normalize_and_encode` helper was a one-line forwarder to `turboquant_compress` after the pipeline was extracted. Inline the call at each test site (passing `ext.into_array()` since `ext` was never used afterwards, except for one test with `as_ref().clone()`) and delete the wrapper. Signed-off-by: Claude --- .../src/encodings/turboquant/tests/compute.rs | 10 +++++----- vortex-tensor/src/encodings/turboquant/tests/mod.rs | 12 +----------- .../src/encodings/turboquant/tests/nullable.rs | 8 ++++---- .../src/encodings/turboquant/tests/roundtrip.rs | 6 +++--- .../src/encodings/turboquant/tests/structural.rs | 12 ++++++------ 5 files changed, 19 insertions(+), 29 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/tests/compute.rs b/vortex-tensor/src/encodings/turboquant/tests/compute.rs index 0a9e0ab7a18..26165ffe3f7 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/compute.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/compute.rs @@ -44,7 +44,7 @@ fn slice_preserves_data() -> VortexResult<()> { num_rounds: 4, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; // Full decompress then slice. let mut ctx = SESSION.create_execution_ctx(); @@ -89,7 +89,7 @@ fn scalar_at_matches_decompress() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let full_decoded = encoded.clone().execute::(&mut ctx)?; @@ -112,7 +112,7 @@ fn l2_norm_readthrough() -> VortexResult<()> { num_rounds: 5, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); // Stored norms should match the actual L2 norms of the input. @@ -150,7 +150,7 @@ fn l2_norm_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()> num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); let stored_norms: PrimitiveArray = norms_child.execute(&mut ctx)?; @@ -187,7 +187,7 @@ fn cosine_similarity_readthrough_is_authoritative_for_lossy_storage() -> VortexR num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let encoded_cos = execute_cosine_similarity(encoded.clone(), encoded.clone(), num_rows, &mut ctx)?; diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs index 58de354561b..214ff792050 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/mod.rs @@ -78,15 +78,6 @@ fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray { ExtensionArray::new(ext_dtype, fsl.clone().into_array()) } -/// Full encode pipeline: normalize → TQ-encode → wrap in L2Denorm. -fn normalize_and_encode( - ext: &ExtensionArray, - config: &TurboQuantConfig, - ctx: &mut vortex_array::ExecutionCtx, -) -> VortexResult { - turboquant_compress(ext.as_ref().clone(), config, ctx) -} - /// Unwrap an L2Denorm ScalarFnArray into (sorf_child, norms_child). fn unwrap_l2denorm(encoded: &ArrayRef) -> (ArrayRef, ArrayRef) { let sfn = encoded @@ -166,8 +157,7 @@ fn encode_decode( let prim = fsl.elements().clone().execute::(&mut ctx)?; prim.as_slice::().to_vec() }; - let ext = make_vector_ext(fsl); - let encoded = normalize_and_encode(&ext, config, &mut ctx)?; + let encoded = turboquant_compress(make_vector_ext(fsl).into_array(), config, &mut ctx)?; let decoded_ext = encoded.execute::(&mut ctx)?; let decoded_fsl = decoded_ext .storage_array() diff --git a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs index 6fc19bb93ec..9bef41807e0 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs @@ -27,7 +27,7 @@ fn nullable_vectors_roundtrip() -> VortexResult<()> { num_rounds: 4, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; assert_eq!(encoded.len(), 10); assert!(encoded.dtype().is_nullable()); @@ -88,7 +88,7 @@ fn nullable_norms_match_validity() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); let norms_validity = norms_child.validity()?; @@ -118,7 +118,7 @@ fn nullable_l2_norm_readthrough() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let norm_sfn = L2Norm::try_new_array(encoded, 5)?; let norms: PrimitiveArray = norm_sfn.into_array().execute(&mut ctx)?; @@ -160,7 +160,7 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let sliced = encoded.slice(1..6)?; assert_eq!(sliced.len(), 5); diff --git a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs index cd61d9193da..e9a61b3dd01 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs @@ -130,7 +130,7 @@ fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let decoded = encoded.execute::(&mut ctx)?; assert_eq!(decoded.len(), num_rows); Ok(()) @@ -255,7 +255,7 @@ fn f64_input_encodes_successfully() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); assert_eq!(norms_child.len(), num_rows); Ok(()) @@ -288,7 +288,7 @@ fn f16_input_encodes_successfully() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); assert_eq!(norms_child.len(), num_rows); diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index 87b59836b38..3429bbf1fb5 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -24,7 +24,7 @@ fn stored_centroids_match_computed() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let (_codes, centroids, _norms) = unwrap_codes_centroids_norms(&encoded, &mut ctx)?; let stored = centroids.as_slice::(); @@ -52,7 +52,7 @@ fn seed_deterministic_rotation_produces_correct_decode() -> VortexResult<()> { // Encode twice with the same seed → should produce identical results. let mut ctx = SESSION.create_execution_ctx(); - let encoded1 = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded1 = turboquant_compress(ext.as_ref().clone(), &config, &mut ctx)?; let decoded1 = encoded1.execute::(&mut ctx)?; let fsl1 = decoded1 .storage_array() @@ -64,7 +64,7 @@ fn seed_deterministic_rotation_produces_correct_decode() -> VortexResult<()> { .execute::(&mut ctx)?; let mut ctx = SESSION.create_execution_ctx(); - let encoded2 = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded2 = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let decoded2 = encoded2.execute::(&mut ctx)?; let fsl2 = decoded2 .storage_array() @@ -94,7 +94,7 @@ fn encoded_dtype_is_vector_extension() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; assert!( encoded.dtype().is_extension(), @@ -119,7 +119,7 @@ fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let input_prim = fsl.elements().clone().execute::(&mut ctx)?; let input_f32 = input_prim.as_slice::(); @@ -176,7 +176,7 @@ fn dot_product_quantized_accuracy() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; let input_prim = fsl.elements().clone().execute::(&mut ctx)?; let input_f32 = input_prim.as_slice::(); From e8bfd984d9eff27d1abed0b56aed10df7c235565 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 13:40:00 +0000 Subject: [PATCH 11/15] Eliminate test-helper duplication and the `compress_turboquant` wrapper A pile of test-only and convenience wrappers were thin layers on top of the shared building blocks introduced in the previous commits. Remove them and route every call site through the shared helpers. - Drop `vector_search::compress_turboquant`: a one-liner around `turboquant_compress(data, &TurboQuantConfig::default(), ctx)`. Update the module-level docstring example and the one test caller. - Make `utils::test_helpers::{tensor_array, vector_array, constant_tensor_array, constant_vector_array, l2_denorm_array}` generic over `T: NativePType` (some also `+ Into`) and refactor their bodies through two new private helpers (`flat_fsl`, `fsl_scalar`) so the FSL-construction and constant-FSL-scalar logic each lives in one place. - Add `utils::test_helpers::literal_vector_array` for the second Vector-literal shape (`ConstantArray`) that was duplicated in `l2_norm` (`constant_vector_extension_array`) and `inner_product` (`literal_vector_f32`). Both are deleted. - In `inner_product::tests::constant_query_optimizations`: drop the module-private `vector_f32` / `constant_vector_f32` helpers in favour of the now-generic shared ones, simplify `dict_vector_f32` via `Vector::wrap_storage`, and switch the local `LazyLock` to the crate-wide `tests::SESSION`. - In `l2_denorm` tests: delete `integer_tensor_array` (just `tensor_array(shape, &[i32; ...])` now) and `f16_vector_array` (the per-element `f16::from_f32` map fits inline at the one call site). - In `vector_search` tests: delete the duplicated `vector_array` (use the shared one, now f32-friendly) and the `test_session()` helper (use `crate::tests::SESSION`). - In `turboquant::tests`: change `make_vector_ext` to return `ArrayRef` via `Vector::wrap_storage` instead of materialising an `ExtensionArray`, drop the trivially redundant `unwrap_sorf` and `make_fsl_small` helpers, and inline `unwrap_sorf` at its single call site inside `unwrap_codes_centroids_norms`. Replace `ext.into_array()` and `ext.as_ref().clone()` with the now-equivalent `ext` and `ext.clone()` everywhere. Net diff: ~170 lines removed. Signed-off-by: Claude --- vortex-tensor/public-api.lock | 2 - .../src/encodings/turboquant/tests/compute.rs | 10 +- .../src/encodings/turboquant/tests/mod.rs | 49 ++------ .../encodings/turboquant/tests/nullable.rs | 8 +- .../encodings/turboquant/tests/roundtrip.rs | 30 +++-- .../encodings/turboquant/tests/structural.rs | 12 +- vortex-tensor/src/scalar_fns/inner_product.rs | 103 ++++------------ vortex-tensor/src/scalar_fns/l2_denorm.rs | 33 +----- vortex-tensor/src/scalar_fns/l2_norm.rs | 17 +-- vortex-tensor/src/utils.rs | 112 +++++++++--------- vortex-tensor/src/vector_search.rs | 91 +++----------- 11 files changed, 146 insertions(+), 321 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 72f089c13a4..ddcc3e13a40 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -612,8 +612,6 @@ pub fn vortex_tensor::vector_search::build_constant_query_vector>(data: vortex_array::array::erased::ArrayRef, query: &[T], threshold: T) -> vortex_error::VortexResult -pub fn vortex_tensor::vector_search::compress_turboquant(data: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult - pub const vortex_tensor::SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str pub fn vortex_tensor::initialize(session: &vortex_session::VortexSession) diff --git a/vortex-tensor/src/encodings/turboquant/tests/compute.rs b/vortex-tensor/src/encodings/turboquant/tests/compute.rs index 26165ffe3f7..678607e90fe 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/compute.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/compute.rs @@ -44,7 +44,7 @@ fn slice_preserves_data() -> VortexResult<()> { num_rounds: 4, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; // Full decompress then slice. let mut ctx = SESSION.create_execution_ctx(); @@ -89,7 +89,7 @@ fn scalar_at_matches_decompress() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let full_decoded = encoded.clone().execute::(&mut ctx)?; @@ -112,7 +112,7 @@ fn l2_norm_readthrough() -> VortexResult<()> { num_rounds: 5, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); // Stored norms should match the actual L2 norms of the input. @@ -150,7 +150,7 @@ fn l2_norm_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()> num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); let stored_norms: PrimitiveArray = norms_child.execute(&mut ctx)?; @@ -187,7 +187,7 @@ fn cosine_similarity_readthrough_is_authoritative_for_lossy_storage() -> VortexR num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let encoded_cos = execute_cosine_similarity(encoded.clone(), encoded.clone(), num_rows, &mut ctx)?; diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs index 214ff792050..22b852b66a3 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/mod.rs @@ -16,7 +16,6 @@ use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::Dict; -use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; @@ -25,8 +24,6 @@ use vortex_array::arrays::dict::DictArraySlotsExt; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::extension::EmptyMetadata; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; @@ -34,8 +31,6 @@ use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::turboquant_compress; -use crate::encodings::turboquant::turboquant_encode_unchecked; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::tests::SESSION; use crate::vector::Vector; @@ -71,11 +66,9 @@ fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { } /// Wrap a `FixedSizeListArray` in a `Vector` extension array. -fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray { - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) - .unwrap() - .erased(); - ExtensionArray::new(ext_dtype, fsl.clone().into_array()) +fn make_vector_ext(fsl: &FixedSizeListArray) -> ArrayRef { + Vector::wrap_storage(fsl.clone().into_array()) + .vortex_expect("test FSL satisfies Vector storage constraints") } /// Unwrap an L2Denorm ScalarFnArray into (sorf_child, norms_child). @@ -83,17 +76,7 @@ fn unwrap_l2denorm(encoded: &ArrayRef) -> (ArrayRef, ArrayRef) { let sfn = encoded .as_opt::() .expect("expected ScalarFnArray (L2Denorm)"); - let sorf_child = sfn.child_at(0).clone(); - let norms_child = sfn.child_at(1).clone(); - (sorf_child, norms_child) -} - -/// Unwrap a SorfTransform ScalarFnArray to get the FSL(Dict) child. -fn unwrap_sorf(sorf: &ArrayRef) -> ArrayRef { - let sfn = sorf - .as_opt::() - .expect("expected ScalarFnArray (SorfTransform)"); - sfn.child_at(0).clone() + (sfn.child_at(0).clone(), sfn.child_at(1).clone()) } /// Navigate the full tree to get (codes, centroids, norms) as flat arrays. @@ -102,7 +85,11 @@ fn unwrap_codes_centroids_norms( ctx: &mut vortex_array::ExecutionCtx, ) -> VortexResult<(PrimitiveArray, PrimitiveArray, PrimitiveArray)> { let (sorf_child, norms_child) = unwrap_l2denorm(encoded); - let padded_vector_child = unwrap_sorf(&sorf_child); + let padded_vector_child = sorf_child + .as_opt::() + .expect("expected SorfTransform ScalarFnArray") + .child_at(0) + .clone(); // Vector wrapping FSL(Dict(codes, centroids)) let padded_vector: ExtensionArray = padded_vector_child.execute(ctx)?; @@ -157,7 +144,7 @@ fn encode_decode( let prim = fsl.elements().clone().execute::(&mut ctx)?; prim.as_slice::().to_vec() }; - let encoded = turboquant_compress(make_vector_ext(fsl).into_array(), config, &mut ctx)?; + let encoded = turboquant_compress(make_vector_ext(fsl), config, &mut ctx)?; let decoded_ext = encoded.execute::(&mut ctx)?; let decoded_fsl = decoded_ext .storage_array() @@ -172,19 +159,3 @@ fn encode_decode( }; Ok((original, decoded_elements)) } - -fn make_fsl_small(dim: usize) -> FixedSizeListArray { - let mut buf = BufferMut::::with_capacity(dim); - for i in 0..dim { - buf.push(i as f32 + 1.0); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - Validity::NonNullable, - 1, - ) - .unwrap() -} diff --git a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs index 9bef41807e0..c5041364ce8 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs @@ -27,7 +27,7 @@ fn nullable_vectors_roundtrip() -> VortexResult<()> { num_rounds: 4, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; assert_eq!(encoded.len(), 10); assert!(encoded.dtype().is_nullable()); @@ -88,7 +88,7 @@ fn nullable_norms_match_validity() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); let norms_validity = norms_child.validity()?; @@ -118,7 +118,7 @@ fn nullable_l2_norm_readthrough() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let norm_sfn = L2Norm::try_new_array(encoded, 5)?; let norms: PrimitiveArray = norm_sfn.into_array().execute(&mut ctx)?; @@ -160,7 +160,7 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let sliced = encoded.slice(1..6)?; assert_eq!(sliced.len(), 5); diff --git a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs index e9a61b3dd01..c3335f35b45 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs @@ -4,6 +4,7 @@ use rstest::rstest; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; @@ -12,6 +13,8 @@ use vortex_buffer::BufferMut; use vortex_error::VortexResult; use super::*; +use crate::encodings::turboquant::turboquant_encode_unchecked; +use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; #[rstest] #[case(128, 1)] @@ -130,7 +133,7 @@ fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let decoded = encoded.execute::(&mut ctx)?; assert_eq!(decoded.len(), num_rows); Ok(()) @@ -141,7 +144,17 @@ fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { #[case(64)] #[case(127)] fn rejects_dimension_below_128(#[case] dim: usize) { - let fsl = make_fsl_small(dim); + let elements = PrimitiveArray::new::( + BufferMut::from_iter((0..dim).map(|i| i as f32 + 1.0)).freeze(), + Validity::NonNullable, + ); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into().expect("dim fits u32"), + Validity::NonNullable, + 1, + ) + .unwrap(); let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 2, @@ -149,9 +162,8 @@ fn rejects_dimension_below_128(#[case] dim: usize) { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - assert!( - crate::encodings::turboquant::turboquant_encode(ext.as_view(), &config, &mut ctx).is_err() - ); + let view = ext.as_opt::().expect("Vector extension"); + assert!(crate::encodings::turboquant::turboquant_encode(view, &config, &mut ctx).is_err()); } #[rstest] @@ -166,7 +178,7 @@ fn rejects_invalid_bit_width(#[case] bit_width: u8) { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(ext.as_ref().clone(), &mut ctx) + let normalized = normalize_as_l2_denorm(ext.clone(), &mut ctx) .unwrap() .child_at(0) .clone(); @@ -255,7 +267,7 @@ fn f64_input_encodes_successfully() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); assert_eq!(norms_child.len(), num_rows); Ok(()) @@ -288,7 +300,7 @@ fn f16_input_encodes_successfully() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); assert_eq!(norms_child.len(), num_rows); @@ -329,7 +341,7 @@ fn checked_encode_accepts_normalized_f16_input() -> VortexResult<()> { }; let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(ext.as_ref().clone(), &mut ctx)? + let normalized = normalize_as_l2_denorm(ext.clone(), &mut ctx)? .child_at(0) .clone(); let normalized_ext = normalized diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index 3429bbf1fb5..37b857a8ad0 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -24,7 +24,7 @@ fn stored_centroids_match_computed() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_codes, centroids, _norms) = unwrap_codes_centroids_norms(&encoded, &mut ctx)?; let stored = centroids.as_slice::(); @@ -52,7 +52,7 @@ fn seed_deterministic_rotation_produces_correct_decode() -> VortexResult<()> { // Encode twice with the same seed → should produce identical results. let mut ctx = SESSION.create_execution_ctx(); - let encoded1 = turboquant_compress(ext.as_ref().clone(), &config, &mut ctx)?; + let encoded1 = turboquant_compress(ext.clone(), &config, &mut ctx)?; let decoded1 = encoded1.execute::(&mut ctx)?; let fsl1 = decoded1 .storage_array() @@ -64,7 +64,7 @@ fn seed_deterministic_rotation_produces_correct_decode() -> VortexResult<()> { .execute::(&mut ctx)?; let mut ctx = SESSION.create_execution_ctx(); - let encoded2 = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded2 = turboquant_compress(ext, &config, &mut ctx)?; let decoded2 = encoded2.execute::(&mut ctx)?; let fsl2 = decoded2 .storage_array() @@ -94,7 +94,7 @@ fn encoded_dtype_is_vector_extension() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; assert!( encoded.dtype().is_extension(), @@ -119,7 +119,7 @@ fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let input_prim = fsl.elements().clone().execute::(&mut ctx)?; let input_f32 = input_prim.as_slice::(); @@ -176,7 +176,7 @@ fn dot_product_quantized_accuracy() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = turboquant_compress(ext.into_array(), &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let input_prim = fsl.elements().clone().execute::(&mut ctx)?; let input_f32 = input_prim.as_slice::(); diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 8fc71e4fd86..89c27de32f7 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -912,15 +912,11 @@ mod tests { reason = "tests build small fixtures with deterministic in-range indices" )] mod constant_query_optimizations { - use std::sync::LazyLock; - use rstest::rstest; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::Constant; - use vortex_array::arrays::ConstantArray; - use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; @@ -928,67 +924,24 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; - use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; - use vortex_array::scalar::Scalar; - use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::inner_product::constant_tensor_storage; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; use crate::scalar_fns::sorf_transform::SorfTransform; + use crate::tests::SESSION; use crate::utils::extract_flat_elements; + use crate::utils::test_helpers::constant_vector_array; + use crate::utils::test_helpers::literal_vector_array; + use crate::utils::test_helpers::vector_array; use crate::vector::Vector; - static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); - - /// Compact f32 Vector extension over a column-major `elements` slice. - fn vector_f32(dim: u32, elements: &[f32]) -> VortexResult { - let row_count = elements.len() / dim as usize; - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - - /// Compact constant-backed f32 Vector extension with a single stored row. - fn constant_vector_f32(elements: &[f32], len: usize) -> VortexResult { - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let storage = ConstantArray::new(storage_scalar, len).into_array(); - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) - } - - /// Expression-literal shape: a ConstantArray whose scalar itself is a Vector extension. - fn literal_vector_f32(elements: &[f32], len: usize) -> ArrayRef { - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let vector_scalar = Scalar::extension::(EmptyMetadata, storage_scalar); - ConstantArray::new(vector_scalar, len).into_array() - } - - /// Build an `ExtensionArray>` whose storage is - /// `FSL(DictArray(codes: u8, values: f32))`. This mirrors the shape that - /// TurboQuant produces as the SorfTransform child. + /// Build a `Vector` whose storage is `FSL(DictArray(codes: u8, values: + /// f32))`. This mirrors the shape that TurboQuant produces as the SorfTransform child. fn dict_vector_f32(list_size: u32, codes: &[u8], values: &[f32]) -> VortexResult { let num_rows = codes.len() / list_size as usize; let codes_arr = @@ -1004,9 +957,7 @@ mod tests { Validity::NonNullable, num_rows, )?; - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + Vector::wrap_storage(fsl.into_array()) } /// Execute an inner product and return the flat `f32` results. @@ -1092,7 +1043,7 @@ mod tests { #[test] fn constant_tensor_storage_accepts_extension_scalar_literal() -> VortexResult<()> { - let literal = literal_vector_f32(&[1.0, 2.0, 3.0], 5); + let literal = literal_vector_array(&[1.0f32, 2.0, 3.0], 5); let storage = constant_tensor_storage(&literal).expect("literal vector should be recognized"); @@ -1128,7 +1079,7 @@ mod tests { // Query has `dim` elements. let query_elems: Vec = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect(); - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = constant_vector_array(&query_elems, num_rows)?; // Ground truth: decode LHS to plain f32 vectors, dot each with the query. let decoded = decode_sorf_dict( @@ -1166,7 +1117,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = (0..dim).map(|i| (i as f32 * 0.2).cos()).collect(); - let const_lhs = constant_vector_f32(&query_elems, num_rows)?; + let const_lhs = constant_vector_array(&query_elems, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1204,7 +1155,7 @@ mod tests { assert_eq!(padded_dim, dim as usize); let query_elems: Vec = (0..dim).map(|i| i as f32 * 0.01 - 0.5).collect(); - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = constant_vector_array(&query_elems, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1241,7 +1192,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = vec![0.0; dim as usize]; - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = constant_vector_array(&query_elems, num_rows)?; let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; assert_eq!(actual.len(), 0); @@ -1264,7 +1215,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|i| (i as f32 + 1.0) * 0.3).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = constant_vector_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1294,7 +1245,7 @@ mod tests { let dict_rhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = vec![0.5, -1.0, 2.5, -0.25]; - let const_lhs = constant_vector_f32(&query, num_rows)?; + let const_lhs = constant_vector_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1339,12 +1290,10 @@ mod tests { Validity::NonNullable, num_rows, )?; - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - let dict_lhs = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array(); + let dict_lhs = Vector::wrap_storage(fsl.into_array())?; let query: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = constant_vector_array(&query, num_rows)?; // Build expected by decoding by hand. let expected: Vec = (0..num_rows) @@ -1372,10 +1321,10 @@ mod tests { let lhs_elems: Vec = (0..num_rows * dim as usize) .map(|i| i as f32 * 0.25) .collect(); - let plain_lhs = vector_f32(dim, &lhs_elems)?; + let plain_lhs = vector_array(dim, &lhs_elems)?; let query: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = constant_vector_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1402,7 +1351,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = vec![0.0; 4]; - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = constant_vector_array(&query, num_rows)?; let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; assert_eq!(actual.len(), 0); @@ -1422,7 +1371,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = (0..dim).map(|i| ((i as f32) * 0.15).sin() * 0.4).collect(); - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = constant_vector_array(&query_elems, num_rows)?; // Ground truth via full decode + naive dot. let decoded = decode_sorf_dict( @@ -1495,7 +1444,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = constant_vector_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1539,7 +1488,7 @@ mod tests { // has cancellation. let mut rng = XorShift64::new(seed ^ 0xABCD_1234); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = constant_vector_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1585,7 +1534,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = constant_vector_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1635,7 +1584,7 @@ mod tests { SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array(); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = constant_vector_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1685,7 +1634,7 @@ mod tests { let mut rng = XorShift64::new(seed ^ (num_rounds as u64)); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = constant_vector_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1719,7 +1668,7 @@ mod tests { let mut rng = XorShift64::new(seed); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_lhs = constant_vector_f32(&query, num_rows)?; + let const_lhs = constant_vector_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 323570076d0..379f2095d46 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -765,16 +765,12 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; use vortex_array::extension::datetime::Date; use vortex_array::extension::datetime::TimeUnit; use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; - use vortex_buffer::Buffer; use vortex_error::VortexResult; - use crate::fixed_shape::FixedShapeTensor; - use crate::fixed_shape::FixedShapeTensorMetadata; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows; @@ -784,7 +780,6 @@ mod tests { use crate::utils::test_helpers::constant_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; - use crate::vector::Vector; /// Evaluates L2 denorm on a tensor/vector array and returns the executed array. fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef, len: usize) -> VortexResult { @@ -793,20 +788,6 @@ mod tests { result.into_array().execute(&mut ctx) } - fn integer_tensor_array(shape: &[usize], elements: &[i32]) -> VortexResult { - let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); - let row_count = elements.len() / list_size as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); - - let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); - let ext_dtype = - ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - fn non_tensor_extension_array() -> VortexResult { let storage = PrimitiveArray::from_iter([1i32, 2]).into_array(); let ext_dtype = @@ -814,16 +795,6 @@ mod tests { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - fn f16_vector_array(dim: u32, elements: &[f32]) -> VortexResult { - let row_count = elements.len() / dim as usize; - let values: Vec<_> = elements.iter().copied().map(half::f16::from_f32).collect(); - let elems: ArrayRef = Buffer::copy_from(values.as_slice()).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - fn tensor_snapshot(array: ArrayRef) -> VortexResult<(DType, Vec, Vec)> { let mut ctx = SESSION.create_execution_ctx(); let ext: ExtensionArray = array.execute(&mut ctx)?; @@ -912,7 +883,7 @@ mod tests { #[test] fn l2_denorm_rejects_integer_tensor_lhs() -> VortexResult<()> { - let lhs = integer_tensor_array(&[2], &[1, 2, 3, 4])?; + let lhs = tensor_array(&[2], &[1i32, 2, 3, 4])?; let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); @@ -934,7 +905,7 @@ mod tests { #[test] fn validate_l2_normalized_rows_accepts_normalized_f16_input() -> VortexResult<()> { - let input = f16_vector_array(2, &[3.0, 4.0, 0.0, 0.0])?; + let input = vector_array(2, &[3.0f32, 4.0, 0.0, 0.0].map(half::f16::from_f32))?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; validate_l2_normalized_rows(&roundtrip.child_at(0).clone(), &mut ctx)?; diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 94c0b41f3d2..730bc67b976 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -290,6 +290,7 @@ mod tests { use crate::scalar_fns::l2_norm::L2Norm; use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::literal_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; use crate::vector::Vector; @@ -363,27 +364,13 @@ mod tests { Ok(()) } - /// Builds a [`ConstantArray`] whose scalar is a [`Vector`] extension scalar wrapping a - /// fixed-size list of `elements`, broadcast to `len` rows. - fn constant_vector_extension_array(elements: &[f64], len: usize) -> ArrayRef { - let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let ext_scalar = Scalar::extension::(EmptyMetadata, storage_scalar); - ConstantArray::new(ext_scalar, len).into_array() - } - /// A constant input whose scalar is a non-null tensor should short-circuit to a /// [`ConstantArray`] output whose scalar is the precomputed norm. Uses [`execute_until`] so /// execution stops at the [`Constant`] encoding instead of canonicalizing into a /// [`PrimitiveArray`]. #[test] fn constant_non_null_input_yields_constant_output() -> VortexResult<()> { - let input = constant_vector_extension_array(&[3.0, 4.0], 4); + let input = literal_vector_array(&[3.0f64, 4.0], 4); let scalar_fn = L2Norm::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![input], 4)?.into_array(); diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 866d7d2325b..ff07a700738 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -177,10 +177,10 @@ pub mod test_helpers { use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::dtype::DType; + use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; - use vortex_array::dtype::PType; use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; + use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; use vortex_buffer::Buffer; @@ -191,92 +191,90 @@ pub mod test_helpers { use crate::scalar_fns::l2_denorm::L2Denorm; use crate::vector::Vector; - /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. + /// Builds a `FixedSizeList` storage array from flat `elements`. The row count is + /// inferred from `elements.len() / list_size`. + fn flat_fsl(elements: &[T], list_size: u32) -> ArrayRef { + let row_count = elements.len() / list_size as usize; + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count).into_array() + } + + /// Builds an FSL-valued [`Scalar`] from `elements` for use as a constant query. + fn fsl_scalar>(elements: &[T]) -> Scalar { + let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable) + } + + /// Builds a [`FixedShapeTensor`] extension array from flat `elements` and a logical shape. /// /// The number of rows is inferred from the total element count divided by the product of the /// shape dimensions. For 0-dimensional tensors (scalar), each element is one row. - pub fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult { + pub fn tensor_array(shape: &[usize], elements: &[T]) -> VortexResult { let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); - let row_count = elements.len() / list_size as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); - + let storage = flat_fsl(elements, list_size); let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); let ext_dtype = - ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + ExtDType::::try_new(metadata, storage.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size. - pub fn vector_array(dim: u32, elements: &[f64]) -> VortexResult { - let row_count = elements.len() / dim as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + /// Builds a [`Vector`] extension array from flat `elements` and a vector dimension size. + pub fn vector_array(dim: u32, elements: &[T]) -> VortexResult { + Vector::wrap_storage(flat_fsl(elements, dim)) } /// Builds a [`FixedShapeTensor`] extension array whose storage is a [`ConstantArray`], /// representing a single query tensor broadcast to `len` rows. - pub fn constant_tensor_array( + pub fn constant_tensor_array>( shape: &[usize], - elements: &[f64], + elements: &[T], len: usize, ) -> VortexResult { - let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - let storage = ConstantArray::new(storage_scalar, len).into_array(); - + let storage = ConstantArray::new(fsl_scalar(elements), len).into_array(); let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); let ext_dtype = ExtDType::::try_new(metadata, storage.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`], representing a /// single query vector broadcast to `len` rows. - pub fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult { - let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - let storage = ConstantArray::new(storage_scalar, len).into_array(); - - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); + pub fn constant_vector_array>( + elements: &[T], + len: usize, + ) -> VortexResult { + Vector::wrap_storage(ConstantArray::new(fsl_scalar(elements), len).into_array()) + } - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + /// Builds a [`ConstantArray`] whose scalar is itself a [`Vector`] extension scalar, broadcast + /// to `len` rows. This is the shape produced by an `lit(vector_scalar)` literal expression — + /// the constant lives at the extension level rather than inside the FSL storage, in contrast + /// to [`constant_vector_array`]. + pub fn literal_vector_array>( + elements: &[T], + len: usize, + ) -> ArrayRef { + use vortex_array::extension::EmptyMetadata; + let ext_scalar = Scalar::extension::(EmptyMetadata, fsl_scalar(elements)); + ConstantArray::new(ext_scalar, len).into_array() } - /// Creates an [`L2Denorm`] scalar function array from pre-normalized f64 tensor elements and - /// f64 norms. The caller must ensure every row of `normalized_elements` is unit-norm or zero. - pub fn l2_denorm_array( + /// Creates an [`L2Denorm`] scalar function array from pre-normalized tensor elements and + /// matching norms. The caller must ensure every row of `normalized_elements` is unit-norm or + /// zero. + pub fn l2_denorm_array( shape: &[usize], - normalized_elements: &[f64], - norms: &[f64], + normalized_elements: &[T], + norms: &[T], ctx: &mut ExecutionCtx, ) -> VortexResult { let len = norms.len(); let normalized = tensor_array(shape, normalized_elements)?; - let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); + let norms = PrimitiveArray::new(Buffer::copy_from(norms), Validity::NonNullable).into_array(); Ok(L2Denorm::try_new_array(normalized, norms, len, ctx)?.into_array()) } diff --git a/vortex-tensor/src/vector_search.rs b/vortex-tensor/src/vector_search.rs index 8df955e476e..4323cede395 100644 --- a/vortex-tensor/src/vector_search.rs +++ b/vortex-tensor/src/vector_search.rs @@ -4,13 +4,11 @@ //! Reusable helpers for building brute-force vector similarity search expressions over //! [`Vector`] extension arrays. //! -//! This module exposes three small building blocks that together make it straightforward to -//! stand up a cosine-similarity-plus-threshold scan on top of a prepared data array: +//! This module exposes two small building blocks that, combined with +//! [`turboquant_compress`](crate::encodings::turboquant::turboquant_compress), make it +//! straightforward to stand up a cosine-similarity-plus-threshold scan on top of a prepared +//! data array: //! -//! - [`compress_turboquant`] applies the canonical TurboQuant encoding pipeline -//! (`L2Denorm(SorfTransform(FSL(Dict(codes, centroids))), norms)`) to a raw -//! `Vector` array without requiring the caller to plumb the -//! `unstable_encodings` feature flag on the `vortex` facade. //! - [`build_constant_query_vector`] wraps a single query vector into a //! [`Vector`] extension array whose storage is a [`ConstantArray`] broadcast //! across `num_rows` rows. This is the shape expected by @@ -28,11 +26,12 @@ //! use vortex_array::{ArrayRef, VortexSessionExecute}; //! use vortex_array::arrays::BoolArray; //! use vortex_session::VortexSession; -//! use vortex_tensor::vector_search::{build_similarity_search_tree, compress_turboquant}; +//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_compress}; +//! use vortex_tensor::vector_search::build_similarity_search_tree; //! //! fn run(session: &VortexSession, data: ArrayRef, query: &[f32]) -> anyhow::Result<()> { //! let mut ctx = session.create_execution_ctx(); -//! let data = compress_turboquant(data, &mut ctx)?; +//! let data = turboquant_compress(data, &TurboQuantConfig::default(), &mut ctx)?; //! let tree = build_similarity_search_tree(data, query, 0.8)?; //! let _matches: BoolArray = tree.execute(&mut ctx)?; //! Ok(()) @@ -43,7 +42,6 @@ //! [`CosineSimilarity::try_new_array`]: crate::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; use vortex_array::builtins::ArrayBuiltins; @@ -55,32 +53,9 @@ use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::fns::operators::Operator; use vortex_error::VortexResult; -use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::turboquant_compress; use crate::scalar_fns::cosine_similarity::CosineSimilarity; use crate::vector::Vector; -/// Apply the canonical TurboQuant encoding pipeline to a `Vector` array. -/// -/// The returned array has the shape -/// `L2Denorm(SorfTransform(FSL(Dict(codes, centroids))), norms)` — exactly what -/// [`crate::encodings::turboquant::TurboQuantScheme`] produces when invoked through -/// `BtrBlocksCompressorBuilder::with_turboquant()`, but without requiring callers to enable -/// the `unstable_encodings` feature on the `vortex` facade. -/// -/// The input `data` must be a [`Vector`] extension array whose element type is `f32` and whose -/// dimensionality is at least -/// [`turboquant::MIN_DIMENSION`](crate::encodings::turboquant::MIN_DIMENSION). The TurboQuant -/// configuration used is [`TurboQuantConfig::default()`] (8-bit codes, 3 SORF rounds, seed 42). -/// -/// # Errors -/// -/// Returns an error if `data` is not a [`Vector`] extension array, if normalization fails, or -/// if the underlying TurboQuant encoder rejects the input shape. -pub fn compress_turboquant(data: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - turboquant_compress(data, &TurboQuantConfig::default(), ctx) -} - /// Build a [`Vector`] extension array whose storage is a [`ConstantArray`] broadcasting a single /// query vector across `num_rows` rows. /// @@ -150,54 +125,18 @@ pub fn build_similarity_search_tree>( #[cfg(test)] mod tests { - use vortex_array::ArrayRef; - use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::BoolArray; use vortex_array::arrays::Extension; - use vortex_array::arrays::ExtensionArray; - use vortex_array::arrays::FixedSizeListArray; - use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::bool::BoolArrayExt; - use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; - use vortex_array::session::ArraySession; - use vortex_array::validity::Validity; - use vortex_buffer::BufferMut; use vortex_error::VortexResult; - use vortex_session::VortexSession; use super::build_constant_query_vector; use super::build_similarity_search_tree; - use super::compress_turboquant; - use crate::vector::Vector; - - /// Build a `Vector` extension array from a flat f32 slice. Each contiguous - /// group of `DIM` values becomes one row. - fn vector_array(dim: u32, values: &[f32]) -> VortexResult { - let dim_usize = dim as usize; - assert_eq!(values.len() % dim_usize, 0); - let num_rows = values.len() / dim_usize; - - let mut buf = BufferMut::::with_capacity(values.len()); - for &v in values { - buf.push(v); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim, - Validity::NonNullable, - num_rows, - )?; - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - - fn test_session() -> VortexSession { - VortexSession::empty().with::() - } + use crate::encodings::turboquant::TurboQuantConfig; + use crate::encodings::turboquant::turboquant_compress; + use crate::tests::SESSION; + use crate::utils::test_helpers::vector_array; #[test] fn constant_query_vector_has_vector_extension_dtype() -> VortexResult<()> { @@ -215,7 +154,7 @@ mod tests { let data = vector_array( 3, &[ - 1.0, 0.0, 0.0, // + 1.0f32, 0.0, 0.0, // 0.0, 1.0, 0.0, // 0.0, 0.0, 1.0, // 1.0, 0.0, 0.0, // @@ -224,7 +163,7 @@ mod tests { let query = [1.0f32, 0.0, 0.0]; let tree = build_similarity_search_tree(data, &query, 0.5)?; - let mut ctx = test_session().create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let result: BoolArray = tree.execute(&mut ctx)?; let bits = result.to_bit_buffer(); @@ -262,8 +201,8 @@ mod tests { } let data = vector_array(DIM, &values)?; - let mut ctx = test_session().create_execution_ctx(); - let compressed = compress_turboquant(data, &mut ctx)?; + let mut ctx = SESSION.create_execution_ctx(); + let compressed = turboquant_compress(data, &TurboQuantConfig::default(), &mut ctx)?; assert_eq!(compressed.len(), NUM_ROWS); // Build a tree with a low threshold so row 0 (cosine=1.0 exact) matches. From d9a5f1b6cf5e0f9f09f9a4d4c6c430d5babcf097 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 13:51:07 +0000 Subject: [PATCH 12/15] Add `Vector::constant_array` and dedupe more validation/test wrappers Continue the wrapper-removal pass: - Add `Vector::constant_array(elements, len)` next to `wrap_storage` on the `Vector` impl. It captures the "broadcast a single vector value across `len` rows" pattern that previously had three duplicate implementations (the public `vector_search::build_constant_query_vector`, the test helper `utils::test_helpers::constant_vector_array`, and an inline build inside `inner_product` tests). Drop both the public `build_constant_query_vector` and the `constant_vector_array` test helper; every call site now uses `Vector::constant_array` directly. - Collapse the trivially nested `validate_l2_normalized_rows_impl` and its `validate_l2_denorm_children` adapter into a single descriptive function `validate_l2_normalized_rows_against_norms(normalized, norms, ctx)`. `validate_l2_normalized_rows` now calls it with `None` and `try_new_array` calls it with `Some(&norms)` directly. - Simplify the `try_execute_sorf_constant` doc-block: replace the two TODO comments that said the same thing in different words plus the stale `vtable.rs line ~218` reference with a single "F32-only" doc section above the function. - Rewrite the `encodings::turboquant` module-level doctest to use `turboquant_compress` + `Vector::wrap_storage` instead of manually constructing the FSL/extension dtype and stitching the normalize/encode pipeline. - Update two stale comments that still said "stride-0" after `FlatElements` switched to the explicit `is_constant` flag. - Remove the redundant `ext.clone()` calls in `normalize_as_l2_denorm(ext, ...)` test sites that clippy flagged. Net diff: ~30 more lines removed. Signed-off-by: Claude --- vortex-tensor/public-api.lock | 4 +- vortex-tensor/src/encodings/turboquant/mod.rs | 25 ++----- .../encodings/turboquant/tests/roundtrip.rs | 4 +- .../src/scalar_fns/cosine_similarity.rs | 6 +- vortex-tensor/src/scalar_fns/inner_product.rs | 60 ++++++++--------- vortex-tensor/src/scalar_fns/l2_denorm.rs | 34 ++++------ vortex-tensor/src/utils.rs | 11 +--- vortex-tensor/src/vector/mod.rs | 47 ++++++++++++++ vortex-tensor/src/vector_search.rs | 65 +++---------------- 9 files changed, 111 insertions(+), 145 deletions(-) diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index ddcc3e13a40..146364aa52a 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -530,6 +530,8 @@ pub struct vortex_tensor::vector::Vector impl vortex_tensor::vector::Vector +pub fn vortex_tensor::vector::Vector::constant_array>(elements: &[T], len: usize) -> vortex_error::VortexResult + pub fn vortex_tensor::vector::Vector::wrap_storage(storage: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult impl core::clone::Clone for vortex_tensor::vector::Vector @@ -608,8 +610,6 @@ impl core::marker::StructuralPartialEq for vortex_tensor::vector::VectorMatcherM pub mod vortex_tensor::vector_search -pub fn vortex_tensor::vector_search::build_constant_query_vector>(query: &[T], num_rows: usize) -> vortex_error::VortexResult - pub fn vortex_tensor::vector_search::build_similarity_search_tree>(data: vortex_array::array::erased::ArrayRef, query: &[T], threshold: T) -> vortex_error::VortexResult pub const vortex_tensor::SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 09704f0052c..b07f87549a1 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -92,19 +92,13 @@ //! ``` //! use vortex_array::IntoArray; //! use vortex_array::VortexSessionExecute; -//! use vortex_array::arrays::ExtensionArray; //! use vortex_array::arrays::FixedSizeListArray; //! use vortex_array::arrays::PrimitiveArray; -//! use vortex_array::arrays::Extension; -//! use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -//! use vortex_array::dtype::extension::ExtDType; -//! use vortex_array::extension::EmptyMetadata; +//! use vortex_array::session::ArraySession; //! use vortex_array::validity::Validity; //! use vortex_buffer::BufferMut; -//! use vortex_array::session::ArraySession; //! use vortex_session::VortexSession; -//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode_unchecked}; -//! use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm; +//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_compress}; //! use vortex_tensor::vector::Vector; //! //! // Create a Vector extension array of 100 random 128-d vectors. @@ -118,22 +112,13 @@ //! let fsl = FixedSizeListArray::try_new( //! elements.into_array(), dim, Validity::NonNullable, num_rows, //! ).unwrap(); -//! let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) -//! .unwrap().erased(); -//! let ext = ExtensionArray::new(ext_dtype, fsl.into_array()); +//! let vector = Vector::wrap_storage(fsl.into_array()).unwrap(); //! -//! // Normalize, then quantize the normalized child at 2 bits per coordinate. +//! // Normalize and quantize at 2 bits per coordinate in one pass. //! let session = VortexSession::empty().with::(); //! let mut ctx = session.create_execution_ctx(); -//! let l2_denorm = normalize_as_l2_denorm(ext.into_array(), &mut ctx).unwrap(); -//! let normalized = l2_denorm.child_at(0).clone(); -//! -//! let normalized_ext = normalized.as_opt::().unwrap(); //! let config = TurboQuantConfig { bit_width: 2, seed: Some(42), num_rounds: 3 }; -//! // SAFETY: We just normalized the input. -//! let tq = unsafe { -//! turboquant_encode_unchecked(normalized_ext, &config, &mut ctx).unwrap() -//! }; +//! let tq = turboquant_compress(vector, &config, &mut ctx).unwrap(); //! //! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input. //! assert!(tq.nbytes() < 51200); diff --git a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs index c3335f35b45..fccaabe4344 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs @@ -178,7 +178,7 @@ fn rejects_invalid_bit_width(#[case] bit_width: u8) { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(ext.clone(), &mut ctx) + let normalized = normalize_as_l2_denorm(ext, &mut ctx) .unwrap() .child_at(0) .clone(); @@ -341,7 +341,7 @@ fn checked_encode_accepts_normalized_f16_input() -> VortexResult<()> { }; let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(ext.clone(), &mut ctx)? + let normalized = normalize_as_l2_denorm(ext, &mut ctx)? .child_at(0) .clone(); let normalized_ext = normalized diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index a8db1cb79c8..85d16236c8c 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -329,10 +329,10 @@ mod tests { use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; - use crate::utils::test_helpers::constant_vector_array; use crate::utils::test_helpers::l2_denorm_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; + use crate::vector::Vector; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { @@ -491,7 +491,7 @@ mod tests { 1.0, 0.0, 0.0, // vector 3 ], )?; - let query = constant_vector_array(&[1.0, 0.0, 0.0], 4)?; + let query = Vector::constant_array(&[1.0, 0.0, 0.0], 4)?; assert_close( &eval_cosine_similarity(data, query, 4)?, @@ -686,7 +686,7 @@ mod tests { #[test] fn vector_constant_matches_plain() -> VortexResult<()> { // Exercise the `Vector` extension variant through the new pre-pass. - let lhs = constant_vector_array(&[1.0, 2.0, 2.0], 4)?; + let lhs = Vector::constant_array(&[1.0, 2.0, 2.0], 4)?; let rhs = vector_array( 3, &[ diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 89c27de32f7..018dbb01e5d 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -387,16 +387,17 @@ impl InnerProduct { /// from `padded_dim` to `dim` applied by `SorfTransform` and `R` is the SORF forward /// matrix. See the proof in the crate-level docs and in the plan file. /// - /// Returns `Ok(None)` if neither side matches or when `element_ptype` is not `F32`. The - /// caller is expected to fall through to the standard path in that case. + /// Returns `Ok(None)` if neither side matches, when the operand element type is not `F32`, + /// or when the constant side is not a constant-backed tensor extension. The caller is + /// expected to fall through to the standard path in that case. /// - /// # TODO(connor): + /// # F32-only /// - /// This rewrite is only sound for `PType::F32` because `SorfTransform` applies an - /// `f32 -> element_ptype` cast at the end of its execute (see `sorf_transform/vtable.rs` - /// line ~218). For F16/F64 the cast changes the inner product's rounding and would - /// change the semantics of the rewrite. Until we push the cast through `InnerProduct`, - /// this path only fires for F32. + /// TODO(connor): this rewrite is only sound for `PType::F32` because `SorfTransform` + /// applies an `f32 -> element_ptype` cast at the end of its `execute`. For `F16`/`F64` + /// the cast changes the inner product's rounding and the rewrite would not be + /// semantically equivalent. Until we push the cast through `InnerProduct`, both the + /// SorfTransform output ptype and the constant-side element ptype must be `F32` here. fn try_execute_sorf_constant( &self, lhs_ref: &ArrayRef, @@ -414,10 +415,6 @@ impl InnerProduct { return Ok(None); }; - // TODO(connor): pull-through is only sound for F32 because SorfTransform applies an - // `f32 -> element_ptype` cast at the end of its execute. For F16/F64 the rewrite - // would change the inner product's rounding semantics. Fall through so the standard - // path (which does the cast before inner product) handles it. if sorf_view.options.element_ptype != PType::F32 { return Ok(None); } @@ -432,11 +429,9 @@ impl InnerProduct { let seed = sorf_view.options.seed; let padded_dim = dim.next_power_of_two(); - // Extract the single stored row of the constant via the stride-0 short-circuit. + // Extract the single stored row of the constant via the `is_constant` short-circuit. let flat = extract_flat_elements(&const_storage, dim, ctx)?; if flat.ptype() != PType::F32 { - // TODO(connor): as above, f16/f64 are not supported by this rewrite yet. The - // standard path handles them correctly. return Ok(None); } @@ -935,8 +930,7 @@ mod tests { use crate::scalar_fns::sorf_transform::SorfTransform; use crate::tests::SESSION; use crate::utils::extract_flat_elements; - use crate::utils::test_helpers::constant_vector_array; - use crate::utils::test_helpers::literal_vector_array; + use crate::utils::test_helpers::literal_vector_array; use crate::utils::test_helpers::vector_array; use crate::vector::Vector; @@ -1079,7 +1073,7 @@ mod tests { // Query has `dim` elements. let query_elems: Vec = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect(); - let const_rhs = constant_vector_array(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; // Ground truth: decode LHS to plain f32 vectors, dot each with the query. let decoded = decode_sorf_dict( @@ -1117,7 +1111,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = (0..dim).map(|i| (i as f32 * 0.2).cos()).collect(); - let const_lhs = constant_vector_array(&query_elems, num_rows)?; + let const_lhs = Vector::constant_array(&query_elems, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1155,7 +1149,7 @@ mod tests { assert_eq!(padded_dim, dim as usize); let query_elems: Vec = (0..dim).map(|i| i as f32 * 0.01 - 0.5).collect(); - let const_rhs = constant_vector_array(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1192,7 +1186,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = vec![0.0; dim as usize]; - let const_rhs = constant_vector_array(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; assert_eq!(actual.len(), 0); @@ -1215,7 +1209,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|i| (i as f32 + 1.0) * 0.3).collect(); - let const_rhs = constant_vector_array(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1245,7 +1239,7 @@ mod tests { let dict_rhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = vec![0.5, -1.0, 2.5, -0.25]; - let const_lhs = constant_vector_array(&query, num_rows)?; + let const_lhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1293,7 +1287,7 @@ mod tests { let dict_lhs = Vector::wrap_storage(fsl.into_array())?; let query: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let const_rhs = constant_vector_array(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; // Build expected by decoding by hand. let expected: Vec = (0..num_rows) @@ -1324,7 +1318,7 @@ mod tests { let plain_lhs = vector_array(dim, &lhs_elems)?; let query: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let const_rhs = constant_vector_array(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1351,7 +1345,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = vec![0.0; 4]; - let const_rhs = constant_vector_array(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; assert_eq!(actual.len(), 0); @@ -1371,7 +1365,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = (0..dim).map(|i| ((i as f32) * 0.15).sin() * 0.4).collect(); - let const_rhs = constant_vector_array(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; // Ground truth via full decode + naive dot. let decoded = decode_sorf_dict( @@ -1444,7 +1438,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_array(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1488,7 +1482,7 @@ mod tests { // has cancellation. let mut rng = XorShift64::new(seed ^ 0xABCD_1234); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_array(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1534,7 +1528,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_array(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1584,7 +1578,7 @@ mod tests { SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array(); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_array(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1634,7 +1628,7 @@ mod tests { let mut rng = XorShift64::new(seed ^ (num_rounds as u64)); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_array(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1668,7 +1662,7 @@ mod tests { let mut rng = XorShift64::new(seed); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_lhs = constant_vector_array(&query, num_rows)?; + let const_lhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 379f2095d46..b16443e8abe 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -114,7 +114,7 @@ impl L2Denorm { len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - validate_l2_denorm_children(&normalized, &norms, ctx)?; + validate_l2_normalized_rows_against_norms(&normalized, Some(&norms), ctx)?; // SAFETY: We just validated that it is normalized. unsafe { Self::new_array_unchecked(normalized, norms, len) } @@ -521,8 +521,8 @@ pub(crate) fn try_build_constant_l2_denorm( let ext_dtype = input.dtype().as_extension().clone(); let storage_fsl_nullability = storage.dtype().nullability(); - // `extract_flat_elements` takes the stride-0 single-row path for `Constant` storage, so - // this is cheap and does not expand the constant to the full column length. + // `extract_flat_elements` takes the `is_constant` single-row path for `Constant` storage, + // so this is cheap and does not expand the constant to the full column length. let flat = extract_flat_elements(storage, list_size, ctx)?; let (normalized_fsl_scalar, norms_scalar) = match_each_float_ptype!(flat.ptype(), |T| { @@ -603,23 +603,17 @@ fn unit_norm_tolerance(element_ptype: PType) -> f64 { /// Validates that every valid row of `input` is already L2-normalized (either length 1 or 0). pub fn validate_l2_normalized_rows(input: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { - validate_l2_normalized_rows_impl(input, None, ctx) + validate_l2_normalized_rows_against_norms(input, None, ctx) } -/// Validates that the `normalized` and `norms` children jointly satisfy the [`L2Denorm`] -/// invariants, which are: +/// Validates that `normalized` and (when supplied) the matching `norms` jointly satisfy the +/// [`L2Denorm`] invariants: /// -/// - All vectors in the normalized array have length 1 or 0. -/// - If the vector has a norm of 0, then the vector must be all 0s. -fn validate_l2_denorm_children( - normalized: &ArrayRef, - norms: &ArrayRef, - ctx: &mut ExecutionCtx, -) -> VortexResult<()> { - validate_l2_normalized_rows_impl(normalized, Some(norms), ctx) -} - -fn validate_l2_normalized_rows_impl( +/// - Every valid row of `normalized` has L2 norm `1.0` or `0.0` (within element-precision +/// tolerance). +/// - When `norms` is supplied, every stored norm is non-negative and any row whose stored +/// norm is `0.0` is exactly the zero vector in `normalized`. +fn validate_l2_normalized_rows_against_norms( normalized: &ArrayRef, norms: Option<&ArrayRef>, ctx: &mut ExecutionCtx, @@ -777,9 +771,9 @@ mod tests { use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; - use crate::utils::test_helpers::constant_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; + use crate::vector::Vector; /// Evaluates L2 denorm on a tensor/vector array and returns the executed array. fn eval_l2_denorm(normalized: ArrayRef, norms: ArrayRef, len: usize) -> VortexResult { @@ -999,7 +993,7 @@ mod tests { #[test] fn normalize_as_l2_denorm_supports_constant_vectors() -> VortexResult<()> { - let input = constant_vector_array(&[3.0, 4.0], 2)?; + let input = Vector::constant_array(&[3.0, 4.0], 2)?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; let actual = roundtrip.into_array().execute(&mut ctx)?; @@ -1013,7 +1007,7 @@ mod tests { // The constant fast path in `normalize_as_l2_denorm` must produce an `L2Denorm` whose // normalized storage and norms child are both still `ConstantArray`s. This is what // allows downstream ops (cosine similarity, inner product) to short-circuit. - let input = constant_vector_array(&[3.0, 4.0], 16)?; + let input = Vector::constant_array(&[3.0, 4.0], 16)?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index ff07a700738..326c49cc950 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -241,19 +241,10 @@ pub mod test_helpers { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`], representing a - /// single query vector broadcast to `len` rows. - pub fn constant_vector_array>( - elements: &[T], - len: usize, - ) -> VortexResult { - Vector::wrap_storage(ConstantArray::new(fsl_scalar(elements), len).into_array()) - } - /// Builds a [`ConstantArray`] whose scalar is itself a [`Vector`] extension scalar, broadcast /// to `len` rows. This is the shape produced by an `lit(vector_scalar)` literal expression — /// the constant lives at the extension level rather than inside the FSL storage, in contrast - /// to [`constant_vector_array`]. + /// to [`Vector::constant_array`]. pub fn literal_vector_array>( elements: &[T], len: usize, diff --git a/vortex-tensor/src/vector/mod.rs b/vortex-tensor/src/vector/mod.rs index 14f3b9249ce..20b5c51047b 100644 --- a/vortex-tensor/src/vector/mod.rs +++ b/vortex-tensor/src/vector/mod.rs @@ -5,9 +5,15 @@ use vortex_array::ArrayRef; use vortex_array::IntoArray; +use vortex_array::arrays::ConstantArray; use vortex_array::arrays::ExtensionArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; use vortex_array::dtype::extension::ExtDType; use vortex_array::extension::EmptyMetadata; +use vortex_array::scalar::PValue; +use vortex_array::scalar::Scalar; use vortex_error::VortexResult; /// The Vector extension type. @@ -31,6 +37,47 @@ impl Vector { let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } + + /// Build a [`Vector`] extension array whose storage is a [`ConstantArray`] broadcasting a + /// single vector `elements` across `len` rows. + /// + /// This is the array shape that [`CosineSimilarity::try_new_array`] and similar binary tensor + /// scalar functions expect for the constant-query side of a database-vs-query scan: the inner + /// `ScalarFnArray` contract requires both children to have the same length, so the query is + /// broadcast rather than represented as a literal length-1 input. + /// + /// [`CosineSimilarity::try_new_array`]: crate::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array + /// + /// # Errors + /// + /// Returns an error if the [`Vector`] extension dtype rejects the constructed storage dtype. + pub fn constant_array>( + elements: &[T], + len: usize, + ) -> VortexResult { + let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let storage_scalar = + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + Self::wrap_storage(ConstantArray::new(storage_scalar, len).into_array()) + } +} + +#[cfg(test)] +mod ctor_tests { + use vortex_array::arrays::Extension; + + use super::*; + + #[test] + fn constant_array_produces_vector_extension() { + let array = Vector::constant_array(&[1.0f32, 0.0, 0.0, 0.0], 5).unwrap(); + assert_eq!(array.len(), 5); + assert!(array.as_opt::().is_some()); + } } mod matcher; diff --git a/vortex-tensor/src/vector_search.rs b/vortex-tensor/src/vector_search.rs index 4323cede395..5ca319ee6c9 100644 --- a/vortex-tensor/src/vector_search.rs +++ b/vortex-tensor/src/vector_search.rs @@ -4,21 +4,14 @@ //! Reusable helpers for building brute-force vector similarity search expressions over //! [`Vector`] extension arrays. //! -//! This module exposes two small building blocks that, combined with -//! [`turboquant_compress`](crate::encodings::turboquant::turboquant_compress), make it -//! straightforward to stand up a cosine-similarity-plus-threshold scan on top of a prepared -//! data array: +//! [`build_similarity_search_tree`] wires together [`Vector::constant_array`] (which broadcasts +//! the query into the shape expected by [`CosineSimilarity`]) and +//! [`turboquant_compress`](crate::encodings::turboquant::turboquant_compress) (when the data is +//! pre-compressed) into a lazy `Binary(Gt, [CosineSimilarity(data, query), threshold])` +//! expression. //! -//! - [`build_constant_query_vector`] wraps a single query vector into a -//! [`Vector`] extension array whose storage is a [`ConstantArray`] broadcast -//! across `num_rows` rows. This is the shape expected by -//! [`CosineSimilarity::try_new_array`] for the RHS of a database-vs-query scan. -//! - [`build_similarity_search_tree`] wires everything together into a lazy -//! `Binary(Gt, [CosineSimilarity(data, query), threshold])` expression. -//! -//! Executing the tree from [`build_similarity_search_tree`] into a -//! [`BoolArray`](vortex_array::arrays::BoolArray) yields one boolean per row indicating whether -//! that row's cosine similarity to the query exceeds `threshold`. +//! Executing the tree into a [`BoolArray`](vortex_array::arrays::BoolArray) yields one boolean +//! per row indicating whether that row's cosine similarity to the query exceeds `threshold`. //! //! # Example //! @@ -39,13 +32,13 @@ //! ``` //! //! [`Vector`]: crate::vector::Vector -//! [`CosineSimilarity::try_new_array`]: crate::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array +//! [`Vector::constant_array`]: crate::vector::Vector::constant_array +//! [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; use vortex_array::builtins::ArrayBuiltins; -use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::scalar::PValue; @@ -56,32 +49,6 @@ use vortex_error::VortexResult; use crate::scalar_fns::cosine_similarity::CosineSimilarity; use crate::vector::Vector; -/// Build a [`Vector`] extension array whose storage is a [`ConstantArray`] broadcasting a single -/// query vector across `num_rows` rows. -/// -/// The element type is inferred from `T` (e.g. `f32` or `f64`). This is the shape expected for -/// the RHS of a database-vs-query [`CosineSimilarity`] scan: the `ScalarFnArray` contract -/// requires both children to have the same length, so rather than hand-rolling a 1-row input we -/// broadcast the query across the whole database. -/// -/// # Errors -/// -/// Returns an error if the [`Vector`] extension dtype rejects the constructed storage dtype. -pub fn build_constant_query_vector>( - query: &[T], - num_rows: usize, -) -> VortexResult { - let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); - - let children: Vec = query - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let storage = ConstantArray::new(storage_scalar, num_rows).into_array(); - Vector::wrap_storage(storage) -} - /// Build the lazy similarity-search expression tree for a prepared database array and a /// single query vector. /// @@ -113,7 +80,7 @@ pub fn build_similarity_search_tree>( threshold: T, ) -> VortexResult { let num_rows = data.len(); - let query_vec = build_constant_query_vector(query, num_rows)?; + let query_vec = Vector::constant_array(query, num_rows)?; let cosine = CosineSimilarity::try_new_array(data, query_vec, num_rows)?.into_array(); @@ -127,27 +94,15 @@ pub fn build_similarity_search_tree>( mod tests { use vortex_array::VortexSessionExecute; use vortex_array::arrays::BoolArray; - use vortex_array::arrays::Extension; use vortex_array::arrays::bool::BoolArrayExt; use vortex_error::VortexResult; - use super::build_constant_query_vector; use super::build_similarity_search_tree; use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::turboquant_compress; use crate::tests::SESSION; use crate::utils::test_helpers::vector_array; - #[test] - fn constant_query_vector_has_vector_extension_dtype() -> VortexResult<()> { - let query = vec![1.0f32, 0.0, 0.0, 0.0]; - let rhs = build_constant_query_vector(&query, 5)?; - - assert_eq!(rhs.len(), 5); - assert!(rhs.as_opt::().is_some()); - Ok(()) - } - #[test] fn similarity_search_tree_executes_to_bool_array() -> VortexResult<()> { // 4 rows of 3-dim vectors; the first and last match the query [1, 0, 0]. From 65052fe45259084e2d1554e8a011e7830ec37a86 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 13:53:37 +0000 Subject: [PATCH 13/15] Use shared `extract_l2_denorm_children` in `L2Norm` readthrough The `L2Norm` execute path that extracts the stored norms from an `L2Denorm` wrapper was hand-rolling the `as_opt::>()` + `nth_child(1).vortex_expect(...)` dance instead of using the shared `extract_l2_denorm_children` helper that already encodes the same expectations. Switch to the helper. Signed-off-by: Claude --- vortex-tensor/src/scalar_fns/l2_norm.rs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 730bc67b976..057bb3c841a 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -47,6 +47,7 @@ use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::utils::extract_flat_elements; +use crate::utils::extract_l2_denorm_children; use crate::utils::validate_tensor_float_input; /// L2 norm (Euclidean norm) of a tensor or vector column. @@ -139,13 +140,9 @@ impl ScalarFnVTable for L2Norm { // L2Norm(L2Denorm(normalized, norms)) is defined to read back the authoritative stored // norms. Exact callers of lossy encodings like TurboQuant opt into that storage semantics // instead of forcing a decode-and-recompute path here. - if let Some(sfn) = input_ref.as_opt::>() { - let norms = sfn - .nth_child(1) - .vortex_expect("L2Denorm must have 2 children"); - + if input_ref.is::>() { + let (_, norms) = extract_l2_denorm_children(&input_ref); vortex_ensure_eq!(norms.dtype(), &norm_dtype); - return Ok(norms); } From e3ee206396930959b9adfd5a2916b061d0c3d3a1 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 18 Apr 2026 13:57:58 +0000 Subject: [PATCH 14/15] Use `Vector::constant_array` to build SORF rewrite query `InnerProduct::try_execute_sorf_constant` was hand-rolling the `ExtDType::::try_new(EmptyMetadata, FSL(...)) + ConstantArray::new(extension_ref(...), len)` dance to wrap the forward-rotated query as a `Vector` constant. That same shape is exactly what `Vector::constant_array` produces, so collapse the 20-line builder down to a single call and drop the now-unused `Arc`/`ExtDType`/`EmptyMetadata`/`Scalar` imports. Also: - Refresh the `encodings::turboquant::scheme` module docstring so it describes the post-extraction reality (the scheme is a thin adapter over `turboquant_compress`) instead of the pre-extraction normalize + encode pipeline. - Add a context message to the previously bare `vortex_ensure!` that guards the `L2Denorm` constant-norms ptype check. - Replace the misleading "with unchecked indexing" comment on the `try_execute_dict_constant` hot loop with one that matches what the helper actually does (chunked indices that LLVM can prove in-bounds). Signed-off-by: Claude --- .../src/encodings/turboquant/scheme.rs | 6 ++-- vortex-tensor/src/scalar_fns/inner_product.rs | 36 ++++--------------- vortex-tensor/src/scalar_fns/l2_denorm.rs | 6 +++- 3 files changed, 14 insertions(+), 34 deletions(-) diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index ee908378da8..9f1179cc5dd 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -3,8 +3,7 @@ //! TurboQuant compression scheme. //! -//! The scheme first normalizes the input via [`normalize_as_l2_denorm`], then encodes the -//! normalized child via [`turboquant_encode_unchecked`]. The result is: +//! The scheme is a thin [`Scheme`] adapter over [`turboquant_compress`], which produces: //! //! ```text //! ScalarFnArray(L2Denorm, [ @@ -18,8 +17,7 @@ //! //! Decompression is automatic: executing the outer array walks the ScalarFn tree. //! -//! [`normalize_as_l2_denorm`]: crate::scalar_fns::l2_denorm::normalize_as_l2_denorm -//! [`turboquant_encode_unchecked`]: crate::encodings::turboquant::turboquant_encode_unchecked +//! [`turboquant_compress`]: crate::encodings::turboquant::turboquant_compress use vortex_array::ArrayRef; use vortex_array::Canonical; diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index 018dbb01e5d..7c28ffc37de 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -4,7 +4,6 @@ //! Inner product expression for tensor-like types. use std::fmt::Formatter; -use std::sync::Arc; use num_traits::Float; use prost::Message; @@ -32,13 +31,10 @@ use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; -use vortex_array::dtype::extension::ExtDType; use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::expr::and; -use vortex_array::extension::EmptyMetadata; use vortex_array::match_each_float_ptype; -use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; use vortex_array::scalar_fn::EmptyOptions; @@ -443,29 +439,10 @@ impl InnerProduct { let mut rotated_query = vec![0.0f32; padded_dim]; rotation.rotate(&padded_query, &mut rotated_query); - // Build the rewritten constant as a `Vector` extension scalar. We reuse - // the original storage FSL nullability so the new extension dtype stays consistent with - // whatever the original tree expected. - let storage_fsl_nullability = const_storage.dtype().nullability(); - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = rotated_query - .into_iter() - .map(|v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let fsl_scalar = - Scalar::fixed_size_list(element_dtype.clone(), children, storage_fsl_nullability); - - // Build a fresh `Vector` extension dtype. We cannot reuse the - // original extension dtype because that one has `dim`, not `padded_dim`. - let padded_dim_u32 = u32::try_from(padded_dim).vortex_expect("padded_dim fits u32"); - let new_fsl_dtype = DType::FixedSizeList( - Arc::new(element_dtype), - padded_dim_u32, - storage_fsl_nullability, - ); - let new_ext_dtype = ExtDType::::try_new(EmptyMetadata, new_fsl_dtype)?.erased(); - let new_constant = - ConstantArray::new(Scalar::extension_ref(new_ext_dtype, fsl_scalar), len).into_array(); + // Wrap the rotated query as a `Vector` constant broadcast to `len` + // rows. The new extension dtype has `padded_dim` instead of `dim`, matching the + // SorfTransform child we are about to dot it with. + let new_constant = Vector::constant_array(&rotated_query, len)?; // Extract the SorfTransform child (the already-padded Vector). let sorf_child = sorf_view @@ -581,8 +558,9 @@ impl InnerProduct { let values: &[f32] = values_prim.as_slice::(); debug_assert_eq!(codes.len(), len * padded_dim); - // The hot loop is extracted into [`execute_dict_constant_inner_product`] with - // unchecked indexing so the compiler can vectorize the inner gather-accumulate. + // The hot loop is extracted into [`execute_dict_constant_inner_product`] so the + // compiler can prove the chunked indices stay in bounds and vectorize the inner + // gather-accumulate. let out = execute_dict_constant_inner_product(q, values, codes, len, padded_dim); // SAFETY: the buffer length equals `len`, which matches the validity length. diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index b16443e8abe..12fb8ca59b7 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -213,7 +213,11 @@ impl ScalarFnVTable for L2Denorm { if let Some(const_norms) = norms_ref.as_opt::() { let norm_scalar = const_norms.scalar(); - vortex_ensure!(norm_scalar.dtype().is_float()); + vortex_ensure!( + norm_scalar.dtype().is_float(), + "L2Denorm constant norms must be a float scalar, got {}", + norm_scalar.dtype(), + ); if let Some(norm_value) = norm_scalar.value() { return execute_l2_denorm_constant_norms( From ee3e882dcf8a4c7cd70d01936b60adca452092ae Mon Sep 17 00:00:00 2001 From: Claude Date: Sun, 19 Apr 2026 13:34:51 +0000 Subject: [PATCH 15/15] Fix unresolved intra-doc link on `Vector::wrap_storage` The doc comment for `Vector::wrap_storage` referenced `ExtVTable::validate_dtype` as a link target without a path declaration in scope, so `cargo doc` flagged it as an unresolved intra-doc link. Drop the link form and just reference the trait method by name in prose. Signed-off-by: Claude --- vortex-tensor/src/vector/mod.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/vortex-tensor/src/vector/mod.rs b/vortex-tensor/src/vector/mod.rs index 20b5c51047b..1d0bdf4014c 100644 --- a/vortex-tensor/src/vector/mod.rs +++ b/vortex-tensor/src/vector/mod.rs @@ -25,9 +25,7 @@ impl Vector { /// /// The storage's dtype is reused verbatim for the extension's storage dtype, so the caller /// is responsible for having already constructed an FSL with the float element ptype and - /// non-nullable elements that [`Vector::validate_dtype`](ExtVTable::validate_dtype) requires. - /// - /// [`ExtVTable::validate_dtype`]: vortex_array::dtype::extension::ExtVTable::validate_dtype + /// non-nullable elements that [`Vector`]'s `validate_dtype` requires. /// /// # Errors ///