Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
346 changes: 323 additions & 23 deletions datafusion/functions/src/core/getfield.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::sync::Arc;
use std::sync::{Arc, OnceLock};

use arrow::array::{
Array, BooleanArray, Capacities, MutableArrayData, Scalar, cast::AsArray, make_array,
Expand All @@ -37,6 +37,9 @@ use datafusion_expr::{
};
use datafusion_macros::user_doc;

use super::named_struct::NamedStructFunc;
use super::r#struct::StructFunc;

#[user_doc(
doc_section(label = "Other Functions"),
description = r#"Returns a field within a map or a struct with the given key.
Expand Down Expand Up @@ -249,6 +252,120 @@ fn extract_single_field(base: ColumnarValue, name: ScalarValue) -> Result<Column
}
}

/// The shared `get_field` UDF, reused whenever simplification needs to build a
/// fresh `get_field` node (e.g. re-wrapping the remaining access path).
fn get_field_udf() -> Arc<ScalarUDF> {
static GET_FIELD_UDF: OnceLock<Arc<ScalarUDF>> = OnceLock::new();
Arc::clone(
GET_FIELD_UDF
.get_or_init(|| Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::new()))),
)
}

/// Try to simplify a `get_field` call whose base is an inline struct
/// constructor by resolving the field access at plan time.
///
/// Handles both struct constructors:
/// * `named_struct('a', x, 'b', y)` — fields are looked up by name.
/// * `struct(x, y)` — fields are positional and named `c0`, `c1`, ...
///
/// For example:
/// * `get_field(named_struct('min', a, 'max', b), 'max')` => `b`
/// * `get_field(struct(a, b), 'c1')` => `b`
///
/// `args` is the (already flattened) argument list of the `get_field` call:
/// `[base, field_name, rest_of_path...]`. When extra path elements remain
/// after resolving the first one (`get_field(named_struct('s', inner), 's', 'k')`),
/// the resolved value is re-wrapped in a `get_field` call for the remaining
/// path so the simplifier can recurse into it on the next pass.
///
/// Returns `None` — leaving the expression untouched — whenever the rewrite
/// cannot be proven safe, e.g. a non-literal field name, a `named_struct`
/// with a non-literal field name (which might shadow the requested field at
/// runtime), or a field the constructor does not produce.
///
/// Replacing the access with the selected field expression drops the
/// expressions for the other (unaccessed) fields, so they are no longer
/// evaluated — e.g. `get_field(named_struct('a', 1/0, 'b', x), 'b')` becomes
/// `x` and the `1/0` is never evaluated. This is intentional and matches the
/// optimizer's contract for immutable expressions: a simplification may drop
/// sub-expressions whose value is not observed.
fn simplify_get_field_over_struct_constructor(args: &[Expr]) -> Option<Expr> {
let [base, field_name, rest @ ..] = args else {
return None;
};

// The accessed field name must be a non-empty string literal.
let Expr::Literal(field_name, _) = field_name else {
return None;
};
let field_name = field_name
.try_as_str()
.flatten()
.filter(|s| !s.is_empty())?;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try_as_str().flatten() correctly handles ScalarValue::Utf8(None) here, so the guard is safe. It could still be nice to add a dedicated test for a null literal field name, something like simplify_get_field_null_field_name_left_alone, just to document the invariant explicitly.

Copy link
Copy Markdown
Contributor Author

@adriangb adriangb May 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in b2e0978 — added simplify_get_field_null_field_name_left_alone covering the ScalarValue::Utf8(None) case, asserting the expression is left unchanged.

let Expr::ScalarFunction(ScalarFunction {
func,
args: ctor_args,
}) = base
else {
return None;
};

let value = if func.inner().is::<NamedStructFunc>() {
// named_struct(name1, value1, name2, value2, ...)
if !ctor_args.len().is_multiple_of(2) {
return None;
}
let mut matched = None;
for pair in ctor_args.chunks_exact(2) {
// Every name must be a literal string: a non-literal name appearing
// *before* the first match could evaluate to `field_name` at runtime
// and become the real first match (Arrow's `column_by_name` returns
// the first match), so we cannot resolve the access.
//
// We conservatively bail on *any* non-literal name. Once a literal
// match has been found, a later non-literal name is in fact harmless
// — it can never precede the first match — so bailing there is a
// deliberate approximation we accept to keep this check simple, not a
// correctness requirement.
let Expr::Literal(name, _) = &pair[0] else {
return None;
};
let name = name.try_as_str().flatten()?;
// `column_by_name` resolves to the first match, so do the same.
if matched.is_none() && name == field_name {
matched = Some(&pair[1]);
}
}
matched?.clone()
} else if func.inner().is::<StructFunc>() {
// struct(value0, value1, ...) produces fields named c0, c1, ...
let index: usize = field_name.strip_prefix('c')?.parse().ok()?;
// Reject non-canonical spellings (e.g. "c01") that name no real field.
if format!("c{index}") != field_name {
return None;
}
ctor_args.get(index)?.clone()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The early return None on a dynamic field name makes sense because a later runtime name could shadow an earlier literal match. One subtle detail is that this also prevents simplification even when a matching literal field was already seen earlier in the loop. It may be worth calling out in the comment that this is a deliberate conservative approximation rather than an accidental missed optimization.

Copy link
Copy Markdown
Contributor Author

@adriangb adriangb May 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — expanded the comment to call out that bailing on a non-literal field name is a deliberate conservative approximation rather than an accidental missed optimization.

To be precise about the rationale: a non-literal name appearing after the first literal match can't change the result, since column_by_name returns the first match. The bail is only strictly required when the non-literal name appears before the first match (it could shadow it at runtime); applying it unconditionally is the conservative-for-simplicity choice. Reworded the comment accordingly in f799333 (force-pushed; new branch tip b2e0978).

} else {
return None;
};

if rest.is_empty() {
return Some(value);
}

// Remaining path elements: re-wrap as get_field(value, rest...) and let
// the simplifier resolve the rest on a subsequent pass.
let mut new_args = Vec::with_capacity(rest.len() + 1);
new_args.push(value);
new_args.extend_from_slice(rest);
Some(Expr::ScalarFunction(ScalarFunction::new_udf(
get_field_udf(),
new_args,
)))
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The re-wrap path creates a fresh Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::new())) each time it builds an intermediate get_field node. Since the same construction already appears elsewhere in simplify, it might be nice to centralize this behind a small helper like get_field_udf() or a shared OnceLock. Not a big deal performance-wise, but it would make the intent and reuse a bit clearer.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in cd20ebb — added a get_field_udf() helper backed by a OnceLock and routed both construction sites (the re-wrap path here and the flatten path in simplify) through it.


impl GetFieldFunc {
pub fn new() -> Self {
Self {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This simplification can eliminate evaluation of struct fields that are never accessed, for example get_field(named_struct('a', 1/0, 'b', x), 'b') simplifying down to x. That matches the existing immutable-expression optimizer contract, but I think a short doc comment here would help future readers understand that unused field expressions are intentionally not evaluated after simplification.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 4f70d80 — added a paragraph to the function doc comment noting that replacing the access drops the other fields' expressions so they are no longer evaluated (e.g. get_field(named_struct('a', 1/0, 'b', x), 'b') => x), which matches the optimizer's immutable-expression contract.

Expand Down Expand Up @@ -479,14 +596,12 @@ impl ScalarUDFImpl for GetFieldFunc {

// Flatten all nested get_field calls in a single pass
// Pattern: get_field(get_field(get_field(base, a), b), c) => get_field(base, a, b, c)

// Collect path arguments from all nested levels
let mut path_args_stack = Vec::new();
//
// `path_args_stack` collects each level's field-name arguments,
// outermost first; it is reversed below to restore access order.
let mut path_args_stack = vec![&args[1..]];
let mut current_expr = &args[0];

// Push the outermost path arguments first
path_args_stack.push(&args[1..]);

// Walk down the chain of nested get_field calls
let base_expr = loop {
if let Expr::ScalarFunction(ScalarFunction {
Expand All @@ -506,28 +621,30 @@ impl ScalarUDFImpl for GetFieldFunc {
break current_expr;
};

// If no nested get_field calls were found, return original
if path_args_stack.len() == args.len() - 1 {
return Ok(ExprSimplifyResult::Original(args));
}
// Whether any nested get_field calls were collapsed above.
let did_flatten = path_args_stack.len() > 1;

// If we found any nested get_field calls, flatten them
// Build merged args: [base, ...all_path_args_in_correct_order]
// Build merged args: [base, ...all path args in access order].
// The stack holds path slices outermost-first, so iterate in reverse.
let mut merged_args = vec![base_expr.clone()];

// Add path args in reverse order (innermost to outermost)
// Stack is: [outermost_paths, ..., innermost_paths]
// We want: [base, innermost_paths, ..., outermost_paths]
for path_slice in path_args_stack.iter().rev() {
merged_args.extend_from_slice(path_slice);
}

Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
ScalarFunction::new_udf(
Arc::new(ScalarUDF::new_from_impl(GetFieldFunc::new())),
merged_args,
),
)))
// Resolve field accesses against an inline struct constructor:
// get_field(named_struct('min', a, 'max', b), 'max') => b
if let Some(simplified) = simplify_get_field_over_struct_constructor(&merged_args)
{
return Ok(ExprSimplifyResult::Simplified(simplified));
}

if did_flatten {
return Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
ScalarFunction::new_udf(get_field_udf(), merged_args),
)));
}

Ok(ExprSimplifyResult::Original(args))
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
Expand Down Expand Up @@ -828,4 +945,187 @@ mod tests {
let args = vec![ExpressionPlacement::Literal, ExpressionPlacement::Literal];
assert_eq!(func.placement(&args), ExpressionPlacement::KeepInPlace);
}

// --- get_field over struct constructor simplification --------------------

use datafusion_common::Column;
use datafusion_expr::simplify::SimplifyContext;

/// A non-empty string literal expression.
fn lit_str(s: &str) -> Expr {
Expr::Literal(ScalarValue::Utf8(Some(s.to_string())), None)
}

/// A column reference expression.
fn col(name: &str) -> Expr {
Expr::Column(Column::from_name(name))
}

fn scalar_fn(udf: ScalarUDF, args: Vec<Expr>) -> Expr {
Expr::ScalarFunction(ScalarFunction::new_udf(Arc::new(udf), args))
}

/// `named_struct(name1, value1, name2, value2, ...)`.
fn named_struct(pairs: Vec<(&str, Expr)>) -> Expr {
let args = pairs
.into_iter()
.flat_map(|(name, value)| [lit_str(name), value])
.collect();
scalar_fn(ScalarUDF::new_from_impl(NamedStructFunc::new()), args)
}

/// `struct(value0, value1, ...)`.
fn struct_fn(values: Vec<Expr>) -> Expr {
scalar_fn(ScalarUDF::new_from_impl(StructFunc::new()), values)
}

/// `get_field(args...)`.
fn get_field(args: Vec<Expr>) -> Expr {
scalar_fn(ScalarUDF::new_from_impl(GetFieldFunc::new()), args)
}

/// Run `GetFieldFunc::simplify` once and return the rewritten expression,
/// panicking if the input was left unchanged.
fn simplified(args: Vec<Expr>) -> Expr {
match GetFieldFunc::new()
.simplify(args, &SimplifyContext::default())
.unwrap()
{
ExprSimplifyResult::Simplified(expr) => expr,
ExprSimplifyResult::Original(args) => {
panic!("expected the expression to be simplified, got {args:?}")
}
}
}

/// Assert that `GetFieldFunc::simplify` leaves the arguments unchanged.
fn assert_not_simplified(args: Vec<Expr>) {
match GetFieldFunc::new()
.simplify(args.clone(), &SimplifyContext::default())
.unwrap()
{
ExprSimplifyResult::Original(unchanged) => assert_eq!(unchanged, args),
ExprSimplifyResult::Simplified(expr) => {
panic!("expected no simplification, got {expr:?}")
}
}
}

#[test]
fn simplify_get_field_named_struct_returns_matching_value() {
// get_field(named_struct('min', a, 'max', b), 'max') => b
let args = vec![
named_struct(vec![("min", col("a")), ("max", col("b"))]),
lit_str("max"),
];
assert_eq!(simplified(args), col("b"));
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small readability nit: this test reconstructs get_field(...), immediately destructures it back into args, and then calls simplified(args). The neighboring tests already pass args directly, which feels a bit easier to follow. It may be worth switching this one to the same style for consistency.

Copy link
Copy Markdown
Contributor Author

@adriangb adriangb May 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 11a390e — the test now builds args directly and passes it to simplified(args), matching the neighboring tests.

#[test]
fn simplify_get_field_named_struct_first_field() {
// get_field(named_struct('min', a, 'max', b), 'min') => a
let args = vec![
named_struct(vec![("min", col("a")), ("max", col("b"))]),
lit_str("min"),
];
assert_eq!(simplified(args), col("a"));
}

#[test]
fn simplify_get_field_named_struct_duplicate_names_picks_first() {
// Arrow's `column_by_name` resolves to the first match; mirror that.
let args = vec![
named_struct(vec![("k", col("a")), ("k", col("b"))]),
lit_str("k"),
];
assert_eq!(simplified(args), col("a"));
}

#[test]
fn simplify_get_field_struct_positional() {
// get_field(struct(a, b), 'c1') => b
let args = vec![struct_fn(vec![col("a"), col("b")]), lit_str("c1")];
assert_eq!(simplified(args), col("b"));
}

#[test]
fn simplify_get_field_nested_named_struct() {
// get_field(named_struct('s', named_struct('k', x)), 's', 'k')
// => get_field(named_struct('k', x), 'k') (first pass)
// => x (second pass)
let args = vec![
named_struct(vec![("s", named_struct(vec![("k", col("x"))]))]),
lit_str("s"),
lit_str("k"),
];
let first_pass = simplified(args);
let Expr::ScalarFunction(ScalarFunction { args, .. }) = first_pass else {
panic!("expected a get_field call after the first pass")
};
assert_eq!(simplified(args), col("x"));
}

#[test]
fn simplify_get_field_flattens_then_resolves_named_struct() {
// get_field(get_field(named_struct('s', named_struct('k', x)), 's'), 'k')
// flattens to get_field(named_struct(...), 's', 'k') and resolves 's'.
let args = vec![
get_field(vec![
named_struct(vec![("s", named_struct(vec![("k", col("x"))]))]),
lit_str("s"),
]),
lit_str("k"),
];
let expected = get_field(vec![named_struct(vec![("k", col("x"))]), lit_str("k")]);
assert_eq!(simplified(args), expected);
}

#[test]
fn simplify_get_field_dynamic_field_name_left_alone() {
// A non-literal field name cannot be resolved at plan time.
let args = vec![named_struct(vec![("a", col("x"))]), col("field_name")];
assert_not_simplified(args);
}

#[test]
fn simplify_get_field_null_field_name_left_alone() {
// A NULL string literal field name resolves to no field, so the
// `try_as_str().flatten()` guard must leave the expression untouched.
let null_field_name = Expr::Literal(ScalarValue::Utf8(None), None);
let args = vec![named_struct(vec![("a", col("x"))]), null_field_name];
assert_not_simplified(args);
}

#[test]
fn simplify_get_field_dynamic_struct_name_left_alone() {
// A non-literal name inside named_struct could shadow the requested
// field at runtime, so the rewrite must bail out entirely.
let named_struct_with_dynamic_name = scalar_fn(
ScalarUDF::new_from_impl(NamedStructFunc::new()),
vec![col("dynamic_name"), col("x")],
);
let args = vec![named_struct_with_dynamic_name, lit_str("a")];
assert_not_simplified(args);
}

#[test]
fn simplify_get_field_missing_field_left_alone() {
// The named_struct does not produce field 'missing'.
let args = vec![named_struct(vec![("a", col("x"))]), lit_str("missing")];
assert_not_simplified(args);
}

#[test]
fn simplify_get_field_non_canonical_struct_field_left_alone() {
// 'c01' is not a real field name produced by `struct(...)`.
let args = vec![struct_fn(vec![col("a"), col("b")]), lit_str("c01")];
assert_not_simplified(args);
}

#[test]
fn simplify_get_field_column_base_left_alone() {
// A plain column base is not a struct constructor.
let args = vec![col("s"), lit_str("a")];
assert_not_simplified(args);
}
}
Loading
Loading