Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions src/lib/mlxcel-core/src/cache/paged.rs
Original file line number Diff line number Diff line change
Expand Up @@ -994,6 +994,203 @@ impl PagedBlockPool {
Ok(())
}

/// Write a whole prefill's worth of K/V for one layer into the pool.
///
/// `k_prefill` / `v_prefill` are `[1, n_kv_heads, n_new_tokens, head_dim]`
/// (the SDPA layout the attention path produces, identical to the primary
/// convention [`write_block`] accepts). The `n_new_tokens` tokens are
/// written starting at the sequence's current absolute tail (`layer.len`),
/// so this also serves the shared-prefix SUFFIX case: when called on a
/// sequence that already references a shared prefix, it appends only the new
/// (divergent) tokens.
///
/// Steps:
/// 1. Validate shapes/dtype/geometry against the layer's pool metadata.
/// 2. [`append_tokens`] to grow the logical length and allocate the trailing
/// blocks the new tokens need (these are always fresh, refcount 1).
/// 3. For every physical block the new tokens span, slice the matching
/// `[.., tok_start..tok_end, ..]` window out of `k_prefill`/`v_prefill`
/// and [`write_block`] it at the right `slot_start`. A partial first
/// block (the prefix ended mid-block) and a partial last block are both
/// handled by the per-block slot arithmetic.
///
/// ## Copy-on-write of a shared partial tail block
///
/// Block sharing in this pool is *not* always block-granular: a shared
/// prefix can end mid-block (e.g. a 6-token prefix with `block_size == 4`
/// leaves the second block half-full), and `CachePool::detach_paged`
/// captures that partially-filled tail block and shares it. Writing the
/// suffix's first tokens into that block would corrupt every sibling
/// sequence that still references it. So before writing into any spanned
/// block whose `refcount > 1`, this method performs a real copy-on-write:
/// it copies the block's current contents into a freshly acquired block,
/// repoints `state.block_ids[block_index]` at the copy, and drops this
/// sequence's reference to the shared original (leaving the sibling's
/// reference intact). Block-aligned prefixes start the suffix on a fresh
/// `append_tokens` block (refcount 1), so the COW check is a no-op there.
///
/// The stored bytes are byte-identical to the dense prefill path: each
/// token lands at the same absolute `[.., abs, ..]` slot a dense
/// `slice_update` from `[0, 0, prev, 0]` would target, so a later
/// [`gather_visible`] returns the same bytes as the equivalent dense buffer.
pub fn write_prefill(
&mut self,
state: &mut PagedSequenceState,
layer_idx: usize,
k_prefill: &MlxArray,
v_prefill: &MlxArray,
) -> Result<(), String> {
self.validate_layer(layer_idx)?;

// Both inputs must be the 4D SDPA layout [1, n_kv_heads, n_new, head_dim].
let k_shape = ffi::array_shape(k_prefill);
let v_shape = ffi::array_shape(v_prefill);
let (n_kv_heads, n_new, head_dim) = match k_shape.as_slice() {
[batch, n_kv_heads, n_new, head_dim] => {
if *batch != 1 {
return Err(format!(
"PagedBlockPool::write_prefill: k_prefill batch dim must be 1, got {batch} (shape {k_shape:?})"
));
}
(*n_kv_heads, *n_new, *head_dim)
}
_ => {
return Err(format!(
"PagedBlockPool::write_prefill: k_prefill must be [1, n_kv_heads, n_new_tokens, head_dim], got shape {k_shape:?}"
));
}
};
if v_shape != k_shape {
return Err(format!(
"PagedBlockPool::write_prefill: k_prefill shape {k_shape:?} does not match v_prefill shape {v_shape:?}"
));
}
let k_dtype = ffi::array_dtype(k_prefill);
let v_dtype = ffi::array_dtype(v_prefill);
if k_dtype != v_dtype {
return Err(format!(
"PagedBlockPool::write_prefill: K dtype {k_dtype} does not match V dtype {v_dtype}"
));
}
if n_new <= 0 {
return Ok(());
}

let block_size = self.layout.block_size as i32;
// Absolute tail before the append — the first new token lands here.
let first_abs = {
let layer = state.layer(layer_idx).ok_or_else(|| {
format!(
"PagedBlockPool::write_prefill: layer {layer_idx} out of range for {} layers",
state.layers.len()
)
})?;
layer.len as i32
};
let last_abs = first_abs + n_new; // exclusive

// Allocate the trailing blocks the new tokens need (and bump len). The
// existing partial tail block, if any, is left untouched by append.
self.append_tokens(state, layer_idx, n_new as usize)?;

// Walk every block the new tokens span and write its slice.
let first_block = (first_abs / block_size) as usize;
let last_block = ((last_abs - 1) / block_size) as usize;
for block_index in first_block..=last_block {
let block_start_abs = block_index as i32 * block_size;
let slot_start = (first_abs.max(block_start_abs) - block_start_abs) as usize;
let abs_begin = first_abs.max(block_start_abs);
let abs_end = last_abs.min(block_start_abs + block_size);
let n_slots = abs_end - abs_begin;
if n_slots <= 0 {
continue;
}
let tok_start = abs_begin - first_abs;
let tok_end = tok_start + n_slots;

// Copy-on-write the target block if a sibling sequence still shares
// it (only ever the partial prefix tail; fresh append blocks have
// refcount 1 and skip this).
let block_id = state.layers[layer_idx].block_ids[block_index];
let target_id = if self.refcount(block_id) > 1 {
let fresh = self.copy_on_write_block(layer_idx, block_id)?;
state.layers[layer_idx].block_ids[block_index] = fresh;
fresh
} else {
block_id
};

// Slice the [1, H, n_slots, D] window out of the prefill tensors.
let starts = [0, 0, tok_start, 0];
let stops = [1, n_kv_heads, tok_end, head_dim];
let k_slice = ffi::slice(k_prefill, &starts, &stops);
let v_slice = ffi::slice(v_prefill, &starts, &stops);

self.write_block(target_id, layer_idx, slot_start, &k_slice, &v_slice)?;
}
Ok(())
}

/// Copy the current contents of `src_block_id` into a freshly acquired block
/// on the same layer and return the new block id. Used by [`write_prefill`]
/// to fork a shared partial tail block before mutating it (copy-on-write).
///
/// The source's full `[block_size, n_kv_heads, head_dim]` K and V slabs are
/// sliced out of the layer's pool tensors and written into the fresh block
/// at `slot_start = 0`, so the new block is a byte-identical copy (the
/// caller then overwrites only the divergent suffix slots). The new block
/// starts at refcount 1; the caller is responsible for releasing the
/// reference to the shared original.
fn copy_on_write_block(
&mut self,
layer_idx: usize,
src_block_id: PagedBlockId,
) -> Result<PagedBlockId, String> {
let meta = self.pool_meta[layer_idx].ok_or_else(|| {
format!(
"PagedBlockPool::copy_on_write_block: layer {layer_idx} has no pool tensors to copy from"
)
})?;
let src_row = *self.block_rows[layer_idx]
.get(&src_block_id)
.ok_or_else(|| {
format!(
"PagedBlockPool::copy_on_write_block: shared block {src_block_id} on layer {layer_idx} has no pool row (was it written?)"
)
})? as i32;
let block_size = self.layout.block_size as i32;

// Slice the source row's [block_size, H, D] slab out of K and V (the
// bare layout-A slab write_block accepts via its 3D convention).
let slab = |pool: &MlxArray| -> UniquePtr<MlxArray> {
let row = ffi::slice(
pool,
&[src_row, 0, 0, 0],
&[src_row + 1, block_size, meta.n_kv_heads, meta.head_dim],
);
ffi::reshape(&row, &[block_size, meta.n_kv_heads, meta.head_dim])
};
let k_slab = {
let pool_k = self.pool_k[layer_idx]
.as_ref()
.expect("pool_k present when pool_meta present");
slab(pool_k)
};
let v_slab = {
let pool_v = self.pool_v[layer_idx]
.as_ref()
.expect("pool_v present when pool_meta present");
slab(pool_v)
};

let fresh = self.acquire_block(layer_idx)?;
self.write_block(fresh, layer_idx, 0, &k_slab, &v_slab)?;
// Drop this sequence's reference to the shared original; the sibling
// that still owns it keeps it alive.
self.release_block(src_block_id)?;
Ok(fresh)
}

/// Gather the visible K/V window for one layer of a sequence into the
/// SDPA-ready shape `[1, n_kv_heads, visible_len, head_dim]`.
///
Expand Down
Loading