Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@ name: Pytest LLM Tests

on:
push:
branches: [ main ]
branches: [main]
pull_request:

env:
UV_LOCKED: true # Assert that the `uv.lock` will remain unchanged
UV_LOCKED: true # Assert that the `uv.lock` will remain unchanged

jobs:
container_job:
name: Pytest LLM Tests Python (${{ matrix.python-version }}) Postgres (${{ matrix.postgres-version }})
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.14']
postgres-version: ['17']
python-version: ["3.14"]
postgres-version: ["17"]
fail-fast: false
container: ubuntu:latest
services:
Expand Down Expand Up @@ -54,10 +55,9 @@ jobs:
DATABASE_URI: postgresql://nwa:nwa@postgres/orchestrator-core-test
ENVIRONMENT: TESTING
PYTHONPATH: .
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SEARCH_ENABLED: "true"
SEARCH_ENABLED: true
run: |
SEARCH_ENABLED=true uv run pytest test/integration_tests/search -v
uv run pytest test/integration_tests/search -v

- name: "Upload coverage to Codecov"
uses: codecov/codecov-action@v3
Expand Down
8 changes: 6 additions & 2 deletions orchestrator/search/aggregations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class TemporalGrouping(BaseModel):
},
)

@property
def alias(self) -> str:
"""Return the SQL-friendly alias for this temporal grouping."""
return f"{BaseAggregation.field_to_alias(self.field)}_{self.period.value}"

def get_pivot_fields(self) -> list[str]:
"""Return fields that need to be pivoted for this temporal grouping."""
return [self.field]
Expand All @@ -83,8 +88,7 @@ def to_expression(self, pivot_cte_columns: Any) -> tuple[Label, Any, str]:
col = getattr(pivot_cte_columns, field_alias)
truncated_col = func.date_trunc(self.period.value, cast(col, TIMESTAMP(timezone=True)))

# Column name without prefix
col_name = f"{field_alias}_{self.period.value}"
col_name = self.alias
select_col = truncated_col.label(col_name)
return select_col, truncated_col, col_name

Expand Down
70 changes: 67 additions & 3 deletions orchestrator/search/query/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from orchestrator.search.aggregations import AggregationType, BaseAggregation, CountAggregation
from orchestrator.search.core.types import EntityType, FieldType, FilterOp, UIType
from orchestrator.search.filters import LtreeFilter
from orchestrator.search.query.mixins import OrderDirection
from orchestrator.search.query.queries import AggregateQuery, CountQuery, Query


Expand Down Expand Up @@ -181,7 +182,8 @@ def _build_pivot_cte(base_query: Select, pivot_fields: list[str]) -> CTE:


def _build_grouping_columns(
query: CountQuery | AggregateQuery, pivot_cte: CTE
query: CountQuery | AggregateQuery,
pivot_cte: CTE,
) -> tuple[list[Any], list[Any], list[str]]:
"""Build GROUP BY columns and their SELECT columns.

Expand Down Expand Up @@ -244,6 +246,68 @@ def _build_aggregation_columns(query: CountQuery | AggregateQuery, pivot_cte: CT
return [count_agg.to_expression(pivot_cte.c.entity_id)]


def _apply_cumulative_aggregations(
stmt: Select,
query: CountQuery | AggregateQuery,
group_column_names: list[str],
aggregation_columns: list[Label],
) -> Select:
"""Add cumulative aggregation columns."""

# At this point, cumulative validation has already happened at query build time
# in GroupingMixin.validate_grouping_constraints, so we know:
# temporal_group_by exists and has exactly 1 element when cumulative=True
if not query.cumulative or not aggregation_columns or not query.temporal_group_by:
return stmt

temporal_alias = query.temporal_group_by[0].alias

base_subquery = stmt.subquery()
partition_cols = [base_subquery.c[name] for name in group_column_names if name != temporal_alias]
order_col = base_subquery.c[temporal_alias]

base_columns = [base_subquery.c[col] for col in base_subquery.c.keys()]

cumulative_columns = []
for agg_col in aggregation_columns:
cumulative_alias = f"{agg_col.key}_cumulative"
over_kwargs: dict[str, Any] = {"order_by": order_col}
if partition_cols:
over_kwargs["partition_by"] = partition_cols
cumulative_expr = func.sum(base_subquery.c[agg_col.key]).over(**over_kwargs).label(cumulative_alias)
cumulative_columns.append(cumulative_expr)

return select(*(base_columns + cumulative_columns)).select_from(base_subquery)


def _apply_ordering(
stmt: Select,
query: CountQuery | AggregateQuery,
group_column_names: list[str],
) -> Select:
"""Apply ordering instructions to the SELECT statement."""
columns_by_key = {col.key: col for col in stmt.selected_columns}

if query.order_by:
order_expressions = []
for instruction in query.order_by:
# Try direct field name first, then normalized alias
col = columns_by_key.get(instruction.field)
if col is None:
col = columns_by_key.get(BaseAggregation.field_to_alias(instruction.field))
if col is None:
raise ValueError(f"Cannot order by '{instruction.field}'; column not found.")
order_expressions.append(col.desc() if instruction.direction == OrderDirection.DESC else col.asc())
return stmt.order_by(*order_expressions)

if query.temporal_group_by:
# Default ordering by all grouping columns (ascending)
order_expressions = [columns_by_key[col_name].asc() for col_name in group_column_names]
return stmt.order_by(*order_expressions)

return stmt


def build_simple_count_query(base_query: Select) -> Select:
"""Build a simple count query without grouping.

Expand Down Expand Up @@ -282,7 +346,7 @@ def build_aggregation_query(query: CountQuery | AggregateQuery, base_query: Sele
if group_cols:
stmt = stmt.group_by(*group_cols)

if query.temporal_group_by:
stmt = stmt.order_by(*group_cols)
stmt = _apply_cumulative_aggregations(stmt, query, group_col_names, agg_cols)
stmt = _apply_ordering(stmt, query, group_col_names)

return stmt, group_col_names
59 changes: 57 additions & 2 deletions orchestrator/search/query/mixins.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,37 @@
import uuid
from enum import Enum
from typing import Self

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator

from orchestrator.search.aggregations import Aggregation, TemporalGrouping

__all__ = [
"SearchMixin",
"GroupingMixin",
"AggregationMixin",
"OrderBy",
"OrderDirection",
]


class OrderDirection(str, Enum):
"""Sorting direction for aggregation results."""

ASC = "asc"
DESC = "desc"


class OrderBy(BaseModel):
"""Ordering descriptor for aggregation responses."""

field: str = Field(description="Grouping or aggregation field/alias to order by.")
direction: OrderDirection = Field(
default=OrderDirection.ASC,
description="Sorting direction (asc or desc).",
)


class SearchMixin(BaseModel):
"""Mixin providing text search capability.

Expand Down Expand Up @@ -59,6 +80,37 @@ class GroupingMixin(BaseModel):
default=None,
description="Temporal grouping specifications (group by month, year, etc.)",
)
cumulative: bool = Field(
default=False,
description="Enable cumulative aggregations when temporal grouping is present.",
)
order_by: list[OrderBy] | None = Field(
default=None,
description="Ordering instructions for grouped aggregation results.",
)

@model_validator(mode="after")
def validate_grouping_constraints(self) -> Self:
"""Validate cross-field constraints for grouping features."""
if self.order_by and not self.group_by and not self.temporal_group_by:
raise ValueError(
"order_by requires at least one grouping field (group_by or temporal_group_by). "
"Ordering only applies to grouped aggregation results."
)

if self.cumulative:
if not self.temporal_group_by:
raise ValueError(
"cumulative requires at least one temporal grouping (temporal_group_by). "
"Cumulative aggregations compute running totals over time."
)
if len(self.temporal_group_by) > 1:
raise ValueError(
"cumulative currently supports only a single temporal grouping. "
"Multiple temporal dimensions with running totals are not yet supported."
)

return self

def get_pivot_fields(self) -> list[str]:
"""Get all fields needed for EAV pivot from grouping.
Expand All @@ -82,7 +134,10 @@ class AggregationMixin(BaseModel):
Used by AGGREGATE queries to define what statistics to compute.
"""

aggregations: list[Aggregation] = Field(description="Aggregations to compute (SUM, AVG, MIN, MAX, COUNT)")
aggregations: list[Aggregation] = Field(
description="Aggregations to compute (SUM, AVG, MIN, MAX, COUNT)",
min_length=1,
)

def get_aggregation_pivot_fields(self) -> list[str]:
"""Get fields needed for EAV pivot from aggregations.
Expand Down
16 changes: 15 additions & 1 deletion orchestrator/search/query/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from typing import Annotated, Any, ClassVar, Literal, Self, Union

from pydantic import BaseModel, ConfigDict, Discriminator, Field
from pydantic import BaseModel, ConfigDict, Discriminator, Field, model_validator

from orchestrator.search.core.types import ActionType, EntityType
from orchestrator.search.filters import FilterTree
Expand Down Expand Up @@ -112,6 +112,20 @@ class AggregateQuery(BaseQuery, GroupingMixin, AggregationMixin):
query_type: Literal["aggregate"] = "aggregate"
_action: ClassVar[ActionType] = ActionType.AGGREGATE

@model_validator(mode="after")
def validate_cumulative_aggregation_types(self) -> Self:
"""Validate that cumulative is only used with COUNT and SUM aggregations."""
if self.cumulative:
from orchestrator.search.aggregations import AggregationType

for agg in self.aggregations:
if agg.type in (AggregationType.AVG, AggregationType.MIN, AggregationType.MAX):
raise ValueError(
f"Cumulative aggregations are not supported for {agg.type.value.upper()} aggregations. "
f"Cumulative only works with COUNT and SUM."
)
return self

def get_pivot_fields(self) -> list[str]:
"""Get all fields needed for EAV pivot including aggregation fields."""
# Get grouping fields from GroupingMixin
Expand Down
13 changes: 7 additions & 6 deletions test/integration_tests/search/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,12 @@ def embedding_fixtures() -> dict[str, list[float]]:
def mock_embeddings(embedding_fixtures: dict[str, list[float]]):
"""Mock embedding API calls to return recorded embeddings.

This ensures consistent test results without calling the actual OpenAI API.
This ensures consistent test results without calling the actual API.
Only mocks async (llm_aembedding) as it's used during query execution.
"""

def mock_embedding_sync(model: str, input: list[str], **kwargs) -> MagicMock:
"""Mock synchronous embedding call."""
async def mock_embedding_async(model: str, input: list[str], **kwargs) -> MagicMock:
"""Mock async embedding call for query execution."""
mock_response = MagicMock()
mock_response.data = []

Expand All @@ -178,7 +179,7 @@ def mock_embedding_sync(model: str, input: list[str], **kwargs) -> MagicMock:

return mock_response

with patch("orchestrator.search.core.embedding.llm_embedding", side_effect=mock_embedding_sync):
with patch("orchestrator.search.core.embedding.llm_aembedding", side_effect=mock_embedding_async):
yield


Expand Down Expand Up @@ -217,12 +218,12 @@ def indexed_subscriptions(db_session, test_subscriptions, mock_embeddings, embed
the full product registry setup. The focus is on testing search ranking with
semantically meaningful descriptions.
"""
for sub in test_subscriptions:
for idx, sub in enumerate(test_subscriptions, start=1):
embedding = embedding_fixtures.get(sub.description.lower())
if embedding is None:
raise ValueError(f"No embedding found for subscription '{sub.description}' in ground_truth.json. ")

index_subscription(sub, embedding, db.session)
index_subscription(sub, embedding, db.session, subscription_index=idx)

db.session.commit()

Expand Down
24 changes: 22 additions & 2 deletions test/integration_tests/search/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,22 @@ def save_ground_truth(entities: list[dict], queries: list[dict]) -> None:
json.dump(ground_truth, f, indent=2)


def index_subscription(subscription: SubscriptionTable, embedding: list[float], session) -> None:
"""Index a single subscription into AiSearchIndex with three records.
def index_subscription(
subscription: SubscriptionTable, embedding: list[float], session, subscription_index: int = 1
) -> None:
"""Index a single subscription into AiSearchIndex with multiple records.

Creates:
- description field with embedding (for semantic search)
- status field without embedding (for filtering)
- insync field without embedding (for filtering)
- start_date field without embedding (for temporal grouping)

Args:
subscription: The subscription to index
embedding: The embedding vector for the subscription description
session: The SQLAlchemy session to use
subscription_index: Index of subscription (1-based) used to generate test dates
"""
# Index description with embedding
index_record = AiSearchIndex(
Expand Down Expand Up @@ -191,3 +195,19 @@ def index_subscription(subscription: SubscriptionTable, embedding: list[float],
embedding=None,
)
session.add(insync_record)

# Cycle through all 12 months: index 1-12 -> Jan-Dec, 13-22 -> Jan-Oct
month = ((subscription_index - 1) % 12) + 1
start_date = f"2024-{month:02d}-01T00:00:00"

start_date_record = AiSearchIndex(
entity_type=EntityType.SUBSCRIPTION.value,
entity_id=subscription.subscription_id,
entity_title=subscription.description[:50],
path=Ltree("subscription.start_date"),
value=start_date,
value_type=FieldType.DATETIME.value,
content_hash=f"test_hash_start_date_{subscription.subscription_id}",
embedding=None,
)
session.add(start_date_record)
Loading