Add `PagedBlockPool::write_prefill`, the Phase 3 (#120) bulk prefill writer for the unified paged KV cache (epic #116). It writes a whole prefill's worth of K/V for one layer into the layout-A pool established in #118: it computes the sequence's current absolute tail, `append_tokens` to allocate the trailing blocks, then chunks the `[1, n_kv_heads, n_new_tokens, head_dim]` SDPA-layout input into block-sized `write_block` calls (handling a partial first block when the prefix ended mid-block and a partial last block). Writing the new tokens at the absolute tail means the same call also serves the shared-prefix SUFFIX case — appending only the divergent tokens after a shared prefix is in place. dtype is preserved (no astype on K/V), matching the rest of the pool path.
Shared-prefix copy-on-write: block sharing in this pool is not always block-granular — `CachePool::detach_paged` captures a partially-filled tail block and shares it, so a 6-token prefix with block_size 4 leaves the second block half-full and shared. Before writing into any spanned block whose refcount > 1, `write_prefill` performs a real copy-on-write via the new `copy_on_write_block` helper: it slices the shared block's current `[block_size, n_kv_heads, head_dim]` slab out of the pool, writes it into a freshly acquired block, repoints the sequence's block_ids at the copy, and drops this sequence's reference to the shared original. The sibling that still references the original keeps it intact. Block-aligned prefixes start the suffix on a fresh append block, so the COW check is a no-op there. The stored bytes are byte-identical to the dense prefill path, so a later `gather_visible` returns the same bytes as the equivalent dense `slice_update` buffer.
This is the pool-layer prefill-write capability only. The model forward sees dense `KVCache` and cannot reach the pool until #121 wires a paged-aware cache mode and scheduler, so `write_prefill` is additive machinery exercised by tests, the same pattern as #118/#119. No model, forward, generate.rs, or C++ files are touched.
Tests:
- `write_prefill_cold_round_trip_is_byte_identical_to_dense` and `write_prefill_block_aligned_prompt_round_trips` (paged_pool_tests): a bulk prefill (including a non-block-aligned length spanning a partial last block) gathers byte-identical to a dense buffer built by `slice_update`.
- `write_prefill_cow_forks_shared_partial_tail_block` (paged_pool_tests): two sequences share a partially-filled tail block (via `retain_block`, as adopt does); each writes a different suffix; asserts the writer-while-shared forks a fresh copy, the writer-while-sole-owner keeps the block in place, refcounts are correct, and each sequence gathers its own prefix + own suffix byte-identically (no cross-corruption).
- `write_prefill_after_shared_prefix_adopt_allocates_only_suffix_blocks` (paged_detach_tests): over the real detach -> park -> adopt path, a suffix `write_prefill` allocates blocks only for the suffix (verified via `PagedCacheStats.live_blocks`) and the consumer reads back prefix + suffix correctly.
- `test_pooled_prefill_then_decode_matches_dense` (ffi_tests): bulk-prefills a 13-token prompt into the pool and a dense buffer, runs 24 decode steps appending one token to each, and compares the pooled vs dense fallback attention each step over fragmented physical rows; max RMS is 0 (FP32 byte-exact).
Summary
Add
PagedBlockPool::write_prefill, the Phase 3 (#120) bulk prefill writer for the unified paged KV cache (epic #116), with real shared-prefix copy-on-write. It writes a whole prefill's worth of K/V for one layer into the layout-A pool established in #118, and writing at the sequence's absolute tail means the same call serves the shared-prefix suffix case.What changed
src/lib/mlxcel-core/src/cache/paged.rswrite_prefill(&mut self, state, layer_idx, k_prefill, v_prefill): validates the[1, n_kv_heads, n_new_tokens, head_dim]SDPA-layout inputs (dtype preserved, no astype),append_tokensto allocate trailing blocks, then chunks the new tokens into block-sizedwrite_blockcalls — handling a partial first block (prefix ended mid-block) and a partial last block via per-block slot arithmetic. Reuseswrite_block's donation discipline.copy_on_write_block(&mut self, layer_idx, src_block_id): forks a shared block. Block sharing here is not always block-granular —CachePool::detach_pagedcaptures a partially-filled tail block and shares it — so before writing into any spanned block whoserefcount > 1,write_prefillslices the shared block's current slab out of the pool, writes it into a freshly acquired block, repointsblock_ids, and drops this sequence's reference to the original (leaving the sibling's intact). Block-aligned prefixes start the suffix on a fresh append block, so the COW check is a no-op there.src/lib/mlxcel-core/src/cache/paged_pool_tests.rs: cold round-trip byte-identity, block-aligned round-trip, and the pool-level two-sharer COW no-corruption proof.src/lib/mlxcel-core/src/cache/paged_detach_tests.rs: aCachePool-level test over the realdetach_paged->park->adopt_pagedpath asserting suffix-only block allocation viaPagedCacheStats.src/lib/mlxcel-core/src/ffi_tests.rs: end-to-end bulk-prefill -> 24-step decode parity vs the dense fallback over fragmented physical rows.Numerical result
The stored bytes are byte-identical to the dense prefill path (each token lands at the same absolute slot a dense
slice_updatefrom[0, 0, prev, 0]targets). The end-to-end prefill→decode parity test measures max RMS = 0 (FP32 byte-exact) over 24 decode steps after a 13-token bulk prefill.Deferred (NOT in this PR)
write_prefillinto the live model forward / a pagedKVCacheMode/ the scheduler, and the dense-vs-pool fast-path decision → Phase 4: Radix-trie and block-pool unification (scheduler wiring) #121. The model forward,generate.rs, andsrc/models/*are untouched; the pool is not reachable from the live forward until Phase 4: Radix-trie and block-pool unification (scheduler wiring) #121, so this is additive pool-layer machinery exercised by tests (same pattern as Phase 1: Global block-pool tensor storage #118/Phase 2: Paged decode attention over real block tables #119).*.cpp/*.hchanges here.Test plan
cargo test --lib -p mlxcel-core write_prefill --features metal,accelerate(4 new write_prefill tests pass)cargo test --lib -p mlxcel-core test_pooled_prefill_then_decode_matches_dense --features metal,accelerate(max RMS = 0)cargo test --lib -p mlxcel-core pooled_paged --features metal,accelerate(Phase 2: Paged decode attention over real block tables #119 suite, no regression)cargo test --lib -p mlxcel-core paged_pool --features metal,accelerate/paged_detach(Phase 1: Global block-pool tensor storage #118 + detach suites, no regression)cargo clippy -p mlxcel-core --tests --features metal,accelerate -- -D warnings(clean)cargo fmt -p mlxcel-core --check(clean)Closes #120