Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ API Changes

* GITHUB#15502: Add `count()` method to `FilterWeight` (Prudhvi Godithi)

* GITHUB#15621: Add validation to prevent zero vectors in KNN fields (Vigya Sharma)

New Features
---------------------
* GITHUB#15328: VectorSimilarityFunction.getValues() now implements doubleVal allowing its
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ public void testFloatVectorFails() throws IOException {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", new float[4], VectorSimilarityFunction.DOT_PRODUCT));
doc.add(
new KnnFloatVectorField(
"f", new float[] {1f, 0f, 0f, 1f}, VectorSimilarityFunction.DOT_PRODUCT));
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc));
e.getMessage().contains("HnswBitVectorsFormat only supports BYTE encoding");
Expand All @@ -68,8 +70,7 @@ public void testIndexAndSearchBitVectors() throws IOException {
new byte[] {(byte) 0b10101110, (byte) 0b01010111},
new byte[] {(byte) 0b11111000, (byte) 0b00001111},
new byte[] {(byte) 0b11001100, (byte) 0b00110011},
new byte[] {(byte) 0b11111111, (byte) 0b00000000},
new byte[] {(byte) 0b00000000, (byte) 0b00000000}
new byte[] {(byte) 0b11111111, (byte) 0b00000000}
};
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.VectorUtil;

/**
* A field that contains a single byte numeric vector (or none) for each document. Vectors are dense
Expand Down Expand Up @@ -97,6 +98,9 @@ public static FieldType createFieldType(
public KnnByteVectorField(
String name, byte[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
if (VectorUtil.isZeroVector(vector) == true) {
throw new IllegalArgumentException("zero vector not allowed for vector field value");
}
fieldsData = vector; // null-check done above
}

Expand Down Expand Up @@ -138,6 +142,9 @@ public KnnByteVectorField(String name, byte[] vector, FieldType fieldType) {
throw new IllegalArgumentException(
"The number of vector dimensions does not match the field type");
}
if (VectorUtil.isZeroVector(vector) == true) {
throw new IllegalArgumentException("zero vector not allowed for vector field value");
}
fieldsData = vector;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ public static Query newVectorQuery(String field, float[] queryVector, int k) {
public KnnFloatVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
if (VectorUtil.isZeroVector(vector) == true) {
throw new IllegalArgumentException("zero vector not allowed for vector field value");
}
fieldsData = VectorUtil.checkFinite(vector); // null check done above
}

Expand Down Expand Up @@ -139,6 +142,9 @@ public KnnFloatVectorField(String name, float[] vector, FieldType fieldType) {
throw new IllegalArgumentException(
"The number of vector dimensions does not match the field type");
}
if (VectorUtil.isZeroVector(vector) == true) {
throw new IllegalArgumentException("zero vector not allowed for vector field value");
}
fieldsData = VectorUtil.checkFinite(vector);
}

Expand Down
20 changes: 20 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,26 @@ public static float[] checkFinite(float[] v) {
return v;
}

/** Returns true if all dimensions of provided vector are zero, false otherwise. */
public static boolean isZeroVector(float[] v) {
for (float value : v) {
if (value != 0) {
return false;
}
}
return true;
}

/** Returns true if all dimensions of provided vector are zero, false otherwise. */
public static boolean isZeroVector(byte[] v) {
for (float value : v) {
if (value != 0) {
return false;
}
}
return true;
}

/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code to}
* exclusive, find the first array index whose value is greater than or equal to {@code target}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.index.RandomCodec;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.TestVectorUtil;
import org.hamcrest.MatcherAssert;

/** Basic tests of PerFieldDocValuesFormat */
Expand Down Expand Up @@ -232,14 +233,14 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
});
try (IndexWriter writer = new IndexWriter(directory, iwc)) {
Document doc1 = new Document();
doc1.add(new KnnFloatVectorField("field1", new float[33]));
doc1.add(new KnnFloatVectorField("field1", TestVectorUtil.randomVector(33)));
Exception exc =
expectThrows(IllegalArgumentException.class, () -> writer.addDocument(doc1));
assertTrue(exc.getMessage().contains("vector's dimensions must be <= [32]"));

Document doc2 = new Document();
doc2.add(new KnnFloatVectorField("field1", new float[32]));
doc2.add(new KnnFloatVectorField("field2", new float[33]));
doc2.add(new KnnFloatVectorField("field1", TestVectorUtil.randomVector(32)));
doc2.add(new KnnFloatVectorField("field2", TestVectorUtil.randomVector(33)));
writer.addDocument(doc2);
}

Expand Down
30 changes: 29 additions & 1 deletion lucene/core/src/test/org/apache/lucene/document/TestField.java
Original file line number Diff line number Diff line change
Expand Up @@ -698,17 +698,45 @@ public void testKnnVectorField() throws Exception {
try (Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
byte[] b = new byte[5];
byte[] empty = new byte[5];
IllegalArgumentException zeroError =
expectThrows(
IllegalArgumentException.class, () -> new KnnByteVectorField("binaryZeroErr", empty));
assertTrue(zeroError.getMessage().contains("zero vector not allowed"));

byte[] b = new byte[] {1, 1, 1, 1, 1};
KnnByteVectorField field =
new KnnByteVectorField("binary", b, VectorSimilarityFunction.EUCLIDEAN);
assertNull(field.binaryValue());
assertArrayEquals(b, field.vectorValue());

expectThrows(
IllegalArgumentException.class,
() -> new KnnFloatVectorField("bogus", new float[] {1}, (FieldType) field.fieldType()));
zeroError =
expectThrows(
IllegalArgumentException.class,
() ->
new KnnByteVectorField(
"float", new byte[] {0, 0, 0, 0, 0}, (FieldType) field.fieldType()));
assertTrue(zeroError.getMessage().contains("zero vector not allowed"));
zeroError =
expectThrows(
IllegalArgumentException.class,
() -> new KnnFloatVectorField("zerovec", new float[] {0, 0, 0, 0}));
assertTrue(zeroError.getMessage().contains("zero vector not allowed"));

float[] vector = new float[] {1, 2};
Field field2 = new KnnFloatVectorField("float", vector);
assertNull(field2.binaryValue());
zeroError =
expectThrows(
IllegalArgumentException.class,
() ->
new KnnFloatVectorField(
"float", new float[] {0, 0}, (FieldType) field2.fieldType()));
assertTrue(zeroError.getMessage().contains("zero vector not allowed"));

doc.add(field);
doc.add(field2);
w.addDocument(doc);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.lucene.tests.util.LuceneTestCase.SuppressFileSystems;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TestVectorUtil;

/** Test that the default codec detects mismatched checksums at open or checkIntegrity time. */
@SuppressFileSystems("ExtrasFS")
Expand Down Expand Up @@ -69,7 +70,7 @@ public void test() throws Exception {
doc.add(pointNumber);
Field dvNumber = new NumericDocValuesField("long", 0L);
doc.add(dvNumber);
KnnFloatVectorField vector = new KnnFloatVectorField("vector", new float[16]);
KnnFloatVectorField vector = new KnnFloatVectorField("vector", TestVectorUtil.randomVector(16));
doc.add(vector);

for (int i = 0; i < 100; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.lucene.tests.util.LuceneTestCase.SuppressFileSystems;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.TestVectorUtil;

/** Test that a plain default detects index file truncation early (on opening a reader). */
@SuppressFileSystems("ExtrasFS")
Expand Down Expand Up @@ -82,7 +83,7 @@ private void doTest(boolean cfs) throws Exception {
doc.add(pointNumber);
Field dvNumber = new NumericDocValuesField("long", 0L);
doc.add(dvNumber);
KnnFloatVectorField vector = new KnnFloatVectorField("vector", new float[16]);
KnnFloatVectorField vector = new KnnFloatVectorField("vector", TestVectorUtil.randomVector(16));
doc.add(vector);

for (int i = 0; i < 100; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.lucene.util.SameThreadExecutorService;
import org.apache.lucene.util.StringHelper;
import org.apache.lucene.util.SuppressForbidden;
import org.apache.lucene.util.TestVectorUtil;
import org.apache.lucene.util.Version;

public class TestConcurrentMergeScheduler extends LuceneTestCase {
Expand Down Expand Up @@ -109,7 +110,7 @@ public Executor getIntraMergeExecutor(MergePolicy.OneMerge merge) {
IndexWriter writer = new IndexWriter(directory, iwc);
Document doc = new Document();
Field idField = newStringField("id", "", Field.Store.YES);
KnnFloatVectorField knnField = new KnnFloatVectorField("knn", new float[] {0.0f, 0.0f});
KnnFloatVectorField knnField = new KnnFloatVectorField("knn", TestVectorUtil.randomVector(2));
doc.add(idField);
// Add knn float vectors to test parallel merge
doc.add(knnField);
Expand Down Expand Up @@ -244,7 +245,7 @@ public void testNoWaitClose() throws IOException {
Directory directory = newDirectory();
Document doc = new Document();
Field idField = newStringField("id", "", Field.Store.YES);
KnnFloatVectorField knnField = new KnnFloatVectorField("knn", new float[] {0.0f, 0.0f});
KnnFloatVectorField knnField = new KnnFloatVectorField("knn", TestVectorUtil.randomVector(2));
doc.add(idField);
doc.add(knnField);
IndexWriterConfig iwc =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,8 @@ private void indexData(IndexWriter iw) throws IOException {
for (int i = 0; i < values.length; i++) {
// System.out.printf("%d: (%d, %d)\n", i, index % n, index / n);
int x = index % n, y = index / n;
values[i] = new float[] {x, y};
// avoid zero vectors
values[i] = new float[] {x + 1e-5f, y + 1e-5f};
index = (index + stepSize) % (n * n);
add(iw, i, values[i]);
if (i == 13) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public void testSortOnAddIndicesRandom() throws IOException {
doc.add(
new SortedSetDocValuesField("sorted_set_dv", new BytesRef(Integer.toString(docId))));
if (dense || docId % 2 == 0) {
doc.add(new KnnFloatVectorField("vector", new float[] {(float) docId}));
doc.add(new KnnFloatVectorField("vector", new float[] {(float) docId + 1e-6f}));
}
doc.add(new NumericDocValuesField("foo", random().nextInt(20)));

Expand Down
Loading