diff --git a/vortex-tensor/public-api.lock b/vortex-tensor/public-api.lock index 96b95e1e91e..146364aa52a 100644 --- a/vortex-tensor/public-api.lock +++ b/vortex-tensor/public-api.lock @@ -80,6 +80,8 @@ pub const vortex_tensor::encodings::turboquant::MIN_DIMENSION: u32 pub fn vortex_tensor::encodings::turboquant::tq_validate_vector_dtype(dtype: &vortex_array::dtype::DType) -> vortex_error::VortexResult +pub fn vortex_tensor::encodings::turboquant::turboquant_compress(input: vortex_array::array::erased::ArrayRef, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult + pub fn vortex_tensor::encodings::turboquant::turboquant_encode(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult pub unsafe fn vortex_tensor::encodings::turboquant::turboquant_encode_unchecked(ext: vortex_array::array::view::ArrayView<'_, vortex_array::arrays::extension::vtable::Extension>, config: &vortex_tensor::encodings::turboquant::TurboQuantConfig, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult @@ -502,7 +504,7 @@ pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::child_name(&sel pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::execute(&self, options: &Self::Options, args: &dyn vortex_array::scalar_fn::vtable::ExecutionArgs, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult -pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::fmt_sql(&self, _options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result +pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::fmt_sql(&self, options: &Self::Options, expr: &vortex_array::expr::expression::Expression, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result pub fn vortex_tensor::scalar_fns::sorf_transform::SorfTransform::id(&self) -> vortex_array::scalar_fn::ScalarFnId @@ -526,6 +528,12 @@ pub fn vortex_tensor::vector::AnyVector::try_match<'a>(ext_dtype: &'a vortex_arr pub struct vortex_tensor::vector::Vector +impl vortex_tensor::vector::Vector + +pub fn vortex_tensor::vector::Vector::constant_array>(elements: &[T], len: usize) -> vortex_error::VortexResult + +pub fn vortex_tensor::vector::Vector::wrap_storage(storage: vortex_array::array::erased::ArrayRef) -> vortex_error::VortexResult + impl core::clone::Clone for vortex_tensor::vector::Vector pub fn vortex_tensor::vector::Vector::clone(&self) -> vortex_tensor::vector::Vector @@ -574,6 +582,8 @@ pub fn vortex_tensor::vector::VectorMatcherMetadata::dimensions(&self) -> u32 pub fn vortex_tensor::vector::VectorMatcherMetadata::element_ptype(&self) -> vortex_array::dtype::ptype::PType +pub fn vortex_tensor::vector::VectorMatcherMetadata::list_size(&self) -> usize + pub fn vortex_tensor::vector::VectorMatcherMetadata::try_new(element_ptype: vortex_array::dtype::ptype::PType, dimensions: u32) -> vortex_error::VortexResult impl core::clone::Clone for vortex_tensor::vector::VectorMatcherMetadata @@ -600,12 +610,8 @@ impl core::marker::StructuralPartialEq for vortex_tensor::vector::VectorMatcherM pub mod vortex_tensor::vector_search -pub fn vortex_tensor::vector_search::build_constant_query_vector>(query: &[T], num_rows: usize) -> vortex_error::VortexResult - pub fn vortex_tensor::vector_search::build_similarity_search_tree>(data: vortex_array::array::erased::ArrayRef, query: &[T], threshold: T) -> vortex_error::VortexResult -pub fn vortex_tensor::vector_search::compress_turboquant(data: vortex_array::array::erased::ArrayRef, ctx: &mut vortex_array::executor::ExecutionCtx) -> vortex_error::VortexResult - pub const vortex_tensor::SCALAR_FN_ARRAY_TENSOR_PLUGIN_ENV: &str pub fn vortex_tensor::initialize(session: &vortex_session::VortexSession) diff --git a/vortex-tensor/src/encodings/turboquant/compress.rs b/vortex-tensor/src/encodings/turboquant/compress.rs index b0173bbe36c..bb194a7796a 100644 --- a/vortex-tensor/src/encodings/turboquant/compress.rs +++ b/vortex-tensor/src/encodings/turboquant/compress.rs @@ -15,16 +15,15 @@ use vortex_array::ArrayView; use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::Extension; -use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::dict::DictArray; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::dtype::Nullability; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::extension::EmptyMetadata; use vortex_array::validity::Validity; +use vortex_buffer::Buffer; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -35,6 +34,8 @@ use crate::encodings::turboquant::MIN_DIMENSION; use crate::encodings::turboquant::centroids::compute_centroid_boundaries; use crate::encodings::turboquant::centroids::find_nearest_centroid; use crate::encodings::turboquant::centroids::get_centroids; +use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; @@ -136,10 +137,8 @@ fn build_quantized_fsl( padded_dim: usize, ) -> VortexResult { let codes = PrimitiveArray::new::(all_indices.freeze(), Validity::NonNullable); - - let mut centroids_buf = BufferMut::::with_capacity(centroids.len()); - centroids_buf.extend_from_slice(centroids); - let centroids_array = PrimitiveArray::new::(centroids_buf.freeze(), Validity::NonNullable); + let centroids_array = + PrimitiveArray::new::(Buffer::copy_from(centroids), Validity::NonNullable); let dict = DictArray::try_new(codes.into_array(), centroids_array.into_array())?; @@ -240,7 +239,7 @@ pub unsafe fn turboquant_encode_unchecked( Validity::NonNullable, 0, )?; - let empty_padded_vector = wrap_padded_as_vector(empty_fsl.into_array())?; + let empty_padded_vector = Vector::wrap_storage(empty_fsl.into_array())?; let sorf_options = SorfOptions { seed, @@ -256,7 +255,7 @@ pub unsafe fn turboquant_encode_unchecked( let core = turboquant_quantize_core(&fsl, seed, config.bit_width, config.num_rounds, ctx)?; let quantized_fsl = build_quantized_fsl(num_rows, core.all_indices, &core.centroids, core.padded_dim)?; - let padded_vector = wrap_padded_as_vector(quantized_fsl)?; + let padded_vector = Vector::wrap_storage(quantized_fsl)?; let sorf_options = SorfOptions { seed, @@ -267,9 +266,42 @@ pub unsafe fn turboquant_encode_unchecked( Ok(SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array()) } -/// Wrap an `FSL` in a [`Vector`](crate::vector::Vector) extension so it can be -/// passed as the child of [`SorfTransform`], which expects a `Vector` input. -fn wrap_padded_as_vector(fsl: ArrayRef) -> VortexResult { - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl).into_array()) +/// Apply the full TurboQuant compression pipeline to a [`Vector`](crate::vector::Vector) +/// extension array: normalize the rows via [`normalize_as_l2_denorm`], quantize the normalized +/// child via [`turboquant_encode_unchecked`], and reattach the stored norms as the outer +/// [`L2Denorm`] wrapper. +/// +/// The returned array has the canonical TurboQuant shape: +/// +/// ```text +/// ScalarFnArray(L2Denorm, [ +/// ScalarFnArray(SorfTransform, [FSL(Dict(codes, centroids))]), +/// norms, +/// ]) +/// ``` +/// +/// # Errors +/// +/// Returns an error if `input` is not a tensor-like extension array, if normalization fails, or +/// if [`turboquant_encode_unchecked`] rejects the input shape. +pub fn turboquant_compress( + input: ArrayRef, + config: &TurboQuantConfig, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let l2_denorm = normalize_as_l2_denorm(input, ctx)?; + let normalized = l2_denorm.child_at(0).clone(); + let norms = l2_denorm.child_at(1).clone(); + let num_rows = l2_denorm.len(); + + let normalized_ext = normalized + .as_opt::() + .vortex_expect("normalize_as_l2_denorm always produces an Extension array child"); + + // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero for null rows). + let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx) }?; + + // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally + // bypass the strict normalized-row validation when reattaching the stored norms. + Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) } diff --git a/vortex-tensor/src/encodings/turboquant/mod.rs b/vortex-tensor/src/encodings/turboquant/mod.rs index 49e53effd0d..b07f87549a1 100644 --- a/vortex-tensor/src/encodings/turboquant/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/mod.rs @@ -92,19 +92,13 @@ //! ``` //! use vortex_array::IntoArray; //! use vortex_array::VortexSessionExecute; -//! use vortex_array::arrays::ExtensionArray; //! use vortex_array::arrays::FixedSizeListArray; //! use vortex_array::arrays::PrimitiveArray; -//! use vortex_array::arrays::Extension; -//! use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -//! use vortex_array::dtype::extension::ExtDType; -//! use vortex_array::extension::EmptyMetadata; +//! use vortex_array::session::ArraySession; //! use vortex_array::validity::Validity; //! use vortex_buffer::BufferMut; -//! use vortex_array::session::ArraySession; //! use vortex_session::VortexSession; -//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_encode_unchecked}; -//! use vortex_tensor::scalar_fns::l2_denorm::normalize_as_l2_denorm; +//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_compress}; //! use vortex_tensor::vector::Vector; //! //! // Create a Vector extension array of 100 random 128-d vectors. @@ -118,22 +112,13 @@ //! let fsl = FixedSizeListArray::try_new( //! elements.into_array(), dim, Validity::NonNullable, num_rows, //! ).unwrap(); -//! let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) -//! .unwrap().erased(); -//! let ext = ExtensionArray::new(ext_dtype, fsl.into_array()); +//! let vector = Vector::wrap_storage(fsl.into_array()).unwrap(); //! -//! // Normalize, then quantize the normalized child at 2 bits per coordinate. +//! // Normalize and quantize at 2 bits per coordinate in one pass. //! let session = VortexSession::empty().with::(); //! let mut ctx = session.create_execution_ctx(); -//! let l2_denorm = normalize_as_l2_denorm(ext.into_array(), &mut ctx).unwrap(); -//! let normalized = l2_denorm.child_at(0).clone(); -//! -//! let normalized_ext = normalized.as_opt::().unwrap(); //! let config = TurboQuantConfig { bit_width: 2, seed: Some(42), num_rounds: 3 }; -//! // SAFETY: We just normalized the input. -//! let tq = unsafe { -//! turboquant_encode_unchecked(normalized_ext, &config, &mut ctx).unwrap() -//! }; +//! let tq = turboquant_compress(vector, &config, &mut ctx).unwrap(); //! //! // Verify compression: 100 vectors x 128 dims x 4 bytes = 51200 bytes input. //! assert!(tq.nbytes() < 51200); @@ -144,6 +129,7 @@ pub(crate) mod compress; mod scheme; pub use compress::TurboQuantConfig; +pub use compress::turboquant_compress; pub use compress::turboquant_encode; pub use compress::turboquant_encode_unchecked; pub use scheme::TurboQuantScheme; diff --git a/vortex-tensor/src/encodings/turboquant/scheme.rs b/vortex-tensor/src/encodings/turboquant/scheme.rs index b603f12e16e..9f1179cc5dd 100644 --- a/vortex-tensor/src/encodings/turboquant/scheme.rs +++ b/vortex-tensor/src/encodings/turboquant/scheme.rs @@ -3,8 +3,7 @@ //! TurboQuant compression scheme. //! -//! The scheme first normalizes the input via [`normalize_as_l2_denorm`], then encodes the -//! normalized child via [`turboquant_encode_unchecked`]. The result is: +//! The scheme is a thin [`Scheme`] adapter over [`turboquant_compress`], which produces: //! //! ```text //! ScalarFnArray(L2Denorm, [ @@ -18,14 +17,10 @@ //! //! Decompression is automatic: executing the outer array walks the ScalarFn tree. //! -//! [`normalize_as_l2_denorm`]: crate::scalar_fns::l2_denorm::normalize_as_l2_denorm -//! [`turboquant_encode_unchecked`]: crate::encodings::turboquant::turboquant_encode_unchecked +//! [`turboquant_compress`]: crate::encodings::turboquant::turboquant_compress use vortex_array::ArrayRef; use vortex_array::Canonical; -use vortex_array::IntoArray; -use vortex_array::arrays::Extension; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_compressor::CascadingCompressor; use vortex_compressor::ctx::CompressorContext; use vortex_compressor::estimate::CompressionEstimate; @@ -38,9 +33,7 @@ use vortex_error::VortexResult; use crate::encodings::turboquant::MAX_CENTROIDS; use crate::encodings::turboquant::TurboQuantConfig; use crate::encodings::turboquant::tq_validate_vector_dtype; -use crate::encodings::turboquant::turboquant_encode_unchecked; -use crate::scalar_fns::l2_denorm::L2Denorm; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; +use crate::encodings::turboquant::turboquant_compress; /// TurboQuant compression scheme for [`Vector`] extension types. /// @@ -105,33 +98,8 @@ impl Scheme for TurboQuantScheme { data: &mut ArrayAndStats, _ctx: CompressorContext, ) -> VortexResult { - let ext_array = data - .array() - .as_opt::() - .vortex_expect("expected an extension array"); - let mut ctx = compressor.execution_ctx(); - - // 1. Normalize: produces L2Denorm(normalized_vectors, norms). - let l2_denorm = normalize_as_l2_denorm(ext_array.as_ref().clone(), &mut ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - // 2. Quantize the normalized child: SorfTransform(FSL(Dict)). - let normalized_ext = normalized - .as_opt::() - .vortex_expect("normalized child should be an Extension array"); - - let config = TurboQuantConfig::default(); - // SAFETY: We just normalized the input via `normalize_as_l2_denorm`, so all rows are - // guaranteed to be unit-norm (or zero for originally-null rows). - let sorf_dict = unsafe { turboquant_encode_unchecked(normalized_ext, &config, &mut ctx)? }; - - // 3. Wrap back in L2Denorm: the SorfTransform is the "normalized" child. - // SAFETY: TurboQuant is a lossy approximation of the normalized child, so we intentionally - // bypass the strict normalized-row validation when reattaching the stored norms. - Ok(unsafe { L2Denorm::new_array_unchecked(sorf_dict, norms, num_rows) }?.into_array()) + turboquant_compress(data.array().clone(), &TurboQuantConfig::default(), &mut ctx) } } diff --git a/vortex-tensor/src/encodings/turboquant/tests/compute.rs b/vortex-tensor/src/encodings/turboquant/tests/compute.rs index 0a9e0ab7a18..678607e90fe 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/compute.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/compute.rs @@ -44,7 +44,7 @@ fn slice_preserves_data() -> VortexResult<()> { num_rounds: 4, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; // Full decompress then slice. let mut ctx = SESSION.create_execution_ctx(); @@ -89,7 +89,7 @@ fn scalar_at_matches_decompress() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let full_decoded = encoded.clone().execute::(&mut ctx)?; @@ -112,7 +112,7 @@ fn l2_norm_readthrough() -> VortexResult<()> { num_rounds: 5, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); // Stored norms should match the actual L2 norms of the input. @@ -150,7 +150,7 @@ fn l2_norm_readthrough_is_authoritative_for_lossy_storage() -> VortexResult<()> num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); let stored_norms: PrimitiveArray = norms_child.execute(&mut ctx)?; @@ -187,7 +187,7 @@ fn cosine_similarity_readthrough_is_authoritative_for_lossy_storage() -> VortexR num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let encoded_cos = execute_cosine_similarity(encoded.clone(), encoded.clone(), num_rows, &mut ctx)?; diff --git a/vortex-tensor/src/encodings/turboquant/tests/mod.rs b/vortex-tensor/src/encodings/turboquant/tests/mod.rs index b111c6e28ba..22b852b66a3 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/mod.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/mod.rs @@ -16,7 +16,6 @@ use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::Dict; -use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; @@ -25,17 +24,13 @@ use vortex_array::arrays::dict::DictArraySlotsExt; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::extension::EmptyMetadata; use vortex_array::validity::Validity; use vortex_buffer::BufferMut; use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::turboquant_encode_unchecked; -use crate::scalar_fns::l2_denorm::L2Denorm; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; +use crate::encodings::turboquant::turboquant_compress; use crate::tests::SESSION; use crate::vector::Vector; @@ -71,31 +66,9 @@ fn make_fsl(num_rows: usize, dim: usize, seed: u64) -> FixedSizeListArray { } /// Wrap a `FixedSizeListArray` in a `Vector` extension array. -fn make_vector_ext(fsl: &FixedSizeListArray) -> ExtensionArray { - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone()) - .unwrap() - .erased(); - ExtensionArray::new(ext_dtype, fsl.clone().into_array()) -} - -/// Full encode pipeline: normalize → TQ-encode → wrap in L2Denorm. -fn normalize_and_encode( - ext: &ExtensionArray, - config: &TurboQuantConfig, - ctx: &mut vortex_array::ExecutionCtx, -) -> VortexResult { - let l2_denorm = normalize_as_l2_denorm(ext.as_ref().clone(), ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - let normalized_ext = normalized - .as_opt::() - .vortex_expect("normalized child should be an Extension array"); - // SAFETY: We just normalized the input via `normalize_as_l2_denorm`. - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, config, ctx)? }; - - Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) +fn make_vector_ext(fsl: &FixedSizeListArray) -> ArrayRef { + Vector::wrap_storage(fsl.clone().into_array()) + .vortex_expect("test FSL satisfies Vector storage constraints") } /// Unwrap an L2Denorm ScalarFnArray into (sorf_child, norms_child). @@ -103,17 +76,7 @@ fn unwrap_l2denorm(encoded: &ArrayRef) -> (ArrayRef, ArrayRef) { let sfn = encoded .as_opt::() .expect("expected ScalarFnArray (L2Denorm)"); - let sorf_child = sfn.child_at(0).clone(); - let norms_child = sfn.child_at(1).clone(); - (sorf_child, norms_child) -} - -/// Unwrap a SorfTransform ScalarFnArray to get the FSL(Dict) child. -fn unwrap_sorf(sorf: &ArrayRef) -> ArrayRef { - let sfn = sorf - .as_opt::() - .expect("expected ScalarFnArray (SorfTransform)"); - sfn.child_at(0).clone() + (sfn.child_at(0).clone(), sfn.child_at(1).clone()) } /// Navigate the full tree to get (codes, centroids, norms) as flat arrays. @@ -122,7 +85,11 @@ fn unwrap_codes_centroids_norms( ctx: &mut vortex_array::ExecutionCtx, ) -> VortexResult<(PrimitiveArray, PrimitiveArray, PrimitiveArray)> { let (sorf_child, norms_child) = unwrap_l2denorm(encoded); - let padded_vector_child = unwrap_sorf(&sorf_child); + let padded_vector_child = sorf_child + .as_opt::() + .expect("expected SorfTransform ScalarFnArray") + .child_at(0) + .clone(); // Vector wrapping FSL(Dict(codes, centroids)) let padded_vector: ExtensionArray = padded_vector_child.execute(ctx)?; @@ -177,8 +144,7 @@ fn encode_decode( let prim = fsl.elements().clone().execute::(&mut ctx)?; prim.as_slice::().to_vec() }; - let ext = make_vector_ext(fsl); - let encoded = normalize_and_encode(&ext, config, &mut ctx)?; + let encoded = turboquant_compress(make_vector_ext(fsl), config, &mut ctx)?; let decoded_ext = encoded.execute::(&mut ctx)?; let decoded_fsl = decoded_ext .storage_array() @@ -193,19 +159,3 @@ fn encode_decode( }; Ok((original, decoded_elements)) } - -fn make_fsl_small(dim: usize) -> FixedSizeListArray { - let mut buf = BufferMut::::with_capacity(dim); - for i in 0..dim { - buf.push(i as f32 + 1.0); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - FixedSizeListArray::try_new( - elements.into_array(), - dim.try_into() - .expect("somehow got dimension greater than u32::MAX"), - Validity::NonNullable, - 1, - ) - .unwrap() -} diff --git a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs index 6fc19bb93ec..c5041364ce8 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/nullable.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/nullable.rs @@ -27,7 +27,7 @@ fn nullable_vectors_roundtrip() -> VortexResult<()> { num_rounds: 4, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; assert_eq!(encoded.len(), 10); assert!(encoded.dtype().is_nullable()); @@ -88,7 +88,7 @@ fn nullable_norms_match_validity() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); let norms_validity = norms_child.validity()?; @@ -118,7 +118,7 @@ fn nullable_l2_norm_readthrough() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let norm_sfn = L2Norm::try_new_array(encoded, 5)?; let norms: PrimitiveArray = norm_sfn.into_array().execute(&mut ctx)?; @@ -160,7 +160,7 @@ fn nullable_slice_preserves_validity() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let sliced = encoded.slice(1..6)?; assert_eq!(sliced.len(), 5); diff --git a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs index cd61d9193da..fccaabe4344 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/roundtrip.rs @@ -4,6 +4,7 @@ use rstest::rstest; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; +use vortex_array::arrays::Extension; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; @@ -12,6 +13,8 @@ use vortex_buffer::BufferMut; use vortex_error::VortexResult; use super::*; +use crate::encodings::turboquant::turboquant_encode_unchecked; +use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; #[rstest] #[case(128, 1)] @@ -130,7 +133,7 @@ fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let decoded = encoded.execute::(&mut ctx)?; assert_eq!(decoded.len(), num_rows); Ok(()) @@ -141,7 +144,17 @@ fn roundtrip_edge_cases(#[case] num_rows: usize) -> VortexResult<()> { #[case(64)] #[case(127)] fn rejects_dimension_below_128(#[case] dim: usize) { - let fsl = make_fsl_small(dim); + let elements = PrimitiveArray::new::( + BufferMut::from_iter((0..dim).map(|i| i as f32 + 1.0)).freeze(), + Validity::NonNullable, + ); + let fsl = FixedSizeListArray::try_new( + elements.into_array(), + dim.try_into().expect("dim fits u32"), + Validity::NonNullable, + 1, + ) + .unwrap(); let ext = make_vector_ext(&fsl); let config = TurboQuantConfig { bit_width: 2, @@ -149,9 +162,8 @@ fn rejects_dimension_below_128(#[case] dim: usize) { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - assert!( - crate::encodings::turboquant::turboquant_encode(ext.as_view(), &config, &mut ctx).is_err() - ); + let view = ext.as_opt::().expect("Vector extension"); + assert!(crate::encodings::turboquant::turboquant_encode(view, &config, &mut ctx).is_err()); } #[rstest] @@ -166,7 +178,7 @@ fn rejects_invalid_bit_width(#[case] bit_width: u8) { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(ext.as_ref().clone(), &mut ctx) + let normalized = normalize_as_l2_denorm(ext, &mut ctx) .unwrap() .child_at(0) .clone(); @@ -255,7 +267,7 @@ fn f64_input_encodes_successfully() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); assert_eq!(norms_child.len(), num_rows); Ok(()) @@ -288,7 +300,7 @@ fn f16_input_encodes_successfully() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_sorf_child, norms_child) = unwrap_l2denorm(&encoded); assert_eq!(norms_child.len(), num_rows); @@ -329,7 +341,7 @@ fn checked_encode_accepts_normalized_f16_input() -> VortexResult<()> { }; let mut ctx = SESSION.create_execution_ctx(); - let normalized = normalize_as_l2_denorm(ext.as_ref().clone(), &mut ctx)? + let normalized = normalize_as_l2_denorm(ext, &mut ctx)? .child_at(0) .clone(); let normalized_ext = normalized diff --git a/vortex-tensor/src/encodings/turboquant/tests/structural.rs b/vortex-tensor/src/encodings/turboquant/tests/structural.rs index 87b59836b38..37b857a8ad0 100644 --- a/vortex-tensor/src/encodings/turboquant/tests/structural.rs +++ b/vortex-tensor/src/encodings/turboquant/tests/structural.rs @@ -24,7 +24,7 @@ fn stored_centroids_match_computed() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let (_codes, centroids, _norms) = unwrap_codes_centroids_norms(&encoded, &mut ctx)?; let stored = centroids.as_slice::(); @@ -52,7 +52,7 @@ fn seed_deterministic_rotation_produces_correct_decode() -> VortexResult<()> { // Encode twice with the same seed → should produce identical results. let mut ctx = SESSION.create_execution_ctx(); - let encoded1 = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded1 = turboquant_compress(ext.clone(), &config, &mut ctx)?; let decoded1 = encoded1.execute::(&mut ctx)?; let fsl1 = decoded1 .storage_array() @@ -64,7 +64,7 @@ fn seed_deterministic_rotation_produces_correct_decode() -> VortexResult<()> { .execute::(&mut ctx)?; let mut ctx = SESSION.create_execution_ctx(); - let encoded2 = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded2 = turboquant_compress(ext, &config, &mut ctx)?; let decoded2 = encoded2.execute::(&mut ctx)?; let fsl2 = decoded2 .storage_array() @@ -94,7 +94,7 @@ fn encoded_dtype_is_vector_extension() -> VortexResult<()> { num_rounds: 2, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; assert!( encoded.dtype().is_extension(), @@ -119,7 +119,7 @@ fn cosine_similarity_quantized_accuracy() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let input_prim = fsl.elements().clone().execute::(&mut ctx)?; let input_f32 = input_prim.as_slice::(); @@ -176,7 +176,7 @@ fn dot_product_quantized_accuracy() -> VortexResult<()> { num_rounds: 3, }; let mut ctx = SESSION.create_execution_ctx(); - let encoded = normalize_and_encode(&ext, &config, &mut ctx)?; + let encoded = turboquant_compress(ext, &config, &mut ctx)?; let input_prim = fsl.elements().clone().execute::(&mut ctx)?; let input_f32 = input_prim.as_slice::(); diff --git a/vortex-tensor/src/fixed_shape/metadata.rs b/vortex-tensor/src/fixed_shape/metadata.rs index 264d18453c4..757138d3e50 100644 --- a/vortex-tensor/src/fixed_shape/metadata.rs +++ b/vortex-tensor/src/fixed_shape/metadata.rs @@ -215,6 +215,7 @@ impl fmt::Display for FixedShapeTensorMetadata { } if let Some(perm) = &self.permutation { + write!(f, ", [")?; for (i, p) in perm.iter().enumerate() { if i > 0 { write!(f, ", ")?; @@ -353,6 +354,44 @@ mod tests { Ok(()) } + // -- Display -- + + #[test] + fn display_shape_only() { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]); + assert_eq!(m.to_string(), "Tensor(2, 3, 4)"); + } + + #[test] + fn display_scalar_0d() { + let m = FixedShapeTensorMetadata::new(vec![]); + assert_eq!(m.to_string(), "Tensor()"); + } + + #[test] + fn display_with_dim_names() -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(vec![3, 4]) + .with_dim_names(vec!["rows".into(), "cols".into()])?; + assert_eq!(m.to_string(), "Tensor(rows: 3, cols: 4)"); + Ok(()) + } + + #[test] + fn display_with_permutation() -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]).with_permutation(vec![1, 0, 2])?; + assert_eq!(m.to_string(), "Tensor(2, 3, 4, [1, 0, 2])"); + Ok(()) + } + + #[test] + fn display_with_dim_names_and_permutation() -> VortexResult<()> { + let m = FixedShapeTensorMetadata::new(vec![2, 3, 4]) + .with_dim_names(vec!["x".into(), "y".into(), "z".into()])? + .with_permutation(vec![1, 2, 0])?; + assert_eq!(m.to_string(), "Tensor(x: 2, y: 3, z: 4, [1, 2, 0])"); + Ok(()) + } + #[test] fn dim_names_wrong_length() { let result = FixedShapeTensorMetadata::new(vec![2, 3]).with_dim_names(vec!["x".into()]); diff --git a/vortex-tensor/src/matcher.rs b/vortex-tensor/src/matcher.rs index 973562be3da..258a464b57f 100644 --- a/vortex-tensor/src/matcher.rs +++ b/vortex-tensor/src/matcher.rs @@ -45,7 +45,7 @@ impl TensorMatch<'_> { pub fn list_size(self) -> usize { match self { Self::FixedShapeTensor(metadata) => metadata.list_size(), - Self::Vector(metadata) => metadata.dimensions() as usize, + Self::Vector(metadata) => metadata.list_size(), } } } diff --git a/vortex-tensor/src/scalar_fns/cosine_similarity.rs b/vortex-tensor/src/scalar_fns/cosine_similarity.rs index 9f8eff0b361..85d16236c8c 100644 --- a/vortex-tensor/src/scalar_fns/cosine_similarity.rs +++ b/vortex-tensor/src/scalar_fns/cosine_similarity.rs @@ -11,7 +11,6 @@ use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; -use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayVTable; @@ -32,16 +31,15 @@ use vortex_array::serde::ArrayChildren; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; -use vortex_error::vortex_ensure; use vortex_session::VortexSession; use crate::scalar_fns::inner_product::BinaryTensorOpMetadata; use crate::scalar_fns::inner_product::InnerProduct; -use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::DenormOrientation; use crate::scalar_fns::l2_denorm::try_build_constant_l2_denorm; use crate::scalar_fns::l2_norm::L2Norm; use crate::utils::extract_l2_denorm_children; -use crate::utils::validate_tensor_float_input; +use crate::utils::validate_binary_tensor_float_inputs; /// Cosine similarity between two columns. /// @@ -59,6 +57,7 @@ use crate::utils::validate_tensor_float_input; /// /// [`FixedShapeTensor`]: crate::fixed_shape::FixedShapeTensor /// [`Vector`]: crate::vector::Vector +/// [`L2Denorm`]: crate::scalar_fns::l2_denorm::L2Denorm #[derive(Clone)] pub struct CosineSimilarity; @@ -84,7 +83,7 @@ impl ScalarFnVTable for CosineSimilarity { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::from("vortex.tensor.cosine_similarity") + ScalarFnId::new("vortex.tensor.cosine_similarity") } fn arity(&self, _options: &Self::Options) -> Arity { @@ -116,16 +115,8 @@ impl ScalarFnVTable for CosineSimilarity { let lhs = &arg_dtypes[0]; let rhs = &arg_dtypes[1]; - // Both must have the same dtype (ignoring top-level nullability). - vortex_ensure!( - lhs.eq_ignore_nullability(rhs), - "CosineSimilarity requires both inputs to have the same dtype, got {lhs} and {rhs}" - ); - - // We don't need to look at rhs anymore since we know lhs and rhs are equal. - let tensor_match = validate_tensor_float_input(lhs)?; + let tensor_match = validate_binary_tensor_float_inputs("CosineSimilarity", lhs, rhs)?; let ptype = tensor_match.element_ptype(); - let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); Ok(DType::Primitive(ptype, nullability)) } @@ -141,32 +132,24 @@ impl ScalarFnVTable for CosineSimilarity { let len = args.row_count(); // If either side is a constant tensor-like extension array, eagerly normalize the single - // stored row and re-wrap it as an `L2Denorm` whose children are both [`ConstantArray`]s. + // stored row and re-wrap it as an `L2Denorm` whose children are both `ConstantArray`s. // The L2Denorm fast path below then picks it up. - if let Some(lhs_constant) = - try_build_constant_l2_denorm(&lhs_ref, len, ctx)?.map(|sfn| sfn.into_array()) - { - lhs_ref = lhs_constant; + if let Some(sfn) = try_build_constant_l2_denorm(&lhs_ref, len, ctx)? { + lhs_ref = sfn.into_array(); } - if let Some(rhs_constant) = - try_build_constant_l2_denorm(&rhs_ref, len, ctx)?.map(|sfn| sfn.into_array()) - { - rhs_ref = rhs_constant; + if let Some(sfn) = try_build_constant_l2_denorm(&rhs_ref, len, ctx)? { + rhs_ref = sfn.into_array(); } - // Check if any of our children have be already normalized. - { - let lhs_is_denorm = lhs_ref.is::>(); - let rhs_is_denorm = rhs_ref.is::>(); - - if lhs_is_denorm && rhs_is_denorm { - return self.execute_both_denorm(&lhs_ref, &rhs_ref, len, ctx); - } else if lhs_is_denorm || rhs_is_denorm { - if rhs_is_denorm { - (lhs_ref, rhs_ref) = (rhs_ref, lhs_ref); - } - return self.execute_one_denorm(&lhs_ref, &rhs_ref, len, ctx); + // Take any L2Denorm-wrapped fast path that applies. + match DenormOrientation::classify(&lhs_ref, &rhs_ref) { + DenormOrientation::Both { lhs, rhs } => { + return self.execute_both_denorm(lhs, rhs, len); + } + DenormOrientation::One { denorm, plain } => { + return self.execute_one_denorm(denorm, plain, len, ctx); } + DenormOrientation::Neither => {} } // Compute combined validity. @@ -266,7 +249,6 @@ impl CosineSimilarity { lhs_ref: &ArrayRef, rhs_ref: &ArrayRef, len: usize, - _ctx: &mut ExecutionCtx, ) -> VortexResult { let validity = lhs_ref.validity()?.and(rhs_ref.validity()?)?; @@ -347,9 +329,10 @@ mod tests { use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; - use crate::utils::test_helpers::constant_vector_array; + use crate::utils::test_helpers::l2_denorm_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; + use crate::vector::Vector; /// Evaluates cosine similarity between two tensor arrays and returns the result as `Vec`. fn eval_cosine_similarity(lhs: ArrayRef, rhs: ArrayRef, len: usize) -> VortexResult> { @@ -508,7 +491,7 @@ mod tests { 1.0, 0.0, 0.0, // vector 3 ], )?; - let query = constant_vector_array(&[1.0, 0.0, 0.0], 4)?; + let query = Vector::constant_array(&[1.0, 0.0, 0.0], 4)?; assert_close( &eval_cosine_similarity(data, query, 4)?, @@ -536,25 +519,13 @@ mod tests { Ok(()) } - /// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms. - fn l2_denorm_array( - shape: &[usize], - normalized_elements: &[f64], - norms: &[f64], - ) -> VortexResult { - let len = norms.len(); - let normalized = tensor_array(shape, normalized_elements)?; - let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); - let mut ctx = SESSION.create_execution_ctx(); - Ok(L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?.into_array()) - } - #[test] fn both_denorm_self_similarity() -> VortexResult<()> { // [3.0, 4.0] has norm 5.0, normalized [0.6, 0.8]. // [1.0, 0.0] has norm 1.0, normalized [1.0, 0.0]. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Self-similarity should always be 1.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 1.0]); @@ -565,8 +536,9 @@ mod tests { fn both_denorm_orthogonal() -> VortexResult<()> { // [3.0, 0.0] normalized [1.0, 0.0], norm 3.0. // [0.0, 4.0] normalized [0.0, 1.0], norm 4.0. - let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0])?; - let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[1.0, 0.0], &[3.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.0, 1.0], &[4.0], &mut ctx)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.0]); Ok(()) @@ -575,8 +547,9 @@ mod tests { #[test] fn both_denorm_zero_norm() -> VortexResult<()> { // Zero-norm row: normalized is [0.0, 0.0], norm is 0.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 0.0], &[5.0, 0.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; // Row 0: dot([0.6, 0.8], [0.6, 0.8]) = 1.0, row 1: dot([0,0], [1,0]) = 0.0. assert_close(&eval_cosine_similarity(lhs, rhs, 2)?, &[1.0, 0.0]); @@ -588,7 +561,8 @@ mod tests { // LHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // RHS is plain [3.0, 4.0]. // cosine_similarity([3.0, 4.0], [3.0, 4.0]) = 1.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; let rhs = tensor_array(&[2], &[3.0, 4.0])?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[1.0]); @@ -599,8 +573,9 @@ mod tests { fn one_side_denorm_rhs() -> VortexResult<()> { // LHS is plain [1.0, 0.0], RHS is L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // cosine_similarity([1.0, 0.0], [3.0, 4.0]) = 3.0 / (1.0 * 5.0) = 0.6. + let mut ctx = SESSION.create_execution_ctx(); let lhs = tensor_array(&[2], &[1.0, 0.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; assert_close(&eval_cosine_similarity(lhs, rhs, 1)?, &[0.6]); Ok(()) @@ -609,11 +584,11 @@ mod tests { #[test] fn both_denorm_null_norms() -> VortexResult<()> { // Row 0: valid, row 1: null (via nullable norms on rhs). - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let normalized_r = tensor_array(&[2], &[0.6, 0.8, 1.0, 0.0])?; let norms_r = PrimitiveArray::from_option_iter([Some(5.0f64), None]).into_array(); - let mut ctx = SESSION.create_execution_ctx(); let rhs = L2Denorm::try_new_array(normalized_r, norms_r, 2, &mut ctx)?.into_array(); let scalar_fn = CosineSimilarity::new().erased(); @@ -711,7 +686,7 @@ mod tests { #[test] fn vector_constant_matches_plain() -> VortexResult<()> { // Exercise the `Vector` extension variant through the new pre-pass. - let lhs = constant_vector_array(&[1.0, 2.0, 2.0], 4)?; + let lhs = Vector::constant_array(&[1.0, 2.0, 2.0], 4)?; let rhs = vector_array( 3, &[ diff --git a/vortex-tensor/src/scalar_fns/inner_product.rs b/vortex-tensor/src/scalar_fns/inner_product.rs index dd9c2a7381f..7c28ffc37de 100644 --- a/vortex-tensor/src/scalar_fns/inner_product.rs +++ b/vortex-tensor/src/scalar_fns/inner_product.rs @@ -4,7 +4,6 @@ //! Inner product expression for tensor-like types. use std::fmt::Formatter; -use std::sync::Arc; use num_traits::Float; use prost::Message; @@ -32,13 +31,10 @@ use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; -use vortex_array::dtype::extension::ExtDType; use vortex_array::dtype::proto::dtype as pb; use vortex_array::expr::Expression; use vortex_array::expr::and; -use vortex_array::extension::EmptyMetadata; use vortex_array::match_each_float_ptype; -use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::Arity; use vortex_array::scalar_fn::ChildName; use vortex_array::scalar_fn::EmptyOptions; @@ -56,11 +52,12 @@ use vortex_error::vortex_err; use vortex_session::VortexSession; use crate::matcher::AnyTensor; -use crate::scalar_fns::l2_denorm::L2Denorm; +use crate::scalar_fns::l2_denorm::DenormOrientation; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfTransform; use crate::utils::extract_flat_elements; use crate::utils::extract_l2_denorm_children; +use crate::utils::validate_binary_tensor_float_inputs; use crate::vector::Vector; /// Inner product (dot product) between two columns. @@ -99,7 +96,7 @@ impl ScalarFnVTable for InnerProduct { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::from("vortex.tensor.inner_product") + ScalarFnId::new("vortex.tensor.inner_product") } fn arity(&self, _options: &Self::Options) -> Arity { @@ -131,32 +128,9 @@ impl ScalarFnVTable for InnerProduct { let lhs = &arg_dtypes[0]; let rhs = &arg_dtypes[1]; - // Both must have the same dtype (ignoring top-level nullability). - vortex_ensure!( - lhs.eq_ignore_nullability(rhs), - "InnerProduct requires both inputs to have the same dtype, got {lhs} and {rhs}" - ); - - // Both inputs must be tensor-like extension types. - let lhs_ext = lhs - .as_extension_opt() - .ok_or_else(|| vortex_err!("InnerProduct lhs must be an extension type, got {lhs}"))?; - - vortex_ensure!( - lhs_ext.is::(), - "InnerProduct inputs must be an `AnyTensor`, got {lhs}" - ); - - let tensor_match = lhs_ext - .metadata_opt::() - .ok_or_else(|| vortex_err!("InnerProduct inputs must be an `AnyTensor`, got {lhs}"))?; + // TODO(connor): relax the float-only gate once integer tensors are supported. + let tensor_match = validate_binary_tensor_float_inputs("InnerProduct", lhs, rhs)?; let ptype = tensor_match.element_ptype(); - // TODO(connor): This should support integer tensors! - vortex_ensure!( - ptype.is_float(), - "InnerProduct element dtype must be a float primitive, got {ptype}" - ); - let nullability = Nullability::from(lhs.is_nullable() || rhs.is_nullable()); Ok(DType::Primitive(ptype, nullability)) } @@ -167,23 +141,19 @@ impl ScalarFnVTable for InnerProduct { args: &dyn ExecutionArgs, ctx: &mut ExecutionCtx, ) -> VortexResult { - let mut lhs_ref = args.get(0)?; - let mut rhs_ref = args.get(1)?; + let lhs_ref = args.get(0)?; + let rhs_ref = args.get(1)?; let len = args.row_count(); - // Check if any of our children have be already normalized. - { - let lhs_is_denorm = lhs_ref.is::>(); - let rhs_is_denorm = rhs_ref.is::>(); - - if lhs_is_denorm && rhs_is_denorm { - return self.execute_both_denorm(&lhs_ref, &rhs_ref, len, ctx); - } else if lhs_is_denorm || rhs_is_denorm { - if rhs_is_denorm { - (lhs_ref, rhs_ref) = (rhs_ref, lhs_ref); - } - return self.execute_one_denorm(&lhs_ref, &rhs_ref, len, ctx); + // Take any L2Denorm-wrapped fast path that applies. + match DenormOrientation::classify(&lhs_ref, &rhs_ref) { + DenormOrientation::Both { lhs, rhs } => { + return self.execute_both_denorm(lhs, rhs, len, ctx); } + DenormOrientation::One { denorm, plain } => { + return self.execute_one_denorm(denorm, plain, len, ctx); + } + DenormOrientation::Neither => {} } // Reduction case 1: `InnerProduct(SorfTransform(x), const)` rewrites to @@ -413,16 +383,17 @@ impl InnerProduct { /// from `padded_dim` to `dim` applied by `SorfTransform` and `R` is the SORF forward /// matrix. See the proof in the crate-level docs and in the plan file. /// - /// Returns `Ok(None)` if neither side matches or when `element_ptype` is not `F32`. The - /// caller is expected to fall through to the standard path in that case. + /// Returns `Ok(None)` if neither side matches, when the operand element type is not `F32`, + /// or when the constant side is not a constant-backed tensor extension. The caller is + /// expected to fall through to the standard path in that case. /// - /// # TODO(connor): + /// # F32-only /// - /// This rewrite is only sound for `PType::F32` because `SorfTransform` applies an - /// `f32 -> element_ptype` cast at the end of its execute (see `sorf_transform/vtable.rs` - /// line ~218). For F16/F64 the cast changes the inner product's rounding and would - /// change the semantics of the rewrite. Until we push the cast through `InnerProduct`, - /// this path only fires for F32. + /// TODO(connor): this rewrite is only sound for `PType::F32` because `SorfTransform` + /// applies an `f32 -> element_ptype` cast at the end of its `execute`. For `F16`/`F64` + /// the cast changes the inner product's rounding and the rewrite would not be + /// semantically equivalent. Until we push the cast through `InnerProduct`, both the + /// SorfTransform output ptype and the constant-side element ptype must be `F32` here. fn try_execute_sorf_constant( &self, lhs_ref: &ArrayRef, @@ -440,10 +411,6 @@ impl InnerProduct { return Ok(None); }; - // TODO(connor): pull-through is only sound for F32 because SorfTransform applies an - // `f32 -> element_ptype` cast at the end of its execute. For F16/F64 the rewrite - // would change the inner product's rounding semantics. Fall through so the standard - // path (which does the cast before inner product) handles it. if sorf_view.options.element_ptype != PType::F32 { return Ok(None); } @@ -458,11 +425,9 @@ impl InnerProduct { let seed = sorf_view.options.seed; let padded_dim = dim.next_power_of_two(); - // Extract the single stored row of the constant via the stride-0 short-circuit. + // Extract the single stored row of the constant via the `is_constant` short-circuit. let flat = extract_flat_elements(&const_storage, dim, ctx)?; if flat.ptype() != PType::F32 { - // TODO(connor): as above, f16/f64 are not supported by this rewrite yet. The - // standard path handles them correctly. return Ok(None); } @@ -474,29 +439,10 @@ impl InnerProduct { let mut rotated_query = vec![0.0f32; padded_dim]; rotation.rotate(&padded_query, &mut rotated_query); - // Build the rewritten constant as a `Vector` extension scalar. We reuse - // the original storage FSL nullability so the new extension dtype stays consistent with - // whatever the original tree expected. - let storage_fsl_nullability = const_storage.dtype().nullability(); - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = rotated_query - .into_iter() - .map(|v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let fsl_scalar = - Scalar::fixed_size_list(element_dtype.clone(), children, storage_fsl_nullability); - - // Build a fresh `Vector` extension dtype. We cannot reuse the - // original extension dtype because that one has `dim`, not `padded_dim`. - let padded_dim_u32 = u32::try_from(padded_dim).vortex_expect("padded_dim fits u32"); - let new_fsl_dtype = DType::FixedSizeList( - Arc::new(element_dtype), - padded_dim_u32, - storage_fsl_nullability, - ); - let new_ext_dtype = ExtDType::::try_new(EmptyMetadata, new_fsl_dtype)?.erased(); - let new_constant = - ConstantArray::new(Scalar::extension_ref(new_ext_dtype, fsl_scalar), len).into_array(); + // Wrap the rotated query as a `Vector` constant broadcast to `len` + // rows. The new extension dtype has `padded_dim` instead of `dim`, matching the + // SorfTransform child we are about to dot it with. + let new_constant = Vector::constant_array(&rotated_query, len)?; // Extract the SorfTransform child (the already-padded Vector). let sorf_child = sorf_view @@ -612,8 +558,9 @@ impl InnerProduct { let values: &[f32] = values_prim.as_slice::(); debug_assert_eq!(codes.len(), len * padded_dim); - // The hot loop is extracted into [`execute_dict_constant_inner_product`] with - // unchecked indexing so the compiler can vectorize the inner gather-accumulate. + // The hot loop is extracted into [`execute_dict_constant_inner_product`] so the + // compiler can prove the chunked indices stay in bounds and vectorize the inner + // gather-accumulate. let out = execute_dict_constant_inner_product(q, values, codes, len, padded_dim); // SAFETY: the buffer length equals `len`, which matches the validity length. @@ -647,8 +594,8 @@ fn inner_product_row(a: &[T], b: &[T]) -> T { /// For each row, computes `sum(q[j] * values[codes[row * dim + j]])` using the codebook /// `values` directly instead of decoding the dictionary into dense vectors. /// -/// The inner loop uses four independent accumulators so the CPU can pipeline FP additions -/// instead of waiting for each `fadd` to retire before starting the next. +/// The inner loop uses `PARTIAL_SUMS` independent accumulators so the CPU can pipeline FP +/// additions instead of waiting for each `fadd` to retire before starting the next. fn execute_dict_constant_inner_product( q: &[f32], values: &[f32], @@ -704,6 +651,7 @@ mod tests { use crate::scalar_fns::l2_denorm::L2Denorm; use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::l2_denorm_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; @@ -818,28 +766,14 @@ mod tests { Ok(()) } - /// Creates an `L2Denorm` scalar function array from pre-normalized elements and norms. - fn l2_denorm_array( - shape: &[usize], - normalized_elements: &[f64], - norms: &[f64], - ) -> VortexResult { - use vortex_array::IntoArray; - - let len = norms.len(); - let normalized = tensor_array(shape, normalized_elements)?; - let norms = PrimitiveArray::from_iter(norms.iter().copied()).into_array(); - let mut ctx = SESSION.create_execution_ctx(); - Ok(L2Denorm::try_new_array(normalized, norms, len, &mut ctx)?.into_array()) - } - #[test] fn both_denorm() -> VortexResult<()> { // LHS: [3.0, 4.0] = L2Denorm([0.6, 0.8], 5.0). // RHS: [1.0, 0.0] = L2Denorm([1.0, 0.0], 1.0). // dot([3.0, 4.0], [1.0, 0.0]) = 3.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; - let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[1.0, 0.0], &[1.0], &mut ctx)?; // Expected: 5.0 * 1.0 * dot([0.6, 0.8], [1.0, 0.0]) = 5.0 * 0.6 = 3.0. assert_close(&eval_inner_product(lhs, rhs, 1)?, &[3.0]); @@ -850,8 +784,9 @@ mod tests { fn both_denorm_multiple_rows() -> VortexResult<()> { // Row 0: [3.0, 4.0] dot [3.0, 4.0] = 25.0. // Row 1: [1.0, 0.0] dot [0.0, 1.0] = 0.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 0.0, 1.0], &[5.0, 1.0], &mut ctx)?; assert_close(&eval_inner_product(lhs, rhs, 2)?, &[25.0, 0.0]); Ok(()) @@ -862,7 +797,8 @@ mod tests { // LHS: L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // RHS: plain [1.0, 2.0]. // dot([3.0, 4.0], [1.0, 2.0]) = 3.0 + 8.0 = 11.0. - let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let mut ctx = SESSION.create_execution_ctx(); + let lhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; let rhs = tensor_array(&[2], &[1.0, 2.0])?; assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); @@ -874,8 +810,9 @@ mod tests { // LHS: plain [1.0, 2.0]. // RHS: L2Denorm([0.6, 0.8], 5.0) representing [3.0, 4.0]. // dot([1.0, 2.0], [3.0, 4.0]) = 3.0 + 8.0 = 11.0. + let mut ctx = SESSION.create_execution_ctx(); let lhs = tensor_array(&[2], &[1.0, 2.0])?; - let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8], &[5.0], &mut ctx)?; assert_close(&eval_inner_product(lhs, rhs, 1)?, &[11.0]); Ok(()) @@ -889,7 +826,7 @@ mod tests { let mut ctx = SESSION.create_execution_ctx(); let lhs = L2Denorm::try_new_array(normalized_l, norms_l, 2, &mut ctx)?.into_array(); - let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0])?; + let rhs = l2_denorm_array(&[2], &[0.6, 0.8, 1.0, 0.0], &[5.0, 1.0], &mut ctx)?; let scalar_fn = InnerProduct::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![lhs, rhs], 2)?; @@ -948,15 +885,11 @@ mod tests { reason = "tests build small fixtures with deterministic in-range indices" )] mod constant_query_optimizations { - use std::sync::LazyLock; - use rstest::rstest; use vortex_array::ArrayRef; use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::Constant; - use vortex_array::arrays::ConstantArray; - use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::ScalarFnArray; @@ -964,67 +897,23 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::PType; - use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; - use vortex_array::scalar::Scalar; - use vortex_array::session::ArraySession; use vortex_array::validity::Validity; use vortex_buffer::Buffer; use vortex_error::VortexResult; - use vortex_session::VortexSession; use crate::scalar_fns::inner_product::InnerProduct; use crate::scalar_fns::inner_product::constant_tensor_storage; use crate::scalar_fns::sorf_transform::SorfMatrix; use crate::scalar_fns::sorf_transform::SorfOptions; use crate::scalar_fns::sorf_transform::SorfTransform; + use crate::tests::SESSION; use crate::utils::extract_flat_elements; + use crate::utils::test_helpers::literal_vector_array; + use crate::utils::test_helpers::vector_array; use crate::vector::Vector; - static SESSION: LazyLock = - LazyLock::new(|| VortexSession::empty().with::()); - - /// Compact f32 Vector extension over a column-major `elements` slice. - fn vector_f32(dim: u32, elements: &[f32]) -> VortexResult { - let row_count = elements.len() / dim as usize; - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - - /// Compact constant-backed f32 Vector extension with a single stored row. - fn constant_vector_f32(elements: &[f32], len: usize) -> VortexResult { - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let storage = ConstantArray::new(storage_scalar, len).into_array(); - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) - } - - /// Expression-literal shape: a ConstantArray whose scalar itself is a Vector extension. - fn literal_vector_f32(elements: &[f32], len: usize) -> ArrayRef { - let element_dtype = DType::Primitive(PType::F32, Nullability::NonNullable); - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let vector_scalar = Scalar::extension::(EmptyMetadata, storage_scalar); - ConstantArray::new(vector_scalar, len).into_array() - } - - /// Build an `ExtensionArray>` whose storage is - /// `FSL(DictArray(codes: u8, values: f32))`. This mirrors the shape that - /// TurboQuant produces as the SorfTransform child. + /// Build a `Vector` whose storage is `FSL(DictArray(codes: u8, values: + /// f32))`. This mirrors the shape that TurboQuant produces as the SorfTransform child. fn dict_vector_f32(list_size: u32, codes: &[u8], values: &[f32]) -> VortexResult { let num_rows = codes.len() / list_size as usize; let codes_arr = @@ -1040,9 +929,7 @@ mod tests { Validity::NonNullable, num_rows, )?; - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + Vector::wrap_storage(fsl.into_array()) } /// Execute an inner product and return the flat `f32` results. @@ -1128,7 +1015,7 @@ mod tests { #[test] fn constant_tensor_storage_accepts_extension_scalar_literal() -> VortexResult<()> { - let literal = literal_vector_f32(&[1.0, 2.0, 3.0], 5); + let literal = literal_vector_array(&[1.0f32, 2.0, 3.0], 5); let storage = constant_tensor_storage(&literal).expect("literal vector should be recognized"); @@ -1164,7 +1051,7 @@ mod tests { // Query has `dim` elements. let query_elems: Vec = (0..dim).map(|i| (i as f32 * 0.1).sin()).collect(); - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; // Ground truth: decode LHS to plain f32 vectors, dot each with the query. let decoded = decode_sorf_dict( @@ -1202,7 +1089,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = (0..dim).map(|i| (i as f32 * 0.2).cos()).collect(); - let const_lhs = constant_vector_f32(&query_elems, num_rows)?; + let const_lhs = Vector::constant_array(&query_elems, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1240,7 +1127,7 @@ mod tests { assert_eq!(padded_dim, dim as usize); let query_elems: Vec = (0..dim).map(|i| i as f32 * 0.01 - 0.5).collect(); - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1277,7 +1164,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = vec![0.0; dim as usize]; - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; let actual = eval_ip_f32(sorf, const_rhs, num_rows)?; assert_eq!(actual.len(), 0); @@ -1300,7 +1187,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|i| (i as f32 + 1.0) * 0.3).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1330,7 +1217,7 @@ mod tests { let dict_rhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = vec![0.5, -1.0, 2.5, -0.25]; - let const_lhs = constant_vector_f32(&query, num_rows)?; + let const_lhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1375,12 +1262,10 @@ mod tests { Validity::NonNullable, num_rows, )?; - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - let dict_lhs = ExtensionArray::new(ext_dtype, fsl.into_array()).into_array(); + let dict_lhs = Vector::wrap_storage(fsl.into_array())?; let query: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; // Build expected by decoding by hand. let expected: Vec = (0..num_rows) @@ -1408,10 +1293,10 @@ mod tests { let lhs_elems: Vec = (0..num_rows * dim as usize) .map(|i| i as f32 * 0.25) .collect(); - let plain_lhs = vector_f32(dim, &lhs_elems)?; + let plain_lhs = vector_array(dim, &lhs_elems)?; let query: Vec = vec![1.0, 2.0, 3.0, 4.0]; - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1438,7 +1323,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = vec![0.0; 4]; - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let actual = eval_ip_f32(dict_lhs, const_rhs, num_rows)?; assert_eq!(actual.len(), 0); @@ -1458,7 +1343,7 @@ mod tests { build_sorf_with_dict_child(dim, num_rows, seed, num_rounds)?; let query_elems: Vec = (0..dim).map(|i| ((i as f32) * 0.15).sin() * 0.4).collect(); - let const_rhs = constant_vector_f32(&query_elems, num_rows)?; + let const_rhs = Vector::constant_array(&query_elems, num_rows)?; // Ground truth via full decode + naive dot. let decoded = decode_sorf_dict( @@ -1531,7 +1416,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1575,7 +1460,7 @@ mod tests { // has cancellation. let mut rng = XorShift64::new(seed ^ 0xABCD_1234); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1621,7 +1506,7 @@ mod tests { let dict_lhs = dict_vector_f32(list_size, &codes, &values)?; let query: Vec = (0..list_size).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let expected: Vec = (0..num_rows) .map(|row| { @@ -1671,7 +1556,7 @@ mod tests { SorfTransform::try_new_array(&sorf_options, padded_vector, num_rows)?.into_array(); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1721,7 +1606,7 @@ mod tests { let mut rng = XorShift64::new(seed ^ (num_rounds as u64)); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_rhs = constant_vector_f32(&query, num_rows)?; + let const_rhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, @@ -1755,7 +1640,7 @@ mod tests { let mut rng = XorShift64::new(seed); let query: Vec = (0..dim).map(|_| rng.next_f32()).collect(); - let const_lhs = constant_vector_f32(&query, num_rows)?; + let const_lhs = Vector::constant_array(&query, num_rows)?; let decoded = decode_sorf_dict( &codes, diff --git a/vortex-tensor/src/scalar_fns/l2_denorm.rs b/vortex-tensor/src/scalar_fns/l2_denorm.rs index 673c454b56b..12fb8ca59b7 100644 --- a/vortex-tensor/src/scalar_fns/l2_denorm.rs +++ b/vortex-tensor/src/scalar_fns/l2_denorm.rs @@ -22,6 +22,7 @@ use vortex_array::arrays::ScalarFnArray; use vortex_array::arrays::ScalarFnVTable as ScalarFnArrayEncoding; use vortex_array::arrays::extension::ExtensionArrayExt; use vortex_array::arrays::fixed_size_list::FixedSizeListArrayExt; +use vortex_array::arrays::scalar_fn::ExactScalarFn; use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::arrays::scalar_fn::ScalarFnArrayView; use vortex_array::arrays::scalar_fn::plugin::ScalarFnArrayParts; @@ -113,7 +114,7 @@ impl L2Denorm { len: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - validate_l2_denorm_children(&normalized, &norms, ctx)?; + validate_l2_normalized_rows_against_norms(&normalized, Some(&norms), ctx)?; // SAFETY: We just validated that it is normalized. unsafe { Self::new_array_unchecked(normalized, norms, len) } @@ -212,7 +213,11 @@ impl ScalarFnVTable for L2Denorm { if let Some(const_norms) = norms_ref.as_opt::() { let norm_scalar = const_norms.scalar(); - vortex_ensure!(norm_scalar.dtype().is_float()); + vortex_ensure!( + norm_scalar.dtype().is_float(), + "L2Denorm constant norms must be a float scalar, got {}", + norm_scalar.dtype(), + ); if let Some(norm_value) = norm_scalar.value() { return execute_l2_denorm_constant_norms( @@ -520,8 +525,8 @@ pub(crate) fn try_build_constant_l2_denorm( let ext_dtype = input.dtype().as_extension().clone(); let storage_fsl_nullability = storage.dtype().nullability(); - // `extract_flat_elements` takes the stride-0 single-row path for `Constant` storage, so - // this is cheap and does not expand the constant to the full column length. + // `extract_flat_elements` takes the `is_constant` single-row path for `Constant` storage, + // so this is cheap and does not expand the constant to the full column length. let flat = extract_flat_elements(storage, list_size, ctx)?; let (normalized_fsl_scalar, norms_scalar) = match_each_float_ptype!(flat.ptype(), |T| { @@ -602,23 +607,17 @@ fn unit_norm_tolerance(element_ptype: PType) -> f64 { /// Validates that every valid row of `input` is already L2-normalized (either length 1 or 0). pub fn validate_l2_normalized_rows(input: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { - validate_l2_normalized_rows_impl(input, None, ctx) + validate_l2_normalized_rows_against_norms(input, None, ctx) } -/// Validates that the `normalized` and `norms` children jointly satisfy the [`L2Denorm`] -/// invariants, which are: +/// Validates that `normalized` and (when supplied) the matching `norms` jointly satisfy the +/// [`L2Denorm`] invariants: /// -/// - All vectors in the normalized array have length 1 or 0. -/// - If the vector has a norm of 0, then the vector must be all 0s. -fn validate_l2_denorm_children( - normalized: &ArrayRef, - norms: &ArrayRef, - ctx: &mut ExecutionCtx, -) -> VortexResult<()> { - validate_l2_normalized_rows_impl(normalized, Some(norms), ctx) -} - -fn validate_l2_normalized_rows_impl( +/// - Every valid row of `normalized` has L2 norm `1.0` or `0.0` (within element-precision +/// tolerance). +/// - When `norms` is supplied, every stored norm is non-negative and any row whose stored +/// norm is `0.0` is exactly the zero vector in `normalized`. +fn validate_l2_normalized_rows_against_norms( normalized: &ArrayRef, norms: Option<&ArrayRef>, ctx: &mut ExecutionCtx, @@ -697,6 +696,51 @@ fn validate_l2_normalized_rows_impl( Ok(()) } +/// Classification of a binary operand pair by which side (if any) is wrapped in [`L2Denorm`]. +/// +/// Symmetric binary tensor operators (e.g. [`CosineSimilarity`], [`InnerProduct`]) have identical +/// fast paths for "only the lhs is denormalized" and "only the rhs is denormalized", and a separate +/// fast path for "both are denormalized". Rather than hand-rolling the commutative swap at every +/// call site, callers classify their operands with [`Self::classify`] and pattern-match on the +/// returned variant. +/// +/// [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity +/// [`InnerProduct`]: crate::scalar_fns::inner_product::InnerProduct +pub(crate) enum DenormOrientation<'a> { + /// Both operands are [`ExactScalarFn`] arrays. + Both { + lhs: &'a ArrayRef, + rhs: &'a ArrayRef, + }, + /// Exactly one operand is an [`ExactScalarFn`]; the other is plain. + One { + denorm: &'a ArrayRef, + plain: &'a ArrayRef, + }, + /// Neither operand is an [`ExactScalarFn`]. + Neither, +} + +impl<'a> DenormOrientation<'a> { + /// Classify `(lhs, rhs)` by which side (if any) is wrapped in [`L2Denorm`]. + pub(crate) fn classify(lhs: &'a ArrayRef, rhs: &'a ArrayRef) -> Self { + let lhs_denorm = lhs.is::>(); + let rhs_denorm = rhs.is::>(); + match (lhs_denorm, rhs_denorm) { + (true, true) => Self::Both { lhs, rhs }, + (true, false) => Self::One { + denorm: lhs, + plain: rhs, + }, + (false, true) => Self::One { + denorm: rhs, + plain: lhs, + }, + (false, false) => Self::Neither, + } + } +} + #[cfg(test)] mod tests { @@ -719,23 +763,18 @@ mod tests { use vortex_array::dtype::DType; use vortex_array::dtype::Nullability; use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; use vortex_array::extension::datetime::Date; use vortex_array::extension::datetime::TimeUnit; use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; - use vortex_buffer::Buffer; use vortex_error::VortexResult; - use crate::fixed_shape::FixedShapeTensor; - use crate::fixed_shape::FixedShapeTensorMetadata; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::scalar_fns::l2_denorm::validate_l2_normalized_rows; use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; use crate::utils::test_helpers::constant_tensor_array; - use crate::utils::test_helpers::constant_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; use crate::vector::Vector; @@ -747,20 +786,6 @@ mod tests { result.into_array().execute(&mut ctx) } - fn integer_tensor_array(shape: &[usize], elements: &[i32]) -> VortexResult { - let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); - let row_count = elements.len() / list_size as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); - - let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); - let ext_dtype = - ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - fn non_tensor_extension_array() -> VortexResult { let storage = PrimitiveArray::from_iter([1i32, 2]).into_array(); let ext_dtype = @@ -768,16 +793,6 @@ mod tests { Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - fn f16_vector_array(dim: u32, elements: &[f32]) -> VortexResult { - let row_count = elements.len() / dim as usize; - let values: Vec<_> = elements.iter().copied().map(half::f16::from_f32).collect(); - let elems: ArrayRef = Buffer::copy_from(values.as_slice()).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - fn tensor_snapshot(array: ArrayRef) -> VortexResult<(DType, Vec, Vec)> { let mut ctx = SESSION.create_execution_ctx(); let ext: ExtensionArray = array.execute(&mut ctx)?; @@ -866,7 +881,7 @@ mod tests { #[test] fn l2_denorm_rejects_integer_tensor_lhs() -> VortexResult<()> { - let lhs = integer_tensor_array(&[2], &[1, 2, 3, 4])?; + let lhs = tensor_array(&[2], &[1i32, 2, 3, 4])?; let rhs = PrimitiveArray::from_iter([1.0f64, 1.0]).into_array(); let mut ctx = SESSION.create_execution_ctx(); @@ -888,7 +903,7 @@ mod tests { #[test] fn validate_l2_normalized_rows_accepts_normalized_f16_input() -> VortexResult<()> { - let input = f16_vector_array(2, &[3.0, 4.0, 0.0, 0.0])?; + let input = vector_array(2, &[3.0f32, 4.0, 0.0, 0.0].map(half::f16::from_f32))?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; validate_l2_normalized_rows(&roundtrip.child_at(0).clone(), &mut ctx)?; @@ -982,7 +997,7 @@ mod tests { #[test] fn normalize_as_l2_denorm_supports_constant_vectors() -> VortexResult<()> { - let input = constant_vector_array(&[3.0, 4.0], 2)?; + let input = Vector::constant_array(&[3.0, 4.0], 2)?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input.clone(), &mut ctx)?; let actual = roundtrip.into_array().execute(&mut ctx)?; @@ -996,7 +1011,7 @@ mod tests { // The constant fast path in `normalize_as_l2_denorm` must produce an `L2Denorm` whose // normalized storage and norms child are both still `ConstantArray`s. This is what // allows downstream ops (cosine similarity, inner product) to short-circuit. - let input = constant_vector_array(&[3.0, 4.0], 16)?; + let input = Vector::constant_array(&[3.0, 4.0], 16)?; let mut ctx = SESSION.create_execution_ctx(); let roundtrip = normalize_as_l2_denorm(input, &mut ctx)?; diff --git a/vortex-tensor/src/scalar_fns/l2_norm.rs b/vortex-tensor/src/scalar_fns/l2_norm.rs index 59a2e5a45b1..057bb3c841a 100644 --- a/vortex-tensor/src/scalar_fns/l2_norm.rs +++ b/vortex-tensor/src/scalar_fns/l2_norm.rs @@ -47,6 +47,7 @@ use vortex_session::VortexSession; use crate::matcher::AnyTensor; use crate::scalar_fns::l2_denorm::L2Denorm; use crate::utils::extract_flat_elements; +use crate::utils::extract_l2_denorm_children; use crate::utils::validate_tensor_float_input; /// L2 norm (Euclidean norm) of a tensor or vector column. @@ -84,7 +85,7 @@ impl ScalarFnVTable for L2Norm { type Options = EmptyOptions; fn id(&self) -> ScalarFnId { - ScalarFnId::from("vortex.tensor.l2_norm") + ScalarFnId::new("vortex.tensor.l2_norm") } fn arity(&self, _options: &Self::Options) -> Arity { @@ -139,13 +140,9 @@ impl ScalarFnVTable for L2Norm { // L2Norm(L2Denorm(normalized, norms)) is defined to read back the authoritative stored // norms. Exact callers of lossy encodings like TurboQuant opt into that storage semantics // instead of forcing a decode-and-recompute path here. - if let Some(sfn) = input_ref.as_opt::>() { - let norms = sfn - .nth_child(1) - .vortex_expect("L2Denom must have at 2 children"); - + if input_ref.is::>() { + let (_, norms) = extract_l2_denorm_children(&input_ref); vortex_ensure_eq!(norms.dtype(), &norm_dtype); - return Ok(norms); } @@ -290,6 +287,7 @@ mod tests { use crate::scalar_fns::l2_norm::L2Norm; use crate::tests::SESSION; use crate::utils::test_helpers::assert_close; + use crate::utils::test_helpers::literal_vector_array; use crate::utils::test_helpers::tensor_array; use crate::utils::test_helpers::vector_array; use crate::vector::Vector; @@ -363,27 +361,13 @@ mod tests { Ok(()) } - /// Builds a [`ConstantArray`] whose scalar is a [`Vector`] extension scalar wrapping a - /// fixed-size list of `elements`, broadcast to `len` rows. - fn constant_vector_extension_array(elements: &[f64], len: usize) -> ArrayRef { - let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - let ext_scalar = Scalar::extension::(EmptyMetadata, storage_scalar); - ConstantArray::new(ext_scalar, len).into_array() - } - /// A constant input whose scalar is a non-null tensor should short-circuit to a /// [`ConstantArray`] output whose scalar is the precomputed norm. Uses [`execute_until`] so /// execution stops at the [`Constant`] encoding instead of canonicalizing into a /// [`PrimitiveArray`]. #[test] fn constant_non_null_input_yields_constant_output() -> VortexResult<()> { - let input = constant_vector_extension_array(&[3.0, 4.0], 4); + let input = literal_vector_array(&[3.0f64, 4.0], 4); let scalar_fn = L2Norm::new().erased(); let result = ScalarFnArray::try_new(scalar_fn, vec![input], 4)?.into_array(); diff --git a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs index 64f92da384f..2c0ca420d8d 100644 --- a/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs +++ b/vortex-tensor/src/scalar_fns/sorf_transform/vtable.rs @@ -70,13 +70,13 @@ impl ScalarFnVTable for SorfTransform { fn fmt_sql( &self, - _options: &Self::Options, + options: &Self::Options, expr: &Expression, f: &mut Formatter<'_>, ) -> fmt::Result { write!(f, "sorf_transform(")?; expr.child(0).fmt_sql(f)?; - write!(f, ")") + write!(f, ", {options})") } fn return_dtype(&self, options: &Self::Options, arg_dtypes: &[DType]) -> VortexResult { @@ -143,9 +143,7 @@ impl ScalarFnVTable for SorfTransform { validity, 0, )?; - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + Vector::wrap_storage(fsl.into_array()) }); } @@ -214,21 +212,46 @@ pub(super) struct SorfTransformMetadata { element_ptype: i32, } +impl From<&SorfOptions> for SorfTransformMetadata { + fn from(options: &SorfOptions) -> Self { + Self { + seed: options.seed, + num_rounds: u32::from(options.num_rounds), + dimension: options.dimension, + element_ptype: options.element_ptype as i32, + } + } +} + +impl SorfTransformMetadata { + /// Rebuild the [`SorfOptions`] this metadata was serialized from, validating that the wire + /// values are in range. + fn to_options(&self) -> VortexResult { + let num_rounds = u8::try_from(self.num_rounds).map_err(|_| { + vortex_err!( + "SorfTransform num_rounds {} does not fit in u8", + self.num_rounds + ) + })?; + let options = SorfOptions { + seed: self.seed, + num_rounds, + dimension: self.dimension, + element_ptype: self.element_ptype(), + }; + validate_sorf_options(&options)?; + Ok(options) + } +} + impl ScalarFnArrayVTable for SorfTransform { fn serialize( &self, view: &ScalarFnArrayView, _session: &VortexSession, ) -> VortexResult>> { - let options = view.options; Ok(Some( - SorfTransformMetadata { - seed: options.seed, - num_rounds: u32::from(options.num_rounds), - dimension: options.dimension, - element_ptype: options.element_ptype as i32, - } - .encode_to_vec(), + SorfTransformMetadata::from(view.options).encode_to_vec(), )) } @@ -240,20 +263,9 @@ impl ScalarFnArrayVTable for SorfTransform { children: &dyn ArrayChildren, _session: &VortexSession, ) -> VortexResult> { - let metadata = SorfTransformMetadata::decode(metadata) - .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))?; - let options = SorfOptions { - seed: metadata.seed, - num_rounds: u8::try_from(metadata.num_rounds).map_err(|_| { - vortex_err!( - "SorfTransform num_rounds {} does not fit in u8", - metadata.num_rounds - ) - })?, - dimension: metadata.dimension, - element_ptype: metadata.element_ptype(), - }; - validate_sorf_options(&options)?; + let options = SorfTransformMetadata::decode(metadata) + .map_err(|e| vortex_err!("Failed to decode SorfTransformMetadata: {e}"))? + .to_options()?; // `return_dtype` sets the output FSL's nullability to the child's nullability (see // `return_dtype` above), so we read the child nullability back from the parent dtype. @@ -316,7 +328,5 @@ fn inverse_rotate_typed( let elements = PrimitiveArray::new::(output.freeze(), Validity::NonNullable); let fsl = FixedSizeListArray::try_new(elements.into_array(), dim_u32, validity, num_rows)?; - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + Vector::wrap_storage(fsl.into_array()) } diff --git a/vortex-tensor/src/utils.rs b/vortex-tensor/src/utils.rs index 76d65ce9eef..326c49cc950 100644 --- a/vortex-tensor/src/utils.rs +++ b/vortex-tensor/src/utils.rs @@ -43,6 +43,23 @@ pub fn validate_tensor_float_input(input_dtype: &DType) -> VortexResult( + op_name: &str, + lhs: &'a DType, + rhs: &DType, +) -> VortexResult> { + vortex_ensure!( + lhs.eq_ignore_nullability(rhs), + "{op_name} requires both inputs to have the same dtype, got {lhs} and {rhs}" + ); + validate_tensor_float_input(lhs) +} + /// Cast a float [`PrimitiveArray`] to a `Buffer`. /// /// Several operations in this crate (SORF transform, TurboQuant quantization) work exclusively @@ -80,12 +97,12 @@ pub fn cast_to_f32(prim: PrimitiveArray) -> VortexResult> { /// The flat primitive elements of a tensor storage array, with typed row access. /// /// This struct hides the stride detail that arises from the [`ConstantArray`] optimization: a -/// constant input materializes only a single row (stride=0), while a full array uses -/// stride=list_size. +/// constant-backed input materializes only a single row that every index reads (`is_constant = +/// true`), while a full array stores one row per index. pub struct FlatElements { elems: PrimitiveArray, - stride: usize, list_size: usize, + is_constant: bool, } impl FlatElements { @@ -96,10 +113,14 @@ impl FlatElements { } /// Returns the `i`-th row as a typed slice of length `list_size`. + /// + /// When the source was a constant-backed storage, all indices resolve to the single stored + /// row. #[must_use] pub fn row(&self, i: usize) -> &[T] { + let row_idx = if self.is_constant { 0 } else { i }; let slice = self.elems.as_slice::(); - &slice[i * self.stride..][..self.list_size] + &slice[row_idx * self.list_size..][..self.list_size] } } @@ -114,26 +135,21 @@ pub fn extract_flat_elements( list_size: usize, ctx: &mut ExecutionCtx, ) -> VortexResult { - if let Some(constant) = storage.as_opt::() { - // Rewrite the array as a length 1 array so when we canonicalize, we do not duplicate a huge - // amount of data. + // Constant-backed storage: materialize just the single stored row so canonicalization does + // not expand the array to the full column length. + let (source, is_constant) = if let Some(constant) = storage.as_opt::() { let single = ConstantArray::new(constant.scalar().clone(), 1).into_array(); - let fsl: FixedSizeListArray = single.execute(ctx)?; - let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; - return Ok(FlatElements { - elems, - stride: 0, - list_size, - }); - } + (single, true) + } else { + (storage.clone(), false) + }; - // Otherwise we have to fully expand all of the data. - let fsl: FixedSizeListArray = storage.clone().execute(ctx)?; + let fsl: FixedSizeListArray = source.execute(ctx)?; let elems: PrimitiveArray = fsl.elements().clone().execute(ctx)?; Ok(FlatElements { elems, - stride: list_size, list_size, + is_constant, }) } @@ -154,15 +170,17 @@ pub fn extract_l2_denorm_children(array: &ArrayRef) -> (ArrayRef, ArrayRef) { #[cfg(test)] pub mod test_helpers { use vortex_array::ArrayRef; + use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; use vortex_array::arrays::ExtensionArray; use vortex_array::arrays::FixedSizeListArray; + use vortex_array::arrays::PrimitiveArray; use vortex_array::dtype::DType; + use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; - use vortex_array::dtype::PType; use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; + use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; use vortex_array::validity::Validity; use vortex_buffer::Buffer; @@ -170,81 +188,85 @@ pub mod test_helpers { use crate::fixed_shape::FixedShapeTensor; use crate::fixed_shape::FixedShapeTensorMetadata; + use crate::scalar_fns::l2_denorm::L2Denorm; use crate::vector::Vector; - /// Builds a [`FixedShapeTensor`] extension array from flat f64 elements and a logical shape. + /// Builds a `FixedSizeList` storage array from flat `elements`. The row count is + /// inferred from `elements.len() / list_size`. + fn flat_fsl(elements: &[T], list_size: u32) -> ArrayRef { + let row_count = elements.len() / list_size as usize; + let elems: ArrayRef = Buffer::copy_from(elements).into_array(); + FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count).into_array() + } + + /// Builds an FSL-valued [`Scalar`] from `elements` for use as a constant query. + fn fsl_scalar>(elements: &[T]) -> Scalar { + let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable) + } + + /// Builds a [`FixedShapeTensor`] extension array from flat `elements` and a logical shape. /// /// The number of rows is inferred from the total element count divided by the product of the /// shape dimensions. For 0-dimensional tensors (scalar), each element is one row. - pub fn tensor_array(shape: &[usize], elements: &[f64]) -> VortexResult { + pub fn tensor_array(shape: &[usize], elements: &[T]) -> VortexResult { let list_size: u32 = shape.iter().product::().max(1).try_into().unwrap(); - let row_count = elements.len() / list_size as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, list_size, Validity::NonNullable, row_count); - + let storage = flat_fsl(elements, list_size); let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); let ext_dtype = - ExtDType::::try_new(metadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + ExtDType::::try_new(metadata, storage.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - /// Builds a [`Vector`] extension array from flat f64 elements and a vector dimension size. - pub fn vector_array(dim: u32, elements: &[f64]) -> VortexResult { - let row_count = elements.len() / dim as usize; - - let elems: ArrayRef = Buffer::copy_from(elements).into_array(); - let fsl = FixedSizeListArray::new(elems, dim, Validity::NonNullable, row_count); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) + /// Builds a [`Vector`] extension array from flat `elements` and a vector dimension size. + pub fn vector_array(dim: u32, elements: &[T]) -> VortexResult { + Vector::wrap_storage(flat_fsl(elements, dim)) } /// Builds a [`FixedShapeTensor`] extension array whose storage is a [`ConstantArray`], /// representing a single query tensor broadcast to `len` rows. - pub fn constant_tensor_array( + pub fn constant_tensor_array>( shape: &[usize], - elements: &[f64], + elements: &[T], len: usize, ) -> VortexResult { - let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - let storage = ConstantArray::new(storage_scalar, len).into_array(); - + let storage = ConstantArray::new(fsl_scalar(elements), len).into_array(); let metadata = FixedShapeTensorMetadata::new(shape.to_vec()); let ext_dtype = ExtDType::::try_new(metadata, storage.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) } - /// Builds a [`Vector`] extension array whose storage is a [`ConstantArray`], representing a - /// single query vector broadcast to `len` rows. - pub fn constant_vector_array(elements: &[f64], len: usize) -> VortexResult { - let element_dtype = DType::Primitive(PType::F64, Nullability::NonNullable); - - let children: Vec = elements - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = - Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - let storage = ConstantArray::new(storage_scalar, len).into_array(); - - let ext_dtype = - ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); + /// Builds a [`ConstantArray`] whose scalar is itself a [`Vector`] extension scalar, broadcast + /// to `len` rows. This is the shape produced by an `lit(vector_scalar)` literal expression — + /// the constant lives at the extension level rather than inside the FSL storage, in contrast + /// to [`Vector::constant_array`]. + pub fn literal_vector_array>( + elements: &[T], + len: usize, + ) -> ArrayRef { + use vortex_array::extension::EmptyMetadata; + let ext_scalar = Scalar::extension::(EmptyMetadata, fsl_scalar(elements)); + ConstantArray::new(ext_scalar, len).into_array() + } - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + /// Creates an [`L2Denorm`] scalar function array from pre-normalized tensor elements and + /// matching norms. The caller must ensure every row of `normalized_elements` is unit-norm or + /// zero. + pub fn l2_denorm_array( + shape: &[usize], + normalized_elements: &[T], + norms: &[T], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let len = norms.len(); + let normalized = tensor_array(shape, normalized_elements)?; + let norms = PrimitiveArray::new(Buffer::copy_from(norms), Validity::NonNullable).into_array(); + Ok(L2Denorm::try_new_array(normalized, norms, len, ctx)?.into_array()) } /// Asserts that each element in `actual` is within `1e-10` of the corresponding `expected` diff --git a/vortex-tensor/src/vector/matcher.rs b/vortex-tensor/src/vector/matcher.rs index 0da5384f303..4e0781c225e 100644 --- a/vortex-tensor/src/vector/matcher.rs +++ b/vortex-tensor/src/vector/matcher.rs @@ -49,7 +49,7 @@ impl Matcher for AnyVector { let dimensions = *list_size; - assert!(element_dtype.is_float(), "element dtype must be primitive"); + assert!(element_dtype.is_float(), "element dtype must be float"); assert!( !element_dtype.is_nullable(), "element dtype must be non-nullable" @@ -87,6 +87,13 @@ impl VectorMatcherMetadata { pub fn dimensions(&self) -> u32 { self.dimensions } + + /// Returns the flattened element count per vector row. Always equal to + /// [`dimensions`](Self::dimensions); exists as a `usize`-typed alias that mirrors + /// [`FixedShapeTensorMatcherMetadata::list_size`](crate::fixed_shape::FixedShapeTensorMatcherMetadata::list_size). + pub fn list_size(&self) -> usize { + self.dimensions as usize + } } #[cfg(test)] diff --git a/vortex-tensor/src/vector/mod.rs b/vortex-tensor/src/vector/mod.rs index 3c6a8a8c8cc..1d0bdf4014c 100644 --- a/vortex-tensor/src/vector/mod.rs +++ b/vortex-tensor/src/vector/mod.rs @@ -3,10 +3,81 @@ //! Vector extension type for fixed-length float vectors (e.g., embeddings). +use vortex_array::ArrayRef; +use vortex_array::IntoArray; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::ExtensionArray; +use vortex_array::dtype::DType; +use vortex_array::dtype::NativePType; +use vortex_array::dtype::Nullability; +use vortex_array::dtype::extension::ExtDType; +use vortex_array::extension::EmptyMetadata; +use vortex_array::scalar::PValue; +use vortex_array::scalar::Scalar; +use vortex_error::VortexResult; + /// The Vector extension type. #[derive(Clone, Debug, Default, PartialEq, Eq, Hash)] pub struct Vector; +impl Vector { + /// Wrap a `FixedSizeList`-valued `storage` array in a [`Vector`] extension array. + /// + /// The storage's dtype is reused verbatim for the extension's storage dtype, so the caller + /// is responsible for having already constructed an FSL with the float element ptype and + /// non-nullable elements that [`Vector`]'s `validate_dtype` requires. + /// + /// # Errors + /// + /// Returns an error if `storage` does not satisfy [`Vector`]'s storage-dtype contract (e.g. + /// it is not a `FixedSizeList` of non-nullable floats). + pub fn wrap_storage(storage: ArrayRef) -> VortexResult { + let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); + Ok(ExtensionArray::new(ext_dtype, storage).into_array()) + } + + /// Build a [`Vector`] extension array whose storage is a [`ConstantArray`] broadcasting a + /// single vector `elements` across `len` rows. + /// + /// This is the array shape that [`CosineSimilarity::try_new_array`] and similar binary tensor + /// scalar functions expect for the constant-query side of a database-vs-query scan: the inner + /// `ScalarFnArray` contract requires both children to have the same length, so the query is + /// broadcast rather than represented as a literal length-1 input. + /// + /// [`CosineSimilarity::try_new_array`]: crate::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array + /// + /// # Errors + /// + /// Returns an error if the [`Vector`] extension dtype rejects the constructed storage dtype. + pub fn constant_array>( + elements: &[T], + len: usize, + ) -> VortexResult { + let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); + let children: Vec = elements + .iter() + .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) + .collect(); + let storage_scalar = + Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); + Self::wrap_storage(ConstantArray::new(storage_scalar, len).into_array()) + } +} + +#[cfg(test)] +mod ctor_tests { + use vortex_array::arrays::Extension; + + use super::*; + + #[test] + fn constant_array_produces_vector_extension() { + let array = Vector::constant_array(&[1.0f32, 0.0, 0.0, 0.0], 5).unwrap(); + assert_eq!(array.len(), 5); + assert!(array.as_opt::().is_some()); + } +} + mod matcher; pub use matcher::AnyVector; diff --git a/vortex-tensor/src/vector_search.rs b/vortex-tensor/src/vector_search.rs index 81a379683db..5ca319ee6c9 100644 --- a/vortex-tensor/src/vector_search.rs +++ b/vortex-tensor/src/vector_search.rs @@ -4,23 +4,14 @@ //! Reusable helpers for building brute-force vector similarity search expressions over //! [`Vector`] extension arrays. //! -//! This module exposes three small building blocks that together make it straightforward to -//! stand up a cosine-similarity-plus-threshold scan on top of a prepared data array: +//! [`build_similarity_search_tree`] wires together [`Vector::constant_array`] (which broadcasts +//! the query into the shape expected by [`CosineSimilarity`]) and +//! [`turboquant_compress`](crate::encodings::turboquant::turboquant_compress) (when the data is +//! pre-compressed) into a lazy `Binary(Gt, [CosineSimilarity(data, query), threshold])` +//! expression. //! -//! - [`compress_turboquant`] applies the canonical TurboQuant encoding pipeline -//! (`L2Denorm(SorfTransform(FSL(Dict(codes, centroids))), norms)`) to a raw -//! `Vector` array without requiring the caller to plumb the -//! `unstable_encodings` feature flag on the `vortex` facade. -//! - [`build_constant_query_vector`] wraps a single query vector into a -//! [`Vector`] extension array whose storage is a [`ConstantArray`] broadcast -//! across `num_rows` rows. This is the shape expected by -//! [`CosineSimilarity::try_new_array`] for the RHS of a database-vs-query scan. -//! - [`build_similarity_search_tree`] wires everything together into a lazy -//! `Binary(Gt, [CosineSimilarity(data, query), threshold])` expression. -//! -//! Executing the tree from [`build_similarity_search_tree`] into a -//! [`BoolArray`](vortex_array::arrays::BoolArray) yields one boolean per row indicating whether -//! that row's cosine similarity to the query exceeds `threshold`. +//! Executing the tree into a [`BoolArray`](vortex_array::arrays::BoolArray) yields one boolean +//! per row indicating whether that row's cosine similarity to the query exceeds `threshold`. //! //! # Example //! @@ -28,11 +19,12 @@ //! use vortex_array::{ArrayRef, VortexSessionExecute}; //! use vortex_array::arrays::BoolArray; //! use vortex_session::VortexSession; -//! use vortex_tensor::vector_search::{build_similarity_search_tree, compress_turboquant}; +//! use vortex_tensor::encodings::turboquant::{TurboQuantConfig, turboquant_compress}; +//! use vortex_tensor::vector_search::build_similarity_search_tree; //! //! fn run(session: &VortexSession, data: ArrayRef, query: &[f32]) -> anyhow::Result<()> { //! let mut ctx = session.create_execution_ctx(); -//! let data = compress_turboquant(data, &mut ctx)?; +//! let data = turboquant_compress(data, &TurboQuantConfig::default(), &mut ctx)?; //! let tree = build_similarity_search_tree(data, query, 0.8)?; //! let _matches: BoolArray = tree.execute(&mut ctx)?; //! Ok(()) @@ -40,98 +32,23 @@ //! ``` //! //! [`Vector`]: crate::vector::Vector -//! [`CosineSimilarity::try_new_array`]: crate::scalar_fns::cosine_similarity::CosineSimilarity::try_new_array +//! [`Vector::constant_array`]: crate::vector::Vector::constant_array +//! [`CosineSimilarity`]: crate::scalar_fns::cosine_similarity::CosineSimilarity use vortex_array::ArrayRef; -use vortex_array::ExecutionCtx; use vortex_array::IntoArray; use vortex_array::arrays::ConstantArray; -use vortex_array::arrays::Extension; -use vortex_array::arrays::ExtensionArray; -use vortex_array::arrays::scalar_fn::ScalarFnArrayExt; use vortex_array::builtins::ArrayBuiltins; -use vortex_array::dtype::DType; use vortex_array::dtype::NativePType; use vortex_array::dtype::Nullability; -use vortex_array::dtype::extension::ExtDType; -use vortex_array::extension::EmptyMetadata; use vortex_array::scalar::PValue; use vortex_array::scalar::Scalar; use vortex_array::scalar_fn::fns::operators::Operator; use vortex_error::VortexResult; -use vortex_error::vortex_bail; -use crate::encodings::turboquant::TurboQuantConfig; -use crate::encodings::turboquant::turboquant_encode_unchecked; use crate::scalar_fns::cosine_similarity::CosineSimilarity; -use crate::scalar_fns::l2_denorm::L2Denorm; -use crate::scalar_fns::l2_denorm::normalize_as_l2_denorm; use crate::vector::Vector; -/// Apply the canonical TurboQuant encoding pipeline to a `Vector` array. -/// -/// The returned array has the shape -/// `L2Denorm(SorfTransform(FSL(Dict(codes, centroids))), norms)` — exactly what -/// [`crate::encodings::turboquant::TurboQuantScheme`] produces when invoked through -/// `BtrBlocksCompressorBuilder::with_turboquant()`, but without requiring callers to enable -/// the `unstable_encodings` feature on the `vortex` facade. -/// -/// The input `data` must be a [`Vector`] extension array whose element type is `f32` and whose -/// dimensionality is at least -/// [`turboquant::MIN_DIMENSION`](crate::encodings::turboquant::MIN_DIMENSION). The TurboQuant -/// configuration used is [`TurboQuantConfig::default()`] (8-bit codes, 3 SORF rounds, seed 42). -/// -/// # Errors -/// -/// Returns an error if `data` is not a [`Vector`] extension array, if normalization fails, or -/// if the underlying TurboQuant encoder rejects the input shape. -pub fn compress_turboquant(data: ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { - let l2_denorm = normalize_as_l2_denorm(data, ctx)?; - let normalized = l2_denorm.child_at(0).clone(); - let norms = l2_denorm.child_at(1).clone(); - let num_rows = l2_denorm.len(); - - let Some(normalized_ext) = normalized.as_opt::() else { - vortex_bail!("normalize_as_l2_denorm must produce an Extension array child"); - }; - - let config = TurboQuantConfig::default(); - // SAFETY: `normalize_as_l2_denorm` guarantees every row is unit-norm (or zero), which is - // the invariant `turboquant_encode_unchecked` expects. - let tq = unsafe { turboquant_encode_unchecked(normalized_ext, &config, ctx) }?; - - Ok(unsafe { L2Denorm::new_array_unchecked(tq, norms, num_rows) }?.into_array()) -} - -/// Build a [`Vector`] extension array whose storage is a [`ConstantArray`] broadcasting a single -/// query vector across `num_rows` rows. -/// -/// The element type is inferred from `T` (e.g. `f32` or `f64`). This is the shape expected for -/// the RHS of a database-vs-query [`CosineSimilarity`] scan: the `ScalarFnArray` contract -/// requires both children to have the same length, so rather than hand-rolling a 1-row input we -/// broadcast the query across the whole database. -/// -/// # Errors -/// -/// Returns an error if the [`Vector`] extension dtype rejects the constructed storage dtype. -pub fn build_constant_query_vector>( - query: &[T], - num_rows: usize, -) -> VortexResult { - let element_dtype = DType::Primitive(T::PTYPE, Nullability::NonNullable); - - let children: Vec = query - .iter() - .map(|&v| Scalar::primitive(v, Nullability::NonNullable)) - .collect(); - let storage_scalar = Scalar::fixed_size_list(element_dtype, children, Nullability::NonNullable); - - let storage = ConstantArray::new(storage_scalar, num_rows).into_array(); - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, storage.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, storage).into_array()) -} - /// Build the lazy similarity-search expression tree for a prepared database array and a /// single query vector. /// @@ -163,7 +80,7 @@ pub fn build_similarity_search_tree>( threshold: T, ) -> VortexResult { let num_rows = data.len(); - let query_vec = build_constant_query_vector(query, num_rows)?; + let query_vec = Vector::constant_array(query, num_rows)?; let cosine = CosineSimilarity::try_new_array(data, query_vec, num_rows)?.into_array(); @@ -175,64 +92,16 @@ pub fn build_similarity_search_tree>( #[cfg(test)] mod tests { - use vortex_array::ArrayRef; - use vortex_array::IntoArray; use vortex_array::VortexSessionExecute; use vortex_array::arrays::BoolArray; - use vortex_array::arrays::Extension; - use vortex_array::arrays::ExtensionArray; - use vortex_array::arrays::FixedSizeListArray; - use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::bool::BoolArrayExt; - use vortex_array::dtype::extension::ExtDType; - use vortex_array::extension::EmptyMetadata; - use vortex_array::session::ArraySession; - use vortex_array::validity::Validity; - use vortex_buffer::BufferMut; use vortex_error::VortexResult; - use vortex_session::VortexSession; - use super::build_constant_query_vector; use super::build_similarity_search_tree; - use super::compress_turboquant; - use crate::vector::Vector; - - /// Build a `Vector` extension array from a flat f32 slice. Each contiguous - /// group of `DIM` values becomes one row. - fn vector_array(dim: u32, values: &[f32]) -> VortexResult { - let dim_usize = dim as usize; - assert_eq!(values.len() % dim_usize, 0); - let num_rows = values.len() / dim_usize; - - let mut buf = BufferMut::::with_capacity(values.len()); - for &v in values { - buf.push(v); - } - let elements = PrimitiveArray::new::(buf.freeze(), Validity::NonNullable); - let fsl = FixedSizeListArray::try_new( - elements.into_array(), - dim, - Validity::NonNullable, - num_rows, - )?; - - let ext_dtype = ExtDType::::try_new(EmptyMetadata, fsl.dtype().clone())?.erased(); - Ok(ExtensionArray::new(ext_dtype, fsl.into_array()).into_array()) - } - - fn test_session() -> VortexSession { - VortexSession::empty().with::() - } - - #[test] - fn constant_query_vector_has_vector_extension_dtype() -> VortexResult<()> { - let query = vec![1.0f32, 0.0, 0.0, 0.0]; - let rhs = build_constant_query_vector(&query, 5)?; - - assert_eq!(rhs.len(), 5); - assert!(rhs.as_opt::().is_some()); - Ok(()) - } + use crate::encodings::turboquant::TurboQuantConfig; + use crate::encodings::turboquant::turboquant_compress; + use crate::tests::SESSION; + use crate::utils::test_helpers::vector_array; #[test] fn similarity_search_tree_executes_to_bool_array() -> VortexResult<()> { @@ -240,7 +109,7 @@ mod tests { let data = vector_array( 3, &[ - 1.0, 0.0, 0.0, // + 1.0f32, 0.0, 0.0, // 0.0, 1.0, 0.0, // 0.0, 0.0, 1.0, // 1.0, 0.0, 0.0, // @@ -249,7 +118,7 @@ mod tests { let query = [1.0f32, 0.0, 0.0]; let tree = build_similarity_search_tree(data, &query, 0.5)?; - let mut ctx = test_session().create_execution_ctx(); + let mut ctx = SESSION.create_execution_ctx(); let result: BoolArray = tree.execute(&mut ctx)?; let bits = result.to_bit_buffer(); @@ -287,8 +156,8 @@ mod tests { } let data = vector_array(DIM, &values)?; - let mut ctx = test_session().create_execution_ctx(); - let compressed = compress_turboquant(data, &mut ctx)?; + let mut ctx = SESSION.create_execution_ctx(); + let compressed = turboquant_compress(data, &TurboQuantConfig::default(), &mut ctx)?; assert_eq!(compressed.len(), NUM_ROWS); // Build a tree with a low threshold so row 0 (cosine=1.0 exact) matches.