Skip to content

feat: exact-prefix snapshot cache for hybrid-SSM / linear-attention multi-turn reuse #140

@inureyes

Description

@inureyes

Summary

The unified KV-cache epic #116 merges the dense radix prompt cache and the paged block pool into one refcounted, copy-on-write paged store. Its sub-issue #124 deliberately keeps the hybrid-SSM and linear-attention families excluded from block sharing, because their recurrent hidden state cannot be reconstructed from a token-prefix hash: sharing KV blocks alone would restore the attention layers and leave the recurrent layers stale. That carve-out is correct, but it leaves those families doing a full cold prefill on every turn, even in a long multi-turn chat that reuses the same prefix.

This issue adds an orthogonal, opt-in exact-prefix snapshot cache for exactly those families. Instead of sharing KV blocks, it snapshots the full recurrent state (conv_state plus ssm_state, and the gated-delta / linear-attention state) at turn end and restores it on an exact token-prefix match. It does not touch block sharing, so the #124 exclusion stays in force and stays correct.

Motivation

The families currently excluded from prefix reuse are the APC opt-out list in src/server/prompt_cache/hybrid_ssm.rs: jamba, mamba, mamba2, nemotron_h, gated_delta, kimi_linear, qwen3_next, falcon_mamba, longcat_flash, rwkv7, recurrent_gemma. They all advertise supports_batching() == false (for example src/models/mamba.rs:676, mamba2.rs:913, rwkv7.rs:996, jamba.rs:1335, nemotron_h.rs:2748), which routes them to SequenceStateBackend::ModelOwned through the default sequence_state_layout() (src/lib/mlxcel-core/src/generate.rs:274). They are then rejected from the prompt-cache donate path at src/server/batch/scheduler.rs:851 (backend != DenseKvCache -> return). Net effect: cold prefill on every turn.

Snapshot and restore is cheap relative to recompute: the recurrent state is a small fixed-size tensor set per layer (conv_state plus ssm_state), far smaller than recomputing it over a long prefix. The model code already has rollback-snapshot machinery for speculative decoding (GdnRollbackSnapshot in src/models/qwen3_5.rs), which shows the state is serializable.

Design

Two properties keep this orthogonal to #116 and compatible with #124:

  1. No block sharing. The SSM and linear families stay excluded from the unified block-shared radix store (the Hybrid SSM exclusion and multimodal adopt policy under the unified cache #124 carve-out is unchanged). The snapshot cache is a separate exact-match bucket, and a snapshot is never shared across concurrent sequences: it is a full-state copy restored into one sequence at a time.
  2. Exact prefix only. Recurrent state has no longest-prefix structure (you cannot truncate an SSM state to an arbitrary earlier token), so the snapshot bucket is keyed by an exact full-prompt token hash, not by the radix trie.

Tasks

1. Snapshot/restore primitive (core, models)

  • Add snapshot() and restore() to the recurrent cache types: MambaCache (src/models/mamba.rs:137), Mamba2Cache (mamba2.rs:174), Rwkv7Cache (rwkv7.rs:78), the per-layer hybrid caches JambaLayerCache (jamba.rs:241) and NemotronLayerCache (nemotron_h.rs:375), and GatedDeltaCache used by Qwen3.5 / Qwen3-Next (src/models/gated_delta.rs, src/models/qwen3_5.rs). Serialize conv_state plus ssm_state (and the gated-delta recurrent state) per layer.
  • Round-trip unit test per cache type: snapshot at token N, restore, generate token N+1 versus the cold path, asserting bit-identical output for greedy decoding.

2. SSM-typed exact-match bucket (server, prompt_cache)

  • Add an SSM snapshot bucket to PromptCacheStore (src/server/prompt_cache/store.rs, beside entries at :55 and tries at :61) keyed by { model_id, lora, template, multimodal_digest, session_key, exact_prompt_token_hash }, reusing the existing key components in src/server/prompt_cache/key.rs.
  • Give it LRU and TTL config separate from the dense KV bucket (evictions_lru / evictions_ttl at store.rs:67-68): snapshots are small, so they can keep a larger count and a longer TTL.
  • Account snapshot byte size in the store capacity checks.

3. Decouple batching from adoption (server, models)

4. Hybrid composition (Jamba, Nemotron-H, Qwen3.5)

  • For attention-plus-SSM hybrids, snapshot only the recurrent layers; let the attention layers ride the [Epic] Unified KV cache: radix-tree-indexed paged attention with shared physical blocks #116 unified block-shared path once it lands, and compose the two halves. (Jamba JambaLayerCache::{Attention, Mamba} at jamba.rs:241; Nemotron-H NemotronLayerCache::{Attention, Mamba} at nemotron_h.rs:375; Qwen3.5 full-attention versus is_linear_layer split at qwen3_5.rs:174.)

5. Chunked checkpointing for long conversations (optional follow-up)

  • Periodic snapshot checkpoints every N tokens during decode; on prefix divergence, restore the nearest checkpoint and replay the diverging tail, bounded by a per-session checkpoint cap. Worth doing only after basic snapshot reuse is proven.

6. Snapshot compression (optional)

  • Compress snapshot payloads through the existing TurboQuant KV pipeline (src/lib/mlxcel-core/src/cache/turbo/, see docs/turbo-kv-cache.md) to allow a larger TTL and capacity. The public API is quantize_* / dequantize_*.

7. Tests, metrics, docs

  • Multi-turn E2E (extend tests/prompt_cache_e2e.rs) on a pure-SSM model (Mamba) and a hybrid (Jamba or Nemotron-H): snapshot hit on turn 2 with decode parity versus the cold path.
  • A per-path hit/miss counter for the snapshot bucket on the prompt-cache metrics endpoint (src/server/prompt_cache/metrics.rs).
  • Document the snapshot-cache path next to the unified-cache section in docs/turbo-kv-cache.md and docs/en/prompt_cache.md.

Acceptance criteria

  • A multi-turn chat on a Mamba model, and on a Jamba or Nemotron-H hybrid, reuses recurrent state on turn 2 instead of cold-prefilling, with greedy-decode output bit-identical to the cold path.
  • The hybrid-SSM and linear families stay excluded from block sharing: the Hybrid SSM exclusion and multimodal adopt policy under the unified cache #124 carve-out is unchanged and no cross-sequence or cross-modal state corruption is possible.
  • Snapshot bucket memory is bounded by its own LRU and TTL, and snapshot byte size is counted in store capacity.
  • A per-path snapshot hit/miss metric is exposed, and the docs describe the snapshot path.

Relationship to the unified-cache epic

This is complementary to #116 and #124, not a replacement. #116 gives block-shared reuse to the dense and paged-capable families; this issue gives snapshot-based reuse to the recurrent families that #124 deliberately leaves out of block sharing. Pure-SSM models (Mamba, Mamba2, RWKV-7) can be built independently of #116. The hybrid models (Jamba, Nemotron-H, Qwen3.5) compose best after #116 Phase 4 (#121) lands so their attention layers use block sharing.

Non-goals

Metadata

Metadata

Assignees

No one assigned

    Labels

    area:coremlxcel-core: MLX FFI, primitives, KV cache, layersarea:inferenceGeneration, sampling, decoding (incl. speculative, DRY)area:modelsModel architectures, weights, loading, metadatapriority:mediumMedium prioritytype:enhancementNew features, capabilities, or significant additions

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions