diff --git a/.github/workflows/run-llm-integration-tests.yaml.yml b/.github/workflows/run-llm-integration-tests.yml similarity index 85% rename from .github/workflows/run-llm-integration-tests.yaml.yml rename to .github/workflows/run-llm-integration-tests.yml index c518ddace..6735204c7 100644 --- a/.github/workflows/run-llm-integration-tests.yaml.yml +++ b/.github/workflows/run-llm-integration-tests.yml @@ -5,10 +5,11 @@ name: Pytest LLM Tests on: push: - branches: [ main ] + branches: [main] + pull_request: env: - UV_LOCKED: true # Assert that the `uv.lock` will remain unchanged + UV_LOCKED: true # Assert that the `uv.lock` will remain unchanged jobs: container_job: @@ -16,8 +17,8 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.14'] - postgres-version: ['17'] + python-version: ["3.14"] + postgres-version: ["17"] fail-fast: false container: ubuntu:latest services: @@ -54,10 +55,9 @@ jobs: DATABASE_URI: postgresql://nwa:nwa@postgres/orchestrator-core-test ENVIRONMENT: TESTING PYTHONPATH: . - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - SEARCH_ENABLED: "true" + SEARCH_ENABLED: true run: | - SEARCH_ENABLED=true uv run pytest test/integration_tests/search -v + uv run pytest test/integration_tests/search -v - name: "Upload coverage to Codecov" uses: codecov/codecov-action@v3 diff --git a/orchestrator/search/aggregations/base.py b/orchestrator/search/aggregations/base.py index 8aa2c941e..725bb3c45 100644 --- a/orchestrator/search/aggregations/base.py +++ b/orchestrator/search/aggregations/base.py @@ -61,6 +61,11 @@ class TemporalGrouping(BaseModel): }, ) + @property + def alias(self) -> str: + """Return the SQL-friendly alias for this temporal grouping.""" + return f"{BaseAggregation.field_to_alias(self.field)}_{self.period.value}" + def get_pivot_fields(self) -> list[str]: """Return fields that need to be pivoted for this temporal grouping.""" return [self.field] @@ -83,8 +88,7 @@ def to_expression(self, pivot_cte_columns: Any) -> tuple[Label, Any, str]: col = getattr(pivot_cte_columns, field_alias) truncated_col = func.date_trunc(self.period.value, cast(col, TIMESTAMP(timezone=True))) - # Column name without prefix - col_name = f"{field_alias}_{self.period.value}" + col_name = self.alias select_col = truncated_col.label(col_name) return select_col, truncated_col, col_name diff --git a/orchestrator/search/query/builder.py b/orchestrator/search/query/builder.py index 60e8d1664..10d3223ee 100644 --- a/orchestrator/search/query/builder.py +++ b/orchestrator/search/query/builder.py @@ -25,6 +25,7 @@ from orchestrator.search.aggregations import AggregationType, BaseAggregation, CountAggregation from orchestrator.search.core.types import EntityType, FieldType, FilterOp, UIType from orchestrator.search.filters import LtreeFilter +from orchestrator.search.query.mixins import OrderDirection from orchestrator.search.query.queries import AggregateQuery, CountQuery, Query @@ -181,7 +182,8 @@ def _build_pivot_cte(base_query: Select, pivot_fields: list[str]) -> CTE: def _build_grouping_columns( - query: CountQuery | AggregateQuery, pivot_cte: CTE + query: CountQuery | AggregateQuery, + pivot_cte: CTE, ) -> tuple[list[Any], list[Any], list[str]]: """Build GROUP BY columns and their SELECT columns. @@ -244,6 +246,68 @@ def _build_aggregation_columns(query: CountQuery | AggregateQuery, pivot_cte: CT return [count_agg.to_expression(pivot_cte.c.entity_id)] +def _apply_cumulative_aggregations( + stmt: Select, + query: CountQuery | AggregateQuery, + group_column_names: list[str], + aggregation_columns: list[Label], +) -> Select: + """Add cumulative aggregation columns.""" + + # At this point, cumulative validation has already happened at query build time + # in GroupingMixin.validate_grouping_constraints, so we know: + # temporal_group_by exists and has exactly 1 element when cumulative=True + if not query.cumulative or not aggregation_columns or not query.temporal_group_by: + return stmt + + temporal_alias = query.temporal_group_by[0].alias + + base_subquery = stmt.subquery() + partition_cols = [base_subquery.c[name] for name in group_column_names if name != temporal_alias] + order_col = base_subquery.c[temporal_alias] + + base_columns = [base_subquery.c[col] for col in base_subquery.c.keys()] + + cumulative_columns = [] + for agg_col in aggregation_columns: + cumulative_alias = f"{agg_col.key}_cumulative" + over_kwargs: dict[str, Any] = {"order_by": order_col} + if partition_cols: + over_kwargs["partition_by"] = partition_cols + cumulative_expr = func.sum(base_subquery.c[agg_col.key]).over(**over_kwargs).label(cumulative_alias) + cumulative_columns.append(cumulative_expr) + + return select(*(base_columns + cumulative_columns)).select_from(base_subquery) + + +def _apply_ordering( + stmt: Select, + query: CountQuery | AggregateQuery, + group_column_names: list[str], +) -> Select: + """Apply ordering instructions to the SELECT statement.""" + columns_by_key = {col.key: col for col in stmt.selected_columns} + + if query.order_by: + order_expressions = [] + for instruction in query.order_by: + # Try direct field name first, then normalized alias + col = columns_by_key.get(instruction.field) + if col is None: + col = columns_by_key.get(BaseAggregation.field_to_alias(instruction.field)) + if col is None: + raise ValueError(f"Cannot order by '{instruction.field}'; column not found.") + order_expressions.append(col.desc() if instruction.direction == OrderDirection.DESC else col.asc()) + return stmt.order_by(*order_expressions) + + if query.temporal_group_by: + # Default ordering by all grouping columns (ascending) + order_expressions = [columns_by_key[col_name].asc() for col_name in group_column_names] + return stmt.order_by(*order_expressions) + + return stmt + + def build_simple_count_query(base_query: Select) -> Select: """Build a simple count query without grouping. @@ -282,7 +346,7 @@ def build_aggregation_query(query: CountQuery | AggregateQuery, base_query: Sele if group_cols: stmt = stmt.group_by(*group_cols) - if query.temporal_group_by: - stmt = stmt.order_by(*group_cols) + stmt = _apply_cumulative_aggregations(stmt, query, group_col_names, agg_cols) + stmt = _apply_ordering(stmt, query, group_col_names) return stmt, group_col_names diff --git a/orchestrator/search/query/mixins.py b/orchestrator/search/query/mixins.py index 551e623e8..c33cd670e 100644 --- a/orchestrator/search/query/mixins.py +++ b/orchestrator/search/query/mixins.py @@ -1,6 +1,8 @@ import uuid +from enum import Enum +from typing import Self -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator from orchestrator.search.aggregations import Aggregation, TemporalGrouping @@ -8,9 +10,28 @@ "SearchMixin", "GroupingMixin", "AggregationMixin", + "OrderBy", + "OrderDirection", ] +class OrderDirection(str, Enum): + """Sorting direction for aggregation results.""" + + ASC = "asc" + DESC = "desc" + + +class OrderBy(BaseModel): + """Ordering descriptor for aggregation responses.""" + + field: str = Field(description="Grouping or aggregation field/alias to order by.") + direction: OrderDirection = Field( + default=OrderDirection.ASC, + description="Sorting direction (asc or desc).", + ) + + class SearchMixin(BaseModel): """Mixin providing text search capability. @@ -59,6 +80,37 @@ class GroupingMixin(BaseModel): default=None, description="Temporal grouping specifications (group by month, year, etc.)", ) + cumulative: bool = Field( + default=False, + description="Enable cumulative aggregations when temporal grouping is present.", + ) + order_by: list[OrderBy] | None = Field( + default=None, + description="Ordering instructions for grouped aggregation results.", + ) + + @model_validator(mode="after") + def validate_grouping_constraints(self) -> Self: + """Validate cross-field constraints for grouping features.""" + if self.order_by and not self.group_by and not self.temporal_group_by: + raise ValueError( + "order_by requires at least one grouping field (group_by or temporal_group_by). " + "Ordering only applies to grouped aggregation results." + ) + + if self.cumulative: + if not self.temporal_group_by: + raise ValueError( + "cumulative requires at least one temporal grouping (temporal_group_by). " + "Cumulative aggregations compute running totals over time." + ) + if len(self.temporal_group_by) > 1: + raise ValueError( + "cumulative currently supports only a single temporal grouping. " + "Multiple temporal dimensions with running totals are not yet supported." + ) + + return self def get_pivot_fields(self) -> list[str]: """Get all fields needed for EAV pivot from grouping. @@ -82,7 +134,10 @@ class AggregationMixin(BaseModel): Used by AGGREGATE queries to define what statistics to compute. """ - aggregations: list[Aggregation] = Field(description="Aggregations to compute (SUM, AVG, MIN, MAX, COUNT)") + aggregations: list[Aggregation] = Field( + description="Aggregations to compute (SUM, AVG, MIN, MAX, COUNT)", + min_length=1, + ) def get_aggregation_pivot_fields(self) -> list[str]: """Get fields needed for EAV pivot from aggregations. diff --git a/orchestrator/search/query/queries.py b/orchestrator/search/query/queries.py index 3d988dfd2..3ddd44d38 100644 --- a/orchestrator/search/query/queries.py +++ b/orchestrator/search/query/queries.py @@ -13,7 +13,7 @@ from typing import Annotated, Any, ClassVar, Literal, Self, Union -from pydantic import BaseModel, ConfigDict, Discriminator, Field +from pydantic import BaseModel, ConfigDict, Discriminator, Field, model_validator from orchestrator.search.core.types import ActionType, EntityType from orchestrator.search.filters import FilterTree @@ -112,6 +112,20 @@ class AggregateQuery(BaseQuery, GroupingMixin, AggregationMixin): query_type: Literal["aggregate"] = "aggregate" _action: ClassVar[ActionType] = ActionType.AGGREGATE + @model_validator(mode="after") + def validate_cumulative_aggregation_types(self) -> Self: + """Validate that cumulative is only used with COUNT and SUM aggregations.""" + if self.cumulative: + from orchestrator.search.aggregations import AggregationType + + for agg in self.aggregations: + if agg.type in (AggregationType.AVG, AggregationType.MIN, AggregationType.MAX): + raise ValueError( + f"Cumulative aggregations are not supported for {agg.type.value.upper()} aggregations. " + f"Cumulative only works with COUNT and SUM." + ) + return self + def get_pivot_fields(self) -> list[str]: """Get all fields needed for EAV pivot including aggregation fields.""" # Get grouping fields from GroupingMixin diff --git a/test/integration_tests/search/conftest.py b/test/integration_tests/search/conftest.py index 302ec88e3..a42b99c79 100644 --- a/test/integration_tests/search/conftest.py +++ b/test/integration_tests/search/conftest.py @@ -162,11 +162,12 @@ def embedding_fixtures() -> dict[str, list[float]]: def mock_embeddings(embedding_fixtures: dict[str, list[float]]): """Mock embedding API calls to return recorded embeddings. - This ensures consistent test results without calling the actual OpenAI API. + This ensures consistent test results without calling the actual API. + Only mocks async (llm_aembedding) as it's used during query execution. """ - def mock_embedding_sync(model: str, input: list[str], **kwargs) -> MagicMock: - """Mock synchronous embedding call.""" + async def mock_embedding_async(model: str, input: list[str], **kwargs) -> MagicMock: + """Mock async embedding call for query execution.""" mock_response = MagicMock() mock_response.data = [] @@ -178,7 +179,7 @@ def mock_embedding_sync(model: str, input: list[str], **kwargs) -> MagicMock: return mock_response - with patch("orchestrator.search.core.embedding.llm_embedding", side_effect=mock_embedding_sync): + with patch("orchestrator.search.core.embedding.llm_aembedding", side_effect=mock_embedding_async): yield @@ -217,12 +218,12 @@ def indexed_subscriptions(db_session, test_subscriptions, mock_embeddings, embed the full product registry setup. The focus is on testing search ranking with semantically meaningful descriptions. """ - for sub in test_subscriptions: + for idx, sub in enumerate(test_subscriptions, start=1): embedding = embedding_fixtures.get(sub.description.lower()) if embedding is None: raise ValueError(f"No embedding found for subscription '{sub.description}' in ground_truth.json. ") - index_subscription(sub, embedding, db.session) + index_subscription(sub, embedding, db.session, subscription_index=idx) db.session.commit() diff --git a/test/integration_tests/search/helpers.py b/test/integration_tests/search/helpers.py index 71a47861a..6c9d0cf74 100644 --- a/test/integration_tests/search/helpers.py +++ b/test/integration_tests/search/helpers.py @@ -140,18 +140,22 @@ def save_ground_truth(entities: list[dict], queries: list[dict]) -> None: json.dump(ground_truth, f, indent=2) -def index_subscription(subscription: SubscriptionTable, embedding: list[float], session) -> None: - """Index a single subscription into AiSearchIndex with three records. +def index_subscription( + subscription: SubscriptionTable, embedding: list[float], session, subscription_index: int = 1 +) -> None: + """Index a single subscription into AiSearchIndex with multiple records. Creates: - description field with embedding (for semantic search) - status field without embedding (for filtering) - insync field without embedding (for filtering) + - start_date field without embedding (for temporal grouping) Args: subscription: The subscription to index embedding: The embedding vector for the subscription description session: The SQLAlchemy session to use + subscription_index: Index of subscription (1-based) used to generate test dates """ # Index description with embedding index_record = AiSearchIndex( @@ -191,3 +195,19 @@ def index_subscription(subscription: SubscriptionTable, embedding: list[float], embedding=None, ) session.add(insync_record) + + # Cycle through all 12 months: index 1-12 -> Jan-Dec, 13-22 -> Jan-Oct + month = ((subscription_index - 1) % 12) + 1 + start_date = f"2024-{month:02d}-01T00:00:00" + + start_date_record = AiSearchIndex( + entity_type=EntityType.SUBSCRIPTION.value, + entity_id=subscription.subscription_id, + entity_title=subscription.description[:50], + path=Ltree("subscription.start_date"), + value=start_date, + value_type=FieldType.DATETIME.value, + content_hash=f"test_hash_start_date_{subscription.subscription_id}", + embedding=None, + ) + session.add(start_date_record) diff --git a/test/integration_tests/search/test_query_builder.py b/test/integration_tests/search/test_query_builder.py index d68603ec2..58be4161f 100644 --- a/test/integration_tests/search/test_query_builder.py +++ b/test/integration_tests/search/test_query_builder.py @@ -14,10 +14,16 @@ import pytest from orchestrator.db import db -from orchestrator.search.aggregations import AggregationType, CountAggregation +from orchestrator.search.aggregations import ( + AggregationType, + CountAggregation, + TemporalGrouping, + TemporalPeriod, +) from orchestrator.search.core.types import BooleanOperator, EntityType, FilterOp, UIType from orchestrator.search.filters import EqualityFilter, FilterTree, PathFilter from orchestrator.search.query import engine +from orchestrator.search.query.mixins import OrderBy, OrderDirection from orchestrator.search.query.queries import AggregateQuery, CountQuery, ExportQuery, SelectQuery from orchestrator.types import SubscriptionLifecycle @@ -173,6 +179,61 @@ async def test_count_with_filters(self, indexed_subscriptions): assert response.results[0].group_values["insync"] == "true" assert response.results[0].aggregations["count"] == 21 + @pytest.mark.asyncio + async def test_aggregate_with_cumulative(self, indexed_subscriptions): + """Test cumulative aggregation with temporal grouping. + + Test data: 22 subscriptions distributed across 2024 (2 per month for Jan-Oct, 1 each for Nov-Dec). + Verifies that cumulative window functions produce correct running totals. + """ + query = AggregateQuery( + entity_type=EntityType.SUBSCRIPTION, + temporal_group_by=[ + TemporalGrouping(field="subscription.start_date", period=TemporalPeriod.MONTH), + ], + aggregations=[ + CountAggregation(type=AggregationType.COUNT, alias="count"), + ], + cumulative=True, + ) + + response = await engine.execute_aggregation(query, db.session) + + assert len(response.results) == 12, f"Should have 12 monthly groups, got {len(response.results)}" + + expected_counts = [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1] # Jan through Dec + expected_cumulative = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 21, 22] # Running totals + + for i, result in enumerate(response.results): + assert ( + result.aggregations["count"] == expected_counts[i] + ), f"Month {i+1}: expected count {expected_counts[i]}, got {result.aggregations['count']}" + assert ( + result.aggregations["count_cumulative"] == expected_cumulative[i] + ), f"Month {i+1}: expected cumulative {expected_cumulative[i]}, got {result.aggregations['count_cumulative']}" + + @pytest.mark.asyncio + async def test_count_with_ordering(self, indexed_subscriptions): + """Test COUNT query with ORDER BY.""" + query = CountQuery( + entity_type=EntityType.SUBSCRIPTION, + group_by=["status"], + order_by=[OrderBy(field="count", direction=OrderDirection.DESC)], + ) + + response = await engine.execute_aggregation(query, db.session) + + # Should be ordered by count descending (active=21, provisioning=1) + assert len(response.results) == 2, "Should have 2 status groups" + assert ( + response.results[0].group_values["status"] == SubscriptionLifecycle.ACTIVE.value + ), "First should be active" + assert response.results[0].aggregations["count"] == 21, "Active count should be 21" + assert ( + response.results[1].group_values["status"] == SubscriptionLifecycle.PROVISIONING.value + ), "Second should be provisioning" + assert response.results[1].aggregations["count"] == 1, "Provisioning count should be 1" + class TestExportQueryBuilder: """Test export query execution with filters.""" diff --git a/test/unit_tests/search/query/test_queries.py b/test/unit_tests/search/query/test_queries.py index 995c5e74b..87d506ba4 100644 --- a/test/unit_tests/search/query/test_queries.py +++ b/test/unit_tests/search/query/test_queries.py @@ -24,6 +24,8 @@ ) from orchestrator.search.core.types import ActionType, EntityType from orchestrator.search.filters import FilterTree +from orchestrator.search.query.builder import build_aggregation_query, build_candidate_query +from orchestrator.search.query.mixins import OrderBy, OrderDirection from orchestrator.search.query.queries import AggregateQuery, CountQuery, ExportQuery, Query, SelectQuery pytestmark = pytest.mark.search @@ -221,6 +223,110 @@ def test_aggregate_query_multiple_temporal_groupings( assert "subscription.end_date" in pivot_fields +class TestAggregationBuilderFeatures: + """Test helper behaviors in aggregation builder.""" + + @pytest.mark.parametrize( + "query_factory,error_match", + [ + ( + lambda: CountQuery(entity_type=EntityType.SUBSCRIPTION, cumulative=True), + "cumulative requires at least one temporal grouping", + ), + ( + lambda: CountQuery( + entity_type=EntityType.SUBSCRIPTION, + order_by=[OrderBy(field="count", direction=OrderDirection.DESC)], + ), + "order_by requires at least one grouping field", + ), + ( + lambda: AggregateQuery( + entity_type=EntityType.SUBSCRIPTION, + aggregations=[], + group_by=["subscription.status"], + ), + "at least 1 item", + ), + ], + ids=["cumulative-needs-temporal", "order_by-needs-grouping", "aggregations-required"], + ) + def test_query_validation_errors(self, query_factory, error_match): + """Test that query construction raises appropriate validation errors.""" + with pytest.raises(ValidationError, match=error_match): + query_factory() + + def test_order_by_uses_group_field(self): + """Order by resolves field paths.""" + query = CountQuery( + entity_type=EntityType.SUBSCRIPTION, + group_by=["subscription.product.name"], + order_by=[OrderBy(field="subscription.product.name", direction=OrderDirection.DESC)], + ) + base_query = build_candidate_query(query) + stmt, _ = build_aggregation_query(query, base_query) + sql = str(stmt.compile()) + + assert "ORDER BY" in sql.upper() + assert "DESC" in sql.upper() + + @pytest.mark.parametrize( + "agg_type", + [AggregationType.AVG, AggregationType.MIN, AggregationType.MAX], + ids=["avg", "min", "max"], + ) + def test_cumulative_rejects_unsupported_aggregations( + self, + temporal_grouping_month: TemporalGrouping, + agg_type: AggregationType, + ): + """Cumulative with AVG/MIN/MAX aggregations raises validation error at query construction. + + These aggregation types are rejected because running versions (e.g., running average, + running minimum) have no clear business meaning for cumulative totals. + """ + with pytest.raises(ValidationError, match=f"not supported for {agg_type.value.upper()} aggregations"): + AggregateQuery( + entity_type=EntityType.SUBSCRIPTION, + aggregations=[ + FieldAggregation(type=agg_type, field="subscription.price", alias="test_agg"), # type: ignore[arg-type] + ], + temporal_group_by=[temporal_grouping_month], + cumulative=True, + ) + + def test_cumulative_allows_sum_aggregation(self, temporal_grouping_month: TemporalGrouping): + """Cumulative with SUM aggregation generates correct SQL.""" + query = AggregateQuery( + entity_type=EntityType.SUBSCRIPTION, + aggregations=[ + FieldAggregation(type=AggregationType.SUM, field="subscription.price", alias="total_revenue"), + ], + temporal_group_by=[temporal_grouping_month], + cumulative=True, + ) + base_query = build_candidate_query(query) + stmt, _ = build_aggregation_query(query, base_query) + + sql = str(stmt.compile()).lower() + assert "over" in sql # Window function for cumulative + assert "total_revenue_cumulative" in sql # Cumulative column alias + + def test_cumulative_multiple_temporal_rejected(self, temporal_grouping_month: TemporalGrouping): + """Cumulative with multiple temporal groupings raises validation error at construction time.""" + temporal_grouping_end_date = TemporalGrouping( + field="subscription.end_date", + period=TemporalPeriod.MONTH, + ) + + with pytest.raises(ValidationError, match="supports only a single temporal grouping"): + CountQuery( + entity_type=EntityType.SUBSCRIPTION, + temporal_group_by=[temporal_grouping_month, temporal_grouping_end_date], + cumulative=True, + ) + + class TestQueryDiscriminator: """Test Pydantic discriminated union for Query type."""