You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Merge mlxcel's two independent KV-cache subsystems, the server-side prompt-prefix radix cache and the core-side paged block pool, into one unified, refcounted, copy-on-write KV store in the style of vLLM PagedAttention and SGLang RadixAttention. End state: physical KV lives in a shared pool of fixed-size blocks; a radix tree indexes those blocks by token prefix; concurrent requests that share a prefix point at the same physical blocks (refcount, no copy); and a gathering attention path reads scattered blocks at decode time.
Motivation
Today prefix reuse and paged batching cannot run at the same time. BatchScheduler::try_adopt_cached_prefix returns None whenever the active decode backend is Paged (src/server/batch/scheduler.rs), and docs/en/prompt_cache.md documents this as a hard limitation ("Paged backend adopt/donate disabled ... Dense backend only"). A batching server therefore gets either cross-request prefix reuse (dense backend) or paged memory management, never both. With paged batching on, a multi-tenant server pays full prefill for every concurrent request that shares a system prompt and stores a duplicate copy of that prefix KV per sequence.
A unified store fixes both: a shared prefix is stored once and reused across concurrent requests without re-prefill, and block-granular allocation removes per-sequence contiguous-buffer over-reservation.
Apple Silicon note: unified memory removes the CPU/GPU swap-out motivation behind parts of vLLM's design, so the win here is memory sharing and prefill avoidance, not host offload. The main risk is decode-time gather cost on MLX without a fused kernel; Phase 0 measures this before committing to the fused-kernel work.
Current architecture
Dense KVCache (cache.rs, KVCache): pre-allocated contiguous per-sequence K/V buffers with slice_update appends. Default mode KVCacheMode::Fp16. Default storage for normal generation.
Paged block pool (cache/paged.rs: PagedBlockPool / PagedSequenceState / PagedKvLayout): refcounted physical-block allocator with free lists. In Fp16/Int8 modes it tracks block-table metadata only; the actual K/V tensors still live in dense KVCache placeholders (see the PagedBlockPool doc comment). Default block size 32.
Paged decode attention (paged_decode_attention_dense_compat in cpp/mlx_cxx_bridge.cpp + layers.rs): wired into llama3/llama4/qwen3/qwen3.5/gemma3 and their VLM wrappers via use_native_paged_kernel. It slices per-sequence dense buffers by a logical identity block table built from visible lengths (PagedDecodeMetadata::from_visible_lengths), concatenates the blocks, and calls fused SDPA. The pool's real physical block ids are not threaded into the kernel.
Backend selection (effective_decode_storage_backend, scheduler.rs): default Auto resolves to Paged when max_batch_size > 1 && supports_batching && supports_paged_decode_backend, otherwise Dense.
What already exists (foundation)
The hardest correctness primitive is already built. cache/paged_detach.rs (DetachedPagedCacheSet, CachePool::detach_paged / adopt_paged) implements refcount-pinned block sharing with automatic copy-on-write: detach pins every physical block via a refcount bump, adopt can hand the same blocks to a new sequence, and the first append_paged_tokens after sharing pulls a fresh block instead of mutating a shared one. Turbo4 quantization sidecars are already per-page (PagedTurboPageSidecars) and round-trip through detach/adopt. The FFI exposes take / gather / scatter / slice_update / concatenate, and custom Metal kernel infrastructure exists under src/lib/mlx-cpp/turbo/.
What is missing
Physical KV tensors are dense per-sequence, not stored in the shared block pool.
The decode kernel uses an identity block table over dense buffers; it does not gather scattered physical blocks by real block id.
The radix tree stores dense snapshots and is wired only to the dense backend; it does not index physical blocks, and prefix adopt is disabled under the paged backend.
Prefill is decode-only for paged; it writes to dense buffers.
No global block-budget admission/eviction tying the radix tree to the pool.
Phase 0 (#117) is complete and merged (PR #145). It added examples/page_gather_microbench.rs and docs/adr/0001-paged-attention-gather-vs-fused-kernel.md, which lock the decisions the later phases inherit:
Pool tensor layout: layout A[num_blocks, block_size, n_kv_heads, head_dim] (per-layer), ~2.1x faster on gather+SDPA than the head-split layout. Keep the default block size 32.
Attention strategy: (A) gather-then-SDPA (take + reshape + transpose + fused SDPA, existing FFI) for Phases 1-5. MLX fuses the gather into the SDPA read, so the overhead is ~15% at <=4k single-sequence context.
Fused Metal kernel (Phase 6: Fused Metal paged-attention kernel #123) is deferred to the long-context (>=16k) or batched-decode regime, where the gather overhead becomes material (~56% at 16k single-seq; ~48% at 1k under batch 4, growing to 2-3x). Batch amplifies the cost more than context length does.
Append discipline: pool writes must reassign the pool tensor so MLX donates the buffer (O(block) in place); the microbench's O(pool) slice_update numbers are the no-donation upper bound.
Definition of Done
A batching server runs prefix reuse and paged block management at the same time (the scheduler.rs Paged guard on try_adopt_cached_prefix is removed and replaced by paged adopt/donate).
Two concurrent requests sharing a system prompt store that prefix's KV once and neither re-prefills it.
Decode output matches the dense-backend reference within RMS < 5e-3 across the supported model families.
Hybrid SSM models keep the existing carve-out; correctness is unchanged for them.
Benchmarks quantify memory saved and any decode throughput delta vs the dense backend on Apple Silicon.
Non-goals (this epic)
CPU/host swap-out preemption (unified memory makes it low value).
Summary
Merge mlxcel's two independent KV-cache subsystems, the server-side prompt-prefix radix cache and the core-side paged block pool, into one unified, refcounted, copy-on-write KV store in the style of vLLM PagedAttention and SGLang RadixAttention. End state: physical KV lives in a shared pool of fixed-size blocks; a radix tree indexes those blocks by token prefix; concurrent requests that share a prefix point at the same physical blocks (refcount, no copy); and a gathering attention path reads scattered blocks at decode time.
Motivation
Today prefix reuse and paged batching cannot run at the same time.
BatchScheduler::try_adopt_cached_prefixreturnsNonewhenever the active decode backend isPaged(src/server/batch/scheduler.rs), anddocs/en/prompt_cache.mddocuments this as a hard limitation ("Paged backend adopt/donate disabled ... Dense backend only"). A batching server therefore gets either cross-request prefix reuse (dense backend) or paged memory management, never both. With paged batching on, a multi-tenant server pays full prefill for every concurrent request that shares a system prompt and stores a duplicate copy of that prefix KV per sequence.A unified store fixes both: a shared prefix is stored once and reused across concurrent requests without re-prefill, and block-granular allocation removes per-sequence contiguous-buffer over-reservation.
Apple Silicon note: unified memory removes the CPU/GPU swap-out motivation behind parts of vLLM's design, so the win here is memory sharing and prefill avoidance, not host offload. The main risk is decode-time gather cost on MLX without a fused kernel; Phase 0 measures this before committing to the fused-kernel work.
Current architecture
cache.rs,KVCache): pre-allocated contiguous per-sequence K/V buffers withslice_updateappends. Default modeKVCacheMode::Fp16. Default storage for normal generation.cache/paged.rs:PagedBlockPool/PagedSequenceState/PagedKvLayout): refcounted physical-block allocator with free lists. In Fp16/Int8 modes it tracks block-table metadata only; the actual K/V tensors still live in denseKVCacheplaceholders (see thePagedBlockPooldoc comment). Default block size 32.paged_decode_attention_dense_compatincpp/mlx_cxx_bridge.cpp+layers.rs): wired into llama3/llama4/qwen3/qwen3.5/gemma3 and their VLM wrappers viause_native_paged_kernel. It slices per-sequence dense buffers by a logical identity block table built from visible lengths (PagedDecodeMetadata::from_visible_lengths), concatenates the blocks, and calls fused SDPA. The pool's real physical block ids are not threaded into the kernel.src/server/prompt_cache/): path-compressed token-indexed radix trie (trie.rs) keyed by BLAKE3 digest, storing whole detached dense KV snapshots (DetachedCacheSet,cache/detach.rs). Adopt/donate at request boundaries. APC adds block-hash-chain matching on top.effective_decode_storage_backend,scheduler.rs): defaultAutoresolves toPagedwhenmax_batch_size > 1 && supports_batching && supports_paged_decode_backend, otherwiseDense.What already exists (foundation)
The hardest correctness primitive is already built.
cache/paged_detach.rs(DetachedPagedCacheSet,CachePool::detach_paged/adopt_paged) implements refcount-pinned block sharing with automatic copy-on-write: detach pins every physical block via a refcount bump, adopt can hand the same blocks to a new sequence, and the firstappend_paged_tokensafter sharing pulls a fresh block instead of mutating a shared one. Turbo4 quantization sidecars are already per-page (PagedTurboPageSidecars) and round-trip through detach/adopt. The FFI exposestake/gather/scatter/slice_update/concatenate, and custom Metal kernel infrastructure exists undersrc/lib/mlx-cpp/turbo/.What is missing
Phased plan
Phase 0 outcome (ADR 0001)
Phase 0 (#117) is complete and merged (PR #145). It added
examples/page_gather_microbench.rsanddocs/adr/0001-paged-attention-gather-vs-fused-kernel.md, which lock the decisions the later phases inherit:[num_blocks, block_size, n_kv_heads, head_dim](per-layer), ~2.1x faster on gather+SDPA than the head-split layout. Keep the default block size 32.take+reshape+transpose+ fused SDPA, existing FFI) for Phases 1-5. MLX fuses the gather into the SDPA read, so the overhead is ~15% at <=4k single-sequence context.slice_updatenumbers are the no-donation upper bound.Definition of Done
scheduler.rsPaged guard ontry_adopt_cached_prefixis removed and replaced by paged adopt/donate).Non-goals (this epic)