From d7e07aab2d78acd5d5fec514a3090ec73faceb0a Mon Sep 17 00:00:00 2001 From: Abanoub Doss Date: Thu, 9 Apr 2026 22:38:47 -0500 Subject: [PATCH] fix(vortex-array): preserve operand width in DecimalValue checked arithmetic (#7022) Signed-off-by: Abanoub Doss --- .../src/scalar/typed_view/decimal/dvalue.rs | 40 +++++------ .../src/scalar/typed_view/decimal/tests.rs | 72 +++++++++++++++++-- 2 files changed, 84 insertions(+), 28 deletions(-) diff --git a/vortex-array/src/scalar/typed_view/decimal/dvalue.rs b/vortex-array/src/scalar/typed_view/decimal/dvalue.rs index 97eb38f07cd..e8f0d78a4b2 100644 --- a/vortex-array/src/scalar/typed_view/decimal/dvalue.rs +++ b/vortex-array/src/scalar/typed_view/decimal/dvalue.rs @@ -19,6 +19,19 @@ use crate::dtype::NativeDecimalType; use crate::dtype::ToI256; use crate::dtype::i256; use crate::match_each_decimal_value; +use crate::match_each_decimal_value_type; + +/// Performs a checked binary operation at the wider of the two operand types. +macro_rules! checked_binary_op { + ($self:expr, $other:expr, $op:path) => {{ + let target = $self.decimal_type().max($other.decimal_type()); + match_each_decimal_value_type!(target, |T| { + let a: T = $self.cast()?; + let b: T = $other.cast()?; + Some(DecimalValue::from($op(&a, &b)?)) + }) + }}; +} /// A decimal value that can be stored in various integer widths. /// @@ -127,43 +140,24 @@ impl DecimalValue { value_i256 > min_value && value_i256 < max_value } - /// Helper function to perform a checked binary operation on two decimal values. - /// - /// Both values are upcast to i256 before the operation, and the result is returned as I256. - fn checked_binary_op(&self, other: &Self, op: F) -> Option - where - F: FnOnce(i256, i256) -> Option, - { - let self_upcast = match_each_decimal_value!(self, |v| { - v.to_i256() - .vortex_expect("upcast to i256 must always succeed") - }); - let other_upcast = match_each_decimal_value!(other, |v| { - v.to_i256() - .vortex_expect("upcast to i256 must always succeed") - }); - - op(self_upcast, other_upcast).map(DecimalValue::I256) - } - /// Checked addition. Returns `None` on overflow. pub fn checked_add(&self, other: &Self) -> Option { - self.checked_binary_op(other, |a, b| a.checked_add(&b)) + checked_binary_op!(self, other, CheckedAdd::checked_add) } /// Checked subtraction. Returns `None` on overflow. pub fn checked_sub(&self, other: &Self) -> Option { - self.checked_binary_op(other, |a, b| a.checked_sub(&b)) + checked_binary_op!(self, other, CheckedSub::checked_sub) } /// Checked multiplication. Returns `None` on overflow. pub fn checked_mul(&self, other: &Self) -> Option { - self.checked_binary_op(other, |a, b| a.checked_mul(&b)) + checked_binary_op!(self, other, CheckedMul::checked_mul) } /// Checked division. Returns `None` on overflow or division by zero. pub fn checked_div(&self, other: &Self) -> Option { - self.checked_binary_op(other, |a, b| a.checked_div(&b)) + checked_binary_op!(self, other, CheckedDiv::checked_div) } } diff --git a/vortex-array/src/scalar/typed_view/decimal/tests.rs b/vortex-array/src/scalar/typed_view/decimal/tests.rs index fa6ccfad96b..b9b3b11673c 100644 --- a/vortex-array/src/scalar/typed_view/decimal/tests.rs +++ b/vortex-array/src/scalar/typed_view/decimal/tests.rs @@ -994,7 +994,7 @@ fn test_decimal_value_checked_add() { let a = DecimalValue::I64(100); let b = DecimalValue::I64(200); let result = a.checked_add(&b).unwrap(); - assert_eq!(result, DecimalValue::I256(i256::from_i128(300))); + assert_eq!(result, DecimalValue::I64(300)); } #[test] @@ -1002,7 +1002,7 @@ fn test_decimal_value_checked_sub() { let a = DecimalValue::I64(500); let b = DecimalValue::I64(200); let result = a.checked_sub(&b).unwrap(); - assert_eq!(result, DecimalValue::I256(i256::from_i128(300))); + assert_eq!(result, DecimalValue::I64(300)); } #[test] @@ -1010,7 +1010,7 @@ fn test_decimal_value_checked_mul() { let a = DecimalValue::I32(50); let b = DecimalValue::I32(10); let result = a.checked_mul(&b).unwrap(); - assert_eq!(result, DecimalValue::I256(i256::from_i128(500))); + assert_eq!(result, DecimalValue::I32(500)); } #[test] @@ -1018,7 +1018,7 @@ fn test_decimal_value_checked_div() { let a = DecimalValue::I64(1000); let b = DecimalValue::I64(10); let result = a.checked_div(&b).unwrap(); - assert_eq!(result, DecimalValue::I256(i256::from_i128(100))); + assert_eq!(result, DecimalValue::I64(100)); } #[test] @@ -1035,7 +1035,69 @@ fn test_decimal_value_mixed_types() { let a = DecimalValue::I8(10); let b = DecimalValue::I128(20); let result = a.checked_add(&b).unwrap(); - assert_eq!(result, DecimalValue::I256(i256::from_i128(30))); + assert_eq!(result, DecimalValue::I128(30)); +} + +#[test] +fn test_checked_ops_preserve_type() { + // Operations should return the wider of the two operand types, not unconditionally upcast to I256 + let add = DecimalValue::I32(5) + .checked_add(&DecimalValue::I32(3)) + .unwrap(); + assert_eq!(add.decimal_type(), DecimalType::I32); + + let sub = DecimalValue::I64(10) + .checked_sub(&DecimalValue::I64(3)) + .unwrap(); + assert_eq!(sub.decimal_type(), DecimalType::I64); + + let mul = DecimalValue::I8(2) + .checked_mul(&DecimalValue::I8(3)) + .unwrap(); + assert_eq!(mul.decimal_type(), DecimalType::I8); + + let div = DecimalValue::I128(10) + .checked_div(&DecimalValue::I128(2)) + .unwrap(); + assert_eq!(div.decimal_type(), DecimalType::I128); + + let add_i256 = DecimalValue::I256(i256::from_i128(1)) + .checked_add(&DecimalValue::I256(i256::from_i128(2))) + .unwrap(); + assert_eq!(add_i256.decimal_type(), DecimalType::I256); +} + +#[test] +fn test_checked_ops_mixed_types_use_wider() { + let add = DecimalValue::I8(1) + .checked_add(&DecimalValue::I64(2)) + .unwrap(); + assert_eq!(add.decimal_type(), DecimalType::I64); + + let sub = DecimalValue::I32(10) + .checked_sub(&DecimalValue::I128(3)) + .unwrap(); + assert_eq!(sub.decimal_type(), DecimalType::I128); +} + +#[test] +fn test_checked_ops_overflow_at_target_width() { + assert_eq!( + DecimalValue::I8(i8::MAX).checked_add(&DecimalValue::I8(1)), + None + ); + assert_eq!( + DecimalValue::I16(i16::MIN).checked_sub(&DecimalValue::I16(1)), + None + ); + assert_eq!( + DecimalValue::I32(i32::MAX).checked_mul(&DecimalValue::I32(2)), + None + ); + assert_eq!( + DecimalValue::I8(i8::MIN).checked_div(&DecimalValue::I8(-1)), + None + ); } #[test]