Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions packages/lib/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
"version": "0.0.0",
"private": true,
"type": "module",
"scripts": {
"test": "bun test"
},
"exports": {
"./*": "./*"
},
Expand Down
157 changes: 157 additions & 0 deletions packages/lib/unlimiformer.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import { describe, expect, test } from "bun:test"
import {
attentionScore,
attentionScores,
rankItemsByAttentionTopK,
topKAttentionKeys,
topKAttentionKeysMultiHead,
} from "./unlimiformer"

const unit = (a: number, b: number, c: number) => {
const v = [a, b, c]
const norm = Math.hypot(a, b, c)
return v.map((x) => x / norm)
}

describe("topKAttentionKeys", () => {
test("returns empty when k is zero or keys empty", () => {
const q = unit(1, 0, 0)
expect(topKAttentionKeys(q, [], 3)).toEqual([])
expect(topKAttentionKeys(q, [unit(1, 0, 0)], 0)).toEqual([])
})

test("orders by dot product / cosine on unit vectors", () => {
const q = unit(1, 0, 0)
const k0 = unit(1, 0, 0)
const k1 = unit(0, 1, 0)
const k2 = unit(-1, 0, 0)
const keys = [k1, k2, k0]

const top = topKAttentionKeys(q, keys, 2)
expect(top.map((t) => t.index)).toEqual([2, 0])
expect(top[0]?.score).toBeGreaterThan(
top[1]?.score ?? Number.NEGATIVE_INFINITY,
)
})

test("caps k at number of keys", () => {
const q = unit(1, 0, 0)
const keys = [unit(1, 0, 0), unit(0, 1, 0)]
const top = topKAttentionKeys(q, keys, 10)
expect(top).toHaveLength(2)
})

test("skips keys whose dimension does not match the query (no throw)", () => {
const q = unit(1, 0, 0)
const keys = [unit(1, 0, 0), [1, 0], unit(0, 1, 0)]
const top = topKAttentionKeys(q, keys, 5)
expect(top.map((t) => t.index)).toEqual([0, 2])
})

test("returns empty when no key matches query dimension", () => {
const q = unit(1, 0, 0)
expect(
topKAttentionKeys(
q,
[
[1, 0],
[0, 1],
],
3,
),
).toEqual([])
})

test("skips keys with NaN components without throwing", () => {
const q = unit(1, 0, 0)
const keys = [unit(1, 0, 0), [1, Number.NaN, 0], unit(0, 1, 0)]
const top = topKAttentionKeys(q, keys, 5)
expect(top.map((t) => t.index)).toEqual([0, 2])
})
})

describe("attentionScore", () => {
test("returns NaN instead of throwing when a vector contains NaN", () => {
const k = unit(1, 0, 0)
expect(attentionScore([1, Number.NaN, 0], k)).toBeNaN()
expect(attentionScore(k, [1, Number.NaN, 0])).toBeNaN()
})

test("returns NaN for non-number or non-finite components", () => {
const k = unit(1, 0, 0)
const stringSlot = [1, "x", 0] as unknown as number[]
expect(attentionScore(stringSlot, k)).toBeNaN()
expect(attentionScore([Number.POSITIVE_INFINITY, 0, 0], k)).toBeNaN()
})
})

describe("attentionScores", () => {
test("matches per-key attentionScore", () => {
const q = unit(1, 1, 0)
const keys = [unit(1, 0, 0), unit(0, 1, 0)]
const scores = attentionScores(q, keys)
expect(scores).toHaveLength(2)
})

test("uses NaN when key dimension mismatches query", () => {
const q = unit(1, 0, 0)
const scores = attentionScores(q, [unit(1, 0, 0), [1, 0]])
expect(scores).toHaveLength(2)
expect(Number.isFinite(scores[0] ?? Number.NaN)).toBe(true)
expect(scores[1]).toBeNaN()
})
})

describe("topKAttentionKeysMultiHead", () => {
test("runs independent top-k per query", () => {
const keys = [unit(1, 0, 0), unit(0, 1, 0), unit(0, 0, 1)]
const q0 = unit(1, 0, 0)
const q1 = unit(0, 1, 0)
const out = topKAttentionKeysMultiHead([q0, q1], keys, 1)
expect(out[0]?.[0]?.index).toBe(0)
expect(out[1]?.[0]?.index).toBe(1)
})
})

describe("rankItemsByAttentionTopK", () => {
test("maps back to original indices and skips bad embeddings", () => {
const items = [
{ id: "a", e: unit(1, 0, 0) },
{ id: "b", e: null as number[] | null },
{ id: "c", e: unit(0, 1, 0) },
]
const q = unit(0, 1, 0)
const ranked = rankItemsByAttentionTopK(q, items, (x) => x.e, 2)
expect(ranked[0]?.item.id).toBe("c")
expect(ranked[0]?.originalIndex).toBe(2)
})

test("skips items whose embedding length does not match the query", () => {
const items = [
{ id: "wide", e: [0.1, 0.2, 0.3, 0.4] },
{ id: "ok", e: unit(0, 1, 0) },
]
const q = unit(0, 1, 0)
const ranked = rankItemsByAttentionTopK(q, items, (x) => x.e, 2)
expect(ranked).toHaveLength(1)
expect(ranked[0]?.item.id).toBe("ok")
})

test("returns empty when query embedding is non-finite", () => {
const items = [{ id: "a", e: unit(1, 0, 0) }]
expect(
rankItemsByAttentionTopK([Number.NaN, 0, 0], items, (x) => x.e, 2),
).toEqual([])
})

test("skips items with non-finite embeddings", () => {
const items = [
{ id: "bad", e: [1, Number.NaN, 0] },
{ id: "ok", e: unit(0, 1, 0) },
]
const q = unit(0, 1, 0)
const ranked = rankItemsByAttentionTopK(q, items, (x) => x.e, 2)
expect(ranked).toHaveLength(1)
expect(ranked[0]?.item.id).toBe("ok")
})
})
149 changes: 149 additions & 0 deletions packages/lib/unlimiformer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/**
* Unlimiformer-style kNN retrieval (Bertsch et al., 2023).
*
* @see https://arxiv.org/abs/2305.01625 — cross-attention is approximated by retrieving
* the top-k keys under dot-product scores. In this codebase, embeddings are treated as
* key/query vectors; for L2-normalized vectors, dot product equals cosine similarity,
* matching the ranking used elsewhere in `@repo/lib/similarity`.
*/

import { cosineSimilarity } from "./similarity"

/** True when every entry is a finite number (empty arrays allowed). */
const isFiniteEmbeddingVector = (v: number[]): boolean =>
v.every((x) => typeof x === "number" && Number.isFinite(x))

export type AttentionTopK = {
index: number
score: number
}

/**
* Dot-product attention score between one query vector and one key vector.
* For normalized embeddings this matches cosine similarity.
*
* Returns `NaN` when `query` and `key` have different lengths (e.g. mixed embedding
* models), or when either vector contains non-finite values (`NaN`, `±Infinity`), so
* callers avoid throwing from `cosineSimilarity`.
*/
export const attentionScore = (query: number[], key: number[]): number => {
if (query.length !== key.length) {
return Number.NaN
}
if (!isFiniteEmbeddingVector(query) || !isFiniteEmbeddingVector(key)) {
return Number.NaN
}
return cosineSimilarity(query, key)
}

/**
* Attention scores for `query` against every row in `keys`, aligned by index.
* Entries are `NaN` when a key length does not match the query, or when either vector
* has non-finite components.
*/
export const attentionScores = (query: number[], keys: number[][]): number[] =>
keys.map((key) => attentionScore(query, key))

/**
* Retrieve the top-k keys by attention score (Unlimiformer's kNN over key index).
* Results are sorted by descending score.
*/
export const topKAttentionKeys = (
query: number[],
keys: number[][],
k: number,
): AttentionTopK[] => {
if (k <= 0 || keys.length === 0) {
return []
}

const scored: AttentionTopK[] = keys.flatMap((key, index) => {
const score = attentionScore(query, key)
if (!Number.isFinite(score)) {
return []
}
return [{ index, score }]
})

if (scored.length === 0) {
return []
}

const effectiveK = Math.min(k, scored.length)
scored.sort((a, b) => b.score - a.score)
return scored.slice(0, effectiveK)
}

/**
* Per-head top-k retrieval: each query vector gets its own top-k over the same key set,
* analogous to multi-head cross-attention with separate query projections.
*/
export const topKAttentionKeysMultiHead = (
queries: number[][],
keys: number[][],
k: number,
): AttentionTopK[][] => queries.map((q) => topKAttentionKeys(q, keys, k))

export type RankedItem<T> = {
item: T
originalIndex: number
score: number
}

/**
* Rank arbitrary items that carry embeddings, returning the top-k by attention score.
* Items with missing or empty embeddings, or embeddings whose length does not match
* `queryEmbedding`, are skipped.
*/
export const rankItemsByAttentionTopK = <T>(
queryEmbedding: number[],
items: readonly T[],
getEmbedding: (item: T) => number[] | null | undefined,
k: number,
): RankedItem<T>[] => {
if (k <= 0 || items.length === 0) {
return []
}

if (!isFiniteEmbeddingVector(queryEmbedding)) {
return []
}

const packed: Array<{ item: T; originalIndex: number; embedding: number[] }> =
[]

for (let i = 0; i < items.length; i++) {
const item = items[i]
if (item === undefined) continue
const embedding = getEmbedding(item)
if (
embedding &&
embedding.length > 0 &&
embedding.length === queryEmbedding.length &&
isFiniteEmbeddingVector(embedding)
) {
packed.push({ item, originalIndex: i, embedding })
}
}

if (packed.length === 0) {
return []
}

const keys = packed.map((p) => p.embedding)
const top = topKAttentionKeys(queryEmbedding, keys, k)

return top.flatMap(({ index: keyIndex, score }) => {
const row = packed[keyIndex]
if (!row) {
return []
}
return [
{
item: row.item,
originalIndex: row.originalIndex,
score,
},
]
})
}