Skip to content
245 changes: 197 additions & 48 deletions datafusion/functions-window/src/lead_lag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! `lead` and `lag` window function implementations

use crate::utils::get_default_value_from_args;
use crate::utils::{get_scalar_value_from_args, get_signed_integer};
use arrow::datatypes::FieldRef;
use datafusion_common::arrow::array::ArrayRef;
Expand Down Expand Up @@ -58,6 +59,12 @@ get_or_init_udwf!(
WindowShift::lead
);

#[derive(Debug, Clone)]
pub enum DefaultValue {
Literal(ScalarValue),
Expression,
}

/// Create an expression to represent the `lag` window function
///
/// returns value evaluated at the row that is offset rows before the current row within the partition;
Expand Down Expand Up @@ -176,19 +183,21 @@ static LAG_DOCUMENTATION: LazyLock<Documentation> = LazyLock::new(|| {
-- Example usage of the lag window function:
SELECT employee_id,
salary,
lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary
lag(salary, 1, 0) OVER (ORDER BY employee_id) AS prev_salary,
lag(salary, 1, salary) OVER (ORDER BY employee_id) AS prev_salary_or_current
FROM employees;

+-------------+--------+-------------+
| employee_id | salary | prev_salary |
+-------------+--------+-------------+
| 1 | 30000 | 0 |
| 2 | 50000 | 30000 |
| 3 | 70000 | 50000 |
| 4 | 60000 | 70000 |
+-------------+--------+-------------+
+-------------+--------+-------------+------------------------+
| employee_id | salary | prev_salary | prev_salary_or_current |
+-------------+--------+-------------+------------------------+
| 1 | 30000 | 0 | 30000 |
| 2 | 50000 | 30000 | 30000 |
| 3 | 70000 | 50000 | 50000 |
| 4 | 60000 | 70000 | 70000 |
+-------------+--------+-------------+------------------------+
```
"#)

.build()
});

Expand All @@ -214,18 +223,19 @@ SELECT
employee_id,
department,
salary,
lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary
lead(salary, 1, 0) OVER (PARTITION BY department ORDER BY salary) AS next_salary,
lead(salary, 1, salary) OVER (PARTITION BY department ORDER BY salary) AS next_salary_or_current
FROM employees;

+-------------+-------------+--------+--------------+
| employee_id | department | salary | next_salary |
+-------------+-------------+--------+--------------+
| 1 | Sales | 30000 | 50000 |
| 2 | Sales | 50000 | 70000 |
| 3 | Sales | 70000 | 0 |
| 4 | Engineering | 40000 | 60000 |
| 5 | Engineering | 60000 | 0 |
+-------------+-------------+--------+--------------+
+-------------+-------------+--------+--------------+------------------------+
| employee_id | department | salary | next_salary | next_salary_or_current |
+-------------+-------------+--------+--------------+------------------------+
| 1 | Sales | 30000 | 50000 | 50000 |
| 2 | Sales | 50000 | 70000 | 70000 |
| 3 | Sales | 70000 | 0 | 70000 |
| 4 | Engineering | 40000 | 60000 | 60000 |
| 5 | Engineering | 60000 | 0 | 60000 |
+-------------+-------------+--------+--------------+------------------------+
```
"#)
.build()
Expand All @@ -244,15 +254,27 @@ impl WindowUDFImpl for WindowShift {
&self.signature
}

/// Handles the case where `NULL` expression is passed as an
/// argument to `lead`/`lag`. The type is refined depending
/// on the default value argument.
/// Handles cases:
/// - where `NULL` expression is passed as an argument to `lead`/`lag`. The type is refined depending
/// on the default value argument.
/// - where input expression contains another expression (PhysicalExpr)
/// in this case, in later evaluate() and evaluate_all() we will have result of applying
/// this PhysicalExpr to the RecordBatch (thus, we can use it as a default value)
///
/// For more details see: <https://github.com/apache/datafusion/issues/12717>
fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
parse_expr(expr_args.input_exprs(), expr_args.input_fields())
.into_iter()
.collect::<Vec<_>>()
let input_exprs = expr_args.input_exprs();
let mut result = Vec::new();

let main_expr =
parse_expr(expr_args.input_exprs(), expr_args.input_fields()).unwrap();
result.push(main_expr);

// Pushing the expression (not a literal value) to the result, so it would be executed
if input_exprs.len() >= 3 {
result.push(Arc::clone(&input_exprs[2]));
}
result
}

fn partition_evaluator(
Expand Down Expand Up @@ -391,20 +413,15 @@ fn parse_expr_field(input_fields: &[FieldRef]) -> Result<FieldRef> {
fn parse_default_value(
input_exprs: &[Arc<dyn PhysicalExpr>],
input_types: &[FieldRef],
) -> Result<ScalarValue> {
) -> Result<DefaultValue> {
let expr_field = parse_expr_field(input_types)?;
let unparsed = get_scalar_value_from_args(input_exprs, 2)?;

unparsed
.filter(|v| !v.data_type().is_null())
.map(|v| v.cast_to(expr_field.data_type()))
.unwrap_or_else(|| ScalarValue::try_from(expr_field.data_type()))
get_default_value_from_args(input_exprs, 2, &expr_field)
}

#[derive(Debug)]
struct WindowShiftEvaluator {
shift_offset: i64,
default_value: ScalarValue,
default_value: DefaultValue,
ignore_nulls: bool,
// VecDeque contains offset values that between non-null entries
non_null_offsets: VecDeque<usize>,
Expand Down Expand Up @@ -501,6 +518,89 @@ fn shift_with_default_value(
}
}

fn shift_with_array_default(
array: &ArrayRef,
offset: i64,
default_values: &ArrayRef,
) -> Result<ArrayRef> {
use datafusion_common::arrow::compute::concat;

let value_len = array.len() as i64;
if offset == 0 {
return Ok(Arc::clone(array));
}
if offset == i64::MIN || offset.abs() >= value_len {
return Ok(Arc::clone(default_values));
}

let slice_offset = (-offset).clamp(0, value_len) as usize;
let length = array.len() - offset.unsigned_abs() as usize;
let slice = array.slice(slice_offset, length);

let defaults_slice = if offset > 0 {
// Lag: defaults go at the beginning
default_values.slice(0, offset.unsigned_abs() as usize)
} else {
// Lead: defaults go at the end
let start = default_values.len() - offset.unsigned_abs() as usize;
default_values.slice(start, offset.unsigned_abs() as usize)
};

if offset > 0 {
concat(&[defaults_slice.as_ref(), slice.as_ref()])
} else {
concat(&[slice.as_ref(), defaults_slice.as_ref()])
}
.map_err(|e| arrow_datafusion_err!(e))
}

fn evaluate_all_with_ignore_null_and_array_default(
array: &ArrayRef,
offset: i64,
default_values: &ArrayRef,
is_lag: bool,
) -> Result<ArrayRef> {
let valid_indices: Vec<usize> = array.nulls().unwrap().valid_indices().collect();
let direction = !is_lag;
let results: Result<Vec<_>> = (0..array.len())
.map(|id| {
let result_index = match valid_indices.binary_search(&id) {
Ok(pos) => if direction {
pos.checked_add(offset as usize)
} else {
pos.checked_sub(offset.unsigned_abs() as usize)
}
.and_then(|new_pos| {
if new_pos < valid_indices.len() {
Some(valid_indices[new_pos])
} else {
None
}
}),
Err(pos) => if direction {
pos.checked_add(offset as usize)
} else if pos > 0 {
pos.checked_sub(offset.unsigned_abs() as usize)
} else {
None
}
.and_then(|new_pos| {
if new_pos < valid_indices.len() {
Some(valid_indices[new_pos])
} else {
None
}
}),
};
match result_index {
Some(index) => ScalarValue::try_from_array(array, index),
None => ScalarValue::try_from_array(&default_values, id),
}
})
.collect();
ScalarValue::iter_to_array(results?)
}

impl PartitionEvaluator for WindowShiftEvaluator {
fn get_range(&self, idx: usize, n_rows: usize) -> Result<Range<usize>> {
if self.is_lag() {
Expand Down Expand Up @@ -640,7 +740,29 @@ impl PartitionEvaluator for WindowShiftEvaluator {
if !(idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap()))) {
ScalarValue::try_from_array(array, idx.unwrap())
} else {
Ok(self.default_value.clone())
match &self.default_value {
DefaultValue::Literal(scalar) => Ok(scalar.clone()),
DefaultValue::Expression => {
let current_row = if self.is_lag() {
range.end.saturating_sub(1)
} else {
range.start
};

values
.get(1)
.map(|defaults| {
let scalar =
ScalarValue::try_from_array(defaults, current_row)?;
if scalar.data_type() != *array.data_type() {
scalar.cast_to(array.data_type())
} else {
Ok(scalar)
}
})
.unwrap_or_else(|| ScalarValue::try_from(array.data_type()))
}
}
}
}

Expand All @@ -649,17 +771,44 @@ impl PartitionEvaluator for WindowShiftEvaluator {
values: &[ArrayRef],
_num_rows: usize,
) -> Result<ArrayRef> {
// LEAD, LAG window functions take single column, values will have size 1
// LEAD, LAG window functions take single column, values will have size:
// '1' - when default_value is a ScalarValue (or we simply did not specify it)
// '2' - when default_value is a PhysicalExpr
let value = &values[0];
if !self.ignore_nulls {
shift_with_default_value(value, self.shift_offset, &self.default_value)
} else {
evaluate_all_with_ignore_null(
value,
self.shift_offset,
&self.default_value,
self.is_lag(),
)
match &self.default_value {
DefaultValue::Literal(scalar) => {
if !self.ignore_nulls {
shift_with_default_value(value, self.shift_offset, &scalar.clone())
} else {
evaluate_all_with_ignore_null(
value,
self.shift_offset,
&scalar.clone(),
self.is_lag(),
)
}
}
DefaultValue::Expression => {
let default_array = values.get(1).cloned().unwrap_or_else(|| {
Arc::new(arrow::array::NullArray::new(value.len()))
});
let default_array = if default_array.data_type() != value.data_type() {
arrow::compute::kernels::cast::cast(&default_array, value.data_type())
.map_err(|e| arrow_datafusion_err!(e))?
} else {
default_array
};
if !self.ignore_nulls {
shift_with_array_default(value, self.shift_offset, &default_array)
} else {
evaluate_all_with_ignore_null_and_array_default(
value,
self.shift_offset,
&default_array,
self.is_lag(),
)
}
}
}
}

Expand Down Expand Up @@ -696,7 +845,7 @@ mod tests {
// LAG(2)
let lag_fn = WindowShiftEvaluator {
shift_offset: 2,
default_value: ScalarValue::Null,
default_value: DefaultValue::Literal(ScalarValue::Null),
ignore_nulls: false,
non_null_offsets: Default::default(),
};
Expand All @@ -706,7 +855,7 @@ mod tests {
// LAG(2 ignore nulls)
let lag_fn = WindowShiftEvaluator {
shift_offset: 2,
default_value: ScalarValue::Null,
default_value: DefaultValue::Literal(ScalarValue::Null),
ignore_nulls: true,
// models data received [<Some>, <Some>, <Some>, NULL, <Some>, NULL, <current row>, ...]
non_null_offsets: vec![2, 2].into(), // [1, 1, 2, 2] actually, just last 2 is used
Expand All @@ -716,7 +865,7 @@ mod tests {
// LEAD(2)
let lead_fn = WindowShiftEvaluator {
shift_offset: -2,
default_value: ScalarValue::Null,
default_value: DefaultValue::Literal(ScalarValue::Null),
ignore_nulls: false,
non_null_offsets: Default::default(),
};
Expand All @@ -726,7 +875,7 @@ mod tests {
// LEAD(2 ignore nulls)
let lead_fn = WindowShiftEvaluator {
shift_offset: -2,
default_value: ScalarValue::Null,
default_value: DefaultValue::Literal(ScalarValue::Null),
ignore_nulls: true,
// models data received [..., <current row>, NULL, <Some>, NULL, <Some>, ..]
non_null_offsets: vec![2, 2].into(),
Expand Down
28 changes: 28 additions & 0 deletions datafusion/functions-window/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::Field;
use datafusion_common::arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result, ScalarValue, exec_err};
use datafusion_physical_expr::expressions::Literal;
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
use std::sync::Arc;

use crate::lead_lag::DefaultValue;

pub(crate) fn get_signed_integer(value: &ScalarValue) -> Result<i64> {
if value.is_null() {
return Ok(0);
Expand Down Expand Up @@ -51,6 +54,31 @@ pub(crate) fn get_scalar_value_from_args(
})
}

pub(crate) fn get_default_value_from_args(
args: &[Arc<dyn PhysicalExpr>],
index: usize,
field: &Arc<Field>,
) -> Result<DefaultValue> {
match args.get(index) {
Some(expr) => {
if let Some(literal) = expr.downcast_ref::<Literal>() {
let scalar = literal.value().clone();
let scalar = if !scalar.data_type().is_null() {
scalar.cast_to(field.data_type())
} else {
ScalarValue::try_from(field.data_type())
}?;
Ok(DefaultValue::Literal(scalar))
} else {
Ok(DefaultValue::Expression)
}
}
None => Ok(DefaultValue::Literal(ScalarValue::try_from(
field.data_type(),
)?)),
}
}

pub(crate) fn get_unsigned_integer(value: &ScalarValue) -> Result<u64> {
if value.is_null() {
return Ok(0);
Expand Down
Loading
Loading