-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Simplify get_field over inline struct constructors #22239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
7540e45
cd20ebb
4f70d80
f799333
11a390e
b2e0978
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| // specific language governing permissions and limitations | ||
| // under the License. | ||
|
|
||
| use std::sync::Arc; | ||
| use std::sync::{Arc, OnceLock}; | ||
|
|
||
| use arrow::array::{ | ||
| Array, BooleanArray, Capacities, MutableArrayData, Scalar, cast::AsArray, make_array, | ||
|
|
@@ -37,6 +37,9 @@ use datafusion_expr::{ | |
| }; | ||
| use datafusion_macros::user_doc; | ||
|
|
||
| use super::named_struct::NamedStructFunc; | ||
| use super::r#struct::StructFunc; | ||
|
|
||
| #[user_doc( | ||
| doc_section(label = "Other Functions"), | ||
| description = r#"Returns a field within a map or a struct with the given key. | ||
|
|
@@ -249,6 +252,120 @@ fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<Column | |
| } | ||
| } | ||
|
|
||
| /// The shared `get_field` UDF, reused whenever simplification needs to build a | ||
| /// fresh `get_field` node (e.g. re-wrapping the remaining access path). | ||
| fn get_field_udf() -> Arc<ScalarUDF> { | ||
| static GET_FIELD_UDF: OnceLock<Arc<ScalarUDF>> = OnceLock::new(); | ||
| Arc::clone( | ||
| GET_FIELD_UDF | ||
| .get_or_init(|| Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::new()))), | ||
| ) | ||
| } | ||
|
|
||
| /// Try to simplify a `get_field` call whose base is an inline struct | ||
| /// constructor by resolving the field access at plan time. | ||
| /// | ||
| /// Handles both struct constructors: | ||
| /// * `named_struct('a', x, 'b', y)` — fields are looked up by name. | ||
| /// * `struct(x, y)` — fields are positional and named `c0`, `c1`, ... | ||
| /// | ||
| /// For example: | ||
| /// * `get_field(named_struct('min', a, 'max', b), 'max')` => `b` | ||
| /// * `get_field(struct(a, b), 'c1')` => `b` | ||
| /// | ||
| /// `args` is the (already flattened) argument list of the `get_field` call: | ||
| /// `[base, field_name, rest_of_path...]`. When extra path elements remain | ||
| /// after resolving the first one (`get_field(named_struct('s', inner), 's', 'k')`), | ||
| /// the resolved value is re-wrapped in a `get_field` call for the remaining | ||
| /// path so the simplifier can recurse into it on the next pass. | ||
| /// | ||
| /// Returns `None` — leaving the expression untouched — whenever the rewrite | ||
| /// cannot be proven safe, e.g. a non-literal field name, a `named_struct` | ||
| /// with a non-literal field name (which might shadow the requested field at | ||
| /// runtime), or a field the constructor does not produce. | ||
| /// | ||
| /// Replacing the access with the selected field expression drops the | ||
| /// expressions for the other (unaccessed) fields, so they are no longer | ||
| /// evaluated — e.g. `get_field(named_struct('a', 1/0, 'b', x), 'b')` becomes | ||
| /// `x` and the `1/0` is never evaluated. This is intentional and matches the | ||
| /// optimizer's contract for immutable expressions: a simplification may drop | ||
| /// sub-expressions whose value is not observed. | ||
| fn simplify_get_field_over_struct_constructor(args: &[Expr]) -> Option<Expr> { | ||
| let [base, field_name, rest @ ..] = args else { | ||
| return None; | ||
| }; | ||
|
|
||
| // The accessed field name must be a non-empty string literal. | ||
| let Expr::Literal(field_name, _) = field_name else { | ||
| return None; | ||
| }; | ||
| let field_name = field_name | ||
| .try_as_str() | ||
| .flatten() | ||
| .filter(|s| !s.is_empty())?; | ||
|
|
||
| let Expr::ScalarFunction(ScalarFunction { | ||
| func, | ||
| args: ctor_args, | ||
| }) = base | ||
| else { | ||
| return None; | ||
| }; | ||
|
|
||
| let value = if func.inner().is::<NamedStructFunc>() { | ||
| // named_struct(name1, value1, name2, value2, ...) | ||
| if !ctor_args.len().is_multiple_of(2) { | ||
| return None; | ||
| } | ||
| let mut matched = None; | ||
| for pair in ctor_args.chunks_exact(2) { | ||
| // Every name must be a literal string: a non-literal name appearing | ||
| // *before* the first match could evaluate to `field_name` at runtime | ||
| // and become the real first match (Arrow's `column_by_name` returns | ||
| // the first match), so we cannot resolve the access. | ||
| // | ||
| // We conservatively bail on *any* non-literal name. Once a literal | ||
| // match has been found, a later non-literal name is in fact harmless | ||
| // — it can never precede the first match — so bailing there is a | ||
| // deliberate approximation we accept to keep this check simple, not a | ||
| // correctness requirement. | ||
| let Expr::Literal(name, _) = &pair[0] else { | ||
| return None; | ||
| }; | ||
| let name = name.try_as_str().flatten()?; | ||
| // `column_by_name` resolves to the first match, so do the same. | ||
| if matched.is_none() && name == field_name { | ||
| matched = Some(&pair[1]); | ||
| } | ||
| } | ||
| matched?.clone() | ||
| } else if func.inner().is::<StructFunc>() { | ||
| // struct(value0, value1, ...) produces fields named c0, c1, ... | ||
| let index: usize = field_name.strip_prefix('c')?.parse().ok()?; | ||
| // Reject non-canonical spellings (e.g. "c01") that name no real field. | ||
| if format!("c{index}") != field_name { | ||
| return None; | ||
| } | ||
| ctor_args.get(index)?.clone() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The early
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done — expanded the comment to call out that bailing on a non-literal field name is a deliberate conservative approximation rather than an accidental missed optimization. To be precise about the rationale: a non-literal name appearing after the first literal match can't change the result, since |
||
| } else { | ||
| return None; | ||
| }; | ||
|
|
||
| if rest.is_empty() { | ||
| return Some(value); | ||
| } | ||
|
|
||
| // Remaining path elements: re-wrap as get_field(value, rest...) and let | ||
| // the simplifier resolve the rest on a subsequent pass. | ||
| let mut new_args = Vec::with_capacity(rest.len() + 1); | ||
| new_args.push(value); | ||
| new_args.extend_from_slice(rest); | ||
| Some(Expr::ScalarFunction(ScalarFunction::new_udf( | ||
| get_field_udf(), | ||
| new_args, | ||
| ))) | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The re-wrap path creates a fresh
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in cd20ebb — added a |
||
|
|
||
| impl GetFieldFunc { | ||
| pub fn new() -> Self { | ||
| Self { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This simplification can eliminate evaluation of struct fields that are never accessed, for example
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 4f70d80 — added a paragraph to the function doc comment noting that replacing the access drops the other fields' expressions so they are no longer evaluated (e.g. |
||
|
|
@@ -479,14 +596,12 @@ impl ScalarUDFImpl for GetFieldFunc { | |
|
|
||
| // Flatten all nested get_field calls in a single pass | ||
| // Pattern: get_field(get_field(get_field(base, a), b), c) => get_field(base, a, b, c) | ||
|
|
||
| // Collect path arguments from all nested levels | ||
| let mut path_args_stack = Vec::new(); | ||
| // | ||
| // `path_args_stack` collects each level's field-name arguments, | ||
| // outermost first; it is reversed below to restore access order. | ||
| let mut path_args_stack = vec![&args[1..]]; | ||
| let mut current_expr = &args[0]; | ||
|
|
||
| // Push the outermost path arguments first | ||
| path_args_stack.push(&args[1..]); | ||
|
|
||
| // Walk down the chain of nested get_field calls | ||
| let base_expr = loop { | ||
| if let Expr::ScalarFunction(ScalarFunction { | ||
|
|
@@ -506,28 +621,30 @@ impl ScalarUDFImpl for GetFieldFunc { | |
| break current_expr; | ||
| }; | ||
|
|
||
| // If no nested get_field calls were found, return original | ||
| if path_args_stack.len() == args.len() - 1 { | ||
| return Ok(ExprSimplifyResult::Original(args)); | ||
| } | ||
| // Whether any nested get_field calls were collapsed above. | ||
| let did_flatten = path_args_stack.len() > 1; | ||
|
|
||
| // If we found any nested get_field calls, flatten them | ||
| // Build merged args: [base, ...all_path_args_in_correct_order] | ||
| // Build merged args: [base, ...all path args in access order]. | ||
| // The stack holds path slices outermost-first, so iterate in reverse. | ||
| let mut merged_args = vec![base_expr.clone()]; | ||
|
|
||
| // Add path args in reverse order (innermost to outermost) | ||
| // Stack is: [outermost_paths, ..., innermost_paths] | ||
| // We want: [base, innermost_paths, ..., outermost_paths] | ||
| for path_slice in path_args_stack.iter().rev() { | ||
| merged_args.extend_from_slice(path_slice); | ||
| } | ||
|
|
||
| Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( | ||
| ScalarFunction::new_udf( | ||
| Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::new())), | ||
| merged_args, | ||
| ), | ||
| ))) | ||
| // Resolve field accesses against an inline struct constructor: | ||
| // get_field(named_struct('min', a, 'max', b), 'max') => b | ||
| if let Some(simplified) = simplify_get_field_over_struct_constructor(&merged_args) | ||
| { | ||
| return Ok(ExprSimplifyResult::Simplified(simplified)); | ||
| } | ||
|
|
||
| if did_flatten { | ||
| return Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction( | ||
| ScalarFunction::new_udf(get_field_udf(), merged_args), | ||
| ))); | ||
| } | ||
|
|
||
| Ok(ExprSimplifyResult::Original(args)) | ||
| } | ||
|
|
||
| fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { | ||
|
|
@@ -828,4 +945,187 @@ mod tests { | |
| let args = vec![ExpressionPlacement::Literal, ExpressionPlacement::Literal]; | ||
| assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace); | ||
| } | ||
|
|
||
| // --- get_field over struct constructor simplification -------------------- | ||
|
|
||
| use datafusion_common::Column; | ||
| use datafusion_expr::simplify::SimplifyContext; | ||
|
|
||
| /// A non-empty string literal expression. | ||
| fn lit_str(s: &str) -> Expr { | ||
| Expr::Literal(ScalarValue::Utf8(Some(s.to_string())), None) | ||
| } | ||
|
|
||
| /// A column reference expression. | ||
| fn col(name: &str) -> Expr { | ||
| Expr::Column(Column::from_name(name)) | ||
| } | ||
|
|
||
| fn scalar_fn(udf: ScalarUDF, args: Vec<Expr>) -> Expr { | ||
| Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args)) | ||
| } | ||
|
|
||
| /// `named_struct(name1, value1, name2, value2, ...)`. | ||
| fn named_struct(pairs: Vec<(&str, Expr)>) -> Expr { | ||
| let args = pairs | ||
| .into_iter() | ||
| .flat_map(|(name, value)| [lit_str(name), value]) | ||
| .collect(); | ||
| scalar_fn(ScalarUDF::new_from_impl(NamedStructFunc::new()), args) | ||
| } | ||
|
|
||
| /// `struct(value0, value1, ...)`. | ||
| fn struct_fn(values: Vec<Expr>) -> Expr { | ||
| scalar_fn(ScalarUDF::new_from_impl(StructFunc::new()), values) | ||
| } | ||
|
|
||
| /// `get_field(args...)`. | ||
| fn get_field(args: Vec<Expr>) -> Expr { | ||
| scalar_fn(ScalarUDF::new_from_impl(GetFieldFunc::new()), args) | ||
| } | ||
|
|
||
| /// Run `GetFieldFunc::simplify` once and return the rewritten expression, | ||
| /// panicking if the input was left unchanged. | ||
| fn simplified(args: Vec<Expr>) -> Expr { | ||
| match GetFieldFunc::new() | ||
| .simplify(args, &SimplifyContext::default()) | ||
| .unwrap() | ||
| { | ||
| ExprSimplifyResult::Simplified(expr) => expr, | ||
| ExprSimplifyResult::Original(args) => { | ||
| panic!("expected the expression to be simplified, got {args:?}") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| /// Assert that `GetFieldFunc::simplify` leaves the arguments unchanged. | ||
| fn assert_not_simplified(args: Vec<Expr>) { | ||
| match GetFieldFunc::new() | ||
| .simplify(args.clone(), &SimplifyContext::default()) | ||
| .unwrap() | ||
| { | ||
| ExprSimplifyResult::Original(unchanged) => assert_eq!(unchanged, args), | ||
| ExprSimplifyResult::Simplified(expr) => { | ||
| panic!("expected no simplification, got {expr:?}") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_named_struct_returns_matching_value() { | ||
| // get_field(named_struct('min', a, 'max', b), 'max') => b | ||
| let args = vec![ | ||
| named_struct(vec![("min", col("a")), ("max", col("b"))]), | ||
| lit_str("max"), | ||
| ]; | ||
| assert_eq!(simplified(args), col("b")); | ||
| } | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small readability nit: this test reconstructs
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in 11a390e — the test now builds |
||
| #[test] | ||
| fn simplify_get_field_named_struct_first_field() { | ||
| // get_field(named_struct('min', a, 'max', b), 'min') => a | ||
| let args = vec![ | ||
| named_struct(vec![("min", col("a")), ("max", col("b"))]), | ||
| lit_str("min"), | ||
| ]; | ||
| assert_eq!(simplified(args), col("a")); | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_named_struct_duplicate_names_picks_first() { | ||
| // Arrow's `column_by_name` resolves to the first match; mirror that. | ||
| let args = vec![ | ||
| named_struct(vec![("k", col("a")), ("k", col("b"))]), | ||
| lit_str("k"), | ||
| ]; | ||
| assert_eq!(simplified(args), col("a")); | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_struct_positional() { | ||
| // get_field(struct(a, b), 'c1') => b | ||
| let args = vec![struct_fn(vec![col("a"), col("b")]), lit_str("c1")]; | ||
| assert_eq!(simplified(args), col("b")); | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_nested_named_struct() { | ||
| // get_field(named_struct('s', named_struct('k', x)), 's', 'k') | ||
| // => get_field(named_struct('k', x), 'k') (first pass) | ||
| // => x (second pass) | ||
| let args = vec![ | ||
| named_struct(vec![("s", named_struct(vec![("k", col("x"))]))]), | ||
| lit_str("s"), | ||
| lit_str("k"), | ||
| ]; | ||
| let first_pass = simplified(args); | ||
| let Expr::ScalarFunction(ScalarFunction { args, .. }) = first_pass else { | ||
| panic!("expected a get_field call after the first pass") | ||
| }; | ||
| assert_eq!(simplified(args), col("x")); | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_flattens_then_resolves_named_struct() { | ||
| // get_field(get_field(named_struct('s', named_struct('k', x)), 's'), 'k') | ||
| // flattens to get_field(named_struct(...), 's', 'k') and resolves 's'. | ||
| let args = vec![ | ||
| get_field(vec![ | ||
| named_struct(vec![("s", named_struct(vec![("k", col("x"))]))]), | ||
| lit_str("s"), | ||
| ]), | ||
| lit_str("k"), | ||
| ]; | ||
| let expected = get_field(vec![named_struct(vec![("k", col("x"))]), lit_str("k")]); | ||
| assert_eq!(simplified(args), expected); | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_dynamic_field_name_left_alone() { | ||
| // A non-literal field name cannot be resolved at plan time. | ||
| let args = vec![named_struct(vec![("a", col("x"))]), col("field_name")]; | ||
| assert_not_simplified(args); | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_null_field_name_left_alone() { | ||
| // A NULL string literal field name resolves to no field, so the | ||
| // `try_as_str().flatten()` guard must leave the expression untouched. | ||
| let null_field_name = Expr::Literal(ScalarValue::Utf8(None), None); | ||
| let args = vec![named_struct(vec![("a", col("x"))]), null_field_name]; | ||
| assert_not_simplified(args); | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_dynamic_struct_name_left_alone() { | ||
| // A non-literal name inside named_struct could shadow the requested | ||
| // field at runtime, so the rewrite must bail out entirely. | ||
| let named_struct_with_dynamic_name = scalar_fn( | ||
| ScalarUDF::new_from_impl(NamedStructFunc::new()), | ||
| vec![col("dynamic_name"), col("x")], | ||
| ); | ||
| let args = vec![named_struct_with_dynamic_name, lit_str("a")]; | ||
| assert_not_simplified(args); | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_missing_field_left_alone() { | ||
| // The named_struct does not produce field 'missing'. | ||
| let args = vec![named_struct(vec![("a", col("x"))]), lit_str("missing")]; | ||
| assert_not_simplified(args); | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_non_canonical_struct_field_left_alone() { | ||
| // 'c01' is not a real field name produced by `struct(...)`. | ||
| let args = vec![struct_fn(vec![col("a"), col("b")]), lit_str("c01")]; | ||
| assert_not_simplified(args); | ||
| } | ||
|
|
||
| #[test] | ||
| fn simplify_get_field_column_base_left_alone() { | ||
| // A plain column base is not a struct constructor. | ||
| let args = vec![col("s"), lit_str("a")]; | ||
| assert_not_simplified(args); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try_as_str().flatten()correctly handlesScalarValue::Utf8(None)here, so the guard is safe. It could still be nice to add a dedicated test for a null literal field name, something likesimplify_get_field_null_field_name_left_alone, just to document the invariant explicitly.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in b2e0978 — added
simplify_get_field_null_field_name_left_alonecovering theScalarValue::Utf8(None)case, asserting the expression is left unchanged.