Skip to content

feat: paged prefill writer with shared-prefix copy-on-write#150

Merged
inureyes merged 1 commit into
mainfrom
feat/120-paged-prefill-into-pool
Jun 3, 2026
Merged

feat: paged prefill writer with shared-prefix copy-on-write#150
inureyes merged 1 commit into
mainfrom
feat/120-paged-prefill-into-pool

Conversation

@inureyes
Copy link
Copy Markdown
Member

@inureyes inureyes commented Jun 3, 2026

Summary

Add PagedBlockPool::write_prefill, the Phase 3 (#120) bulk prefill writer for the unified paged KV cache (epic #116), with real shared-prefix copy-on-write. It writes a whole prefill's worth of K/V for one layer into the layout-A pool established in #118, and writing at the sequence's absolute tail means the same call serves the shared-prefix suffix case.

What changed

  • src/lib/mlxcel-core/src/cache/paged.rs
    • write_prefill(&mut self, state, layer_idx, k_prefill, v_prefill): validates the [1, n_kv_heads, n_new_tokens, head_dim] SDPA-layout inputs (dtype preserved, no astype), append_tokens to allocate trailing blocks, then chunks the new tokens into block-sized write_block calls — handling a partial first block (prefix ended mid-block) and a partial last block via per-block slot arithmetic. Reuses write_block's donation discipline.
    • copy_on_write_block(&mut self, layer_idx, src_block_id): forks a shared block. Block sharing here is not always block-granular — CachePool::detach_paged captures a partially-filled tail block and shares it — so before writing into any spanned block whose refcount > 1, write_prefill slices the shared block's current slab out of the pool, writes it into a freshly acquired block, repoints block_ids, and drops this sequence's reference to the original (leaving the sibling's intact). Block-aligned prefixes start the suffix on a fresh append block, so the COW check is a no-op there.
  • src/lib/mlxcel-core/src/cache/paged_pool_tests.rs: cold round-trip byte-identity, block-aligned round-trip, and the pool-level two-sharer COW no-corruption proof.
  • src/lib/mlxcel-core/src/cache/paged_detach_tests.rs: a CachePool-level test over the real detach_paged -> park -> adopt_paged path asserting suffix-only block allocation via PagedCacheStats.
  • src/lib/mlxcel-core/src/ffi_tests.rs: end-to-end bulk-prefill -> 24-step decode parity vs the dense fallback over fragmented physical rows.

Numerical result

The stored bytes are byte-identical to the dense prefill path (each token lands at the same absolute slot a dense slice_update from [0, 0, prev, 0] targets). The end-to-end prefill→decode parity test measures max RMS = 0 (FP32 byte-exact) over 24 decode steps after a 13-token bulk prefill.

Deferred (NOT in this PR)

Test plan

  • cargo test --lib -p mlxcel-core write_prefill --features metal,accelerate (4 new write_prefill tests pass)
  • cargo test --lib -p mlxcel-core test_pooled_prefill_then_decode_matches_dense --features metal,accelerate (max RMS = 0)
  • cargo test --lib -p mlxcel-core pooled_paged --features metal,accelerate (Phase 2: Paged decode attention over real block tables #119 suite, no regression)
  • cargo test --lib -p mlxcel-core paged_pool --features metal,accelerate / paged_detach (Phase 1: Global block-pool tensor storage #118 + detach suites, no regression)
  • cargo clippy -p mlxcel-core --tests --features metal,accelerate -- -D warnings (clean)
  • cargo fmt -p mlxcel-core --check (clean)

Closes #120

Add `PagedBlockPool::write_prefill`, the Phase 3 (#120) bulk prefill writer for the unified paged KV cache (epic #116). It writes a whole prefill's worth of K/V for one layer into the layout-A pool established in #118: it computes the sequence's current absolute tail, `append_tokens` to allocate the trailing blocks, then chunks the `[1, n_kv_heads, n_new_tokens, head_dim]` SDPA-layout input into block-sized `write_block` calls (handling a partial first block when the prefix ended mid-block and a partial last block). Writing the new tokens at the absolute tail means the same call also serves the shared-prefix SUFFIX case — appending only the divergent tokens after a shared prefix is in place. dtype is preserved (no astype on K/V), matching the rest of the pool path.

Shared-prefix copy-on-write: block sharing in this pool is not always block-granular — `CachePool::detach_paged` captures a partially-filled tail block and shares it, so a 6-token prefix with block_size 4 leaves the second block half-full and shared. Before writing into any spanned block whose refcount > 1, `write_prefill` performs a real copy-on-write via the new `copy_on_write_block` helper: it slices the shared block's current `[block_size, n_kv_heads, head_dim]` slab out of the pool, writes it into a freshly acquired block, repoints the sequence's block_ids at the copy, and drops this sequence's reference to the shared original. The sibling that still references the original keeps it intact. Block-aligned prefixes start the suffix on a fresh append block, so the COW check is a no-op there. The stored bytes are byte-identical to the dense prefill path, so a later `gather_visible` returns the same bytes as the equivalent dense `slice_update` buffer.

This is the pool-layer prefill-write capability only. The model forward sees dense `KVCache` and cannot reach the pool until #121 wires a paged-aware cache mode and scheduler, so `write_prefill` is additive machinery exercised by tests, the same pattern as #118/#119. No model, forward, generate.rs, or C++ files are touched.

Tests:

- `write_prefill_cold_round_trip_is_byte_identical_to_dense` and `write_prefill_block_aligned_prompt_round_trips` (paged_pool_tests): a bulk prefill (including a non-block-aligned length spanning a partial last block) gathers byte-identical to a dense buffer built by `slice_update`.
- `write_prefill_cow_forks_shared_partial_tail_block` (paged_pool_tests): two sequences share a partially-filled tail block (via `retain_block`, as adopt does); each writes a different suffix; asserts the writer-while-shared forks a fresh copy, the writer-while-sole-owner keeps the block in place, refcounts are correct, and each sequence gathers its own prefix + own suffix byte-identically (no cross-corruption).
- `write_prefill_after_shared_prefix_adopt_allocates_only_suffix_blocks` (paged_detach_tests): over the real detach -> park -> adopt path, a suffix `write_prefill` allocates blocks only for the suffix (verified via `PagedCacheStats.live_blocks`) and the consumer reads back prefix + suffix correctly.
- `test_pooled_prefill_then_decode_matches_dense` (ffi_tests): bulk-prefills a 13-token prompt into the pool and a dense buffer, runs 24 decode steps appending one token to each, and compares the pooled vs dense fallback attention each step over fragmented physical rows; max RMS is 0 (FP32 byte-exact).
@inureyes inureyes added type:enhancement New features, capabilities, or significant additions area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:inference Generation, sampling, decoding (incl. speculative, DRY) status:review Under review labels Jun 3, 2026
@inureyes
Copy link
Copy Markdown
Member Author

inureyes commented Jun 3, 2026

Implementation Review Summary

Intent

Phase 3 of epic #116: add PagedBlockPool::write_prefill (bulk-write a prompt's K/V into the layout-A pool, chunked into blocks, with real shared-prefix copy-on-write of a partially-filled tail block) + tests. Additive pool-layer machinery; live-forward/scheduler wiring is the documented #121 carry-forward.

Findings Addressed

  • None — no CRITICAL/HIGH/MEDIUM defects found. No code changes were required.

Verification

  • All stated requirements implemented (write_prefill bulk writer + private copy_on_write_block; partial first/last block slot arithmetic; append_tokens before the walk; byte-identity to the dense slice_update contract)
  • No placeholder/mock code remaining (real COW; reuses write_block's donation discipline rather than reimplementing the slice_update append)
  • Integrated into project code flow — N/A by design: Phase 3: Paged prefill into the block pool #120 is additive pool-layer machinery; live-forward/scheduler wiring is the documented Phase 4: Radix-trie and block-pool unification (scheduler wiring) #121 carry-forward (model files / generate.rs / C++ intentionally untouched, confirmed by diff)
  • Project conventions followed (Result<_, String>; only 2 .expect() in production, both invariant-guarded with documenting messages; no unwrap/panic/todo; thorough rustdoc)
  • Existing modules reused where applicable (write_block, append_tokens, acquire_block/release_block, refcount, gather_visible)
  • No unintended structural changes (exactly the 4 intended files; +782/-0)
  • Tests pass (write_prefill 4/4; test_pooled_prefill_then_decode_matches_dense max RMS = 0e0; clippy --lib clean)

COW correctness (reviewed most carefully)

  • COW triggers only while shared (refcount > 1); first writer forks, second writer (now sole owner) writes in-place — proven by write_prefill_cow_forks_shared_partial_tail_block (A forks → a_tail != tail_block refcount 1; B keeps → b_tail == tail_block refcount 2→1; prefix_blocks[0] stays 2)
  • copy_on_write_block copies the FULL [block_size, H, D] slab; the subsequent write_block overwrites only the divergent suffix slots — proven by the forked tail holding prefix base (slots 0,1) + suffix base (slots 2,3) under an independent dense reference
  • No double-free / refcount imbalance — acquire_block (fresh row) precedes release_block(src) (sibling pin keeps src row alive at refcount > 0)
  • Acceptance criterion 2 (suffix-only allocation) proven over the real detach_paged → park → adopt_parked_paged path via PagedCacheStats (live_after - live_before == 2, prefix blocks physically reused)
  • Absolute-indexing (first_abs = layer.len) is correct for logical_start == 0; confirmed no Phase 3: Paged prefill into the block pool #120 test relies on a slid sequence (the logical_start > 0 slid case is the documented Phase 4: Radix-trie and block-pool unification (scheduler wiring) #121 carry-forward)

Remaining Items

  • None.

@inureyes
Copy link
Copy Markdown
Member Author

inureyes commented Jun 3, 2026

Security & performance review — clean (no changes required)

Reviewed the Phase 3 (#120) pool-layer additions in cache/paged.rs (write_prefill + private copy_on_write_block) plus the three test files, against the COW/refcount, memory-safety, and hot-path performance criteria. No CRITICAL/HIGH/MEDIUM findings; nothing auto-fixed. Warm verification all green: cargo check --lib, cargo clippy --lib -- -D warnings, and the 4 write_prefill tests (incl. the two-sharer COW no-corruption proof) pass.

Refcount integrity across COW — sound

  • No double-free: release_block on a still-shared block (refcount 2) decrements to 1 and returns early (paged.rs:705-706), never reaching the block_rows.remove -> free_rows.push / free_lists.push recycle path (only runs at refcount 0). The sibling keeps the block live.
  • No use-after-free / aliasing: copy_on_write_block reads src_row's slab (ffi::slice + ffi::reshape, independent UniquePtrs) before acquire_block(fresh) and write_block(fresh), and release_block(src) runs only afterward. Since src is still mapped in block_rows when assign_row(fresh) runs, the fresh row can never equal src_row — read source and write target are disjoint rows. MLX slice_update also rebuilds the pool tensor rather than mutating in place, so there is no in-place hazard regardless.
  • No refcount leak: each acquire_block (refcount 1) is stored into exactly one block_ids[block_index] entry, and release_block(src) balances the fork (+1 fresh / -1 src). The original sequence's block_ids[block_index] is repointed to fresh, dropping its reference to src.

Memory growth / leak — bounded

  • No UniquePtr<MlxArray> leak: all sliced slabs/slots are RAII-scoped. write_block uses .take() + reassign so MLX donates the old buffer (O(block)).
  • COW fork = exactly one fresh block per shared boundary block; append_tokens allocates only the trailing blocks the suffix needs (refcount 1). No unbounded pool growth, no O(pool)-per-token copy. The per-block write loop is O(n_new) total, not O(n^2).

No new unsafe

Zero unsafe in any of the four changed files. The new code uses only the cxx-generated safe wrappers (ffi::slice / slice_update / reshape / take / array_shape / array_dtype, declared as plain fn in the bridge) plus acquire_block / write_block / release_block.

Performance — no hot-path regressions

  • No eval/sync anywhere on the write path; the MLX graph stays lazy until gather/SDPA.
  • COW does one full block_size slab copy then the suffix overwrite (2 bounded writes on the forked block), once per shared boundary block. No accidental quadratic behavior or redundant materialization.
  • No astype/dtype conversion on the K/V write path — dtype is preserved end-to-end.

Integer / bounds

first_abs/last_abs/block_index/slot_start/n_slots and the i32 casts are sound for all scheduler-reachable inputs: block_size > 0 is enforced at layout construction (no div-by-zero), n_new > 0 is guaranteed before the (last_abs - 1) subtraction (no underflow), and i32 overflow would require >2.1B tokens. The i32 convention matches the surrounding code and the MLX FFI boundary.

One latent (NOT reachable today) robustness note — LOW, intentionally not fixed

write_prefill indexes block_ids[block_index] by absolute position (abs / block_size), consistent with append_tokens and gather_visible (both treat block_ids[i] as absolute slots [i*bs, (i+1)*bs) and apply logical_start only as a post-gather window slice). append_tokens, however, sizes block_ids from visible_len() = len - logical_start. These agree only while logical_start == 0. Today every production path sets PagedLayerState::logical_start = 0 (paged.rs:554/581, cache.rs construction), so this is not reachable. A future sliding-window/trim-then-prefill feature that leaves logical_start > 0 before calling write_prefill would make the top block_ids[block_index] access out-of-bounds. Out of scope for #120 (no caller produces it); flagging so #121's live wiring either keeps the absolute-vs-visible indexing reconciled or adds a debug_assert_eq!(layer.logical_start, 0) / explicit guard at the write_prefill entry.

No commits pushed; PR left at status:review.

@inureyes inureyes added status:done Completed and removed status:review Under review labels Jun 3, 2026
@inureyes inureyes merged commit c6af2fd into main Jun 3, 2026
5 checks passed
@inureyes inureyes deleted the feat/120-paged-prefill-into-pool branch June 3, 2026 15:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:core mlxcel-core: MLX FFI, primitives, KV cache, layers area:inference Generation, sampling, decoding (incl. speculative, DRY) status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Phase 3: Paged prefill into the block pool

1 participant