diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs index 5d23046ec7726..9bcc749c23dce 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/materializing_stream.rs @@ -1540,7 +1540,7 @@ impl MaterializingSortMergeJoinStream { /// gathers columns across sources. A null-row sentinel at source index 0 /// handles null right indices (unmatched streamed rows). fn materialize_right_columns( - &self, + &mut self, matched_chunks: &[(usize, UInt64Array, UInt64Array)], total_matched_rows: usize, ) -> Result> { @@ -1555,6 +1555,19 @@ impl MaterializingSortMergeJoinStream { matched_chunks.iter().map(|c| &c.2 as &dyn Array).collect(); as_uint64_array(&compute::concat(&refs)?)?.clone() }; + + let spill_reservation = self.reservation.new_empty(); + if matches!( + &self.buffered_data.batches[first_batch_idx].batch, + BufferedBatchState::Spilled(_) + ) { + spill_reservation + .grow(self.buffered_data.batches[first_batch_idx].size_estimation); + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size() + spill_reservation.size()); + } + return fetch_right_columns_by_idxs( &self.buffered_data, first_batch_idx, @@ -1588,24 +1601,33 @@ impl MaterializingSortMergeJoinStream { } let num_right_cols = self.buffered_schema.fields().len(); - let mut right_columns = Vec::with_capacity(num_right_cols); // Read each source batch once (spilled batches require disk I/O). - let source_data: Vec> = source_batches - .iter() - .map(|&idx| { - let bb = &self.buffered_data.batches[idx]; - match &bb.batch { - BufferedBatchState::InMemory(batch) => Some(batch.clone()), - BufferedBatchState::Spilled(spill_file) => { - let file = BufReader::new(File::open(spill_file.path()).ok()?); - let reader = StreamReader::try_new(file, None).ok()?; - reader.into_iter().next()?.ok() - } + // Track memory for each spilled batch at the point of deserialization + // so the pool reflects actual usage as it grows. + let spill_reservation = self.reservation.new_empty(); + let mut source_data: Vec> = + Vec::with_capacity(source_batches.len()); + for &idx in &source_batches { + let bb = &self.buffered_data.batches[idx]; + match &bb.batch { + BufferedBatchState::InMemory(batch) => { + source_data.push(Some(batch.clone())); } - }) - .collect(); + BufferedBatchState::Spilled(spill_file) => { + spill_reservation.grow(bb.size_estimation); + self.join_metrics + .peak_mem_used() + .set_max(self.reservation.size() + spill_reservation.size()); + + let file = BufReader::new(File::open(spill_file.path())?); + let reader = StreamReader::try_new(file, None)?; + source_data.push(reader.into_iter().next().transpose()?); + } + } + } + let mut right_columns = Vec::with_capacity(num_right_cols); for col_idx in 0..num_right_cols { let dtype = self.buffered_schema.field(col_idx).data_type(); let null_array = new_null_array(dtype, 1); @@ -1624,7 +1646,6 @@ impl MaterializingSortMergeJoinStream { } } } - right_columns.push(interleave(&source_arrays, &interleave_indices)?); } diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs index bc34c351c5e21..c4377b3189ff7 100644 --- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs +++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs @@ -4724,3 +4724,226 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> { Ok(()) } + +/// Verifies that `peak_mem_used` reflects spill read-back memory during +/// output materialization (multi-source path). +/// +/// When spilled buffered batches are read back from disk to produce join +/// output, a scoped `MemoryReservation` (via `new_empty()`) tracks the +/// transient memory. Its `Drop` guarantees the pool is balanced on every +/// exit path — normal return or early `?` error. +#[tokio::test] +async fn spill_read_back_memory_accounting() -> Result<()> { + use arrow::array::Array; + + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + + // Memory limit too small for a full batch — forces spilling. + let memory_limit = size_estimation / 2; + + // All rows share the same join key (b=1) to force multiple buffered + // batches in the same key group — triggering spill read-back during + // output materialization. + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + let right_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![1, 1]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); + + assert!(!result.is_empty(), "Expected non-empty join result"); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // peak_mem_used should reflect the spill read-back: when buffered + // batches are read from disk during output materialization, grow() + // temporarily reserves size_estimation. This pushes peak above what + // join_arrays_mem alone would show. + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= size_estimation, + "peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \ + because spill read-back temporarily loads full batch into memory" + ); + + // All memory must be released (grow/shrink balanced) + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); + + Ok(()) +} + +/// Verifies spill read-back memory tracking for the single-source path. +/// +/// When only ONE buffered batch exists for a key group and it's spilled, +/// `fetch_right_columns_by_idxs` reads it back. A scoped `MemoryReservation` +/// (via `new_empty()`) tracks the transient memory and releases it on drop. +#[tokio::test] +async fn spill_read_back_single_source() -> Result<()> { + use arrow::array::Array; + + let left_batch = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let size_estimation = left_batch.get_array_memory_size() + + Int32Array::from(vec![1, 1]).get_array_memory_size() + + 2usize.next_power_of_two() * size_of::() + + size_of::>() + + size_of::(); + + // Memory limit too small for a full batch — forces spilling. + let memory_limit = size_estimation / 2; + + // Multiple distinct keys so each key group has exactly ONE buffered batch. + // This ensures the single-source path is exercised. + let left_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a1", &vec![i * 2, i * 2 + 1]), + ("b1", &vec![i, i]), + ("c1", &vec![100 + i, 101 + i]), + ) + }) + .collect(); + let left = build_table_from_batches(left_batches); + + // One batch per key — each key group has single source + let right_batches: Vec = (0..4) + .map(|i| { + build_table_i32( + ("a2", &vec![i * 2, i * 2 + 1]), + ("b2", &vec![i, i]), + ("c2", &vec![200 + i, 201 + i]), + ) + }) + .collect(); + let right = build_table_from_batches(right_batches); + + let on = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?) as _, + Arc::new(Column::new_with_schema("b2", &right.schema())?) as _, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let runtime = RuntimeEnvBuilder::new() + .with_memory_limit(memory_limit, 1.0) + .with_disk_manager_builder( + DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory), + ) + .build_arc()?; + + let session_config = SessionConfig::default().with_batch_size(50); + let task_ctx = Arc::new( + TaskContext::default() + .with_session_config(session_config) + .with_runtime(Arc::clone(&runtime)), + ); + + let join = join_with_options( + Arc::clone(&left), + Arc::clone(&right), + on.clone(), + Inner, + sort_options, + NullEquality::NullEqualsNothing, + )?; + + let stream = join.execute(0, task_ctx)?; + let result = common::collect(stream).await.unwrap(); + + assert!(!result.is_empty(), "Expected non-empty join result"); + + let metrics = join.metrics().unwrap(); + assert!( + metrics.spill_count().unwrap() > 0, + "Expected spilling to occur" + ); + + // peak_mem_used should reflect the single-batch read-back + let peak_mem = metrics + .sum_by_name("peak_mem_used") + .map(|m| m.as_usize()) + .unwrap_or(0); + assert!( + peak_mem >= size_estimation, + "peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \ + because single-source spill read-back loads full batch" + ); + + // All memory must be released + assert_eq!( + runtime.memory_pool.reserved(), + 0, + "All memory should be released after join completes" + ); + + Ok(()) +}