diff --git a/vortex-array/src/scalar_fn/fns/mask/mod.rs b/vortex-array/src/scalar_fn/fns/mask/mod.rs index eaf83307b35..2a41e3d1610 100644 --- a/vortex-array/src/scalar_fn/fns/mask/mod.rs +++ b/vortex-array/src/scalar_fn/fns/mask/mod.rs @@ -3,6 +3,7 @@ mod kernel; use std::fmt::Formatter; +use std::sync::Arc; pub use kernel::*; use vortex_error::VortexExpect; @@ -30,8 +31,12 @@ use crate::scalar_fn::Arity; use crate::scalar_fn::ChildName; use crate::scalar_fn::EmptyOptions; use crate::scalar_fn::ExecutionArgs; +use crate::scalar_fn::ReduceCtx; +use crate::scalar_fn::ReduceNode; +use crate::scalar_fn::ReduceNodeRef; use crate::scalar_fn::ScalarFnId; use crate::scalar_fn::ScalarFnVTable; +use crate::scalar_fn::ScalarFnVTableExt; use crate::scalar_fn::SimplifyCtx; use crate::scalar_fn::fns::literal::Literal; @@ -111,6 +116,43 @@ impl ScalarFnVTable for Mask { execute_canonical(input, mask_array, ctx) } + fn reduce( + &self, + options: &Self::Options, + node: &dyn ReduceNode, + ctx: &dyn ReduceCtx, + ) -> VortexResult> { + _ = options; + let input = node.child(0); + let Some(input_scalar_fn) = input.scalar_fn() else { + return Ok(None); + }; + + // The null-sensitivity property is exactly whether this rewrite is valid. + if input_scalar_fn.signature().is_null_sensitive() { + return Ok(None); + } + + // Zero-arity scalar functions (e.g. literals) have no children to push the mask into. + if input.child_count() == 0 { + return Ok(None); + } + + let mask = node.child(1); + let mut masked_children = Vec::with_capacity(input.child_count()); + for child_idx in 0..input.child_count() { + let masked_child = ctx.new_node( + Mask.bind(EmptyOptions), + &[input.child(child_idx), Arc::clone(&mask)], + )?; + masked_children.push(masked_child); + } + + Ok(Some( + ctx.new_node(input_scalar_fn.clone(), &masked_children)?, + )) + } + fn simplify( &self, _options: &Self::Options, @@ -193,12 +235,27 @@ fn execute_canonical( mod test { use vortex_error::VortexExpect; + use super::Mask; + use crate::IntoArray; + use crate::arrays::BoolArray; + use crate::arrays::ScalarFnVTable; + use crate::arrays::scalar_fn::ScalarFnArrayExt; + use crate::arrays::scalar_fn::ScalarFnFactoryExt; use crate::dtype::DType; + use crate::dtype::Nullability; use crate::dtype::Nullability::Nullable; use crate::dtype::PType; + use crate::dtype::StructFields; + use crate::expr::col; + use crate::expr::is_null; use crate::expr::lit; use crate::expr::mask; + use crate::expr::not; + use crate::optimizer::ArrayOptimizer; use crate::scalar::Scalar; + use crate::scalar_fn::EmptyOptions; + use crate::scalar_fn::fns::is_null::IsNull; + use crate::scalar_fn::fns::not::Not; #[test] fn test_simplify() { @@ -219,4 +276,90 @@ mod test { let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable))); assert_eq!(&simplified_false, &expected_null_expr); } + + #[test] + fn test_reduce_pushdown_expression_not() { + let scope = DType::Struct( + StructFields::new( + ["bool1", "m"].into(), + vec![ + DType::Bool(Nullability::NonNullable), + DType::Bool(Nullability::NonNullable), + ], + ), + Nullability::NonNullable, + ); + + let expr = mask(not(col("bool1")), col("m")); + let reduced = expr.optimize(&scope).vortex_expect("optimize"); + + let expected = not(mask(col("bool1"), col("m"))); + assert_eq!(reduced, expected); + } + + #[test] + fn test_reduce_no_pushdown_expression_null_sensitive() { + let scope = DType::Struct( + StructFields::new( + ["bool1", "m"].into(), + vec![ + DType::Bool(Nullability::NonNullable), + DType::Bool(Nullability::NonNullable), + ], + ), + Nullability::NonNullable, + ); + + let expr = mask(is_null(col("bool1")), col("m")); + let reduced = expr.optimize(&scope).vortex_expect("optimize"); + assert_eq!(reduced, expr); + } + + #[test] + fn test_reduce_pushdown_array_not() { + let values = BoolArray::from_iter([true, false, true]).into_array(); + let mask_values = BoolArray::from_iter([true, false, true]).into_array(); + + let not_array = Not + .try_new_array(values.len(), EmptyOptions, [values]) + .vortex_expect("not array"); + let mask_array = Mask + .try_new_array(mask_values.len(), EmptyOptions, [not_array, mask_values]) + .vortex_expect("mask array"); + + let reduced = mask_array.optimize().vortex_expect("optimize"); + let reduced_sfn = reduced + .as_opt::() + .vortex_expect("expected scalar fn root"); + assert!(reduced_sfn.scalar_fn().is::()); + + let child = reduced_sfn.child_at(0); + let child_sfn = child + .as_opt::() + .vortex_expect("expected masked child"); + assert!(child_sfn.scalar_fn().is::()); + } + + #[test] + fn test_reduce_no_pushdown_array_null_sensitive() { + let values = BoolArray::from_iter([true, false, true]).into_array(); + let mask_values = BoolArray::from_iter([true, false, true]).into_array(); + + let is_null_array = IsNull + .try_new_array(values.len(), EmptyOptions, [values]) + .vortex_expect("is_null array"); + let mask_array = Mask + .try_new_array( + mask_values.len(), + EmptyOptions, + [is_null_array, mask_values], + ) + .vortex_expect("mask array"); + + let reduced = mask_array.optimize().vortex_expect("optimize"); + let reduced_sfn = reduced + .as_opt::() + .vortex_expect("expected scalar fn root"); + assert!(reduced_sfn.scalar_fn().is::()); + } }