Skip to content

Commit 22a0f6d

Browse files
committed
Add tests
1 parent 1bc5188 commit 22a0f6d

File tree

1 file changed

+173
-0
lines changed

1 file changed

+173
-0
lines changed

apis/python/test/test_ingestion.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,3 +2010,176 @@ def test_ivf_flat_taskgraph_query(tmp_path):
20102010
queries, k=k, nprobe=nprobe, nthreads=8, mode=Mode.LOCAL, num_partitions=10
20112011
)
20122012
assert accuracy(result, gt_i) > MINIMUM_ACCURACY
2013+
2014+
2015+
def test_dimensions_parameter_override(tmp_path):
2016+
"""
2017+
Test the dimensions parameter functionality with TileDB array input.
2018+
2019+
This test verifies that the dimensions parameter can override
2020+
the dimensions detected from the source array, which is useful
2021+
for handling cases where the source array has an artificially
2022+
large domain (e.g., due to TileDBSOMA; https://github.com/TileDB-Inc/TileDB-Vector-Search/issues/564).
2023+
"""
2024+
# Create test data
2025+
actual_dimensions = 64
2026+
nb = 1000
2027+
nq = 10
2028+
k = 5
2029+
2030+
# Create random test vectors with actual dimensions
2031+
test_vectors = np.random.rand(nb, actual_dimensions).astype(np.float32)
2032+
queries = np.random.rand(nq, actual_dimensions).astype(np.float32)
2033+
2034+
# Create a TileDB array with artificially large domain (simulating the problem)
2035+
source_uri = os.path.join(tmp_path, "source_array")
2036+
large_domain_value = 100000
2037+
2038+
# Create schema with large dimension domain
2039+
schema = tiledb.ArraySchema(
2040+
domain=tiledb.Domain(
2041+
tiledb.Dim(
2042+
name="__dim_0",
2043+
domain=(0, large_domain_value),
2044+
tile=1000,
2045+
dtype="int32",
2046+
),
2047+
tiledb.Dim(
2048+
name="__dim_1",
2049+
domain=(0, large_domain_value),
2050+
tile=actual_dimensions,
2051+
dtype="int32",
2052+
),
2053+
),
2054+
attrs=[
2055+
tiledb.Attr(name="values", dtype="float32", var=False, nullable=False),
2056+
],
2057+
cell_order="col-major",
2058+
tile_order="col-major",
2059+
capacity=10000,
2060+
sparse=False,
2061+
)
2062+
2063+
# Create the array and write test data
2064+
tiledb.Array.create(source_uri, schema)
2065+
with tiledb.open(source_uri, "w") as A:
2066+
A[0:nb, 0:actual_dimensions] = test_vectors
2067+
2068+
# Test ingestion with dimensions parameter override
2069+
# Without the override, the large domain would be detected as 100001 dimensions
2070+
# With the override, we explicitly set it to the actual dimensions (64)
2071+
index_uri = os.path.join(tmp_path, "test_index")
2072+
2073+
index = ingest(
2074+
index_type="FLAT",
2075+
index_uri=index_uri,
2076+
source_uri=source_uri,
2077+
source_type="TILEDB_ARRAY",
2078+
dimensions=actual_dimensions, # Override the detected large dimensions
2079+
size=nb,
2080+
)
2081+
2082+
# Verify the index was created successfully
2083+
assert index is not None
2084+
index.vacuum()
2085+
2086+
# Verify the index works correctly with queries
2087+
distances, indices = index.query(queries, k=k)
2088+
2089+
# Basic sanity checks
2090+
assert distances.shape == (nq, k)
2091+
assert indices.shape == (nq, k)
2092+
assert np.all(indices >= 0)
2093+
assert np.all(indices < nb)
2094+
2095+
# Verify that dimensions=-1 (or not passing at all) detects large dimensions but creates unusable index
2096+
# This demonstrates the problem that the dimensions parameter is meant to solve
2097+
index_uri_2 = os.path.join(tmp_path, "test_index_2")
2098+
2099+
# Create with explicit dimensions=-1 - this will use the large detected dimensions
2100+
# The index creation will succeed, but queries will fail due to dimension mismatch
2101+
index_2 = ingest(
2102+
index_type="FLAT",
2103+
index_uri=index_uri_2,
2104+
source_uri=source_uri,
2105+
source_type="TILEDB_ARRAY",
2106+
dimensions=-1, # Uses detected large dimensions (100001)
2107+
size=nb,
2108+
)
2109+
2110+
assert index_2 is not None
2111+
index_2.vacuum()
2112+
2113+
# Verify that the index was created with the large detected dimensions
2114+
assert index_2.dimensions == large_domain_value + 1 # 100001 dimensions
2115+
2116+
# Verify that queries fail due to dimension mismatch
2117+
# This demonstrates why the dimensions parameter override is needed
2118+
with pytest.raises(Exception) as exc_info:
2119+
index_2.query(queries, k=k)
2120+
assert (
2121+
"A query in queries has 64 dimensions, but the indexed data had 100001 dimensions"
2122+
in str(exc_info.value)
2123+
) # Should contain dimension mismatch error
2124+
2125+
2126+
def test_dimensions_parameter_with_numpy_input(tmp_path):
2127+
"""
2128+
Test the dimensions parameter with numpy input vectors.
2129+
2130+
This is to ensure that when input_vectors is provided as a numpy array,
2131+
the dimensions parameter is either ignored or validated correctly.
2132+
"""
2133+
# Create test data
2134+
nb = 100
2135+
actual_dimensions = 32
2136+
nq = 5
2137+
k = 3
2138+
2139+
# Create random test vectors
2140+
input_vectors = np.random.rand(nb, actual_dimensions).astype(np.float32)
2141+
queries = np.random.rand(nq, actual_dimensions).astype(np.float32)
2142+
2143+
# Ingest with numpy input and dimensions parameter (should be ignored since input_vectors is provided)
2144+
index_uri = os.path.join(tmp_path, "test_numpy_index")
2145+
2146+
# When input_vectors is provided, the dimensions parameter should not affect the detected dimensions
2147+
# but the function should still accept it without error
2148+
index = ingest(
2149+
index_type="FLAT",
2150+
index_uri=index_uri,
2151+
input_vectors=input_vectors,
2152+
dimensions=999, # This should be ignored since input_vectors is provided
2153+
)
2154+
2155+
# Verify the index was created successfully
2156+
assert index is not None
2157+
index.vacuum()
2158+
2159+
# Test that queries work correctly
2160+
distances, indices = index.query(queries, k=k)
2161+
2162+
# Basic sanity checks
2163+
assert distances.shape == (nq, k)
2164+
assert indices.shape == (nq, k)
2165+
assert np.all(indices >= 0)
2166+
assert np.all(indices < nb)
2167+
2168+
# Verify that dimensions parameter doesn't cause issues with default behavior
2169+
index_uri_2 = os.path.join(tmp_path, "test_numpy_index_2")
2170+
2171+
# Test without dimensions parameter (default behavior)
2172+
index_2 = ingest(
2173+
index_type="FLAT",
2174+
index_uri=index_uri_2,
2175+
input_vectors=input_vectors,
2176+
# No dimensions parameter - should work as before
2177+
)
2178+
2179+
assert index_2 is not None
2180+
index_2.vacuum()
2181+
2182+
# Verify both indexes produce similar results
2183+
distances_2, indices_2 = index_2.query(queries, k=k)
2184+
assert distances_2.shape == (nq, k)
2185+
assert indices_2.shape == (nq, k)

0 commit comments

Comments
 (0)