From 071b9d6e792a0b31f2995c7a1b0b2325b29cdfd6 Mon Sep 17 00:00:00 2001 From: kayo09 <68217041+kayo09@users.noreply.github.com> Date: Sat, 4 Apr 2026 17:56:45 +0530 Subject: [PATCH 1/3] feat(lib): add Unlimiformer-style kNN attention retrieval helpers Implement the score-and-top-k pattern from Bertsch et al. (arXiv:2305.01625) for embedding vectors: dot-product attention scores with top-k key selection, plus multi-head and generic item ranking utilities. Add Bun tests and a test script for @repo/lib. --- packages/lib/package.json | 3 + packages/lib/unlimiformer.test.ts | 76 ++++++++++++++++++++ packages/lib/unlimiformer.ts | 115 ++++++++++++++++++++++++++++++ 3 files changed, 194 insertions(+) create mode 100644 packages/lib/unlimiformer.test.ts create mode 100644 packages/lib/unlimiformer.ts diff --git a/packages/lib/package.json b/packages/lib/package.json index 99b9d262..ee245167 100644 --- a/packages/lib/package.json +++ b/packages/lib/package.json @@ -3,6 +3,9 @@ "version": "0.0.0", "private": true, "type": "module", + "scripts": { + "test": "bun test" + }, "exports": { "./*": "./*" }, diff --git a/packages/lib/unlimiformer.test.ts b/packages/lib/unlimiformer.test.ts new file mode 100644 index 00000000..ae988a22 --- /dev/null +++ b/packages/lib/unlimiformer.test.ts @@ -0,0 +1,76 @@ +import { describe, expect, test } from "bun:test" +import { + attentionScores, + rankItemsByAttentionTopK, + topKAttentionKeys, + topKAttentionKeysMultiHead, +} from "./unlimiformer" + +const unit = (a: number, b: number, c: number) => { + const v = [a, b, c] + const norm = Math.hypot(a, b, c) + return v.map((x) => x / norm) +} + +describe("topKAttentionKeys", () => { + test("returns empty when k is zero or keys empty", () => { + const q = unit(1, 0, 0) + expect(topKAttentionKeys(q, [], 3)).toEqual([]) + expect(topKAttentionKeys(q, [unit(1, 0, 0)], 0)).toEqual([]) + }) + + test("orders by dot product / cosine on unit vectors", () => { + const q = unit(1, 0, 0) + const k0 = unit(1, 0, 0) + const k1 = unit(0, 1, 0) + const k2 = unit(-1, 0, 0) + const keys = [k1, k2, k0] + + const top = topKAttentionKeys(q, keys, 2) + expect(top.map((t) => t.index)).toEqual([2, 0]) + expect(top[0]?.score).toBeGreaterThan( + top[1]?.score ?? Number.NEGATIVE_INFINITY, + ) + }) + + test("caps k at number of keys", () => { + const q = unit(1, 0, 0) + const keys = [unit(1, 0, 0), unit(0, 1, 0)] + const top = topKAttentionKeys(q, keys, 10) + expect(top).toHaveLength(2) + }) +}) + +describe("attentionScores", () => { + test("matches per-key attentionScore", () => { + const q = unit(1, 1, 0) + const keys = [unit(1, 0, 0), unit(0, 1, 0)] + const scores = attentionScores(q, keys) + expect(scores).toHaveLength(2) + }) +}) + +describe("topKAttentionKeysMultiHead", () => { + test("runs independent top-k per query", () => { + const keys = [unit(1, 0, 0), unit(0, 1, 0), unit(0, 0, 1)] + const q0 = unit(1, 0, 0) + const q1 = unit(0, 1, 0) + const out = topKAttentionKeysMultiHead([q0, q1], keys, 1) + expect(out[0]?.[0]?.index).toBe(0) + expect(out[1]?.[0]?.index).toBe(1) + }) +}) + +describe("rankItemsByAttentionTopK", () => { + test("maps back to original indices and skips bad embeddings", () => { + const items = [ + { id: "a", e: unit(1, 0, 0) }, + { id: "b", e: null as number[] | null }, + { id: "c", e: unit(0, 1, 0) }, + ] + const q = unit(0, 1, 0) + const ranked = rankItemsByAttentionTopK(q, items, (x) => x.e, 2) + expect(ranked[0]?.item.id).toBe("c") + expect(ranked[0]?.originalIndex).toBe(2) + }) +}) diff --git a/packages/lib/unlimiformer.ts b/packages/lib/unlimiformer.ts new file mode 100644 index 00000000..a72a0c10 --- /dev/null +++ b/packages/lib/unlimiformer.ts @@ -0,0 +1,115 @@ +/** + * Unlimiformer-style kNN retrieval (Bertsch et al., 2023). + * + * @see https://arxiv.org/abs/2305.01625 — cross-attention is approximated by retrieving + * the top-k keys under dot-product scores. In this codebase, embeddings are treated as + * key/query vectors; for L2-normalized vectors, dot product equals cosine similarity, + * matching the ranking used elsewhere in `@repo/lib/similarity`. + */ + +import { cosineSimilarity } from "./similarity" + +export type AttentionTopK = { + index: number + score: number +} + +/** + * Dot-product attention score between one query vector and one key vector. + * For normalized embeddings this matches cosine similarity. + */ +export const attentionScore = (query: number[], key: number[]): number => + cosineSimilarity(query, key) + +/** + * Attention scores for `query` against every row in `keys` (same dimension as query). + */ +export const attentionScores = (query: number[], keys: number[][]): number[] => + keys.map((key) => attentionScore(query, key)) + +/** + * Retrieve the top-k keys by attention score (Unlimiformer's kNN over key index). + * Results are sorted by descending score. + */ +export const topKAttentionKeys = ( + query: number[], + keys: number[][], + k: number, +): AttentionTopK[] => { + if (k <= 0 || keys.length === 0) { + return [] + } + + const effectiveK = Math.min(k, keys.length) + const scored: AttentionTopK[] = keys.map((key, index) => ({ + index, + score: attentionScore(query, key), + })) + + scored.sort((a, b) => b.score - a.score) + return scored.slice(0, effectiveK) +} + +/** + * Per-head top-k retrieval: each query vector gets its own top-k over the same key set, + * analogous to multi-head cross-attention with separate query projections. + */ +export const topKAttentionKeysMultiHead = ( + queries: number[][], + keys: number[][], + k: number, +): AttentionTopK[][] => queries.map((q) => topKAttentionKeys(q, keys, k)) + +export type RankedItem = { + item: T + originalIndex: number + score: number +} + +/** + * Rank arbitrary items that carry embeddings, returning the top-k by attention score. + * Items with missing or empty embeddings are skipped. + */ +export const rankItemsByAttentionTopK = ( + queryEmbedding: number[], + items: readonly T[], + getEmbedding: (item: T) => number[] | null | undefined, + k: number, +): RankedItem[] => { + if (k <= 0 || items.length === 0) { + return [] + } + + const packed: Array<{ item: T; originalIndex: number; embedding: number[] }> = + [] + + for (let i = 0; i < items.length; i++) { + const item = items[i] + if (item === undefined) continue + const embedding = getEmbedding(item) + if (embedding && embedding.length > 0) { + packed.push({ item, originalIndex: i, embedding }) + } + } + + if (packed.length === 0) { + return [] + } + + const keys = packed.map((p) => p.embedding) + const top = topKAttentionKeys(queryEmbedding, keys, k) + + return top.flatMap(({ index: keyIndex, score }) => { + const row = packed[keyIndex] + if (!row) { + return [] + } + return [ + { + item: row.item, + originalIndex: row.originalIndex, + score, + }, + ] + }) +} From 33f47db17574254b41b51289ef8a758f59700045 Mon Sep 17 00:00:00 2001 From: kayo09 <68217041+kayo09@users.noreply.github.com> Date: Sat, 4 Apr 2026 18:14:11 +0530 Subject: [PATCH 2/3] fix(lib): handle embedding dimension mismatch in Unlimiformer retrieval Skip mismatched key/query lengths instead of throwing from cosineSimilarity; attentionScore returns NaN for length mismatch; extend tests. --- packages/lib/unlimiformer.test.ts | 40 ++++++++++++++++++++++++++++++ packages/lib/unlimiformer.ts | 41 +++++++++++++++++++++++-------- 2 files changed, 71 insertions(+), 10 deletions(-) diff --git a/packages/lib/unlimiformer.test.ts b/packages/lib/unlimiformer.test.ts index ae988a22..cd72921e 100644 --- a/packages/lib/unlimiformer.test.ts +++ b/packages/lib/unlimiformer.test.ts @@ -39,6 +39,27 @@ describe("topKAttentionKeys", () => { const top = topKAttentionKeys(q, keys, 10) expect(top).toHaveLength(2) }) + + test("skips keys whose dimension does not match the query (no throw)", () => { + const q = unit(1, 0, 0) + const keys = [unit(1, 0, 0), [1, 0], unit(0, 1, 0)] + const top = topKAttentionKeys(q, keys, 5) + expect(top.map((t) => t.index)).toEqual([0, 2]) + }) + + test("returns empty when no key matches query dimension", () => { + const q = unit(1, 0, 0) + expect( + topKAttentionKeys( + q, + [ + [1, 0], + [0, 1], + ], + 3, + ), + ).toEqual([]) + }) }) describe("attentionScores", () => { @@ -48,6 +69,14 @@ describe("attentionScores", () => { const scores = attentionScores(q, keys) expect(scores).toHaveLength(2) }) + + test("uses NaN when key dimension mismatches query", () => { + const q = unit(1, 0, 0) + const scores = attentionScores(q, [unit(1, 0, 0), [1, 0]]) + expect(scores).toHaveLength(2) + expect(Number.isFinite(scores[0] ?? Number.NaN)).toBe(true) + expect(scores[1]).toBeNaN() + }) }) describe("topKAttentionKeysMultiHead", () => { @@ -73,4 +102,15 @@ describe("rankItemsByAttentionTopK", () => { expect(ranked[0]?.item.id).toBe("c") expect(ranked[0]?.originalIndex).toBe(2) }) + + test("skips items whose embedding length does not match the query", () => { + const items = [ + { id: "wide", e: [0.1, 0.2, 0.3, 0.4] }, + { id: "ok", e: unit(0, 1, 0) }, + ] + const q = unit(0, 1, 0) + const ranked = rankItemsByAttentionTopK(q, items, (x) => x.e, 2) + expect(ranked).toHaveLength(1) + expect(ranked[0]?.item.id).toBe("ok") + }) }) diff --git a/packages/lib/unlimiformer.ts b/packages/lib/unlimiformer.ts index a72a0c10..882fae9c 100644 --- a/packages/lib/unlimiformer.ts +++ b/packages/lib/unlimiformer.ts @@ -17,12 +17,21 @@ export type AttentionTopK = { /** * Dot-product attention score between one query vector and one key vector. * For normalized embeddings this matches cosine similarity. + * + * Returns `NaN` when `query` and `key` have different lengths (e.g. mixed embedding + * models) so callers can avoid throwing from `cosineSimilarity`. */ -export const attentionScore = (query: number[], key: number[]): number => - cosineSimilarity(query, key) +export const attentionScore = (query: number[], key: number[]): number => { + if (query.length !== key.length) { + return Number.NaN + } + return cosineSimilarity(query, key) +} /** - * Attention scores for `query` against every row in `keys` (same dimension as query). + * Attention scores for `query` against every row in `keys`, aligned by index. + * Entries are `NaN` when a key length does not match the query (same embedding model + * is required for a meaningful score). */ export const attentionScores = (query: number[], keys: number[][]): number[] => keys.map((key) => attentionScore(query, key)) @@ -40,12 +49,19 @@ export const topKAttentionKeys = ( return [] } - const effectiveK = Math.min(k, keys.length) - const scored: AttentionTopK[] = keys.map((key, index) => ({ - index, - score: attentionScore(query, key), - })) + const scored: AttentionTopK[] = keys.flatMap((key, index) => { + const score = attentionScore(query, key) + if (!Number.isFinite(score)) { + return [] + } + return [{ index, score }] + }) + + if (scored.length === 0) { + return [] + } + const effectiveK = Math.min(k, scored.length) scored.sort((a, b) => b.score - a.score) return scored.slice(0, effectiveK) } @@ -68,7 +84,8 @@ export type RankedItem = { /** * Rank arbitrary items that carry embeddings, returning the top-k by attention score. - * Items with missing or empty embeddings are skipped. + * Items with missing or empty embeddings, or embeddings whose length does not match + * `queryEmbedding`, are skipped. */ export const rankItemsByAttentionTopK = ( queryEmbedding: number[], @@ -87,7 +104,11 @@ export const rankItemsByAttentionTopK = ( const item = items[i] if (item === undefined) continue const embedding = getEmbedding(item) - if (embedding && embedding.length > 0) { + if ( + embedding && + embedding.length > 0 && + embedding.length === queryEmbedding.length + ) { packed.push({ item, originalIndex: i, embedding }) } } From e0eccbdd7e2094f97f32e67fdf1da71f80a97964 Mon Sep 17 00:00:00 2001 From: kayo09 <68217041+kayo09@users.noreply.github.com> Date: Sat, 4 Apr 2026 18:24:12 +0530 Subject: [PATCH 3/3] fix(lib): guard Unlimiformer attentionScore against non-finite embeddings Return NaN when vectors contain NaN, Infinity, or non-numbers instead of throwing from cosineSimilarity. Skip non-finite query and item embeddings in rankItemsByAttentionTopK. Add tests. --- packages/lib/unlimiformer.test.ts | 41 +++++++++++++++++++++++++++++++ packages/lib/unlimiformer.ts | 21 +++++++++++++--- 2 files changed, 58 insertions(+), 4 deletions(-) diff --git a/packages/lib/unlimiformer.test.ts b/packages/lib/unlimiformer.test.ts index cd72921e..3c134d9f 100644 --- a/packages/lib/unlimiformer.test.ts +++ b/packages/lib/unlimiformer.test.ts @@ -1,5 +1,6 @@ import { describe, expect, test } from "bun:test" import { + attentionScore, attentionScores, rankItemsByAttentionTopK, topKAttentionKeys, @@ -60,6 +61,28 @@ describe("topKAttentionKeys", () => { ), ).toEqual([]) }) + + test("skips keys with NaN components without throwing", () => { + const q = unit(1, 0, 0) + const keys = [unit(1, 0, 0), [1, Number.NaN, 0], unit(0, 1, 0)] + const top = topKAttentionKeys(q, keys, 5) + expect(top.map((t) => t.index)).toEqual([0, 2]) + }) +}) + +describe("attentionScore", () => { + test("returns NaN instead of throwing when a vector contains NaN", () => { + const k = unit(1, 0, 0) + expect(attentionScore([1, Number.NaN, 0], k)).toBeNaN() + expect(attentionScore(k, [1, Number.NaN, 0])).toBeNaN() + }) + + test("returns NaN for non-number or non-finite components", () => { + const k = unit(1, 0, 0) + const stringSlot = [1, "x", 0] as unknown as number[] + expect(attentionScore(stringSlot, k)).toBeNaN() + expect(attentionScore([Number.POSITIVE_INFINITY, 0, 0], k)).toBeNaN() + }) }) describe("attentionScores", () => { @@ -113,4 +136,22 @@ describe("rankItemsByAttentionTopK", () => { expect(ranked).toHaveLength(1) expect(ranked[0]?.item.id).toBe("ok") }) + + test("returns empty when query embedding is non-finite", () => { + const items = [{ id: "a", e: unit(1, 0, 0) }] + expect( + rankItemsByAttentionTopK([Number.NaN, 0, 0], items, (x) => x.e, 2), + ).toEqual([]) + }) + + test("skips items with non-finite embeddings", () => { + const items = [ + { id: "bad", e: [1, Number.NaN, 0] }, + { id: "ok", e: unit(0, 1, 0) }, + ] + const q = unit(0, 1, 0) + const ranked = rankItemsByAttentionTopK(q, items, (x) => x.e, 2) + expect(ranked).toHaveLength(1) + expect(ranked[0]?.item.id).toBe("ok") + }) }) diff --git a/packages/lib/unlimiformer.ts b/packages/lib/unlimiformer.ts index 882fae9c..5403c67f 100644 --- a/packages/lib/unlimiformer.ts +++ b/packages/lib/unlimiformer.ts @@ -9,6 +9,10 @@ import { cosineSimilarity } from "./similarity" +/** True when every entry is a finite number (empty arrays allowed). */ +const isFiniteEmbeddingVector = (v: number[]): boolean => + v.every((x) => typeof x === "number" && Number.isFinite(x)) + export type AttentionTopK = { index: number score: number @@ -19,19 +23,23 @@ export type AttentionTopK = { * For normalized embeddings this matches cosine similarity. * * Returns `NaN` when `query` and `key` have different lengths (e.g. mixed embedding - * models) so callers can avoid throwing from `cosineSimilarity`. + * models), or when either vector contains non-finite values (`NaN`, `±Infinity`), so + * callers avoid throwing from `cosineSimilarity`. */ export const attentionScore = (query: number[], key: number[]): number => { if (query.length !== key.length) { return Number.NaN } + if (!isFiniteEmbeddingVector(query) || !isFiniteEmbeddingVector(key)) { + return Number.NaN + } return cosineSimilarity(query, key) } /** * Attention scores for `query` against every row in `keys`, aligned by index. - * Entries are `NaN` when a key length does not match the query (same embedding model - * is required for a meaningful score). + * Entries are `NaN` when a key length does not match the query, or when either vector + * has non-finite components. */ export const attentionScores = (query: number[], keys: number[][]): number[] => keys.map((key) => attentionScore(query, key)) @@ -97,6 +105,10 @@ export const rankItemsByAttentionTopK = ( return [] } + if (!isFiniteEmbeddingVector(queryEmbedding)) { + return [] + } + const packed: Array<{ item: T; originalIndex: number; embedding: number[] }> = [] @@ -107,7 +119,8 @@ export const rankItemsByAttentionTopK = ( if ( embedding && embedding.length > 0 && - embedding.length === queryEmbedding.length + embedding.length === queryEmbedding.length && + isFiniteEmbeddingVector(embedding) ) { packed.push({ item, originalIndex: i, embedding }) }