Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1540,7 +1540,7 @@ impl MaterializingSortMergeJoinStream {
/// gathers columns across sources. A null-row sentinel at source index 0
/// handles null right indices (unmatched streamed rows).
fn materialize_right_columns(
&self,
&mut self,
matched_chunks: &[(usize, UInt64Array, UInt64Array)],
total_matched_rows: usize,
) -> Result<Vec<ArrayRef>> {
Expand All @@ -1555,6 +1555,19 @@ impl MaterializingSortMergeJoinStream {
matched_chunks.iter().map(|c| &c.2 as &dyn Array).collect();
as_uint64_array(&compute::concat(&refs)?)?.clone()
};

let spill_reservation = self.reservation.new_empty();
if matches!(
&self.buffered_data.batches[first_batch_idx].batch,
BufferedBatchState::Spilled(_)
) {
spill_reservation
.grow(self.buffered_data.batches[first_batch_idx].size_estimation);
self.join_metrics
.peak_mem_used()
.set_max(self.reservation.size() + spill_reservation.size());
}

return fetch_right_columns_by_idxs(
&self.buffered_data,
first_batch_idx,
Expand Down Expand Up @@ -1588,24 +1601,33 @@ impl MaterializingSortMergeJoinStream {
}

let num_right_cols = self.buffered_schema.fields().len();
let mut right_columns = Vec::with_capacity(num_right_cols);

// Read each source batch once (spilled batches require disk I/O).
let source_data: Vec<Option<RecordBatch>> = source_batches
.iter()
.map(|&idx| {
let bb = &self.buffered_data.batches[idx];
match &bb.batch {
BufferedBatchState::InMemory(batch) => Some(batch.clone()),
BufferedBatchState::Spilled(spill_file) => {
let file = BufReader::new(File::open(spill_file.path()).ok()?);
let reader = StreamReader::try_new(file, None).ok()?;
reader.into_iter().next()?.ok()
}
// Track memory for each spilled batch at the point of deserialization
// so the pool reflects actual usage as it grows.
let spill_reservation = self.reservation.new_empty();
let mut source_data: Vec<Option<RecordBatch>> =
Vec::with_capacity(source_batches.len());
for &idx in &source_batches {
let bb = &self.buffered_data.batches[idx];
match &bb.batch {
BufferedBatchState::InMemory(batch) => {
source_data.push(Some(batch.clone()));
}
})
.collect();
BufferedBatchState::Spilled(spill_file) => {
spill_reservation.grow(bb.size_estimation);
self.join_metrics
.peak_mem_used()
.set_max(self.reservation.size() + spill_reservation.size());

let file = BufReader::new(File::open(spill_file.path())?);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like spill_read_mem is only shrunk after all file reads and all interleave calls succeed. Once self.reservation.grow(batch_mem) succeeds, any later ? such as from File::open, StreamReader::try_new, next().transpose(), or interleave can return early before the shrink happens at line 1651.

That leaves the reservation inflated until the stream is dropped, which breaks the grow/shrink accounting invariant on error paths and can leave the memory pool reporting stale reserved memory after a failed poll.

Could we make this temporary read-back reservation scoped/RAII-based, or otherwise guarantee that shrink runs on every return path after a successful grow?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kosiew Thanks for highlighting this. Switched from manual grow/shrink on self.reservation to a scoped MemoryReservation via self.reservation.new_empty().

let reader = StreamReader::try_new(file, None)?;
source_data.push(reader.into_iter().next().transpose()?);
}
}
}

let mut right_columns = Vec::with_capacity(num_right_cols);
for col_idx in 0..num_right_cols {
let dtype = self.buffered_schema.field(col_idx).data_type();
let null_array = new_null_array(dtype, 1);
Expand All @@ -1624,7 +1646,6 @@ impl MaterializingSortMergeJoinStream {
}
}
}

right_columns.push(interleave(&source_arrays, &interleave_indices)?);
}

Expand Down
223 changes: 223 additions & 0 deletions datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4724,3 +4724,226 @@ async fn spill_filtered_boundary_loses_outer_rows() -> Result<()> {

Ok(())
}

/// Verifies that `peak_mem_used` reflects spill read-back memory during
/// output materialization (multi-source path).
///
/// When spilled buffered batches are read back from disk to produce join
/// output, a scoped `MemoryReservation` (via `new_empty()`) tracks the
/// transient memory. Its `Drop` guarantees the pool is balanced on every
/// exit path — normal return or early `?` error.
#[tokio::test]
async fn spill_read_back_memory_accounting() -> Result<()> {
use arrow::array::Array;

let left_batch = build_table_i32(
("a1", &vec![0, 1]),
("b1", &vec![1, 1]),
("c1", &vec![4, 5]),
);
let size_estimation = left_batch.get_array_memory_size()
+ Int32Array::from(vec![1, 1]).get_array_memory_size()
+ 2usize.next_power_of_two() * size_of::<usize>()
+ size_of::<std::ops::Range<usize>>()
+ size_of::<usize>();

// Memory limit too small for a full batch — forces spilling.
let memory_limit = size_estimation / 2;

// All rows share the same join key (b=1) to force multiple buffered
// batches in the same key group — triggering spill read-back during
// output materialization.
let left_batches: Vec<RecordBatch> = (0..4)
.map(|i| {
build_table_i32(
("a1", &vec![i * 2, i * 2 + 1]),
("b1", &vec![1, 1]),
("c1", &vec![100 + i, 101 + i]),
)
})
.collect();
let left = build_table_from_batches(left_batches);

let right_batches: Vec<RecordBatch> = (0..4)
.map(|i| {
build_table_i32(
("a2", &vec![i * 2, i * 2 + 1]),
("b2", &vec![1, 1]),
("c2", &vec![200 + i, 201 + i]),
)
})
.collect();
let right = build_table_from_batches(right_batches);

let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let sort_options = vec![SortOptions::default(); on.len()];

let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(memory_limit, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
)
.build_arc()?;

let session_config = SessionConfig::default().with_batch_size(50);
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(session_config)
.with_runtime(Arc::clone(&runtime)),
);

let join = join_with_options(
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
Inner,
sort_options,
NullEquality::NullEqualsNothing,
)?;

let stream = join.execute(0, task_ctx)?;
let result = common::collect(stream).await.unwrap();

assert!(!result.is_empty(), "Expected non-empty join result");

let metrics = join.metrics().unwrap();
assert!(
metrics.spill_count().unwrap() > 0,
"Expected spilling to occur"
);

// peak_mem_used should reflect the spill read-back: when buffered
// batches are read from disk during output materialization, grow()
// temporarily reserves size_estimation. This pushes peak above what
// join_arrays_mem alone would show.
let peak_mem = metrics
.sum_by_name("peak_mem_used")
.map(|m| m.as_usize())
.unwrap_or(0);
assert!(
peak_mem >= size_estimation,
"peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \
because spill read-back temporarily loads full batch into memory"
);

// All memory must be released (grow/shrink balanced)
assert_eq!(
runtime.memory_pool.reserved(),
0,
"All memory should be released after join completes"
);

Ok(())
}

/// Verifies spill read-back memory tracking for the single-source path.
///
/// When only ONE buffered batch exists for a key group and it's spilled,
/// `fetch_right_columns_by_idxs` reads it back. A scoped `MemoryReservation`
/// (via `new_empty()`) tracks the transient memory and releases it on drop.
#[tokio::test]
async fn spill_read_back_single_source() -> Result<()> {
use arrow::array::Array;

let left_batch = build_table_i32(
("a1", &vec![0, 1]),
("b1", &vec![1, 1]),
("c1", &vec![4, 5]),
);
let size_estimation = left_batch.get_array_memory_size()
+ Int32Array::from(vec![1, 1]).get_array_memory_size()
+ 2usize.next_power_of_two() * size_of::<usize>()
+ size_of::<std::ops::Range<usize>>()
+ size_of::<usize>();

// Memory limit too small for a full batch — forces spilling.
let memory_limit = size_estimation / 2;

// Multiple distinct keys so each key group has exactly ONE buffered batch.
// This ensures the single-source path is exercised.
let left_batches: Vec<RecordBatch> = (0..4)
.map(|i| {
build_table_i32(
("a1", &vec![i * 2, i * 2 + 1]),
("b1", &vec![i, i]),
("c1", &vec![100 + i, 101 + i]),
)
})
.collect();
let left = build_table_from_batches(left_batches);

// One batch per key — each key group has single source
let right_batches: Vec<RecordBatch> = (0..4)
.map(|i| {
build_table_i32(
("a2", &vec![i * 2, i * 2 + 1]),
("b2", &vec![i, i]),
("c2", &vec![200 + i, 201 + i]),
)
})
.collect();
let right = build_table_from_batches(right_batches);

let on = vec![(
Arc::new(Column::new_with_schema("b1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("b2", &right.schema())?) as _,
)];
let sort_options = vec![SortOptions::default(); on.len()];

let runtime = RuntimeEnvBuilder::new()
.with_memory_limit(memory_limit, 1.0)
.with_disk_manager_builder(
DiskManagerBuilder::default().with_mode(DiskManagerMode::OsTmpDirectory),
)
.build_arc()?;

let session_config = SessionConfig::default().with_batch_size(50);
let task_ctx = Arc::new(
TaskContext::default()
.with_session_config(session_config)
.with_runtime(Arc::clone(&runtime)),
);

let join = join_with_options(
Arc::clone(&left),
Arc::clone(&right),
on.clone(),
Inner,
sort_options,
NullEquality::NullEqualsNothing,
)?;

let stream = join.execute(0, task_ctx)?;
let result = common::collect(stream).await.unwrap();

assert!(!result.is_empty(), "Expected non-empty join result");

let metrics = join.metrics().unwrap();
assert!(
metrics.spill_count().unwrap() > 0,
"Expected spilling to occur"
);

// peak_mem_used should reflect the single-batch read-back
let peak_mem = metrics
.sum_by_name("peak_mem_used")
.map(|m| m.as_usize())
.unwrap_or(0);
assert!(
peak_mem >= size_estimation,
"peak_mem_used ({peak_mem}) should be >= size_estimation ({size_estimation}) \
because single-source spill read-back loads full batch"
);

// All memory must be released
assert_eq!(
runtime.memory_pool.reserved(),
0,
"All memory should be released after join completes"
);

Ok(())
}
Loading