-
Notifications
You must be signed in to change notification settings - Fork 1.9k
feat(lib): add Unlimiformer-style kNN attention retrieval helpers #830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
kayo09
wants to merge
3
commits into
supermemoryai:main
Choose a base branch
from
kayo09:feature/unlimiformer-knn-retrieval
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| import { describe, expect, test } from "bun:test" | ||
| import { | ||
| attentionScore, | ||
| attentionScores, | ||
| rankItemsByAttentionTopK, | ||
| topKAttentionKeys, | ||
| topKAttentionKeysMultiHead, | ||
| } from "./unlimiformer" | ||
|
|
||
| const unit = (a: number, b: number, c: number) => { | ||
| const v = [a, b, c] | ||
| const norm = Math.hypot(a, b, c) | ||
| return v.map((x) => x / norm) | ||
| } | ||
|
|
||
| describe("topKAttentionKeys", () => { | ||
| test("returns empty when k is zero or keys empty", () => { | ||
| const q = unit(1, 0, 0) | ||
| expect(topKAttentionKeys(q, [], 3)).toEqual([]) | ||
| expect(topKAttentionKeys(q, [unit(1, 0, 0)], 0)).toEqual([]) | ||
| }) | ||
|
|
||
| test("orders by dot product / cosine on unit vectors", () => { | ||
| const q = unit(1, 0, 0) | ||
| const k0 = unit(1, 0, 0) | ||
| const k1 = unit(0, 1, 0) | ||
| const k2 = unit(-1, 0, 0) | ||
| const keys = [k1, k2, k0] | ||
|
|
||
| const top = topKAttentionKeys(q, keys, 2) | ||
| expect(top.map((t) => t.index)).toEqual([2, 0]) | ||
| expect(top[0]?.score).toBeGreaterThan( | ||
| top[1]?.score ?? Number.NEGATIVE_INFINITY, | ||
| ) | ||
| }) | ||
|
|
||
| test("caps k at number of keys", () => { | ||
| const q = unit(1, 0, 0) | ||
| const keys = [unit(1, 0, 0), unit(0, 1, 0)] | ||
| const top = topKAttentionKeys(q, keys, 10) | ||
| expect(top).toHaveLength(2) | ||
| }) | ||
|
|
||
| test("skips keys whose dimension does not match the query (no throw)", () => { | ||
| const q = unit(1, 0, 0) | ||
| const keys = [unit(1, 0, 0), [1, 0], unit(0, 1, 0)] | ||
| const top = topKAttentionKeys(q, keys, 5) | ||
| expect(top.map((t) => t.index)).toEqual([0, 2]) | ||
| }) | ||
|
|
||
| test("returns empty when no key matches query dimension", () => { | ||
| const q = unit(1, 0, 0) | ||
| expect( | ||
| topKAttentionKeys( | ||
| q, | ||
| [ | ||
| [1, 0], | ||
| [0, 1], | ||
| ], | ||
| 3, | ||
| ), | ||
| ).toEqual([]) | ||
| }) | ||
|
|
||
| test("skips keys with NaN components without throwing", () => { | ||
| const q = unit(1, 0, 0) | ||
| const keys = [unit(1, 0, 0), [1, Number.NaN, 0], unit(0, 1, 0)] | ||
| const top = topKAttentionKeys(q, keys, 5) | ||
| expect(top.map((t) => t.index)).toEqual([0, 2]) | ||
| }) | ||
| }) | ||
|
|
||
| describe("attentionScore", () => { | ||
| test("returns NaN instead of throwing when a vector contains NaN", () => { | ||
| const k = unit(1, 0, 0) | ||
| expect(attentionScore([1, Number.NaN, 0], k)).toBeNaN() | ||
| expect(attentionScore(k, [1, Number.NaN, 0])).toBeNaN() | ||
| }) | ||
|
|
||
| test("returns NaN for non-number or non-finite components", () => { | ||
| const k = unit(1, 0, 0) | ||
| const stringSlot = [1, "x", 0] as unknown as number[] | ||
| expect(attentionScore(stringSlot, k)).toBeNaN() | ||
| expect(attentionScore([Number.POSITIVE_INFINITY, 0, 0], k)).toBeNaN() | ||
| }) | ||
| }) | ||
|
|
||
| describe("attentionScores", () => { | ||
| test("matches per-key attentionScore", () => { | ||
| const q = unit(1, 1, 0) | ||
| const keys = [unit(1, 0, 0), unit(0, 1, 0)] | ||
| const scores = attentionScores(q, keys) | ||
| expect(scores).toHaveLength(2) | ||
| }) | ||
|
|
||
| test("uses NaN when key dimension mismatches query", () => { | ||
| const q = unit(1, 0, 0) | ||
| const scores = attentionScores(q, [unit(1, 0, 0), [1, 0]]) | ||
| expect(scores).toHaveLength(2) | ||
| expect(Number.isFinite(scores[0] ?? Number.NaN)).toBe(true) | ||
| expect(scores[1]).toBeNaN() | ||
| }) | ||
| }) | ||
|
|
||
| describe("topKAttentionKeysMultiHead", () => { | ||
| test("runs independent top-k per query", () => { | ||
| const keys = [unit(1, 0, 0), unit(0, 1, 0), unit(0, 0, 1)] | ||
| const q0 = unit(1, 0, 0) | ||
| const q1 = unit(0, 1, 0) | ||
| const out = topKAttentionKeysMultiHead([q0, q1], keys, 1) | ||
| expect(out[0]?.[0]?.index).toBe(0) | ||
| expect(out[1]?.[0]?.index).toBe(1) | ||
| }) | ||
| }) | ||
|
|
||
| describe("rankItemsByAttentionTopK", () => { | ||
| test("maps back to original indices and skips bad embeddings", () => { | ||
| const items = [ | ||
| { id: "a", e: unit(1, 0, 0) }, | ||
| { id: "b", e: null as number[] | null }, | ||
| { id: "c", e: unit(0, 1, 0) }, | ||
| ] | ||
| const q = unit(0, 1, 0) | ||
| const ranked = rankItemsByAttentionTopK(q, items, (x) => x.e, 2) | ||
| expect(ranked[0]?.item.id).toBe("c") | ||
| expect(ranked[0]?.originalIndex).toBe(2) | ||
| }) | ||
|
|
||
| test("skips items whose embedding length does not match the query", () => { | ||
| const items = [ | ||
| { id: "wide", e: [0.1, 0.2, 0.3, 0.4] }, | ||
| { id: "ok", e: unit(0, 1, 0) }, | ||
| ] | ||
| const q = unit(0, 1, 0) | ||
| const ranked = rankItemsByAttentionTopK(q, items, (x) => x.e, 2) | ||
| expect(ranked).toHaveLength(1) | ||
| expect(ranked[0]?.item.id).toBe("ok") | ||
| }) | ||
|
|
||
| test("returns empty when query embedding is non-finite", () => { | ||
| const items = [{ id: "a", e: unit(1, 0, 0) }] | ||
| expect( | ||
| rankItemsByAttentionTopK([Number.NaN, 0, 0], items, (x) => x.e, 2), | ||
| ).toEqual([]) | ||
| }) | ||
|
|
||
| test("skips items with non-finite embeddings", () => { | ||
| const items = [ | ||
| { id: "bad", e: [1, Number.NaN, 0] }, | ||
| { id: "ok", e: unit(0, 1, 0) }, | ||
| ] | ||
| const q = unit(0, 1, 0) | ||
| const ranked = rankItemsByAttentionTopK(q, items, (x) => x.e, 2) | ||
| expect(ranked).toHaveLength(1) | ||
| expect(ranked[0]?.item.id).toBe("ok") | ||
| }) | ||
| }) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,149 @@ | ||
| /** | ||
| * Unlimiformer-style kNN retrieval (Bertsch et al., 2023). | ||
| * | ||
| * @see https://arxiv.org/abs/2305.01625 — cross-attention is approximated by retrieving | ||
| * the top-k keys under dot-product scores. In this codebase, embeddings are treated as | ||
| * key/query vectors; for L2-normalized vectors, dot product equals cosine similarity, | ||
| * matching the ranking used elsewhere in `@repo/lib/similarity`. | ||
| */ | ||
|
|
||
| import { cosineSimilarity } from "./similarity" | ||
|
|
||
| /** True when every entry is a finite number (empty arrays allowed). */ | ||
| const isFiniteEmbeddingVector = (v: number[]): boolean => | ||
| v.every((x) => typeof x === "number" && Number.isFinite(x)) | ||
|
|
||
| export type AttentionTopK = { | ||
| index: number | ||
| score: number | ||
| } | ||
|
|
||
| /** | ||
| * Dot-product attention score between one query vector and one key vector. | ||
| * For normalized embeddings this matches cosine similarity. | ||
| * | ||
| * Returns `NaN` when `query` and `key` have different lengths (e.g. mixed embedding | ||
| * models), or when either vector contains non-finite values (`NaN`, `±Infinity`), so | ||
| * callers avoid throwing from `cosineSimilarity`. | ||
| */ | ||
| export const attentionScore = (query: number[], key: number[]): number => { | ||
| if (query.length !== key.length) { | ||
| return Number.NaN | ||
| } | ||
| if (!isFiniteEmbeddingVector(query) || !isFiniteEmbeddingVector(key)) { | ||
| return Number.NaN | ||
| } | ||
| return cosineSimilarity(query, key) | ||
| } | ||
|
|
||
| /** | ||
| * Attention scores for `query` against every row in `keys`, aligned by index. | ||
| * Entries are `NaN` when a key length does not match the query, or when either vector | ||
| * has non-finite components. | ||
| */ | ||
| export const attentionScores = (query: number[], keys: number[][]): number[] => | ||
| keys.map((key) => attentionScore(query, key)) | ||
|
|
||
| /** | ||
| * Retrieve the top-k keys by attention score (Unlimiformer's kNN over key index). | ||
| * Results are sorted by descending score. | ||
| */ | ||
| export const topKAttentionKeys = ( | ||
| query: number[], | ||
| keys: number[][], | ||
| k: number, | ||
| ): AttentionTopK[] => { | ||
| if (k <= 0 || keys.length === 0) { | ||
| return [] | ||
| } | ||
|
|
||
| const scored: AttentionTopK[] = keys.flatMap((key, index) => { | ||
| const score = attentionScore(query, key) | ||
| if (!Number.isFinite(score)) { | ||
| return [] | ||
| } | ||
| return [{ index, score }] | ||
| }) | ||
|
|
||
| if (scored.length === 0) { | ||
| return [] | ||
| } | ||
|
|
||
| const effectiveK = Math.min(k, scored.length) | ||
| scored.sort((a, b) => b.score - a.score) | ||
| return scored.slice(0, effectiveK) | ||
| } | ||
|
|
||
| /** | ||
| * Per-head top-k retrieval: each query vector gets its own top-k over the same key set, | ||
| * analogous to multi-head cross-attention with separate query projections. | ||
| */ | ||
| export const topKAttentionKeysMultiHead = ( | ||
| queries: number[][], | ||
| keys: number[][], | ||
| k: number, | ||
| ): AttentionTopK[][] => queries.map((q) => topKAttentionKeys(q, keys, k)) | ||
|
|
||
| export type RankedItem<T> = { | ||
| item: T | ||
| originalIndex: number | ||
| score: number | ||
| } | ||
|
|
||
| /** | ||
| * Rank arbitrary items that carry embeddings, returning the top-k by attention score. | ||
| * Items with missing or empty embeddings, or embeddings whose length does not match | ||
| * `queryEmbedding`, are skipped. | ||
| */ | ||
| export const rankItemsByAttentionTopK = <T>( | ||
| queryEmbedding: number[], | ||
| items: readonly T[], | ||
| getEmbedding: (item: T) => number[] | null | undefined, | ||
| k: number, | ||
| ): RankedItem<T>[] => { | ||
| if (k <= 0 || items.length === 0) { | ||
| return [] | ||
| } | ||
|
|
||
| if (!isFiniteEmbeddingVector(queryEmbedding)) { | ||
| return [] | ||
| } | ||
|
|
||
| const packed: Array<{ item: T; originalIndex: number; embedding: number[] }> = | ||
| [] | ||
|
|
||
| for (let i = 0; i < items.length; i++) { | ||
| const item = items[i] | ||
| if (item === undefined) continue | ||
| const embedding = getEmbedding(item) | ||
| if ( | ||
| embedding && | ||
| embedding.length > 0 && | ||
| embedding.length === queryEmbedding.length && | ||
| isFiniteEmbeddingVector(embedding) | ||
| ) { | ||
| packed.push({ item, originalIndex: i, embedding }) | ||
| } | ||
| } | ||
|
|
||
| if (packed.length === 0) { | ||
| return [] | ||
| } | ||
|
|
||
| const keys = packed.map((p) => p.embedding) | ||
| const top = topKAttentionKeys(queryEmbedding, keys, k) | ||
|
|
||
| return top.flatMap(({ index: keyIndex, score }) => { | ||
| const row = packed[keyIndex] | ||
| if (!row) { | ||
| return [] | ||
| } | ||
| return [ | ||
| { | ||
| item: row.item, | ||
| originalIndex: row.originalIndex, | ||
| score, | ||
| }, | ||
| ] | ||
| }) | ||
| } | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.