Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 202 additions & 29 deletions datafusion/functions-nested/src/remove.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
//! [`ScalarUDFImpl`] definitions for array_remove, array_remove_n, array_remove_all functions.

use crate::utils;
use crate::utils::make_scalar_function;
use arrow::array::{
Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetSizeTrait,
cast::AsArray, make_array,
Array, ArrayRef, Capacities, GenericListArray, MutableArrayData, OffsetBufferBuilder,
OffsetSizeTrait, Scalar, cast::AsArray, make_array,
};
use arrow::buffer::{NullBuffer, OffsetBuffer};
use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cast::as_int64_array;
use datafusion_common::utils::ListCoercion;
use datafusion_common::{Result, exec_err, internal_err, utils::take_function_args};
use datafusion_common::{
Result, ScalarValue, exec_err, internal_err, utils::take_function_args,
};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, Documentation,
ScalarFunctionArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility,
Expand Down Expand Up @@ -113,7 +114,24 @@ impl ScalarUDFImpl for ArrayRemove {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_inner)(&args.args)
let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
match element_arg {
ColumnarValue::Scalar(scalar_element)
if !scalar_element.is_null()
&& !scalar_element.data_type().is_nested() =>
{
let result =
array_remove_with_scalar_args(&list_array, scalar_element, 1i64)?;
Ok(ColumnarValue::Array(result))
}
element_arg => {
let element_array = element_arg.to_array(num_rows)?;
let result = array_remove_internal(&list_array, &element_array, &[1])?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -214,7 +232,40 @@ impl ScalarUDFImpl for ArrayRemoveN {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_n_inner)(&args.args)
let [list_arg, element_arg, max_arg] =
take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
match (element_arg, max_arg) {
(
ColumnarValue::Scalar(scalar_element),
ColumnarValue::Scalar(scalar_max),
) if !scalar_element.is_null() && !scalar_element.data_type().is_nested() => {
let ScalarValue::Int64(Some(n)) = scalar_max else {
// null max means no remove
return Ok(ColumnarValue::Array(list_array));
};
let result =
array_remove_with_scalar_args(&list_array, scalar_element, *n)?;
Ok(ColumnarValue::Array(result))
}
(element_arg, max_arg) => {
let element_array = element_arg.to_array(num_rows)?;
let max_array = max_arg.to_array(num_rows)?;
let max_array = as_int64_array(&max_array)?;
let arr_n = (0..max_array.len())
.map(|i| {
if max_array.is_null(i) {
0
} else {
max_array.value(i)
}
})
.collect::<Vec<_>>();
let result = array_remove_internal(&list_array, &element_array, &arr_n)?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand Down Expand Up @@ -304,7 +355,25 @@ impl ScalarUDFImpl for ArrayRemoveAll {
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
make_scalar_function(array_remove_all_inner)(&args.args)
let [list_arg, element_arg] = take_function_args(self.name(), &args.args)?;
let num_rows = args.number_rows;
let list_array = list_arg.to_array(num_rows)?;
match element_arg {
ColumnarValue::Scalar(scalar_element)
if !scalar_element.is_null()
&& !scalar_element.data_type().is_nested() =>
{
let result =
array_remove_with_scalar_args(&list_array, scalar_element, i64::MAX)?;
Ok(ColumnarValue::Array(result))
}
element_arg => {
let element_array = element_arg.to_array(num_rows)?;
let result =
array_remove_internal(&list_array, &element_array, &[i64::MAX])?;
Ok(ColumnarValue::Array(result))
}
}
}

fn aliases(&self) -> &[String] {
Expand All @@ -316,27 +385,6 @@ impl ScalarUDFImpl for ArrayRemoveAll {
}
}

fn array_remove_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element] = take_function_args("array_remove", args)?;

let arr_n = vec![1; array.len()];
array_remove_internal(array, element, &arr_n)
}

fn array_remove_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element, max] = take_function_args("array_remove_n", args)?;

let arr_n = as_int64_array(max)?.values().to_vec();
array_remove_internal(array, element, &arr_n)
}

fn array_remove_all_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
let [array, element] = take_function_args("array_remove_all", args)?;

let arr_n = vec![i64::MAX; array.len()];
array_remove_internal(array, element, &arr_n)
}

fn array_remove_internal(
array: &ArrayRef,
element_array: &ArrayRef,
Expand All @@ -357,6 +405,28 @@ fn array_remove_internal(
}
}

/// Fast path for `array_remove` when the needle is a non-null, non-nested scalar.
/// Dispatches to the bulk `not_distinct` comparison kernel.
fn array_remove_with_scalar_args(
array: &ArrayRef,
scalar_needle: &ScalarValue,
max_removals: i64,
) -> Result<ArrayRef> {
match array.data_type() {
DataType::List(_) => {
let list_array = array.as_list::<i32>();
general_remove_with_scalar::<i32>(list_array, scalar_needle, max_removals)
}
DataType::LargeList(_) => {
let list_array = array.as_list::<i64>();
general_remove_with_scalar::<i64>(list_array, scalar_needle, max_removals)
}
array_type => exec_err!(
"array_remove/array_remove_n/array_remove_all does not support type '{array_type}'."
),
}
}

/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurrences
/// of `element_array[i]`.
///
Expand Down Expand Up @@ -411,7 +481,11 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
let start = offset_window[0].to_usize().unwrap();
let end = offset_window[1].to_usize().unwrap();
// n is the number of elements to remove in this row
let n = arr_n[row_index];
let n = if arr_n.len() == 1 {
arr_n[0]
} else {
arr_n[row_index]
};

// compare each element in the list, `false` means the element matches and should be removed
let eq_array = utils::compare_element_to_list(
Expand Down Expand Up @@ -468,6 +542,105 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
)?))
}

/// For each element of `list_array[i]`, removes up to `max_removals` occurrences
/// of the scalar needle.
///
/// This is a specialized version of `general_remove` for scalar elements that
/// uses bulk comparison for better performance.
fn general_remove_with_scalar<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
scalar_needle: &ScalarValue,
max_removals: i64,
) -> Result<ArrayRef> {
if max_removals <= 0 {
return Ok(Arc::new(list_array.clone()));
}

let list_field = match list_array.data_type() {
DataType::List(field) | DataType::LargeList(field) => field,
_ => {
return exec_err!(
"Expected List or LargeList data type, got {:?}",
list_array.data_type()
);
}
};

let list_offsets = list_array.offsets();
let first_offset = list_offsets[0].to_usize().unwrap();
let last_offset = list_offsets[list_offsets.len() - 1].to_usize().unwrap();
let values_range_len = last_offset - first_offset;
let values_slice = list_array.values().slice(first_offset, values_range_len);
let original_data = values_slice.to_data();
let mut offsets = OffsetBufferBuilder::<OffsetSize>::new(list_array.len());

let mut mutable = MutableArrayData::with_capacities(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if an approach using take kernel could provide even more performance gains?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I benchmarked both take and filter kernel approaches against the current MutableArrayData::extend path. Overall, MutableArrayData performs best — on small lists (size=10) take is faster (~10%), possibly by avoiding MutableArrayData initialization overhead, but on medium/large lists (size≥100) MutableArrayData tends to win decisively (take is 60–170% slower depending on type). For variable-length types (strings), the gap appears to widen further.

One possible explanation is that take performs per-index random access for each element, whereas MutableArrayData may instead execute contiguous memcpy operations over memory regions?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But, the current benchmarks use 10% needle density (sparse removals / high retention), which is likely the most common case? For dense removal workloads the trade-offs may shift, take or filter could become competitive when there are fewer contiguous ranges to memcpy.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, especially as we're operating on list arrays anyway

vec![&original_data],
false,
Capacities::Array(original_data.len()),
);
let nulls = list_array.nulls().cloned();
let needle = scalar_needle.to_array_of_size(1)?;
let remove_mask = arrow_ord::cmp::not_distinct(&values_slice, &Scalar::new(needle))?;
let remove_bits = remove_mask.values();

for (row_index, offset_window) in list_offsets.windows(2).enumerate() {
if nulls.as_ref().is_some_and(|nulls| nulls.is_null(row_index)) {
offsets.push_length(0);
continue;
}

let start = offset_window[0].to_usize().unwrap() - first_offset;
let end = offset_window[1].to_usize().unwrap() - first_offset;
let row_len = end - start;

let row_remove_bits = remove_bits.slice(start, row_len);
let num_to_remove = row_remove_bits.count_set_bits();

if num_to_remove == 0 {
mutable.extend(0, start, end);
offsets.push_length(row_len);
continue;
}

let removals_to_apply = max_removals.min(num_to_remove as i64) as usize;

// Iterate only over the removal positions via set_indices. This is
// efficient when the number of removals is small relative to the row
// length (common case), since it skips over retained elements.
let mut removed = 0usize;
let mut copied = 0usize;
let mut prev_end = start;
for remove_pos in row_remove_bits.set_indices() {
let abs_pos = start + remove_pos;
if abs_pos > prev_end {
mutable.extend(0, prev_end, abs_pos);
copied += abs_pos - prev_end;
}
prev_end = abs_pos + 1;
removed += 1;
if removed == removals_to_apply {
break;
}
}
// Copy the remaining tail after the last removal
if prev_end < end {
mutable.extend(0, prev_end, end);
copied += end - prev_end;
}

offsets.push_length(copied);
}

let new_values = make_array(mutable.freeze());
Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
Arc::clone(list_field),
offsets.finish(),
new_values,
nulls,
)?))
}

#[cfg(test)]
mod tests {
use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
Expand Down
Loading
Loading