diff --git a/python/tests/detail/test_collection_dql.py b/python/tests/detail/test_collection_dql.py index f4804f26d..52e196e4e 100644 --- a/python/tests/detail/test_collection_dql.py +++ b/python/tests/detail/test_collection_dql.py @@ -24,14 +24,9 @@ HnswQueryParam, IVFQueryParam, ) - - from zvec.model.schema import FieldSchema, VectorSchema from zvec.extension import RrfReRanker, WeightedReRanker, QwenReRanker from distance_helper import * - -from zvec import StatusCode -from distance_helper import * from fixture_helper import * from doc_helper import * from params_helper import * @@ -305,7 +300,7 @@ def test_query_with_filter_empty(self, full_collection: Collection, doc_num): ids2 = set(doc.id for doc in result2) assert ids1 == ids2 - @pytest.mark.parametrize("field_name", ["int32_field"]) + @pytest.mark.parametrize("field_name", DEFAULT_SCALAR_FIELD_NAME.values()) @pytest.mark.parametrize("doc_num", [10]) def test_query_with_filter_single_condition( self, full_collection: Collection, doc_num, field_name @@ -314,19 +309,81 @@ def test_query_with_filter_single_condition( generate_doc(i, full_collection.schema) for i in range(doc_num) ] batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") - filter = field_name + " > 5" + + # Construct different filter conditions based on field type + if field_name == "bool_field": + filter = field_name + " = true" + expected_doc_indices = [ + i + for i in range(doc_num) + if generate_doc(i, full_collection.schema).field(field_name) == True + ] + elif field_name in ["float_field", "double_field"]: + filter = field_name + " > 5.0" + expected_doc_indices = [ + i + for i in range(doc_num) + if generate_doc(i, full_collection.schema).field(field_name) > 5.0 + ] + elif field_name.startswith("array_"): + # For array types, we check that the array length is greater than 0 (indicating the array is not empty) + filter = ( + "array_length(" + field_name + ") > 0" + ) # Use array_length function to check array length + expected_doc_indices = [ + i + for i in range(doc_num) + if len(generate_doc(i, full_collection.schema).field(field_name)) > 0 + ] + elif field_name in ["string_field"]: + filter = field_name + " != 'lcy'" + expected_doc_indices = [ + i + for i in range(doc_num) + if generate_doc(i, full_collection.schema).field(field_name) != "lcy" + ] + elif field_name in [ + "int32_field", + "int64_field", + "uint32_field", + "uint64_field", + ]: # Integer type + filter = field_name + " > 5" + expected_doc_indices = [i for i in range(6, doc_num)] + else: + raise ValueError(f"Unsupported field type for filtering: {field_name}") + query_result = full_collection.query(filter=filter) - assert len(query_result) == doc_num - 6 + + assert len(query_result) == len(expected_doc_indices) returned_doc_ids = set() for doc in query_result: returned_doc_ids.add(doc.id) - expected_doc_ids = set(str(i) for i in range(6, doc_num)) + expected_doc_ids = set(str(i) for i in expected_doc_indices) for doc in query_result: assert doc.id in expected_doc_ids - assert int(doc.field(field_name)) > 5 + if field_name == "bool_field": + assert doc.field(field_name) == True + elif field_name in ["float_field", "double_field"]: + assert doc.field(field_name) > 5.0 + elif field_name.startswith("array_"): + # For array types, validate that the array length is greater than 0 + field_val = doc.field(field_name) + assert len(field_val) > 0 + elif field_name in [ + "int32_field", + "int64_field", + "uint32_field", + "uint64_field", + ]: # Integer type + assert int(doc.field(field_name)) > 5 + elif field_name in ["string_field"]: + assert doc.field(field_name) != "lcy" + else: + raise ValueError(f"Unsupported field type for validation: {field_name}") single_querydoc_check(multiple_docs, query_result, full_collection) @@ -450,6 +507,50 @@ def test_query_with_filter_parentheses( ) single_querydoc_check(multiple_docs, query_result, full_collection) + @pytest.mark.parametrize( + "filter", + [ + # Test combinations with different scalar field types using AND, OR, parentheses + "(int32_field > 2 AND int32_field < 8) OR bool_field = true", + "(float_field > 3.0 AND float_field < 7.0) AND string_field != 'exclude'", + "(double_field >= 1.5 OR double_field <= 8.5) AND uint32_field > 2", + "bool_field = false OR (int64_field > 3 AND int64_field < 9)", + "(string_field = 'special' OR string_field = 'test') AND int32_field > 1", + "(array_length(array_int32_field) > 0 AND int32_field > 5) OR bool_field = true", + "uint64_field > 1 AND uint64_field < 9 AND (float_field > 2.0 OR double_field < 9.0)", + "(bool_field = true OR string_field = 'special') AND (int32_field > 4 OR int64_field < 6)", + # More complex combinations covering more field types + "((int32_field > 1 AND int32_field < 5) OR (int64_field > 6 AND int64_field < 9)) AND bool_field = true", + "(uint32_field > 2 OR uint64_field < 8) AND (float_field > 1.0 AND double_field < 10.0)", + "(string_field != 'skip' AND int32_field >= 2) OR (bool_field = false AND double_field <= 7.5)", + # Additional combinations with array_length and other supported operations on array fields + "(array_length(array_string_field) > 0 OR int32_field > 5) AND bool_field = true", + "(array_length(array_int32_field) >= 1 AND float_field > 2.0) OR (array_length(array_float_field) > 0)", + "(array_length(array_bool_field) > 0 OR array_length(array_double_field) > 0) AND string_field != ''", + # Additional combinations with other supported scalar operations using range comparisons + "(int32_field > 1 AND int32_field < 10 OR string_field != 'exclude1') AND bool_field = true", + "(float_field > 1.0 AND float_field < 10.0 AND double_field > 0.5) OR (uint32_field > 5 AND uint32_field < 50)", + "(int64_field > 50 OR uint64_field < 1000) AND string_field != ''", + ], + ) + @pytest.mark.parametrize( + "doc_num", [20] + ) # Increased doc number for better coverage + def test_query_with_filter_complex_combinations( + self, full_collection: Collection, doc_num, filter + ): + multiple_docs = [ + generate_doc(i, full_collection.schema) for i in range(doc_num) + ] + batchdoc_and_check(full_collection, multiple_docs, doc_num, operator="insert") + query_result = full_collection.query(filter=filter) + assert query_result is not None + + for doc in query_result: + assert hasattr(doc, "id") + assert doc.id in [d.id for d in multiple_docs] + single_querydoc_check(multiple_docs, query_result, full_collection) + @pytest.mark.parametrize( "filter", [