diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index ddeb9b0870a1..b709317ca1d7 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -51,6 +51,7 @@ use datafusion_functions_aggregate_common::utils::DecimalAverager; use datafusion_macros::user_doc; use log::debug; use std::fmt::Debug; +use std::marker::PhantomData; use std::mem::{size_of, size_of_val}; use std::sync::Arc; @@ -125,6 +126,18 @@ impl Default for Avg { } } +fn avg_sum_data_type(data_type: &DataType) -> DataType { + match data_type { + DataType::Decimal32(_, scale) => { + DataType::Decimal64(DECIMAL64_MAX_PRECISION, *scale) + } + DataType::Decimal64(_, scale) => { + DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale) + } + data_type => data_type.clone(), + } +} + impl AggregateUDFImpl for Avg { fn name(&self) -> &str { "avg" @@ -178,7 +191,6 @@ impl AggregateUDFImpl for Avg { match (data_type, acc_args.return_type()) { // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())), - ( Decimal32(_, scale), Decimal32(target_precision, target_scale), @@ -187,7 +199,7 @@ impl AggregateUDFImpl for Avg { *target_precision, *target_scale, ))), - ( + ( Decimal64(_, scale), Decimal64(target_precision, target_scale), ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( @@ -203,7 +215,6 @@ impl AggregateUDFImpl for Avg { *target_precision, *target_scale, ))), - ( Decimal256(_, scale), Decimal256(target_precision, target_scale), @@ -212,7 +223,6 @@ impl AggregateUDFImpl for Avg { *target_precision, *target_scale, ))), - (dt, return_type) => exec_err!( "AVG(DISTINCT) for ({} --> {}) not supported", dt, @@ -223,50 +233,48 @@ impl AggregateUDFImpl for Avg { match (&data_type, acc_args.return_type()) { (Float64, Float64) => Ok(Box::::default()), ( - Decimal32(sum_precision, sum_scale), + Decimal32(_sum_precision, sum_scale), Decimal32(target_precision, target_scale), - ) => Ok(Box::new(DecimalAvgAccumulator:: { - sum: None, - count: 0, - sum_scale: *sum_scale, - sum_precision: *sum_precision, - target_precision: *target_precision, - target_scale: *target_scale, - })), + ) => Ok(Box::new(DecimalAvgAccumulator::< + Decimal32Type, + Decimal64Type, + >::new( + *sum_scale, + DECIMAL64_MAX_PRECISION, + *target_precision, + *target_scale, + ))), ( - Decimal64(sum_precision, sum_scale), + Decimal64(_sum_precision, sum_scale), Decimal64(target_precision, target_scale), - ) => Ok(Box::new(DecimalAvgAccumulator:: { - sum: None, - count: 0, - sum_scale: *sum_scale, - sum_precision: *sum_precision, - target_precision: *target_precision, - target_scale: *target_scale, - })), + ) => Ok(Box::new(DecimalAvgAccumulator::< + Decimal64Type, + Decimal128Type, + >::new( + *sum_scale, + DECIMAL128_MAX_PRECISION, + *target_precision, + *target_scale, + ))), ( Decimal128(sum_precision, sum_scale), Decimal128(target_precision, target_scale), - ) => Ok(Box::new(DecimalAvgAccumulator:: { - sum: None, - count: 0, - sum_scale: *sum_scale, - sum_precision: *sum_precision, - target_precision: *target_precision, - target_scale: *target_scale, - })), + ) => Ok(Box::new(DecimalAvgAccumulator::::new( + *sum_scale, + *sum_precision, + *target_precision, + *target_scale, + ))), ( Decimal256(sum_precision, sum_scale), Decimal256(target_precision, target_scale), - ) => Ok(Box::new(DecimalAvgAccumulator:: { - sum: None, - count: 0, - sum_scale: *sum_scale, - sum_precision: *sum_precision, - target_precision: *target_precision, - target_scale: *target_scale, - })), + ) => Ok(Box::new(DecimalAvgAccumulator::::new( + *sum_scale, + *sum_precision, + *target_precision, + *target_scale, + ))), (Duration(time_unit), Duration(result_unit)) => { Ok(Box::new(DurationAvgAccumulator { @@ -314,17 +322,14 @@ impl AggregateUDFImpl for Avg { .into(), ]) } else { + let sum_data_type = avg_sum_data_type(args.input_fields[0].data_type()); Ok(vec![ Field::new( format_state_name(args.name, "count"), DataType::UInt64, true, ), - Field::new( - format_state_name(args.name, "sum"), - args.input_fields[0].data_type().clone(), - true, - ), + Field::new(format_state_name(args.name, "sum"), sum_data_type, true), ] .into_iter() .map(Arc::new) @@ -356,7 +361,7 @@ impl AggregateUDFImpl for Avg { match (data_type, args.return_field.data_type()) { (Float64, Float64) => { Ok(Box::new(AvgGroupsAccumulator::::new( - data_type, + data_type.clone(), args.return_field.data_type(), |sum: f64, count: u64| Ok(sum / count as f64), ))) @@ -365,17 +370,27 @@ impl AggregateUDFImpl for Avg { Decimal32(_sum_precision, sum_scale), Decimal32(target_precision, target_scale), ) => { - let decimal_averager = DecimalAverager::::try_new( + let decimal_averager = DecimalAverager::::try_new( *sum_scale, *target_precision, *target_scale, )?; - let avg_fn = - move |sum: i32, count: u64| decimal_averager.avg(sum, count as i32); + let avg_fn = move |sum: i64, count: u64| { + let avg = decimal_averager.avg(sum, count as i64)?; + if let Ok(avg) = i32::try_from(avg) { + Ok(avg) + } else { + exec_err!("Arithmetic Overflow in AvgAccumulator") + } + }; - Ok(Box::new(AvgGroupsAccumulator::::new( - data_type, + Ok(Box::new(AvgGroupsAccumulator::< + Decimal32Type, + _, + Decimal64Type, + >::new( + avg_sum_data_type(data_type), args.return_field.data_type(), avg_fn, ))) @@ -384,17 +399,27 @@ impl AggregateUDFImpl for Avg { Decimal64(_sum_precision, sum_scale), Decimal64(target_precision, target_scale), ) => { - let decimal_averager = DecimalAverager::::try_new( + let decimal_averager = DecimalAverager::::try_new( *sum_scale, *target_precision, *target_scale, )?; - let avg_fn = - move |sum: i64, count: u64| decimal_averager.avg(sum, count as i64); + let avg_fn = move |sum: i128, count: u64| { + let avg = decimal_averager.avg(sum, count as i128)?; + if let Ok(avg) = i64::try_from(avg) { + Ok(avg) + } else { + exec_err!("Arithmetic Overflow in AvgAccumulator") + } + }; - Ok(Box::new(AvgGroupsAccumulator::::new( - data_type, + Ok(Box::new(AvgGroupsAccumulator::< + Decimal64Type, + _, + Decimal128Type, + >::new( + avg_sum_data_type(data_type), args.return_field.data_type(), avg_fn, ))) @@ -413,7 +438,7 @@ impl AggregateUDFImpl for Avg { move |sum: i128, count: u64| decimal_averager.avg(sum, count as i128); Ok(Box::new(AvgGroupsAccumulator::::new( - data_type, + data_type.clone(), args.return_field.data_type(), avg_fn, ))) @@ -434,7 +459,7 @@ impl AggregateUDFImpl for Avg { }; Ok(Box::new(AvgGroupsAccumulator::::new( - data_type, + data_type.clone(), args.return_field.data_type(), avg_fn, ))) @@ -448,7 +473,7 @@ impl AggregateUDFImpl for Avg { DurationSecondType, _, >::new( - data_type, + data_type.clone(), args.return_type(), avg_fn, ))), @@ -456,7 +481,7 @@ impl AggregateUDFImpl for Avg { DurationMillisecondType, _, >::new( - data_type, + data_type.clone(), args.return_type(), avg_fn, ))), @@ -464,7 +489,7 @@ impl AggregateUDFImpl for Avg { DurationMicrosecondType, _, >::new( - data_type, + data_type.clone(), args.return_type(), avg_fn, ))), @@ -472,7 +497,7 @@ impl AggregateUDFImpl for Avg { DurationNanosecondType, _, >::new( - data_type, + data_type.clone(), args.return_type(), avg_fn, ))), @@ -567,24 +592,83 @@ impl Accumulator for AvgAccumulator { } } -/// An accumulator to compute the average for decimals +/// An accumulator to compute the average for decimals. +/// +/// `I` is the input (and output) decimal type. `S` is a possibly wider decimal +/// type used to accumulate the sum so the running total does not overflow +/// (e.g. `Decimal32` values are summed as `Decimal64`). #[derive(Debug)] -struct DecimalAvgAccumulator { - sum: Option, +struct DecimalAvgAccumulator +where + I: DecimalType + ArrowNumericType + Debug, + S: DecimalType + ArrowNumericType + Debug, + I::Native: Into, +{ + sum: Option, count: u64, sum_scale: i8, sum_precision: u8, target_precision: u8, target_scale: i8, + _phantom: PhantomData, } -impl Accumulator for DecimalAvgAccumulator { +impl DecimalAvgAccumulator +where + I: DecimalType + ArrowNumericType + Debug, + S: DecimalType + ArrowNumericType + Debug, + I::Native: Into, +{ + fn new( + sum_scale: i8, + sum_precision: u8, + target_precision: u8, + target_scale: i8, + ) -> Self { + Self { + sum: None, + count: 0, + sum_scale, + sum_precision, + target_precision, + target_scale, + _phantom: PhantomData, + } + } +} + +fn decimal_sum_as(values: &PrimitiveArray) -> Option +where + I: DecimalType + ArrowNumericType, + S: DecimalType + ArrowNumericType, + I::Native: Into, +{ + let mut sum: Option = None; + + for value in values.iter().flatten() { + let value = value.into(); + sum = Some(match sum { + Some(sum) => sum.add_wrapping(value), + None => value, + }); + } + + sum +} + +impl Accumulator for DecimalAvgAccumulator +where + I: DecimalType + ArrowNumericType + Debug, + S: DecimalType + ArrowNumericType + Debug, + I::Native: Into, + S::Native: TryInto, +{ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); + let values = values[0].as_primitive::(); self.count += (values.len() - values.null_count()) as u64; - if let Some(x) = sum(values) { - let v = self.sum.get_or_insert_with(T::Native::default); + if let Some(x) = decimal_sum_as::(values) { + let v = self.sum.get_or_insert_with(S::Native::default); self.sum = Some(v.add_wrapping(x)); } Ok(()) @@ -599,19 +683,24 @@ impl Accumulator for DecimalAvgAccumu } else { self.sum .map(|v| { - DecimalAverager::::try_new( + let avg = DecimalAverager::::try_new( self.sum_scale, self.target_precision, self.target_scale, )? - .avg(v, T::Native::from_usize(self.count as usize).unwrap()) + .avg(v, S::Native::from_usize(self.count as usize).unwrap())?; + if let Ok(avg) = avg.try_into() { + Ok(avg) + } else { + exec_err!("Arithmetic Overflow in AvgAccumulator") + } }) .transpose()? }; - ScalarValue::new_primitive::( + ScalarValue::new_primitive::( v, - &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), + &I::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), ) } @@ -622,9 +711,9 @@ impl Accumulator for DecimalAvgAccumu fn state(&mut self) -> Result> { Ok(vec![ ScalarValue::from(self.count), - ScalarValue::new_primitive::( + ScalarValue::new_primitive::( self.sum, - &T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale), + &S::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale), )?, ]) } @@ -634,16 +723,16 @@ impl Accumulator for DecimalAvgAccumu self.count += sum(states[0].as_primitive::()).unwrap_or_default(); // sums are summed - if let Some(x) = sum(states[1].as_primitive::()) { - let v = self.sum.get_or_insert_with(T::Native::default); + if let Some(x) = sum(states[1].as_primitive::()) { + let v = self.sum.get_or_insert_with(S::Native::default); self.sum = Some(v.add_wrapping(x)); } Ok(()) } fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = values[0].as_primitive::(); + let values = values[0].as_primitive::(); self.count -= (values.len() - values.null_count()) as u64; - if let Some(x) = sum(values) { + if let Some(x) = decimal_sum_as::(values) { self.sum = Some(self.sum.unwrap().sub_wrapping(x)); } Ok(()) @@ -764,12 +853,17 @@ impl Accumulator for DurationAvgAccumulator { /// Stores values as native types, and does overflow checking /// /// F: Function that calculates the average value from a sum of -/// T::Native and a total count +/// S::Native and a total count +/// +/// `I` is the input (and output) type. `S` is a possibly wider type used to +/// accumulate the sum so it does not overflow. #[derive(Debug)] -struct AvgGroupsAccumulator +struct AvgGroupsAccumulator where - T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send + 'static, + I: ArrowNumericType + Send, + S: ArrowNumericType + Send, + I::Native: Into, + F: Fn(S::Native, u64) -> Result + Send + 'static, { /// The type of the internal sum sum_data_type: DataType, @@ -781,41 +875,53 @@ where counts: Vec, /// Sums per group, stored as the native type - sums: Vec, + sums: Vec, /// Track nulls in the input / filters null_state: NullState, /// Function that computes the final average (value / count) avg_fn: F, + + _phantom: PhantomData, } -impl AvgGroupsAccumulator +impl AvgGroupsAccumulator where - T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send + 'static, + I: ArrowNumericType + Send, + S: ArrowNumericType + Send, + I::Native: Into, + F: Fn(S::Native, u64) -> Result + Send + 'static, { - pub fn new(sum_data_type: &DataType, return_data_type: &DataType, avg_fn: F) -> Self { + pub fn new( + sum_data_type: impl Into, + return_data_type: &DataType, + avg_fn: F, + ) -> Self { + let sum_data_type = sum_data_type.into(); debug!( "AvgGroupsAccumulator ({}, sum type: {sum_data_type}) --> {return_data_type}", - std::any::type_name::() + std::any::type_name::() ); Self { return_data_type: return_data_type.clone(), - sum_data_type: sum_data_type.clone(), + sum_data_type, counts: vec![], sums: vec![], null_state: NullState::new(), avg_fn, + _phantom: PhantomData, } } } -impl GroupsAccumulator for AvgGroupsAccumulator +impl GroupsAccumulator for AvgGroupsAccumulator where - T: ArrowNumericType + Send, - F: Fn(T::Native, u64) -> Result + Send + 'static, + I: ArrowNumericType + Send, + S: ArrowNumericType + Send, + I::Native: Into, + F: Fn(S::Native, u64) -> Result + Send + 'static, { fn update_batch( &mut self, @@ -825,11 +931,11 @@ where total_num_groups: usize, ) -> Result<()> { assert_eq!(values.len(), 1, "single argument to update_batch"); - let values = values[0].as_primitive::(); + let values = values[0].as_primitive::(); // increment counts, update sums self.counts.resize(total_num_groups, 0); - self.sums.resize(total_num_groups, T::default_value()); + self.sums.resize(total_num_groups, S::default_value()); self.null_state.accumulate( group_indices, values, @@ -838,7 +944,7 @@ where |group_index, new_value| { // SAFETY: group_index is guaranteed to be in bounds let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; - *sum = sum.add_wrapping(new_value); + *sum = sum.add_wrapping(new_value.into()); self.counts[group_index] += 1; }, @@ -859,10 +965,10 @@ where // don't evaluate averages with null inputs to avoid errors on null values - let array: PrimitiveArray = if let Some(nulls) = &nulls + let array: PrimitiveArray = if let Some(nulls) = &nulls && nulls.null_count() > 0 { - let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()) + let mut builder = PrimitiveBuilder::::with_capacity(nulls.len()) .with_data_type(self.return_data_type.clone()); let iter = sums.into_iter().zip(counts).zip(nulls.iter()); @@ -875,7 +981,7 @@ where } builder.finish() } else { - let averages: Vec = sums + let averages: Vec = sums .into_iter() .zip(counts) .map(|(sum, count)| (self.avg_fn)(sum, count)) @@ -895,7 +1001,7 @@ where let counts = UInt64Array::new(counts.into(), nulls.clone()); // zero copy let sums = emit_to.take_needed(&mut self.sums); - let sums = PrimitiveArray::::new(sums.into(), nulls) // zero copy + let sums = PrimitiveArray::::new(sums.into(), nulls) // zero copy .with_data_type(self.sum_data_type.clone()); Ok(vec![ @@ -914,7 +1020,7 @@ where assert_eq!(values.len(), 2, "two arguments to merge_batch"); // first batch is counts, second is partial sums let partial_counts = values[0].as_primitive::(); - let partial_sums = values[1].as_primitive::(); + let partial_sums = values[1].as_primitive::(); // update counts with partial counts self.counts.resize(total_num_groups, 0); self.null_state.accumulate( @@ -930,13 +1036,13 @@ where ); // update sums - self.sums.resize(total_num_groups, T::default_value()); + self.sums.resize(total_num_groups, S::default_value()); self.null_state.accumulate( group_indices, partial_sums, opt_filter, total_num_groups, - |group_index, new_value: ::Native| { + |group_index, new_value: ::Native| { // SAFETY: group_index is guaranteed to be in bounds let sum = unsafe { self.sums.get_unchecked_mut(group_index) }; *sum = sum.add_wrapping(new_value); @@ -951,10 +1057,17 @@ where values: &[ArrayRef], opt_filter: Option<&BooleanArray>, ) -> Result> { - let sums = values[0] - .as_primitive::() - .clone() + let values = values[0].as_primitive::(); + let mut sums = PrimitiveBuilder::::with_capacity(values.len()) .with_data_type(self.sum_data_type.clone()); + for value in values.iter() { + if let Some(value) = value { + sums.append_value(value.into()); + } else { + sums.append_null(); + } + } + let sums = sums.finish(); let counts = UInt64Array::from_value(1, sums.len()); let nulls = filtered_null_mask(opt_filter, &sums); @@ -971,6 +1084,253 @@ where } fn size(&self) -> usize { - self.counts.capacity() * size_of::() + self.sums.capacity() * size_of::() + self.counts.capacity() * size_of::() + + self.sums.capacity() * size_of::() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{ + Decimal32Array, Decimal64Array, Decimal128Array, Decimal256Array, + DurationMicrosecondArray, DurationMillisecondArray, DurationNanosecondArray, + DurationSecondArray, Float64Array, + }; + use arrow::datatypes::Schema; + + struct AvgCase { + name: &'static str, + values: ArrayRef, + return_type: DataType, + sum_type: DataType, + expected: ScalarValue, + } + + fn avg_groups_accumulator( + input_type: &DataType, + return_type: &DataType, + ) -> Result> { + let schema = Schema::empty(); + let expr_field = Arc::new(Field::new("a", input_type.clone(), true)); + let return_field = Arc::new(Field::new("avg", return_type.clone(), true)); + + Avg::new().create_groups_accumulator(AccumulatorArgs { + return_field, + schema: &schema, + expr_fields: &[expr_field], + ignore_nulls: false, + order_bys: &[], + is_distinct: false, + name: "avg", + is_reversed: false, + exprs: &[], + }) + } + + fn avg_accumulator( + input_type: &DataType, + return_type: &DataType, + ) -> Result> { + let schema = Schema::empty(); + let expr_field = Arc::new(Field::new("a", input_type.clone(), true)); + let return_field = Arc::new(Field::new("avg", return_type.clone(), true)); + + Avg::new().accumulator(AccumulatorArgs { + return_field, + schema: &schema, + expr_fields: &[expr_field], + ignore_nulls: false, + order_bys: &[], + is_distinct: false, + name: "avg", + is_reversed: false, + exprs: &[], + }) + } + + fn avg_state_fields( + input_type: &DataType, + return_type: &DataType, + ) -> Result> { + let input_field = Arc::new(Field::new("a", input_type.clone(), true)); + let return_field = Arc::new(Field::new("avg", return_type.clone(), true)); + + Avg::new().state_fields(StateFieldsArgs { + name: "avg", + input_fields: &[input_field], + return_field, + ordering_fields: &[], + is_distinct: false, + }) + } + + fn avg_cases() -> Result> { + const ROWS: usize = 21_476; + const DECIMAL32_VALUE: i32 = 99_999; + const DECIMAL64_ROWS: usize = 92_235; + const DECIMAL64_VALUE: i64 = 99_999_999_999; + + Ok(vec![ + AvgCase { + name: "float64", + values: Arc::new(Float64Array::from(vec![10.0, 20.0])), + return_type: DataType::Float64, + sum_type: DataType::Float64, + expected: ScalarValue::Float64(Some(15.0)), + }, + AvgCase { + name: "decimal32", + values: Arc::new( + Decimal32Array::from(vec![Some(DECIMAL32_VALUE); ROWS]) + .with_precision_and_scale(5, 0)?, + ), + return_type: DataType::Decimal32(9, 4), + sum_type: DataType::Decimal64(18, 0), + expected: ScalarValue::Decimal32(Some(DECIMAL32_VALUE * 10_000), 9, 4), + }, + AvgCase { + name: "decimal64", + values: Arc::new( + Decimal64Array::from(vec![Some(DECIMAL64_VALUE); DECIMAL64_ROWS]) + .with_precision_and_scale(11, 0)?, + ), + return_type: DataType::Decimal64(15, 4), + sum_type: DataType::Decimal128(38, 0), + expected: ScalarValue::Decimal64(Some(DECIMAL64_VALUE * 10_000), 15, 4), + }, + AvgCase { + name: "decimal128", + values: Arc::new( + Decimal128Array::from(vec![10_i128, 20_i128]) + .with_precision_and_scale(20, 0)?, + ), + return_type: DataType::Decimal128(24, 4), + sum_type: DataType::Decimal128(20, 0), + expected: ScalarValue::Decimal128(Some(150_000), 24, 4), + }, + AvgCase { + name: "decimal256", + values: Arc::new( + Decimal256Array::from(vec![i256::from_i128(10), i256::from_i128(20)]) + .with_precision_and_scale(50, 0)?, + ), + return_type: DataType::Decimal256(54, 4), + sum_type: DataType::Decimal256(50, 0), + expected: ScalarValue::Decimal256(Some(i256::from_i128(150_000)), 54, 4), + }, + AvgCase { + name: "duration_second", + values: Arc::new(DurationSecondArray::from(vec![10, 20])), + return_type: DataType::Duration(TimeUnit::Second), + sum_type: DataType::Duration(TimeUnit::Second), + expected: ScalarValue::DurationSecond(Some(15)), + }, + AvgCase { + name: "duration_millisecond", + values: Arc::new(DurationMillisecondArray::from(vec![10, 20])), + return_type: DataType::Duration(TimeUnit::Millisecond), + sum_type: DataType::Duration(TimeUnit::Millisecond), + expected: ScalarValue::DurationMillisecond(Some(15)), + }, + AvgCase { + name: "duration_microsecond", + values: Arc::new(DurationMicrosecondArray::from(vec![10, 20])), + return_type: DataType::Duration(TimeUnit::Microsecond), + sum_type: DataType::Duration(TimeUnit::Microsecond), + expected: ScalarValue::DurationMicrosecond(Some(15)), + }, + AvgCase { + name: "duration_nanosecond", + values: Arc::new(DurationNanosecondArray::from(vec![10, 20])), + return_type: DataType::Duration(TimeUnit::Nanosecond), + sum_type: DataType::Duration(TimeUnit::Nanosecond), + expected: ScalarValue::DurationNanosecond(Some(15)), + }, + ]) + } + + #[test] + fn avg_accumulator_state_types_match_state_fields() -> Result<()> { + for case in avg_cases()? { + let input_type = case.values.data_type(); + let state_fields = avg_state_fields(input_type, &case.return_type)?; + let mut acc = avg_accumulator(input_type, &case.return_type)?; + acc.update_batch(std::slice::from_ref(&case.values))?; + let state = acc.state()?; + + assert_eq!( + &state[0].data_type(), + state_fields[0].data_type(), + "{}", + case.name + ); + assert_eq!( + &state[1].data_type(), + state_fields[1].data_type(), + "{}", + case.name + ); + } + + Ok(()) + } + + #[test] + fn avg_accumulator_evaluate() -> Result<()> { + for case in avg_cases()? { + let input_type = case.values.data_type(); + let mut acc = avg_accumulator(input_type, &case.return_type)?; + acc.update_batch(std::slice::from_ref(&case.values))?; + + assert_eq!(acc.evaluate()?, case.expected, "{}", case.name); + } + + Ok(()) + } + + #[test] + fn avg_groups_state_types_match_state_fields() -> Result<()> { + for case in avg_cases()? { + let input_type = case.values.data_type(); + let state_fields = avg_state_fields(input_type, &case.return_type)?; + let acc = avg_groups_accumulator(input_type, &case.return_type)?; + let state = acc.convert_to_state(std::slice::from_ref(&case.values), None)?; + + assert_eq!( + state_fields[0].data_type(), + &DataType::UInt64, + "{}", + case.name + ); + assert_eq!(state_fields[1].data_type(), &case.sum_type, "{}", case.name); + assert_eq!(state[0].data_type(), &DataType::UInt64, "{}", case.name); + assert_eq!(state[1].data_type(), &case.sum_type, "{}", case.name); + } + + Ok(()) + } + + #[test] + fn avg_groups_convert_to_state_roundtrip() -> Result<()> { + for case in avg_cases()? { + let input_type = case.values.data_type(); + let partial = avg_groups_accumulator(input_type, &case.return_type)?; + let mut final_acc = avg_groups_accumulator(input_type, &case.return_type)?; + let state = + partial.convert_to_state(std::slice::from_ref(&case.values), None)?; + final_acc.merge_batch(&state, &vec![0; case.values.len()], None, 1)?; + + let result = final_acc.evaluate(EmitTo::All)?; + assert_eq!(result.data_type(), &case.return_type, "{}", case.name); + assert_eq!( + ScalarValue::try_from_array(result.as_ref(), 0)?, + case.expected, + "{}", + case.name + ); + } + + Ok(()) } } diff --git a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs index 2c549422d654..31c9d4bd8c06 100644 --- a/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs +++ b/datafusion/sqllogictest/src/engines/datafusion_engine/normalize.rs @@ -271,6 +271,8 @@ pub fn convert_schema_to_types(columns: &Fields) -> Vec { DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => DFColumnType::Float, DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => { diff --git a/datafusion/sqllogictest/test_files/decimal.slt b/datafusion/sqllogictest/test_files/decimal.slt index d9eac8492814..2b9caf254d68 100644 --- a/datafusion/sqllogictest/test_files/decimal.slt +++ b/datafusion/sqllogictest/test_files/decimal.slt @@ -1151,12 +1151,12 @@ Infinity Float64 query error Arrow error: Arithmetic overflow: Unsupported exp value SELECT power(2::decimal(38, 0), 5000000000) -query ?T +query RT SELECT power(arrow_cast(2, 'Decimal32(5, 0)'), 4), arrow_typeof(power(arrow_cast(2, 'Decimal32(5, 0)'), 4)); ---- 16 Decimal32(5, 0) -query ?T +query RT SELECT power(arrow_cast(2, 'Decimal64(5, 0)'), 4), arrow_typeof(power(arrow_cast(2, 'Decimal64(5, 0)'), 4)); ---- 16 Decimal64(5, 0) @@ -1268,3 +1268,12 @@ ORDER BY c1; statement ok DROP TABLE decimal_div_mismatch; + +query RT +select avg(d), arrow_typeof(avg(d)) +from ( + select arrow_cast(99999.0, 'Decimal32(5, 0)') as d + from generate_series(1, 21476) +) t; +---- +99999.0000 Decimal32(9, 4) diff --git a/datafusion/sqllogictest/test_files/spark/math/round.slt b/datafusion/sqllogictest/test_files/spark/math/round.slt index 91c5bdf0506f..247c7f2e0b11 100644 --- a/datafusion/sqllogictest/test_files/spark/math/round.slt +++ b/datafusion/sqllogictest/test_files/spark/math/round.slt @@ -281,24 +281,24 @@ SELECT round(arrow_cast(42, 'UInt32'), 2::int); # --- Decimal32 --- # round(decimal32, 0) — round to integer -query ? +query R SELECT round(arrow_cast(2.5, 'Decimal32(9, 1)'), 0::int); ---- 3.0 -query ? +query R SELECT round(arrow_cast(-2.5, 'Decimal32(9, 1)'), 0::int); ---- -3.0 # round(decimal32, 2) -query ? +query R SELECT round(arrow_cast(2.345, 'Decimal32(9, 3)'), 2::int); ---- 2.350 # round(decimal32) default scale = 0 -query ? +query R SELECT round(arrow_cast(3.5, 'Decimal32(9, 1)')); ---- 4.0 @@ -306,24 +306,24 @@ SELECT round(arrow_cast(3.5, 'Decimal32(9, 1)')); # --- Decimal64 --- # round(decimal64, 0) — round to integer -query ? +query R SELECT round(arrow_cast(2.5, 'Decimal64(18, 1)'), 0::int); ---- 3.0 -query ? +query R SELECT round(arrow_cast(-2.5, 'Decimal64(18, 1)'), 0::int); ---- -3.0 # round(decimal64, 2) -query ? +query R SELECT round(arrow_cast(2.345, 'Decimal64(18, 3)'), 2::int); ---- 2.350 # round(decimal64) default scale = 0 -query ? +query R SELECT round(arrow_cast(3.5, 'Decimal64(18, 1)')); ---- 4.0