Skip to content

[Epic] Unified KV cache: radix-tree-indexed paged attention with shared physical blocks #116

@inureyes

Description

@inureyes

Summary

Merge mlxcel's two independent KV-cache subsystems, the server-side prompt-prefix radix cache and the core-side paged block pool, into one unified, refcounted, copy-on-write KV store in the style of vLLM PagedAttention and SGLang RadixAttention. End state: physical KV lives in a shared pool of fixed-size blocks; a radix tree indexes those blocks by token prefix; concurrent requests that share a prefix point at the same physical blocks (refcount, no copy); and a gathering attention path reads scattered blocks at decode time.

Motivation

Today prefix reuse and paged batching cannot run at the same time. BatchScheduler::try_adopt_cached_prefix returns None whenever the active decode backend is Paged (src/server/batch/scheduler.rs), and docs/en/prompt_cache.md documents this as a hard limitation ("Paged backend adopt/donate disabled ... Dense backend only"). A batching server therefore gets either cross-request prefix reuse (dense backend) or paged memory management, never both. With paged batching on, a multi-tenant server pays full prefill for every concurrent request that shares a system prompt and stores a duplicate copy of that prefix KV per sequence.

A unified store fixes both: a shared prefix is stored once and reused across concurrent requests without re-prefill, and block-granular allocation removes per-sequence contiguous-buffer over-reservation.

Apple Silicon note: unified memory removes the CPU/GPU swap-out motivation behind parts of vLLM's design, so the win here is memory sharing and prefill avoidance, not host offload. The main risk is decode-time gather cost on MLX without a fused kernel; Phase 0 measures this before committing to the fused-kernel work.

Current architecture

  1. Dense KVCache (cache.rs, KVCache): pre-allocated contiguous per-sequence K/V buffers with slice_update appends. Default mode KVCacheMode::Fp16. Default storage for normal generation.
  2. Paged block pool (cache/paged.rs: PagedBlockPool / PagedSequenceState / PagedKvLayout): refcounted physical-block allocator with free lists. In Fp16/Int8 modes it tracks block-table metadata only; the actual K/V tensors still live in dense KVCache placeholders (see the PagedBlockPool doc comment). Default block size 32.
  3. Paged decode attention (paged_decode_attention_dense_compat in cpp/mlx_cxx_bridge.cpp + layers.rs): wired into llama3/llama4/qwen3/qwen3.5/gemma3 and their VLM wrappers via use_native_paged_kernel. It slices per-sequence dense buffers by a logical identity block table built from visible lengths (PagedDecodeMetadata::from_visible_lengths), concatenates the blocks, and calls fused SDPA. The pool's real physical block ids are not threaded into the kernel.
  4. Radix prefix cache (src/server/prompt_cache/): path-compressed token-indexed radix trie (trie.rs) keyed by BLAKE3 digest, storing whole detached dense KV snapshots (DetachedCacheSet, cache/detach.rs). Adopt/donate at request boundaries. APC adds block-hash-chain matching on top.
  5. Backend selection (effective_decode_storage_backend, scheduler.rs): default Auto resolves to Paged when max_batch_size > 1 && supports_batching && supports_paged_decode_backend, otherwise Dense.

What already exists (foundation)

The hardest correctness primitive is already built. cache/paged_detach.rs (DetachedPagedCacheSet, CachePool::detach_paged / adopt_paged) implements refcount-pinned block sharing with automatic copy-on-write: detach pins every physical block via a refcount bump, adopt can hand the same blocks to a new sequence, and the first append_paged_tokens after sharing pulls a fresh block instead of mutating a shared one. Turbo4 quantization sidecars are already per-page (PagedTurboPageSidecars) and round-trip through detach/adopt. The FFI exposes take / gather / scatter / slice_update / concatenate, and custom Metal kernel infrastructure exists under src/lib/mlx-cpp/turbo/.

What is missing

  • Physical KV tensors are dense per-sequence, not stored in the shared block pool.
  • The decode kernel uses an identity block table over dense buffers; it does not gather scattered physical blocks by real block id.
  • The radix tree stores dense snapshots and is wired only to the dense backend; it does not index physical blocks, and prefix adopt is disabled under the paged backend.
  • Prefill is decode-only for paged; it writes to dense buffers.
  • No global block-budget admission/eviction tying the radix tree to the pool.

Phased plan

Phase 0 outcome (ADR 0001)

Phase 0 (#117) is complete and merged (PR #145). It added examples/page_gather_microbench.rs and docs/adr/0001-paged-attention-gather-vs-fused-kernel.md, which lock the decisions the later phases inherit:

  • Pool tensor layout: layout A [num_blocks, block_size, n_kv_heads, head_dim] (per-layer), ~2.1x faster on gather+SDPA than the head-split layout. Keep the default block size 32.
  • Attention strategy: (A) gather-then-SDPA (take + reshape + transpose + fused SDPA, existing FFI) for Phases 1-5. MLX fuses the gather into the SDPA read, so the overhead is ~15% at <=4k single-sequence context.
  • Fused Metal kernel (Phase 6: Fused Metal paged-attention kernel #123) is deferred to the long-context (>=16k) or batched-decode regime, where the gather overhead becomes material (~56% at 16k single-seq; ~48% at 1k under batch 4, growing to 2-3x). Batch amplifies the cost more than context length does.
  • Append discipline: pool writes must reassign the pool tensor so MLX donates the buffer (O(block) in place); the microbench's O(pool) slice_update numbers are the no-donation upper bound.

Definition of Done

  • A batching server runs prefix reuse and paged block management at the same time (the scheduler.rs Paged guard on try_adopt_cached_prefix is removed and replaced by paged adopt/donate).
  • Two concurrent requests sharing a system prompt store that prefix's KV once and neither re-prefills it.
  • Decode output matches the dense-backend reference within RMS < 5e-3 across the supported model families.
  • Hybrid SSM models keep the existing carve-out; correctness is unchanged for them.
  • Benchmarks quantify memory saved and any decode throughput delta vs the dense backend on Apple Silicon.

Non-goals (this epic)

  • CPU/host swap-out preemption (unified memory makes it low value).
  • Chunked prefill scheduling (possible follow-up).
  • Cross-process / multi-node cache sharing beyond the existing distributed serde path.

Metadata

Metadata

Assignees

Labels

area:architectureArchitecture and code structure changesarea:coremlxcel-core: MLX FFI, primitives, KV cache, layerspriority:highHigh prioritystatus:readyReady to be worked ontype:enhancementNew features, capabilities, or significant additions

Type

No type
No fields configured for issues without a type.

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions