Skip to content
56 changes: 56 additions & 0 deletions sdk/cosmos/azure-cosmos/tests/test_vector_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,47 @@
import test_config
from azure.cosmos import CosmosClient, PartitionKey

VectorPolicyTestData = {
"valid_vector_indexing_policy" : {
"indexing_policy": {
"vectorIndexes": [
{"path": "/vector1", "type": "flat"},
{"path": "/vector2", "type": "quantizedFlat", "quantizerType": "product", "quantizationByteSize": 8},
{"path": "/vector3", "type": "diskANN", "quantizerType": "product", "quantizationByteSize": 8,
"vectorIndexShardKey": ["/city"], "indexingSearchListSize": 50},
{"path": "/vector4", "type": "diskANN", "quantizerType": "spherical", "indexingSearchListSize": 50},
]
},
"vector_embedding_policy": {
"vectorEmbeddings": [
{
"path": "/vector1",
"dataType": "float32",
"dimensions": 256,
"distanceFunction": "euclidean"
},
{
"path": "/vector2",
"dataType": "int8",
"dimensions": 200,
"distanceFunction": "dotproduct"
},
{
"path": "/vector3",
"dataType": "uint8",
"dimensions": 400,
"distanceFunction": "cosine"
},
{
"path": "/vector4",
"dataType": "uint8",
"dimensions": 400,
"distanceFunction": "euclidean"
},
]
}
}
}

@pytest.mark.cosmosSearchQuery
class TestVectorPolicy(unittest.TestCase):
Expand Down Expand Up @@ -55,6 +96,21 @@ def test_create_valid_vector_embedding_policy(self):
assert properties["vectorEmbeddingPolicy"]["vectorEmbeddings"][0]["dataType"] == data_type
self.test_db.delete_container('vector_container_' + data_type)

@unittest.skip
def test_create_valid_vector_indexing_policy(self):
test_data = VectorPolicyTestData["valid_vector_indexing_policy"]
indexing_policy = test_data["indexing_policy"]
vector_embedding_policy = test_data["vector_embedding_policy"]

created_container = self.test_db.create_container(
id="container_" + str(uuid.uuid4()),
partition_key=PartitionKey(path="/id"),
vector_embedding_policy=vector_embedding_policy,
indexing_policy=indexing_policy)
properties = created_container.read()
assert properties['indexingPolicy']['vectorIndexes'] == indexing_policy['vectorIndexes']
Comment thread
allenkim0129 marked this conversation as resolved.
self.test_db.delete_container(created_container.id)

def test_create_vector_embedding_container(self):
indexing_policy = {
"vectorIndexes": [
Expand Down
17 changes: 16 additions & 1 deletion sdk/cosmos/azure-cosmos/tests/test_vector_policy_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from azure.cosmos import CosmosClient as CosmosSyncClient
from azure.cosmos import PartitionKey
from azure.cosmos.aio import CosmosClient

from test_vector_policy import VectorPolicyTestData

@pytest.mark.cosmosSearchQuery
class TestVectorPolicyAsync(unittest.IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -46,6 +46,21 @@ async def asyncSetUp(self):
async def asyncTearDown(self):
await self.client.close()

@unittest.skip
async def test_create_valid_vector_indexing_policy_async(self):
test_data = VectorPolicyTestData["valid_vector_indexing_policy"]
indexing_policy = test_data["indexing_policy"]
vector_embedding_policy = test_data["vector_embedding_policy"]

created_container = await self.test_db.create_container(
id="container_" + str(uuid.uuid4()),
partition_key=PartitionKey(path="/id"),
vector_embedding_policy=vector_embedding_policy,
indexing_policy=indexing_policy)
properties = await created_container.read()
assert properties['indexingPolicy']['vectorIndexes'] == indexing_policy['vectorIndexes']
Comment thread
allenkim0129 marked this conversation as resolved.
await self.test_db.delete_container(created_container.id)

async def test_create_valid_vector_embedding_policy_async(self):
# Using valid data types
data_types = ["float32", "float16", "int8", "uint8"]
Expand Down