Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 143 additions & 0 deletions vortex-array/src/scalar_fn/fns/mask/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

mod kernel;
use std::fmt::Formatter;
use std::sync::Arc;

pub use kernel::*;
use vortex_error::VortexExpect;
Expand Down Expand Up @@ -30,8 +31,12 @@ use crate::scalar_fn::Arity;
use crate::scalar_fn::ChildName;
use crate::scalar_fn::EmptyOptions;
use crate::scalar_fn::ExecutionArgs;
use crate::scalar_fn::ReduceCtx;
use crate::scalar_fn::ReduceNode;
use crate::scalar_fn::ReduceNodeRef;
use crate::scalar_fn::ScalarFnId;
use crate::scalar_fn::ScalarFnVTable;
use crate::scalar_fn::ScalarFnVTableExt;
use crate::scalar_fn::SimplifyCtx;
use crate::scalar_fn::fns::literal::Literal;

Expand Down Expand Up @@ -111,6 +116,43 @@ impl ScalarFnVTable for Mask {
execute_canonical(input, mask_array, ctx)
}

fn reduce(
&self,
options: &Self::Options,
node: &dyn ReduceNode,
ctx: &dyn ReduceCtx,
) -> VortexResult<Option<ReduceNodeRef>> {
_ = options;
let input = node.child(0);
let Some(input_scalar_fn) = input.scalar_fn() else {
return Ok(None);
};

// The null-sensitivity property is exactly whether this rewrite is valid.
if input_scalar_fn.signature().is_null_sensitive() {
return Ok(None);
}

// Zero-arity scalar functions (e.g. literals) have no children to push the mask into.
if input.child_count() == 0 {
return Ok(None);
}

let mask = node.child(1);
let mut masked_children = Vec::with_capacity(input.child_count());
for child_idx in 0..input.child_count() {
let masked_child = ctx.new_node(
Mask.bind(EmptyOptions),
&[input.child(child_idx), Arc::clone(&mask)],
)?;
masked_children.push(masked_child);
}

Ok(Some(
ctx.new_node(input_scalar_fn.clone(), &masked_children)?,
))
}

fn simplify(
&self,
_options: &Self::Options,
Expand Down Expand Up @@ -193,12 +235,27 @@ fn execute_canonical(
mod test {
use vortex_error::VortexExpect;

use super::Mask;
use crate::IntoArray;
use crate::arrays::BoolArray;
use crate::arrays::ScalarFnVTable;
use crate::arrays::scalar_fn::ScalarFnArrayExt;
use crate::arrays::scalar_fn::ScalarFnFactoryExt;
use crate::dtype::DType;
use crate::dtype::Nullability;
use crate::dtype::Nullability::Nullable;
use crate::dtype::PType;
use crate::dtype::StructFields;
use crate::expr::col;
use crate::expr::is_null;
use crate::expr::lit;
use crate::expr::mask;
use crate::expr::not;
use crate::optimizer::ArrayOptimizer;
use crate::scalar::Scalar;
use crate::scalar_fn::EmptyOptions;
use crate::scalar_fn::fns::is_null::IsNull;
use crate::scalar_fn::fns::not::Not;

#[test]
fn test_simplify() {
Expand All @@ -219,4 +276,90 @@ mod test {
let expected_null_expr = lit(Scalar::null(DType::Primitive(PType::U32, Nullable)));
assert_eq!(&simplified_false, &expected_null_expr);
}

#[test]
fn test_reduce_pushdown_expression_not() {
let scope = DType::Struct(
StructFields::new(
["bool1", "m"].into(),
vec![
DType::Bool(Nullability::NonNullable),
DType::Bool(Nullability::NonNullable),
],
),
Nullability::NonNullable,
);

let expr = mask(not(col("bool1")), col("m"));
let reduced = expr.optimize(&scope).vortex_expect("optimize");

let expected = not(mask(col("bool1"), col("m")));
assert_eq!(reduced, expected);
}

#[test]
fn test_reduce_no_pushdown_expression_null_sensitive() {
let scope = DType::Struct(
StructFields::new(
["bool1", "m"].into(),
vec![
DType::Bool(Nullability::NonNullable),
DType::Bool(Nullability::NonNullable),
],
),
Nullability::NonNullable,
);

let expr = mask(is_null(col("bool1")), col("m"));
let reduced = expr.optimize(&scope).vortex_expect("optimize");
assert_eq!(reduced, expr);
}

#[test]
fn test_reduce_pushdown_array_not() {
let values = BoolArray::from_iter([true, false, true]).into_array();
let mask_values = BoolArray::from_iter([true, false, true]).into_array();

let not_array = Not
.try_new_array(values.len(), EmptyOptions, [values])
.vortex_expect("not array");
let mask_array = Mask
.try_new_array(mask_values.len(), EmptyOptions, [not_array, mask_values])
.vortex_expect("mask array");

let reduced = mask_array.optimize().vortex_expect("optimize");
let reduced_sfn = reduced
.as_opt::<ScalarFnVTable>()
.vortex_expect("expected scalar fn root");
assert!(reduced_sfn.scalar_fn().is::<Not>());

let child = reduced_sfn.child_at(0);
let child_sfn = child
.as_opt::<ScalarFnVTable>()
.vortex_expect("expected masked child");
assert!(child_sfn.scalar_fn().is::<Mask>());
}

#[test]
fn test_reduce_no_pushdown_array_null_sensitive() {
let values = BoolArray::from_iter([true, false, true]).into_array();
let mask_values = BoolArray::from_iter([true, false, true]).into_array();

let is_null_array = IsNull
.try_new_array(values.len(), EmptyOptions, [values])
.vortex_expect("is_null array");
let mask_array = Mask
.try_new_array(
mask_values.len(),
EmptyOptions,
[is_null_array, mask_values],
)
.vortex_expect("mask array");

let reduced = mask_array.optimize().vortex_expect("optimize");
let reduced_sfn = reduced
.as_opt::<ScalarFnVTable>()
.vortex_expect("expected scalar fn root");
assert!(reduced_sfn.scalar_fn().is::<Mask>());
}
}
Loading