diff --git a/src/lib/mlxcel-core/src/cache/paged.rs b/src/lib/mlxcel-core/src/cache/paged.rs index 1c49ba5..214d08b 100644 --- a/src/lib/mlxcel-core/src/cache/paged.rs +++ b/src/lib/mlxcel-core/src/cache/paged.rs @@ -994,6 +994,203 @@ impl PagedBlockPool { Ok(()) } + /// Write a whole prefill's worth of K/V for one layer into the pool. + /// + /// `k_prefill` / `v_prefill` are `[1, n_kv_heads, n_new_tokens, head_dim]` + /// (the SDPA layout the attention path produces, identical to the primary + /// convention [`write_block`] accepts). The `n_new_tokens` tokens are + /// written starting at the sequence's current absolute tail (`layer.len`), + /// so this also serves the shared-prefix SUFFIX case: when called on a + /// sequence that already references a shared prefix, it appends only the new + /// (divergent) tokens. + /// + /// Steps: + /// 1. Validate shapes/dtype/geometry against the layer's pool metadata. + /// 2. [`append_tokens`] to grow the logical length and allocate the trailing + /// blocks the new tokens need (these are always fresh, refcount 1). + /// 3. For every physical block the new tokens span, slice the matching + /// `[.., tok_start..tok_end, ..]` window out of `k_prefill`/`v_prefill` + /// and [`write_block`] it at the right `slot_start`. A partial first + /// block (the prefix ended mid-block) and a partial last block are both + /// handled by the per-block slot arithmetic. + /// + /// ## Copy-on-write of a shared partial tail block + /// + /// Block sharing in this pool is *not* always block-granular: a shared + /// prefix can end mid-block (e.g. a 6-token prefix with `block_size == 4` + /// leaves the second block half-full), and `CachePool::detach_paged` + /// captures that partially-filled tail block and shares it. Writing the + /// suffix's first tokens into that block would corrupt every sibling + /// sequence that still references it. So before writing into any spanned + /// block whose `refcount > 1`, this method performs a real copy-on-write: + /// it copies the block's current contents into a freshly acquired block, + /// repoints `state.block_ids[block_index]` at the copy, and drops this + /// sequence's reference to the shared original (leaving the sibling's + /// reference intact). Block-aligned prefixes start the suffix on a fresh + /// `append_tokens` block (refcount 1), so the COW check is a no-op there. + /// + /// The stored bytes are byte-identical to the dense prefill path: each + /// token lands at the same absolute `[.., abs, ..]` slot a dense + /// `slice_update` from `[0, 0, prev, 0]` would target, so a later + /// [`gather_visible`] returns the same bytes as the equivalent dense buffer. + pub fn write_prefill( + &mut self, + state: &mut PagedSequenceState, + layer_idx: usize, + k_prefill: &MlxArray, + v_prefill: &MlxArray, + ) -> Result<(), String> { + self.validate_layer(layer_idx)?; + + // Both inputs must be the 4D SDPA layout [1, n_kv_heads, n_new, head_dim]. + let k_shape = ffi::array_shape(k_prefill); + let v_shape = ffi::array_shape(v_prefill); + let (n_kv_heads, n_new, head_dim) = match k_shape.as_slice() { + [batch, n_kv_heads, n_new, head_dim] => { + if *batch != 1 { + return Err(format!( + "PagedBlockPool::write_prefill: k_prefill batch dim must be 1, got {batch} (shape {k_shape:?})" + )); + } + (*n_kv_heads, *n_new, *head_dim) + } + _ => { + return Err(format!( + "PagedBlockPool::write_prefill: k_prefill must be [1, n_kv_heads, n_new_tokens, head_dim], got shape {k_shape:?}" + )); + } + }; + if v_shape != k_shape { + return Err(format!( + "PagedBlockPool::write_prefill: k_prefill shape {k_shape:?} does not match v_prefill shape {v_shape:?}" + )); + } + let k_dtype = ffi::array_dtype(k_prefill); + let v_dtype = ffi::array_dtype(v_prefill); + if k_dtype != v_dtype { + return Err(format!( + "PagedBlockPool::write_prefill: K dtype {k_dtype} does not match V dtype {v_dtype}" + )); + } + if n_new <= 0 { + return Ok(()); + } + + let block_size = self.layout.block_size as i32; + // Absolute tail before the append — the first new token lands here. + let first_abs = { + let layer = state.layer(layer_idx).ok_or_else(|| { + format!( + "PagedBlockPool::write_prefill: layer {layer_idx} out of range for {} layers", + state.layers.len() + ) + })?; + layer.len as i32 + }; + let last_abs = first_abs + n_new; // exclusive + + // Allocate the trailing blocks the new tokens need (and bump len). The + // existing partial tail block, if any, is left untouched by append. + self.append_tokens(state, layer_idx, n_new as usize)?; + + // Walk every block the new tokens span and write its slice. + let first_block = (first_abs / block_size) as usize; + let last_block = ((last_abs - 1) / block_size) as usize; + for block_index in first_block..=last_block { + let block_start_abs = block_index as i32 * block_size; + let slot_start = (first_abs.max(block_start_abs) - block_start_abs) as usize; + let abs_begin = first_abs.max(block_start_abs); + let abs_end = last_abs.min(block_start_abs + block_size); + let n_slots = abs_end - abs_begin; + if n_slots <= 0 { + continue; + } + let tok_start = abs_begin - first_abs; + let tok_end = tok_start + n_slots; + + // Copy-on-write the target block if a sibling sequence still shares + // it (only ever the partial prefix tail; fresh append blocks have + // refcount 1 and skip this). + let block_id = state.layers[layer_idx].block_ids[block_index]; + let target_id = if self.refcount(block_id) > 1 { + let fresh = self.copy_on_write_block(layer_idx, block_id)?; + state.layers[layer_idx].block_ids[block_index] = fresh; + fresh + } else { + block_id + }; + + // Slice the [1, H, n_slots, D] window out of the prefill tensors. + let starts = [0, 0, tok_start, 0]; + let stops = [1, n_kv_heads, tok_end, head_dim]; + let k_slice = ffi::slice(k_prefill, &starts, &stops); + let v_slice = ffi::slice(v_prefill, &starts, &stops); + + self.write_block(target_id, layer_idx, slot_start, &k_slice, &v_slice)?; + } + Ok(()) + } + + /// Copy the current contents of `src_block_id` into a freshly acquired block + /// on the same layer and return the new block id. Used by [`write_prefill`] + /// to fork a shared partial tail block before mutating it (copy-on-write). + /// + /// The source's full `[block_size, n_kv_heads, head_dim]` K and V slabs are + /// sliced out of the layer's pool tensors and written into the fresh block + /// at `slot_start = 0`, so the new block is a byte-identical copy (the + /// caller then overwrites only the divergent suffix slots). The new block + /// starts at refcount 1; the caller is responsible for releasing the + /// reference to the shared original. + fn copy_on_write_block( + &mut self, + layer_idx: usize, + src_block_id: PagedBlockId, + ) -> Result { + let meta = self.pool_meta[layer_idx].ok_or_else(|| { + format!( + "PagedBlockPool::copy_on_write_block: layer {layer_idx} has no pool tensors to copy from" + ) + })?; + let src_row = *self.block_rows[layer_idx] + .get(&src_block_id) + .ok_or_else(|| { + format!( + "PagedBlockPool::copy_on_write_block: shared block {src_block_id} on layer {layer_idx} has no pool row (was it written?)" + ) + })? as i32; + let block_size = self.layout.block_size as i32; + + // Slice the source row's [block_size, H, D] slab out of K and V (the + // bare layout-A slab write_block accepts via its 3D convention). + let slab = |pool: &MlxArray| -> UniquePtr { + let row = ffi::slice( + pool, + &[src_row, 0, 0, 0], + &[src_row + 1, block_size, meta.n_kv_heads, meta.head_dim], + ); + ffi::reshape(&row, &[block_size, meta.n_kv_heads, meta.head_dim]) + }; + let k_slab = { + let pool_k = self.pool_k[layer_idx] + .as_ref() + .expect("pool_k present when pool_meta present"); + slab(pool_k) + }; + let v_slab = { + let pool_v = self.pool_v[layer_idx] + .as_ref() + .expect("pool_v present when pool_meta present"); + slab(pool_v) + }; + + let fresh = self.acquire_block(layer_idx)?; + self.write_block(fresh, layer_idx, 0, &k_slab, &v_slab)?; + // Drop this sequence's reference to the shared original; the sibling + // that still owns it keeps it alive. + self.release_block(src_block_id)?; + Ok(fresh) + } + /// Gather the visible K/V window for one layer of a sequence into the /// SDPA-ready shape `[1, n_kv_heads, visible_len, head_dim]`. /// diff --git a/src/lib/mlxcel-core/src/cache/paged_detach_tests.rs b/src/lib/mlxcel-core/src/cache/paged_detach_tests.rs index e316c04..86953fe 100644 --- a/src/lib/mlxcel-core/src/cache/paged_detach_tests.rs +++ b/src/lib/mlxcel-core/src/cache/paged_detach_tests.rs @@ -104,6 +104,80 @@ fn default_layout() -> PagedKvLayout { PagedKvLayout::uniform(2, 4, 128).unwrap() } +/// Drive `PagedBlockPool::write_prefill` for an active sequence at the +/// `CachePool` level. This is the split-borrow `append_paged_tokens` already +/// uses internally; it is replicated here (the test module is a descendant of +/// `cache.rs`, so it may touch the private `paged_pool` / `active` fields) +/// because the live forward/scheduler wiring of `write_prefill` is #121 and is +/// out of scope for #120. `k`/`v` are `[1, n_kv_heads, n_new, head_dim]`. +fn write_prefill_for( + pool: &mut CachePool, + id: SequenceId, + layer_idx: usize, + k: &MlxArray, + v: &MlxArray, +) -> Result<(), String> { + let block_pool = pool + .paged_pool + .as_mut() + .ok_or_else(|| "paged backend not initialized".to_string())?; + let state = pool + .active + .get_mut(&id) + .ok_or_else(|| format!("sequence {id} not found"))? + .paged_state_mut() + .ok_or_else(|| format!("sequence {id} is not paged"))?; + block_pool.write_prefill(state, layer_idx, k, v) +} + +/// Deterministic `[1, H, n_tokens, D]` FP32 prefill block whose per-token +/// values are distinct (encodes head/token/dim), so misplacement is caught +/// exactly and FP32 round-trips bit-for-bit. Mirrors `paged_pool_tests::make_block`. +fn prefill_block(base: f32, n_kv_heads: i32, n_tokens: i32, head_dim: i32) -> UniquePtr { + let mut values = Vec::with_capacity((n_kv_heads * n_tokens * head_dim) as usize); + for head in 0..n_kv_heads { + for tok in 0..n_tokens { + for dim in 0..head_dim { + values.push(base + head as f32 * 1000.0 + tok as f32 * 10.0 + dim as f32 * 0.1); + } + } + } + crate::ffi::from_slice_f32(&values, &[1, n_kv_heads, n_tokens, head_dim]) +} + +/// Flatten a tensor to a row-major `Vec` (after an FP32 cast) for content +/// comparison. Mirrors `paged_pool_tests::flatten_fp32`. +fn flatten_fp32(arr: &MlxArray) -> Vec { + let a = crate::ffi::astype(arr, crate::dtype::FLOAT32); + crate::ffi::eval(&a); + crate::ffi::array_to_raw_bytes(&a) + .chunks_exact(4) + .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) + .collect() +} + +/// Build the dense contiguous reference `[1, H, total, D]` by writing each +/// `[1, H, n, D]` block at its token offset. Mirrors +/// `paged_pool_tests::dense_reference` (no trimming; full visible window). +fn dense_reference( + blocks: &[(UniquePtr, usize)], + n_kv_heads: i32, + total: i32, + head_dim: i32, +) -> UniquePtr { + let mut dense = crate::ffi::zeros(&[1, n_kv_heads, total, head_dim], crate::dtype::FLOAT32); + for (block, offset) in blocks { + let n = crate::ffi::array_shape(block)[2]; + dense = crate::ffi::slice_update( + &dense, + block, + &[0, 0, *offset as i32, 0], + &[1, n_kv_heads, *offset as i32 + n, head_dim], + ); + } + dense +} + // --------------------------------------------------------------------------- // 1. PagedBlockPool refcount plumbing // --------------------------------------------------------------------------- @@ -750,6 +824,141 @@ fn release_detached_paged_is_safe_after_take_no_double_release() { assert_eq!(pool.parked_count(), 0); } +// --------------------------------------------------------------------------- +// 9. write_prefill (#120) over the real detach/adopt machinery: a shared-prefix +// request allocates blocks ONLY for its suffix (acceptance criterion 2), +// verified via PagedCacheStats, and the adopted sequence reads back its own +// prefix + suffix correctly. +// +// The simultaneous two-live-sharers copy-on-write proof lives in +// `paged_pool_tests::write_prefill_cow_forks_shared_partial_tail_block`: +// `CachePool::adopt_paged` is move-semantics (it consumes the detached set and +// transfers the block pins), so it cannot stand up two live sequences whose +// block tables alias the same physical prefix blocks at the same instant — that +// aliasing is produced with `retain_block` at the pool level. Here we exercise +// the real detach -> park -> adopt path and prove the block-accounting half of +// the acceptance criteria with `PagedCacheStats`. +// --------------------------------------------------------------------------- + +#[test] +fn write_prefill_after_shared_prefix_adopt_allocates_only_suffix_blocks() { + let n_kv_heads = 2i32; + let head_dim = 3i32; + let block_size = 4usize; + // Single-layer paged layout so the block accounting is unambiguous. + let layout = PagedKvLayout::uniform( + 1, + block_size, + block_size * n_kv_heads as usize * head_dim as usize * 2, + ) + .unwrap(); + let model = PagedStubModel::new(layout.clone()); + let mut pool = CachePool::new(4); + + // --- Build an 8-token (block-aligned: 2 whole blocks) prefix on a seed. --- + let seed = pool.allocate(&model).unwrap(); + let prefix_len = 8i32; + let prefix_k = prefill_block(1000.0, n_kv_heads, prefix_len, head_dim); + let prefix_v = prefill_block(5000.0, n_kv_heads, prefix_len, head_dim); + write_prefill_for(&mut pool, seed, 0, &prefix_k, &prefix_v).unwrap(); + let prefix_blocks = pool + .get_paged_state(seed) + .unwrap() + .layer(0) + .unwrap() + .block_ids + .clone(); + assert_eq!(prefix_blocks.len(), 2, "8 tokens => 2 whole prefix blocks"); + + // --- Detach + park + adopt: the consumer inherits the prefix block table + // (the real shared-prefix path). --- + let detached = pool.detach_paged(seed).unwrap(); + let handle = pool.park_detached_paged(detached); + let consumer = pool.adopt_parked_paged(&model, handle).unwrap(); + assert_eq!( + pool.get_paged_state(consumer) + .unwrap() + .layer(0) + .unwrap() + .block_ids, + prefix_blocks, + "adopted consumer must reuse the prefix blocks, not fresh copies" + ); + + // Live blocks right after adoption: exactly the 2 prefix blocks. + let live_before = pool.paged_stats().unwrap().live_blocks; + assert_eq!(live_before, 2, "only the 2 shared prefix blocks are live"); + + // --- write_prefill a 6-token suffix (positions [8, 14)). 8 is block- + // aligned, so this allocates 2 fresh suffix blocks and touches no + // prefix block. --- + let suffix_len = 6i32; + let suffix_k = prefill_block(2000.0, n_kv_heads, suffix_len, head_dim); + let suffix_v = prefill_block(6000.0, n_kv_heads, suffix_len, head_dim); + write_prefill_for(&mut pool, consumer, 0, &suffix_k, &suffix_v).unwrap(); + + let live_after = pool.paged_stats().unwrap().live_blocks; + // Suffix is 6 tokens over 2 blocks; prefix blocks are untouched. So live + // grows by EXACTLY the suffix block count, NOT by another copy of the + // prefix. This is the "allocates blocks only for its suffix" criterion. + assert_eq!( + live_after - live_before, + 2, + "shared-prefix request must allocate blocks only for its suffix" + ); + let suffix_blocks = pool + .get_paged_state(consumer) + .unwrap() + .layer(0) + .unwrap() + .block_ids + .clone(); + assert_eq!(suffix_blocks.len(), 4, "2 prefix + 2 suffix blocks"); + // The prefix blocks are still the SAME physical blocks (no reallocation). + assert_eq!(&suffix_blocks[..2], &prefix_blocks[..]); + + // --- Correctness: the consumer gathers prefix [0,8) + suffix [8,14). --- + let total = prefix_len + suffix_len; // 14 + let state = pool.get_paged_state(consumer).unwrap(); + let (gk, gv) = pool + .paged_pool_ref() + .unwrap() + .gather_visible(state, 0) + .unwrap() + .expect("gather must return data"); + assert_eq!( + crate::ffi::array_shape(&gk), + vec![1, n_kv_heads, total, head_dim] + ); + + let dense_k = dense_reference( + &[ + (prefill_block(1000.0, n_kv_heads, prefix_len, head_dim), 0), + ( + prefill_block(2000.0, n_kv_heads, suffix_len, head_dim), + prefix_len as usize, + ), + ], + n_kv_heads, + total, + head_dim, + ); + let dense_v = dense_reference( + &[ + (prefill_block(5000.0, n_kv_heads, prefix_len, head_dim), 0), + ( + prefill_block(6000.0, n_kv_heads, suffix_len, head_dim), + prefix_len as usize, + ), + ], + n_kv_heads, + total, + head_dim, + ); + assert_eq!(flatten_fp32(&gk), flatten_fp32(&dense_k)); + assert_eq!(flatten_fp32(&gv), flatten_fp32(&dense_v)); +} + /// Free reference to DetachedPagedCacheSet to keep the import alive. #[allow(dead_code)] fn _type_alive() -> Option { diff --git a/src/lib/mlxcel-core/src/cache/paged_pool_tests.rs b/src/lib/mlxcel-core/src/cache/paged_pool_tests.rs index 2a21314..285ae68 100644 --- a/src/lib/mlxcel-core/src/cache/paged_pool_tests.rs +++ b/src/lib/mlxcel-core/src/cache/paged_pool_tests.rs @@ -26,6 +26,13 @@ //! 7. `release_block` to refcount 0 frees and recycles a row with fresh data. //! 8. `pool_tensor_bytes` reflects allocated pool tensors. //! 9. Turbo4 sidecars coexist with main-K/V rows. +//! +//! Plus the bulk prefill writer added in #120 (Phase 3): +//! +//! - `write_prefill` cold round-trip is byte-identical to a dense prefill. +//! - `write_prefill` of a block-aligned prompt round-trips. +//! - `write_prefill` copy-on-write forks a shared partial tail block so two +//! sequences' suffixes never corrupt each other or the shared prefix. use super::paged::{PagedBlockId, PagedBlockPool, PagedKvLayout, PagedSequenceState}; use super::KVCacheMode; @@ -671,3 +678,206 @@ fn write_rejects_unknown_block_and_oob_slot() { let err = pool.write_block(id, 0, 3, &two, &two).unwrap_err(); assert!(err.contains("out of bounds"), "{err}"); } + +// --------------------------------------------------------------------------- +// 11. write_prefill (#120): cold bulk write -> gather is byte-identical to dense +// --------------------------------------------------------------------------- + +#[test] +fn write_prefill_cold_round_trip_is_byte_identical_to_dense() { + // T = 10 with block_size 4 => 2 full blocks + a half-full third block, so + // the bulk write must chunk across a partial last block. + let block_size = 4usize; + let total = 10i32; + let mut pool = fp16_pool(block_size, 1); + let mut state = PagedSequenceState::new(pool.layout()); + + // A single [1, H, T, D] prefill whose per-token values are distinct so any + // mis-slotting is caught exactly (make_block encodes slot index in value). + let k_prefill = make_block(100.0, total); + let v_prefill = make_block(500.0, total); + + pool.write_prefill(&mut state, 0, &k_prefill, &v_prefill) + .unwrap(); + + // The bulk write must have allocated ceil(10/4) = 3 blocks and advanced len. + assert_eq!(state.layer(0).unwrap().block_ids.len(), 3); + assert_eq!(state.layer(0).unwrap().len, total as usize); + assert_eq!(state.layer(0).unwrap().visible_len(), total as usize); + + let (gk, gv) = pool + .gather_visible(&state, 0) + .unwrap() + .expect("gather must return data"); + assert_eq!(ffi::array_shape(&gk), vec![1, H, total, D]); + + // Dense reference: the same [1, H, T, D] buffer written at offset 0. + let dense_k = dense_reference(&[(make_block(100.0, total), 0)], total, 0, total); + let dense_v = dense_reference(&[(make_block(500.0, total), 0)], total, 0, total); + assert_eq!(flatten_fp32(&gk), flatten_fp32(&dense_k)); + assert_eq!(flatten_fp32(&gv), flatten_fp32(&dense_v)); +} + +#[test] +fn write_prefill_block_aligned_prompt_round_trips() { + // Block-aligned T (== 2 * block_size) takes the no-COW fresh-block path. + let block_size = 4usize; + let total = 8i32; + let mut pool = fp16_pool(block_size, 1); + let mut state = PagedSequenceState::new(pool.layout()); + + let k_prefill = make_block(11.0, total); + let v_prefill = make_block(77.0, total); + pool.write_prefill(&mut state, 0, &k_prefill, &v_prefill) + .unwrap(); + assert_eq!(state.layer(0).unwrap().block_ids.len(), 2); + + let (gk, gv) = pool + .gather_visible(&state, 0) + .unwrap() + .expect("gather must return data"); + let dense_k = dense_reference(&[(make_block(11.0, total), 0)], total, 0, total); + let dense_v = dense_reference(&[(make_block(77.0, total), 0)], total, 0, total); + assert_eq!(flatten_fp32(&gk), flatten_fp32(&dense_k)); + assert_eq!(flatten_fp32(&gv), flatten_fp32(&dense_v)); +} + +// --------------------------------------------------------------------------- +// 12. write_prefill copy-on-write: a shared partial tail block is forked so two +// sequences' suffixes never corrupt each other or the shared prefix. +// +// This is the PagedBlockPool-level COW proof. Two sequence states share the +// same block table (the second adopts the first's block_ids with a refcount +// bump via retain_block, exactly as CachePool::adopt_paged does), where the +// last shared block is PARTIALLY filled (the prefix ends mid-block). Each +// sequence then write_prefills a DIFFERENT suffix; the partial tail block is +// refcount==2 at suffix time, so write_prefill must copy-on-write it for the +// writer rather than mutating the block the sibling still references. +// --------------------------------------------------------------------------- + +#[test] +fn write_prefill_cow_forks_shared_partial_tail_block() { + let block_size = 4usize; + let prefix_len = 6i32; // 2 blocks; second block half-full (slots 0,1). + let mut pool = fp16_pool(block_size, 1); + + // --- Build the shared prefix on sequence A. --- + let mut state_a = PagedSequenceState::new(pool.layout()); + let prefix_k = make_block(1000.0, prefix_len); + let prefix_v = make_block(5000.0, prefix_len); + pool.write_prefill(&mut state_a, 0, &prefix_k, &prefix_v) + .unwrap(); + let prefix_blocks = state_a.layer(0).unwrap().block_ids.clone(); + assert_eq!(prefix_blocks.len(), 2); + let tail_block = prefix_blocks[1]; + assert_eq!(pool.refcount(tail_block), 1); + + // --- Sequence B adopts the SAME block table, pinning every prefix block + // (this is what CachePool::adopt_paged does for a shared prefix). --- + let mut state_b = PagedSequenceState::new(pool.layout()); + { + let layer_b = state_b.layer_mut(0).unwrap(); + layer_b.block_ids = prefix_blocks.clone(); + layer_b.len = prefix_len as usize; + layer_b.logical_start = 0; + } + for &id in &prefix_blocks { + pool.retain_block(id).unwrap(); + } + // The partial tail block is now shared by both sequences. + assert_eq!(pool.refcount(tail_block), 2); + + // --- Each sequence writes a DIFFERENT 5-token suffix (positions [6, 11)). + // The first suffix block is the shared partial tail (slots 2,3), which + // must be copy-on-written for each writer. --- + let suffix_a_k = make_block(2000.0, 5); + let suffix_a_v = make_block(6000.0, 5); + pool.write_prefill(&mut state_a, 0, &suffix_a_k, &suffix_a_v) + .unwrap(); + + let suffix_b_k = make_block(3000.0, 5); + let suffix_b_v = make_block(7000.0, 5); + pool.write_prefill(&mut state_b, 0, &suffix_b_k, &suffix_b_v) + .unwrap(); + + // --- COW accounting. A wrote first: at that moment the tail was shared + // (refcount 2), so A copy-on-wrote it to a fresh block and released its + // reference to the original (2 -> 1). B wrote second: by then B was the + // SOLE owner of the original tail (refcount 1), so the in-place write is + // safe and no fork is needed — B keeps the original block. This is + // exactly the refcount-driven COW contract: copy only while shared. --- + let a_tail = state_a.layer(0).unwrap().block_ids[1]; + let b_tail = state_b.layer(0).unwrap().block_ids[1]; + assert_ne!( + a_tail, tail_block, + "A wrote while the tail was shared, so it must have forked a fresh copy" + ); + assert_eq!( + b_tail, tail_block, + "B wrote while it was the sole owner of the tail, so it keeps the block in place" + ); + assert_ne!(a_tail, b_tail, "A and B must hold independent tail blocks"); + assert_eq!( + pool.refcount(tail_block), + 1, + "the original tail is now solely owned by B" + ); + assert_eq!( + pool.refcount(a_tail), + 1, + "A's forked tail copy is solely owned by A" + ); + // The first shared block (fully inside the prefix) is never written, so it + // is still shared by both sequences. + assert_eq!(pool.refcount(prefix_blocks[0]), 2); + + // --- Correctness: each sequence gathers its OWN shared-prefix + own suffix, + // proving the two suffixes did not corrupt each other or the prefix. --- + // Sequence A: prefix tokens [0,6) (base 1000) then suffix tokens [6,11) + // (base 2000, i.e. dense-token index 6 carries suffix slot 0). + let total = 11i32; + let (gk_a, gv_a) = pool.gather_visible(&state_a, 0).unwrap().expect("gather A"); + let dense_a_k = dense_reference( + &[ + (make_block(1000.0, prefix_len), 0), + (make_block(2000.0, 5), prefix_len as usize), + ], + total, + 0, + total, + ); + let dense_a_v = dense_reference( + &[ + (make_block(5000.0, prefix_len), 0), + (make_block(6000.0, 5), prefix_len as usize), + ], + total, + 0, + total, + ); + assert_eq!(flatten_fp32(&gk_a), flatten_fp32(&dense_a_k)); + assert_eq!(flatten_fp32(&gv_a), flatten_fp32(&dense_a_v)); + + // Sequence B: same prefix, DIFFERENT suffix (base 3000 / 7000). + let (gk_b, gv_b) = pool.gather_visible(&state_b, 0).unwrap().expect("gather B"); + let dense_b_k = dense_reference( + &[ + (make_block(1000.0, prefix_len), 0), + (make_block(3000.0, 5), prefix_len as usize), + ], + total, + 0, + total, + ); + let dense_b_v = dense_reference( + &[ + (make_block(5000.0, prefix_len), 0), + (make_block(7000.0, 5), prefix_len as usize), + ], + total, + 0, + total, + ); + assert_eq!(flatten_fp32(&gk_b), flatten_fp32(&dense_b_k)); + assert_eq!(flatten_fp32(&gv_b), flatten_fp32(&dense_b_v)); +} diff --git a/src/lib/mlxcel-core/src/ffi_tests.rs b/src/lib/mlxcel-core/src/ffi_tests.rs index cb00df0..0314cd4 100644 --- a/src/lib/mlxcel-core/src/ffi_tests.rs +++ b/src/lib/mlxcel-core/src/ffi_tests.rs @@ -743,6 +743,172 @@ fn test_pooled_paged_decode_batch_of_two() { ); } +/// End-to-end prefill -> decode parity (#120 acceptance criterion 1). +/// +/// Writes a T-token prompt into the pool via the BULK `write_prefill` writer +/// (the #120 path) and into a parallel dense `[1, H, T, D]` buffer, then runs N +/// decode steps appending one fresh token to each store (pool via `write_block`, +/// dense via concat) and compares the pooled vs dense fallback attention each +/// step. This proves prefill-write + decode-read is end-to-end correct: the +/// bulk prefill must store K/V byte-identically to the dense prefill so the +/// gather-then-SDPA decode matches the dense-slice decode. The pooled sequence +/// is forced onto NON-CONTIGUOUS physical rows by a spacer so the gather +/// genuinely reorders fragmented blocks. +#[test] +fn test_pooled_prefill_then_decode_matches_dense() { + use crate::cache::{PagedBlockPool, PagedSequenceState}; + + const PROMPT: i32 = 13; // non-block-aligned prompt (spans 4 blocks @ bs 4) + const STEPS: usize = 24; // >= 20 decode steps + let n_kv_heads: i32 = 4; + let head_dim: i32 = 8; + let block_size = 4usize; + let layer_idx = 0usize; + let scale = 1.0f32 / (head_dim as f32).sqrt(); + + let mut pool = PagedBlockPool::new(pooled_layout(block_size, 1, n_kv_heads, head_dim)); + let mut target = PagedSequenceState::new(pool.layout()); + // Spacer claims a physical row up front so the target's prefill blocks are + // not a dense ascending run. + let mut spacer = PagedSequenceState::new(pool.layout()); + pool.append_tokens(&mut spacer, layer_idx, block_size) + .unwrap(); + { + let spacer_block = *spacer.layer(layer_idx).unwrap().block_ids.last().unwrap(); + let sk = from_slice_f32( + &pooled_pseudo_f32(0xABCD, (n_kv_heads * block_size as i32 * head_dim) as usize), + &[1, n_kv_heads, block_size as i32, head_dim], + ); + let sv = from_slice_f32( + &pooled_pseudo_f32(0xDCBA, (n_kv_heads * block_size as i32 * head_dim) as usize), + &[1, n_kv_heads, block_size as i32, head_dim], + ); + pool.write_block(spacer_block, layer_idx, 0, &sk, &sv) + .unwrap(); + } + + // Build the prompt as one [1, H, PROMPT, D] prefill tensor (and an identical + // dense buffer) by concatenating per-token K/V along axis 2. + let mut dense_k: Option> = None; + let mut dense_v: Option> = None; + for t in 0..PROMPT { + let kt = from_slice_f32( + &pooled_pseudo_f32(t as u64 + 7, (n_kv_heads * head_dim) as usize), + &[1, n_kv_heads, 1, head_dim], + ); + let vt = from_slice_f32( + &pooled_pseudo_f32( + (t as u64 + 7).wrapping_mul(5), + (n_kv_heads * head_dim) as usize, + ), + &[1, n_kv_heads, 1, head_dim], + ); + dense_k = Some(match dense_k.take() { + None => kt, + Some(prev) => concatenate(&prev, &kt, 2), + }); + dense_v = Some(match dense_v.take() { + None => vt, + Some(prev) => concatenate(&prev, &vt, 2), + }); + } + assert_eq!( + array_shape(dense_k.as_ref().unwrap()), + vec![1, n_kv_heads, PROMPT, head_dim] + ); + + // BULK prefill write into the pool (the #120 path). The dense buffers double + // as the prefill input here — passed by reference, they keep growing as the + // dense decode reference below. + pool.write_prefill( + &mut target, + layer_idx, + dense_k.as_ref().unwrap(), + dense_v.as_ref().unwrap(), + ) + .unwrap(); + assert_eq!(target.layer(layer_idx).unwrap().len, PROMPT as usize); + // Prefill spanned ceil(13/4)=4 blocks; the spacer holds row 0, so the + // target's physical rows start at 1 (fragmented relative to a 0-based run). + assert_eq!(target.layer(layer_idx).unwrap().block_ids.len(), 4); + + let mut max_rms = 0.0f32; + + for step in 0..STEPS { + let t = PROMPT + step as i32; // absolute token index being appended + let visible_len = t + 1; + + // Append one fresh decode token to the pooled target. + pool.append_tokens(&mut target, layer_idx, 1).unwrap(); + let block_ids = target.layer(layer_idx).unwrap().block_ids.clone(); + let slot = (t as usize) % block_size; + let block_index = (t as usize) / block_size; + let k_tok = from_slice_f32( + &pooled_pseudo_f32(t as u64 + 1000, (n_kv_heads * head_dim) as usize), + &[1, n_kv_heads, 1, head_dim], + ); + let v_tok = from_slice_f32( + &pooled_pseudo_f32( + (t as u64 + 1000).wrapping_mul(3), + (n_kv_heads * head_dim) as usize, + ), + &[1, n_kv_heads, 1, head_dim], + ); + pool.write_block(block_ids[block_index], layer_idx, slot, &k_tok, &v_tok) + .unwrap(); + + // Grow the dense reference identically. + dense_k = Some(concatenate(dense_k.as_ref().unwrap(), &k_tok, 2)); + dense_v = Some(concatenate(dense_v.as_ref().unwrap(), &v_tok, 2)); + let dk = dense_k.as_ref().unwrap(); + let dv = dense_v.as_ref().unwrap(); + assert_eq!(array_shape(dk), vec![1, n_kv_heads, visible_len, head_dim]); + + // Fresh query for this step. + let q = from_slice_f32( + &pooled_pseudo_f32( + (step as u64).wrapping_mul(0x1000_0001) + 99, + (n_kv_heads * head_dim) as usize, + ), + &[1, n_kv_heads, 1, head_dim], + ); + + let states: [&PagedSequenceState; 1] = [&target]; + let pooled_out = crate::layers::paged_decode_attention_pooled_fallback( + &q, &pool, &states, layer_idx, scale, + ) + .unwrap(); + + let metadata = crate::cache::PagedDecodeMetadata::from_visible_lengths( + &[visible_len], + block_size as i32, + ) + .unwrap(); + let cache_keys = vec![dk.as_ref().unwrap() as *const MlxArray]; + let cache_values = vec![dv.as_ref().unwrap() as *const MlxArray]; + let dense_out = crate::layers::paged_decode_attention_dense_fallback( + &q, + &cache_keys, + &cache_values, + &metadata, + scale, + ) + .unwrap(); + + let rms = pooled_rms( + &flatten_f32_local(&pooled_out), + &flatten_f32_local(&dense_out), + ); + assert!( + rms < 5e-3, + "prefill->decode step {step} (abs token {t}): pooled vs dense RMS {rms} exceeded 5e-3" + ); + max_rms = max_rms.max(rms); + } + + println!("test_pooled_prefill_then_decode_matches_dense: max RMS = {max_rms:e}"); +} + /// Flatten any tensor to a row-major `Vec` (after an FP32 cast). Local to /// these pooled tests; mirrors the pool-test `flatten_fp32` helper. fn flatten_f32_local(arr: &MlxArray) -> Vec {