diff --git a/.claude/skills/slayer-query.md b/.claude/skills/slayer-query.md index 5d15f72..c0dc06f 100644 --- a/.claude/skills/slayer-query.md +++ b/.claude/skills/slayer-query.md @@ -53,7 +53,9 @@ filters=[ **Boolean logic**: `and`, `or`, `not` within a single string -**Functions**: `contains(col, 'val')`, `starts_with(col, 'val')`, `ends_with(col, 'val')`, `between(col, 'a', 'b')`. Filters on measures are automatically routed to HAVING. +**Pattern matching**: `like` and `not like` operators (e.g., `"name like '%acme%'"`, `"name not like '%test%'"`). Filters on measures are automatically routed to HAVING. + +**Filtering on computed columns**: filters can reference field names from `fields` (e.g., `"rev_change < 0"`) or contain inline transform expressions (e.g., `"last(change(revenue)) < 0"`). These are applied as post-filters on the outer query. ## Executing diff --git a/CLAUDE.md b/CLAUDE.md index 4be561a..4a67b40 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -72,6 +72,7 @@ poetry run ruff check slayer/ tests/ - Dimension/measure SQL uses bare column names (e.g., `"amount"`); `${TABLE}` for complex expressions - Queries support `fields` — list of `{"formula": "...", "name": "...", "label": "..."}` parsed by `slayer/core/formula.py`. `label` is an optional human-readable display name (also supported on `ColumnRef` and `TimeDimension`) - Available formula functions: cumsum, time_shift, change, change_pct, rank, last (FIRST_VALUE window), lag, lead. time_shift, change, and change_pct always use self-join CTEs (no edge NULLs, gap-safe). time_shift uses row-number-based join without granularity, date-arithmetic-based with granularity. lag/lead use LAG/LEAD window functions directly (more efficient but produce NULLs at edges) +- Filters can reference computed field names or contain inline transform expressions (e.g., `"change(revenue) > 0"`, `"last(change(revenue)) < 0"`). These are auto-extracted as hidden fields and applied as post-filters on the outer query - Functions needing time ordering use resolution chain: query main_time_dimension -> query time_dimensions (if exactly 1) -> model default_time_dimension -> error - SlayerModel has optional `default_time_dimension` field for time-dependent formula resolution - SQLite dialect uses STRFTIME instead of DATE_TRUNC (handled automatically by sqlglot) diff --git a/README.md b/README.md index e5aeb9a..1b8476e 100644 --- a/README.md +++ b/README.md @@ -270,7 +270,12 @@ Filters use simple formula strings — no verbose JSON objects: "filters": ["status == 'completed' or status == 'pending'"] ``` -**Functions**: `contains(col, 'val')`, `starts_with(col, 'val')`, `ends_with(col, 'val')`, `between(col, 'a', 'b')`. Filters on measures (e.g., `"count > 10"`) are automatically routed to HAVING. +**Pattern matching**: `like` and `not like` operators (e.g., `"name like '%acme%'"`, `"name not like '%test%'"`). Filters on measures (e.g., `"count > 10"`) are automatically routed to HAVING. + +**Computed column filters**: filters can reference field names or contain inline transform expressions. These are applied as post-filters after all transforms are computed: +```json +"filters": ["change(revenue_sum) > 0", "last(change(revenue_sum)) < 0"] +``` ## Auto-Ingestion diff --git a/docs/concepts/formulas.md b/docs/concepts/formulas.md index 791de16..285c909 100644 --- a/docs/concepts/formulas.md +++ b/docs/concepts/formulas.md @@ -41,10 +41,10 @@ Functions apply window operations to measures: | `time_shift(x, -n)` | Value N periods back | Self-join CTE on row number | | `time_shift(x, 1)` | Next period's value | Self-join CTE on row number | | `time_shift(x, offset, gran)` | Value from a different calendar time bucket | Self-join CTE on date arithmetic | -| `change(x)` | Difference from previous period | Self-join CTE (current - previous) | -| `change_pct(x)` | Percentage change from previous period | Self-join CTE ((current - previous) / previous) | | `lag(x, n)` | Value N rows back (window function) | `LAG(x, n) OVER (ORDER BY time)` | | `lead(x, n)` | Value N rows ahead (window function) | `LEAD(x, n) OVER (ORDER BY time)` | +| `change(x)` | Difference from previous period | Self-join CTE (current - previous) | +| `change_pct(x)` | Percentage change from previous period | Self-join CTE ((current - previous) / previous) | | `rank(x)` | Ranking by value (descending) | `RANK() OVER (ORDER BY x DESC)` | | `last(x)` | Most recent time bucket's value | `FIRST_VALUE(x) OVER (ORDER BY time DESC ...)` | @@ -102,6 +102,8 @@ Filter formulas define conditions for the query. They go in the `filters` parame | `in` | `"status in ('active', 'pending')"` | | `is None` | `"discount is None"` (IS NULL) | | `is not None` | `"discount is not None"` (IS NOT NULL) | +| `like` | `"name like '%acme%'"` | +| `not like` | `"name not like '%test%'"` | ### Boolean Logic @@ -116,14 +118,39 @@ Use `and`, `or`, `not` within a single filter string: Multiple entries in the `filters` list are combined with AND. -### Filter Functions +### Filtering on Computed Columns -| Function | Example | SQL | -|----------|---------|-----| -| `contains(col, val)` | `"contains(name, 'acme')"` | `name LIKE '%acme%'` | -| `starts_with(col, val)` | `"starts_with(name, 'A')"` | `name LIKE 'A%'` | -| `ends_with(col, val)` | `"ends_with(email, '.com')"` | `email LIKE '%.com'` | -| `between(col, low, high)` | `"between(amount, 100, 500)"` | `amount BETWEEN 100 AND 500` | +Filters can reference names of computed fields — transforms and arithmetic expressions defined in `fields`. These are applied as post-filters on the outer query, after all transforms are computed. Note: bare measure renames (e.g., `{"formula": "count", "name": "n"}`) are not post-filterable by name; use the original measure name instead. + +```json +{ + "fields": [ + {"formula": "revenue"}, + {"formula": "change(revenue)", "name": "rev_change"} + ], + "filters": ["rev_change < 0"] +} +``` + +This returns only rows where revenue decreased from the previous period. + +Transform expressions can also be used **directly in filters** without defining them as fields first: + +```json +{ + "filters": ["last(change(revenue)) < 0"] +} +``` + +This keeps only rows where the most recent period's revenue change is negative — useful for queries like "show me monthly data, but only for metrics that are declining." The transform is auto-extracted as a hidden field and applied as a post-filter. + +Post-filters can be combined with regular filters — base filters (on dimensions/measures) are applied in the inner query, post-filters on the outer wrapper: + +```json +{ + "filters": ["status == 'completed'", "change(revenue) > 0"] +} +``` --- diff --git a/docs/index.md b/docs/index.md index 5a9d7ed..634af6f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -7,7 +7,8 @@ A lightweight, open-source semantic layer by [MotleyAI](https://github.com/motle ## Key Features - **Agent-first design** — MCP, Python SDK, and REST API interfaces -- **Datasource-agnostic** — Postgres, MySQL, BigQuery, Snowflake, and more via sqlglot +- **Datasource-agnostic** — first-class support for Postgres, MySQL, ClickHouse, and SQLite; additional support for Snowflake, BigQuery, Oracle, Redshift, DuckDB, and more via sqlglot +- **`fields` API** — derived metrics with formulas, transforms (`cumsum`, `time_shift`, `change`), and inline transform filters - **Auto-ingestion with rollup joins** — Connect to a DB, introspect schema, generate denormalized models with FK-based LEFT JOINs automatically - **Incremental model editing** — Add/remove measures and dimensions without replacing the full model - **Lightweight** — Minimal dependencies, easy to set up and extend diff --git a/examples/embedded/verify.py b/examples/embedded/verify.py index c6669bb..1409968 100644 --- a/examples/embedded/verify.py +++ b/examples/embedded/verify.py @@ -150,15 +150,15 @@ def check(name, condition): cumvals = [r["orders.cumulative"] for r in result.data] check("cumsum non-decreasing", all(a <= b for a, b in zip(cumvals, cumvals[1:]))) - # Lag + # time_shift (row-based, previous period) result = engine.execute(query=SlayerQuery( model="orders", time_dimensions=[{"dimension": {"name": "created_at"}, "granularity": "month"}], fields=[Field(formula="count"), Field(formula="time_shift(count, -1)", name="prev")], order=[{"column": {"name": "created_at"}, "direction": "asc"}], )) - check("lag first month is null", result.data[0]["orders.prev"] is None) - check("lag second month = first month count", result.data[1]["orders.prev"] == result.data[0]["orders.count"]) + check("time_shift first month is null", result.data[0]["orders.prev"] is None) + check("time_shift second month = first month count", result.data[1]["orders.prev"] == result.data[0]["orders.count"]) # Change result = engine.execute(query=SlayerQuery( diff --git a/poetry.lock b/poetry.lock index b394564..ab060cf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2165,10 +2165,9 @@ pytest = ">=8.2" name = "python-dateutil" version = "2.9.0.post0" description = "Extensions to the standard Python datetime module" -optional = true +optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" groups = ["main"] -markers = "extra == \"client\" or extra == \"all\" or extra == \"docs\"" files = [ {file = "python-dateutil-2.9.0.post0.tar.gz", hash = "sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3"}, {file = "python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427"}, @@ -2548,10 +2547,9 @@ files = [ name = "six" version = "1.17.0" description = "Python 2 and 3 compatibility utilities" -optional = true +optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" groups = ["main"] -markers = "extra == \"client\" or extra == \"all\" or extra == \"docs\"" files = [ {file = "six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274"}, {file = "six-1.17.0.tar.gz", hash = "sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81"}, @@ -3047,4 +3045,4 @@ postgres = ["psycopg2-binary"] [metadata] lock-version = "2.1" python-versions = "^3.11" -content-hash = "308d9efe83e4e57023d4bfe83c47c226300d45b31169b6c77a046aa658c044cc" +content-hash = "0e53d8abd9ee0a46e5e075fe78f3446f8dbda42931528a458e67852b1b2b30cc" diff --git a/pyproject.toml b/pyproject.toml index 67c605e..bac7d12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ python = "^3.11" sqlglot = ">=20.0" sqlalchemy = ">=2.0" pydantic = ">=2.0" +python-dateutil = ">=2.8" pyyaml = ">=6.0" fastapi = ">=0.100" uvicorn = ">=0.20" diff --git a/slayer/core/enums.py b/slayer/core/enums.py index e5298d0..491bde4 100644 --- a/slayer/core/enums.py +++ b/slayer/core/enums.py @@ -21,6 +21,7 @@ class DataType(StrEnum): AVERAGE = "avg" MIN = "min" MAX = "max" + LAST = "last" @property def is_aggregation(self) -> bool: @@ -31,6 +32,7 @@ def is_aggregation(self) -> bool: DataType.AVERAGE, DataType.MIN, DataType.MAX, + DataType.LAST, ) @property @@ -47,6 +49,7 @@ def python_type(self) -> type: DataType.AVERAGE: float, DataType.MIN: float, DataType.MAX: float, + DataType.LAST: float, }[self] diff --git a/slayer/core/formula.py b/slayer/core/formula.py index c001714..6ef4f8e 100644 --- a/slayer/core/formula.py +++ b/slayer/core/formula.py @@ -194,8 +194,8 @@ def _parse_literal(node: ast.AST, original: str) -> Any: # Filter parsing # --------------------------------------------------------------------------- -# String filter functions (no Python operator equivalent) -FILTER_FUNCTIONS = {"contains", "starts_with", "ends_with", "between"} +# Internal filter functions (used after pre-processing operators like `like`) +FILTER_FUNCTIONS = {"__like__", "__notlike__"} @dataclass @@ -208,6 +208,41 @@ class ParsedFilter: sql: str # e.g., "status = 'completed'" columns: List[str] # Column names referenced is_having: bool = False # True if this is a HAVING filter (aggregate condition) + is_post_filter: bool = False # True if this references a computed column (transform/expression) + + +def _preprocess_like(formula: str) -> str: + """Convert `like` and `not like` operators to internal function calls for AST parsing. + + "name like '%acme%'" → "__like__(name, '%acme%')" + "name not like '%acme%'" → "__notlike__(name, '%acme%')" + """ + import re + # Skip if already preprocessed (contains __like__ or __notlike__) + if "__like__" in formula or "__notlike__" in formula: + return formula + formula = re.sub( + r'\b(\w+)\s+not\s+like\s+', + r'__notlike__(\1, ', + formula, flags=re.IGNORECASE, + ) + # Close the parenthesis: find the string argument and close after it + formula = re.sub( + r'(__notlike__\([^,]+,\s*\'[^\']*\')', + r'\1)', + formula, + ) + formula = re.sub( + r'\b(\w+)\s+like\s+', + r'__like__(\1, ', + formula, flags=re.IGNORECASE, + ) + formula = re.sub( + r'(__like__\([^,]+,\s*\'[^\']*\')', + r'\1)', + formula, + ) + return formula def parse_filter(formula: str) -> ParsedFilter: @@ -222,13 +257,13 @@ def parse_filter(formula: str) -> ParsedFilter: "status in ('a', 'b', 'c')" → WHERE status IN ('a', 'b', 'c') "status is None" → WHERE status IS NULL "status is not None" → WHERE status IS NOT NULL - "contains(name, 'acme')" → WHERE name LIKE '%acme%' - "starts_with(name, 'A')" → WHERE name LIKE 'A%' - "ends_with(name, 'Inc')" → WHERE name LIKE '%Inc' - "between(created_at, '2024-01-01', '2024-12-31')" → WHERE created_at BETWEEN '...' AND '...' + "name like '%acme%'" → WHERE name LIKE '%acme%' + "name not like '%test%'" → WHERE name NOT LIKE '%test%' """ + # Pre-process `like` / `not like` operators into internal function calls + processed = _preprocess_like(formula) try: - tree = ast.parse(formula, mode="eval") + tree = ast.parse(processed, mode="eval") except SyntaxError as e: raise ValueError(f"Invalid filter syntax: {formula!r} — {e}") @@ -296,26 +331,30 @@ def _filter_node_to_sql(node: ast.AST, original: str, columns: list[str]) -> str elts = [_filter_node_to_sql(e, original, columns) for e in node.elts] return f"({', '.join(elts)})" - # Function call → contains, starts_with, ends_with, between + # Arithmetic expression (e.g., change / revenue in a filter LHS) + if isinstance(node, ast.BinOp): + op_map = { + ast.Add: "+", ast.Sub: "-", ast.Mult: "*", + ast.Div: "/", ast.Mod: "%", ast.Pow: "**", + } + op_str = op_map.get(type(node.op)) + if op_str is None: + raise ValueError(f"Unsupported arithmetic operator in filter: {original!r}") + left = _filter_node_to_sql(node.left, original, columns) + right = _filter_node_to_sql(node.right, original, columns) + return f"{left} {op_str} {right}" + + # Internal function calls for like/not like operators if isinstance(node, ast.Call) and isinstance(node.func, ast.Name): func_name = node.func.id - if func_name == "contains" and len(node.args) >= 2: + if func_name == "__like__" and len(node.args) >= 2: col = _filter_node_to_sql(node.args[0], original, columns) val = _get_string_arg(node.args[1], original) - return f"{col} LIKE '%{val}%'" - elif func_name == "starts_with" and len(node.args) >= 2: + return f"{col} LIKE '{val}'" + elif func_name == "__notlike__" and len(node.args) >= 2: col = _filter_node_to_sql(node.args[0], original, columns) val = _get_string_arg(node.args[1], original) - return f"{col} LIKE '{val}%'" - elif func_name == "ends_with" and len(node.args) >= 2: - col = _filter_node_to_sql(node.args[0], original, columns) - val = _get_string_arg(node.args[1], original) - return f"{col} LIKE '%{val}'" - elif func_name == "between" and len(node.args) >= 3: - col = _filter_node_to_sql(node.args[0], original, columns) - low = _filter_node_to_sql(node.args[1], original, columns) - high = _filter_node_to_sql(node.args[2], original, columns) - return f"{col} BETWEEN {low} AND {high}" + return f"{col} NOT LIKE '{val}'" raise ValueError(f"Unknown filter function '{func_name}' in: {original!r}") raise ValueError(f"Unsupported filter syntax: {original!r}") diff --git a/slayer/core/query.py b/slayer/core/query.py index ec1a0ee..25d8db1 100644 --- a/slayer/core/query.py +++ b/slayer/core/query.py @@ -107,7 +107,7 @@ class SlayerQuery(BaseModel): def snap_to_whole_periods(self) -> "SlayerQuery": """Adjust date filters to align with period boundaries when whole_periods_only=True. - For each time dimension with a granularity, adds a between() filter + For each time dimension with a granularity, adds a date range filter to exclude the current incomplete period if no date filter exists. """ if not self.whole_periods_only or not self.time_dimensions: diff --git a/slayer/engine/query_engine.py b/slayer/engine/query_engine.py index 9e6fb4b..1a46319 100644 --- a/slayer/engine/query_engine.py +++ b/slayer/engine/query_engine.py @@ -6,6 +6,7 @@ import logging from typing import Any, Dict, List, Optional +from slayer.core.enums import DataType from slayer.core.models import DatasourceConfig, SlayerModel from slayer.core.query import SlayerQuery from slayer.engine.enriched import ( @@ -134,13 +135,29 @@ def _enrich( )) # Resolve time dimension for transforms that need ORDER BY time. - # Resolution chain: query main_time_dimension → query time_dimensions (if exactly 1) → model default. + # Resolution chain: + # 1. query.main_time_dimension (explicit override) + # 2. First time dimension in query.time_dimensions (groupby) + # 3. First time dimension referenced in filters + # 4. model.default_time_dimension resolved_time_alias = None if query.main_time_dimension: resolved_time_alias = f"{model.name}.{query.main_time_dimension}" if resolved_time_alias is None and time_dimensions: - if len(time_dimensions) == 1: - resolved_time_alias = time_dimensions[0].alias + resolved_time_alias = time_dimensions[0].alias + if resolved_time_alias is None and query.filters: + # Check if any filter references a time/timestamp/date dimension + time_dim_names = { + d.name for d in model.dimensions + if d.type in (DataType.TIMESTAMP, DataType.DATE) + } + for f_str in query.filters: + for td_name in time_dim_names: + if td_name in f_str: + resolved_time_alias = f"{model.name}.{td_name}" + break + if resolved_time_alias: + break if resolved_time_alias is None and model.default_time_dimension: resolved_time_alias = f"{model.name}.{model.default_time_dimension}" @@ -229,6 +246,18 @@ def _flatten_spec(spec, field_name: str) -> str: return alias elif isinstance(spec, TransformField): + # Validate: nesting a self-join transform inside another is not supported + # (e.g., change(time_shift(x)) — the outer's shifted CTE can't replay the inner) + _self_join = {"time_shift", "change", "change_pct"} + if (spec.transform in _self_join + and isinstance(spec.inner, TransformField) + and spec.inner.transform in _self_join): + raise ValueError( + f"Nesting '{spec.transform}' around '{spec.inner.transform}' is not supported. " + f"Both use self-join CTEs. Try wrapping with a window function instead " + f"(e.g., cumsum, lag)." + ) + # Flatten inner first inner_name = f"_inner_{field_name}" if isinstance(spec.inner, MeasureRef): @@ -250,6 +279,11 @@ def _flatten_spec(spec, field_name: str) -> str: if len(spec.args) >= 2: granularity = str(spec.args[1]) + # change/change_pct look backward by default (like LAG), + # so negate offset for self-join semantics + if spec.transform in ("change", "change_pct") and not spec.args: + offset = -1 + _add_transform( name=field_name, transform=spec.transform, measure_alias=inner_alias, offset=offset, granularity=granularity, @@ -265,8 +299,27 @@ def _flatten_spec(spec, field_name: str) -> str: if isinstance(spec, MeasureRef): _ensure_measure(spec.name) + # If the measure has type=last, auto-wrap with last() transform + measure_def = model.get_measure(spec.name) + if measure_def and measure_def.type == DataType.LAST: + # Rename the base measure's alias to an internal name + # so it doesn't collide with the transform's output alias + base_alias = f"{query.model}.{spec.name}" + internal_alias = f"{query.model}._base_{spec.name}" + for m in measures: + if m.alias == base_alias: + m.alias = internal_alias + known_aliases[spec.name] = internal_alias + _add_transform( + name=field_name, transform="last", + measure_alias=internal_alias, offset=1, + ) + if field.label: + for t in enriched_transforms: + if t.alias == f"{query.model}.{field_name}": + t.label = field.label # Apply label to the measure if provided - if field.label: + elif field.label: for m in measures: if m.name == spec.name: m.label = field.label @@ -282,6 +335,19 @@ def _flatten_spec(spec, field_name: str) -> str: if t.alias == alias: t.label = field.label + # Pre-process filters: extract inline transform expressions + # (e.g., "last(change(revenue)) < 0" → hidden field + rewritten filter) + processed_filters = [] + ft_counter = [0] # Shared counter across all filters for unique _ftN names + for f_str in (query.filters or []): + rewritten, extra_fields = SlayerQueryEngine._extract_filter_transforms( + f_str, counter=ft_counter, + ) + for name, formula in extra_fields: + spec = parse_formula(formula) + _flatten_spec(spec, name) + processed_filters.append(rewritten) + return EnrichedQuery( model_name=model.name, sql_table=model.sql_table, @@ -292,8 +358,10 @@ def _flatten_spec(spec, field_name: str) -> str: expressions=enriched_expressions, transforms=enriched_transforms, filters=SlayerQueryEngine._classify_filters( - filters=[parse_filter(f) for f in (query.filters or [])], + filters=[parse_filter(f) for f in processed_filters], measure_names={m.name for m in measures}, + computed_names={t.name for t in enriched_transforms} + | {e.name for e in enriched_expressions}, ), order=query.order, limit=query.limit, @@ -301,10 +369,74 @@ def _flatten_spec(spec, field_name: str) -> str: ) @staticmethod - def _classify_filters(filters: list, measure_names: set) -> list: - """Classify filters as WHERE or HAVING based on whether they reference measures.""" + def _extract_filter_transforms(filter_str: str, + counter: list[int] = None) -> tuple[str, list[tuple[str, str]]]: + """Extract transform function calls from a filter string. + + Returns (rewritten_filter, [(name, formula), ...]) where transform + calls are replaced with generated field names. + + Args: + counter: Shared mutable counter [n] for unique _ftN names across + multiple filter strings. If None, starts at 0. + + Example: "last(change(revenue)) < 0" + → ("_ft0 < 0", [("_ft0", "last(change(revenue))")]) + """ + import ast as _ast + from slayer.core.formula import ALL_TRANSFORMS, _preprocess_like + + if counter is None: + counter = [0] + + # Pre-process `like`/`not like` operators so ast.parse doesn't fail + preprocessed = _preprocess_like(filter_str) + try: + tree = _ast.parse(preprocessed, mode="eval") + except SyntaxError: + return filter_str, [] + + transforms: list[tuple[str, str]] = [] + + def _replace(node): + if isinstance(node, _ast.Call) and isinstance(node.func, _ast.Name) and node.func.id in ALL_TRANSFORMS: + name = f"_ft{counter[0]}" + counter[0] += 1 + formula = _ast.unparse(node) + transforms.append((name, formula)) + return _ast.Name(id=name, ctx=_ast.Load()) + # Recurse into child nodes + if isinstance(node, _ast.BinOp): + node.left = _replace(node.left) + node.right = _replace(node.right) + elif isinstance(node, _ast.UnaryOp): + node.operand = _replace(node.operand) + elif isinstance(node, _ast.Compare): + node.left = _replace(node.left) + node.comparators = [_replace(c) for c in node.comparators] + elif isinstance(node, _ast.BoolOp): + node.values = [_replace(v) for v in node.values] + return node + + modified = _replace(tree.body) + if not transforms: + return filter_str, [] + return _ast.unparse(modified), transforms + + @staticmethod + def _classify_filters(filters: list, measure_names: set, + computed_names: set = None) -> list: + """Classify filters as WHERE, HAVING, or post-filter. + + Post-filters reference computed columns (transforms/expressions) and + are applied as a WHERE on an outer wrapper around the final query. + """ + computed_names = computed_names or set() for f in filters: - f.is_having = any(col in measure_names for col in f.columns) + if any(col in computed_names for col in f.columns): + f.is_post_filter = True + elif any(col in measure_names for col in f.columns): + f.is_having = True return filters def _resolve_datasource(self, model: SlayerModel) -> DatasourceConfig: diff --git a/slayer/mcp/server.py b/slayer/mcp/server.py index 096413c..be2cd09 100644 --- a/slayer/mcp/server.py +++ b/slayer/mcp/server.py @@ -166,8 +166,10 @@ def query( dimensions: List of dimension names to group by, e.g. ["status", "region"]. filters: Filter conditions as formula strings. Examples: "status == 'completed'", "amount > 100", "status in ('a', 'b')", "status is None", - "contains(name, 'acme')". Filters on measures are automatically routed to HAVING. + "name like '%acme%'". Filters on measures are automatically routed to HAVING. Supports and/or: "status == 'a' or status == 'b'". + Filters can also reference computed field names or contain inline transforms: + "change(revenue) > 0", "last(change(revenue)) < 0". time_dimensions: Time grouping. Format: {"dimension": "created_at", "granularity": "day|week|month|quarter|year", "date_range": ["2024-01-01", "2024-12-31"]}. order: Sorting. Format: {"column": "field_name", "direction": "asc|desc"}. limit: Max rows to return. diff --git a/slayer/sql/generator.py b/slayer/sql/generator.py index 20ec3ec..16ea927 100644 --- a/slayer/sql/generator.py +++ b/slayer/sql/generator.py @@ -5,6 +5,7 @@ query engine's _enrich() step. """ +import copy import logging from typing import Optional @@ -23,6 +24,8 @@ DataType.AVERAGE: "AVG", DataType.MIN: "MIN", DataType.MAX: "MAX", + DataType.LAST: "MAX", # Base aggregation for `last` type; the actual "most recent" logic + # is handled by auto-adding a last() transform during enrichment } # Transforms that use self-join CTEs instead of window functions. @@ -63,8 +66,118 @@ def generate(self, enriched: EnrichedQuery) -> str: # Wrap base query as CTE, compute expressions/transforms in outer SELECT return self._generate_with_computed(enriched=enriched, base_sql=base_sql) - def _generate_base(self, enriched: EnrichedQuery) -> str: - """Generate the base SELECT (measures, dimensions, filters).""" + def _generate_shifted_base(self, enriched: EnrichedQuery, transform, + calendar_join: bool = False) -> str: + """Generate a base query with date ranges shifted for a self-join transform. + + Instead of copying the base CTE (which has the original date filter and + would miss data outside that range), this generates a fresh query against + the source table with adjusted date ranges so the shifted CTE contains + the data needed for the join. + + When calendar_join is True, the raw timestamps are also shifted by -offset + inside the DATE_TRUNC so that the aggregated time buckets align with the + base query's buckets. This allows a simple equality join (no date arithmetic + in the ON clause). + """ + # Determine the shift: use transform's granularity, or fall back to + # the query's time dimension granularity for row-based transforms + gran = transform.granularity + offset = transform.offset + if not gran: + # Row-based: use the time dimension's granularity + for td in enriched.time_dimensions: + if td.alias == transform.time_alias: + gran = td.granularity.value + break + if not gran: + gran = "month" # Shouldn't happen — transforms require a time dim + + # Create a copy of enriched with shifted date ranges and (optionally) + # shifted time dimension expressions + shifted = copy.deepcopy(enriched) + + # Shift date ranges if present + has_date_ranges = any( + td.date_range and len(td.date_range) == 2 + for td in enriched.time_dimensions + ) + if has_date_ranges: + for td in shifted.time_dimensions: + if td.date_range and len(td.date_range) == 2: + td.date_range = [ + self._shift_date(date=td.date_range[0], offset=offset, granularity=gran), + self._shift_date(date=td.date_range[1], offset=offset, granularity=gran), + ] + + # For calendar joins, pass the time offset so _generate_base shifts raw + # timestamps before DATE_TRUNC. This makes aggregated buckets align with + # the base query's buckets → simple equality join. + time_offset = None + if calendar_join: + time_offset = (-offset, gran) + + return self._generate_base(enriched=shifted, time_offset=time_offset) + + def _build_time_offset_expr(self, col_expr: exp.Expression, offset: int, + granularity: str) -> exp.Expression: + """Apply a time offset to a column expression (dialect-aware). + + Used to shift raw timestamps before DATE_TRUNC in shifted CTEs so that + aggregated time buckets align with the base query's buckets. + """ + unit_map = {"year": "YEAR", "month": "MONTH", "day": "DAY", + "quarter": "MONTH", "week": "WEEK", "hour": "HOUR", + "minute": "MINUTE", "second": "SECOND"} + unit = unit_map.get(granularity, granularity.upper()) + val = offset * 3 if granularity == "quarter" else offset + + if self.dialect == "sqlite": + sqlite_units = {"YEAR": "years", "MONTH": "months", "DAY": "days", + "WEEK": "days", "HOUR": "hours", "MINUTE": "minutes", + "SECOND": "seconds"} + sqlite_unit = sqlite_units.get(unit, unit.lower() + "s") + sqlite_val = val * 7 if granularity == "week" else val + col_sql = col_expr.sql(dialect="sqlite") + return sqlglot.parse_one( + f"DATE({col_sql}, '{sqlite_val} {sqlite_unit}')", dialect="sqlite" + ) + + # Standard SQL: col + INTERVAL 'N' UNIT + interval_str = f"INTERVAL '{val}' {unit}" + col_sql = col_expr.sql(dialect=self.dialect) + return sqlglot.parse_one(f"{col_sql} + {interval_str}", dialect=self.dialect) + + @staticmethod + def _shift_date(date: str, offset: int, granularity: str) -> str: + """Shift a date string by offset units of granularity.""" + from datetime import datetime, timedelta + from dateutil.relativedelta import relativedelta + + dt = datetime.strptime(date[:10], "%Y-%m-%d") + gran_map = { + "year": relativedelta(years=offset), + "quarter": relativedelta(months=offset * 3), + "month": relativedelta(months=offset), + "week": timedelta(weeks=offset), + "day": timedelta(days=offset), + "hour": timedelta(hours=offset), + "minute": timedelta(minutes=offset), + "second": timedelta(seconds=offset), + } + delta = gran_map.get(granularity, relativedelta(months=offset)) + shifted = dt + delta + return shifted.strftime("%Y-%m-%d") + + def _generate_base(self, enriched: EnrichedQuery, + time_offset: Optional[tuple[int, str]] = None) -> str: + """Generate the base SELECT (measures, dimensions, filters). + + Args: + time_offset: Optional (offset, granularity) to shift raw timestamps + before DATE_TRUNC. Used by shifted CTEs so aggregated buckets + align with the base query for simple equality joins. + """ from_clause = self._build_from_clause(enriched=enriched) select_columns = [] @@ -77,6 +190,12 @@ def _generate_base(self, enriched: EnrichedQuery) -> str: for td in enriched.time_dimensions: col_expr = self._resolve_sql(sql=td.sql, name=td.name, model_name=td.model_name) + # Apply time offset before DATE_TRUNC (for shifted CTEs) + if time_offset is not None: + offset_val, offset_gran = time_offset + col_expr = self._build_time_offset_expr( + col_expr=col_expr, offset=offset_val, granularity=offset_gran, + ) col_expr = self._build_date_trunc(col_expr=col_expr, granularity=td.granularity) select_columns.append(col_expr.as_(td.alias)) group_by_columns.append(col_expr) @@ -140,8 +259,6 @@ def _generate_with_computed(self, enriched: EnrichedQuery, base_sql: str) -> str pending_expressions = list(enriched.expressions) pending_transforms = list(enriched.transforms) layer_num = 0 - has_self_joins = False # Track if any self-join was emitted (for ORDER BY qualification) - while pending_expressions or pending_transforms: layer_num += 1 prev_cte = ctes[-1][0] @@ -180,11 +297,27 @@ def _generate_with_computed(self, enriched: EnrichedQuery, base_sql: str) -> str # Now emit each self-join transform as its own CTE layer for t in deferred_self_joins: - has_self_joins = True src_cte = ctes[-1][0] - # Add ROW_NUMBER if row-based and not already present - if not t.granularity: + # Determine effective join granularity: + # - If transform has explicit granularity (calendar-based), use it + # - If no granularity (row-based) but date ranges are shifted, + # use the time dimension's granularity for calendar join + # - If no granularity and no date ranges, use row-number join + has_date_ranges = any( + td.date_range and len(td.date_range) == 2 + for td in enriched.time_dimensions + ) + join_granularity = t.granularity + if not join_granularity and has_date_ranges: + # Use query's time dimension granularity for calendar-based join + for td in enriched.time_dimensions: + if td.alias == t.time_alias: + join_granularity = td.granularity.value + break + + # Add ROW_NUMBER if using row-number join + if not join_granularity: time_col = f'"{t.time_alias}"' all_cols = ", ".join(f'"{a}"' for a in sorted(available_aliases)) rn_cte = f"{src_cte}_rn" @@ -192,13 +325,35 @@ def _generate_with_computed(self, enriched: EnrichedQuery, base_sql: str) -> str ctes.append((rn_cte, rn_sql)) src_cte = rn_cte + # Generate shifted CTE as a fresh base query with adjusted date ranges. + # For calendar joins, also shift raw timestamps so buckets align. + is_calendar = join_granularity is not None + shift_base_name = f"shifted_base_{t.name}" shift_name = f"shifted_{t.name}" - # Build the self-join CTE: src LEFT JOIN shifted ON condition → result column - time_col = f'"{t.time_alias}"' - join_cond = self._build_time_shift_join( - left_table=src_cte, right_table=shift_name, - time_col=time_col, offset=t.offset, granularity=t.granularity, + shifted_sql = self._generate_shifted_base( + enriched=enriched, transform=t, calendar_join=is_calendar, ) + ctes.append((shift_base_name, shifted_sql)) + + # For row-number joins, add ROW_NUMBER to the shifted CTE too + if not is_calendar: + time_col = f'"{t.time_alias}"' + shift_base_cols = ", ".join(f'"{a}"' for a in sorted(base_aliases)) + shift_rn_sql = f"SELECT {shift_base_cols}, ROW_NUMBER() OVER (ORDER BY {time_col}) AS _rn FROM {shift_base_name}" + ctes.append((shift_name, shift_rn_sql)) + else: + ctes.append((shift_name, f"SELECT * FROM {shift_base_name}")) + + # Build the self-join CTE: src LEFT JOIN shifted ON condition + time_col = f'"{t.time_alias}"' + if is_calendar: + # Calendar join: simple equality (shifted timestamps are already aligned) + join_cond = f'{src_cte}.{time_col} = {shift_name}.{time_col}' + else: + # Row-number join + join_cond = self._build_row_number_join( + left_table=src_cte, right_table=shift_name, offset=t.offset, + ) col_sql = self._build_self_join_column( transform=t.transform, left_table=src_cte, right_table=shift_name, measure_alias=t.measure_alias, @@ -211,8 +366,6 @@ def _generate_with_computed(self, enriched: EnrichedQuery, base_sql: str) -> str f"LEFT JOIN {shift_name}\n" f" ON {join_cond}" ) - # The shifted source CTE is a copy of src_cte - ctes.append((shift_name, f"SELECT * FROM {src_cte}")) ctes.append((join_layer, join_sql)) available_aliases.add(t.alias) added_this_layer.append(t.alias) @@ -262,6 +415,28 @@ def _generate_with_computed(self, enriched: EnrichedQuery, base_sql: str) -> str if enriched.offset is not None: sql += f"\nOFFSET {enriched.offset}" + # Apply post-filters (filters referencing computed columns) + post_filters = [f for f in enriched.filters if f.is_post_filter] + if post_filters: + import re + model = enriched.model_name + conditions = [] + for f in post_filters: + qualified_sql = f.sql + for col_name in dict.fromkeys(f.columns): + qualified_sql = re.sub( + rf'(? str: - """Build a JOIN condition for time_shift (row-based or calendar-based).""" - if granularity is None: - # Row-based: join on ROW_NUMBER offset - return f"{left_table}._rn + {offset} = {right_table}._rn" - if self.dialect == "sqlite": - # SQLite: DATE(col, 'N months') for date arithmetic - unit_map = {"year": "years", "month": "months", "day": "days", - "quarter": "months", "week": "days"} - unit = unit_map.get(granularity, granularity + "s") - multiplier = 3 if granularity == "quarter" else 7 if granularity == "week" else 1 - val = offset * multiplier - return f"DATE({left_table}.{time_col}, '{val} {unit}') = {right_table}.{time_col}" - # Standard SQL date arithmetic with dialect-specific syntax - unit_map = {"year": "YEAR", "month": "MONTH", "day": "DAY", - "quarter": "MONTH", "week": "WEEK"} - unit = unit_map.get(granularity, granularity.upper()) - val = offset * 3 if granularity == "quarter" else offset - right_col = f"{right_table}.{time_col}" - left_col = f"{left_table}.{time_col}" - if self.dialect == "bigquery": - return f"{left_col} = DATE_ADD({right_col}, INTERVAL {val} {unit})" - elif self.dialect in ("snowflake", "redshift"): - return f"{left_col} = DATEADD('{unit}', {val}, {right_col})" - elif self.dialect == "clickhouse": - return f"{left_col} = DATE_ADD({unit}, {val}, {right_col})" - elif self.dialect in ("trino", "presto"): - return f"{left_col} = DATE_ADD('{unit}', {val}, {right_col})" - elif self.dialect in ("databricks", "spark"): - return f"{left_col} = DATEADD({unit}, {val}, {right_col})" - elif self.dialect == "tsql": - return f"{left_col} = DATEADD({unit}, {val}, {right_col})" - # Postgres / MySQL / DuckDB — standard INTERVAL syntax - return f"{left_col} = {right_col} + INTERVAL '{val}' {unit}" + @staticmethod + def _build_row_number_join(left_table: str, right_table: str, offset: int) -> str: + """Build a row-number-based JOIN condition for row-based self-join transforms.""" + return f"{left_table}._rn + {offset} = {right_table}._rn" def _apply_order_limit(self, select: exp.Select, enriched: EnrichedQuery) -> exp.Select: """Apply ORDER BY, LIMIT, OFFSET to a select expression.""" @@ -495,6 +639,9 @@ def _build_where_and_having( import re model = enriched.model_name for f in enriched.filters: + # Post-filters are applied later, on the outer wrapper + if f.is_post_filter: + continue # Qualify column names with model name (deduplicate, word boundary, skip already qualified) qualified_sql = f.sql for col_name in dict.fromkeys(f.columns): # deduplicate preserving order diff --git a/tests/test_integration.py b/tests/test_integration.py index 4b9754e..df35a09 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -107,6 +107,7 @@ def integration_env(tmp_path): measures=[ Measure(name="count", type=DataType.COUNT), Measure(name="total_amount", sql="amount", type=DataType.SUM), + Measure(name="latest_amount", sql="amount", type=DataType.LAST), ], ) storage.save_model(orders_model) @@ -289,11 +290,12 @@ def test_cumsum_change_identity(integration_env): assert response.row_count == 3 assert "orders.cumsum_change" in response.columns - # With self-join change, the first row's change is NULL (no previous period), - # cumsum(NULL) = 0 in SQLite. The identity cumsum(change(x)) == x - x[0] - # holds for all rows (including the first, where it equals 0). + # First row: change is NULL (no previous period), cumsum(NULL) = NULL + assert response.data[0]["orders.cumsum_change"] is None + + # Remaining rows: cumsum(change(x)) == x - x[0] first_count = response.data[0]["orders.count"] - for row in response.data: + for row in response.data[1:]: assert row["orders.cumsum_change"] == row["orders.count"] - first_count @@ -403,3 +405,333 @@ def test_time_shift_calendar_based(integration_env): assert response.data[1]["orders.prev_month"] == pytest.approx(300.0) # Mar's previous month is Feb assert response.data[2]["orders.prev_month"] == pytest.approx(125.0) + + +def test_time_shift_with_date_range(integration_env): + """time_shift with date_range should fetch shifted data from outside the filtered range.""" + engine = integration_env + + # Query only March, but ask for previous month's value (February) + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + date_range=["2025-03-01", "2025-03-31"], + )], + fields=[ + Field(formula="total_amount"), + Field(formula="time_shift(total_amount, -1, 'month')", name="prev_month"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + # Only March in the result (date filter) + assert response.row_count == 1 + assert response.data[0]["orders.total_amount"] == pytest.approx(325.0) + # Previous month (February) should be fetched from the DB, not NULL + assert response.data[0]["orders.prev_month"] == pytest.approx(125.0) + + +def test_change_with_date_range(integration_env): + """change() with date_range should fetch previous period from outside the filtered range.""" + engine = integration_env + + # Query only March, change should compare to February + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + date_range=["2025-03-01", "2025-03-31"], + )], + fields=[ + Field(formula="total_amount"), + Field(formula="change(total_amount)", name="amount_change"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 1 + # March(325) - February(125) = 200 + assert response.data[0]["orders.amount_change"] == pytest.approx(200.0) + + +def test_change_pct_with_date_range(integration_env): + """change_pct() with date_range should compute correct percentage from shifted data.""" + engine = integration_env + + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + date_range=["2025-03-01", "2025-03-31"], + )], + fields=[ + Field(formula="total_amount"), + Field(formula="change_pct(total_amount)", name="pct"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 1 + # (325 - 125) / 125 = 1.6 + assert response.data[0]["orders.pct"] == pytest.approx(1.6) + + +def test_multiple_date_range_shifts(integration_env): + """Multiple self-join transforms with different offsets should each get correct shifted data.""" + engine = integration_env + + # Query Feb only, ask for both previous (Jan) and next (Mar) month + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + date_range=["2025-02-01", "2025-02-28"], + )], + fields=[ + Field(formula="total_amount"), + Field(formula="time_shift(total_amount, -1, 'month')", name="prev"), + Field(formula="time_shift(total_amount, 1, 'month')", name="next"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 1 + assert response.data[0]["orders.total_amount"] == pytest.approx(125.0) + # Jan = 300 + assert response.data[0]["orders.prev"] == pytest.approx(300.0) + # Mar = 325 + assert response.data[0]["orders.next"] == pytest.approx(325.0) + + +def test_forward_row_shift_with_date_range(integration_env): + """time_shift(x, 1) (forward, row-based) with date_range should fetch the next period.""" + engine = integration_env + + # Query Feb only, ask for the next period's value (March) + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + date_range=["2025-02-01", "2025-02-28"], + )], + fields=[ + Field(formula="total_amount"), + Field(formula="time_shift(total_amount, 1)", name="next_period"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 1 + assert response.data[0]["orders.total_amount"] == pytest.approx(125.0) + # Next period (March) should be fetched from DB = 325 + assert response.data[0]["orders.next_period"] == pytest.approx(325.0) + + +def test_post_filter_on_change(integration_env): + """Filter on a computed column (change) should only return matching rows.""" + engine = integration_env + + # 3 months: Jan(300), Feb(125), Mar(325) + # change values: Jan=NULL, Feb=125-300=-175, Mar=325-125=200 + # Filter: change < 0 → only February + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + )], + fields=[ + Field(formula="total_amount"), + Field(formula="change(total_amount)", name="amount_change"), + ], + filters=["amount_change < 0"], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + # Only February should remain (change = -175) + assert response.row_count == 1 + assert response.data[0]["orders.amount_change"] == pytest.approx(-175.0) + assert response.data[0]["orders.total_amount"] == pytest.approx(125.0) + + +def test_post_filter_with_base_filter(integration_env): + """Post-filter and base filter should both be applied correctly.""" + engine = integration_env + + # Without base filter: Jan(300), Feb(125), Mar(325) + # change: Jan=NULL, Feb=-175, Mar=200 + # Post-filter: amount_change > 0 → only March + # Base filter: status != 'cancelled' → excludes order 4 (cancelled, 75, Feb) + # Without cancelled: Jan(300), Feb(50), Mar(325) + # change: Jan=NULL, Feb=50-300=-250, Mar=325-50=275 + # Post-filter: amount_change > 0 → only March + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + )], + fields=[ + Field(formula="total_amount"), + Field(formula="change(total_amount)", name="amount_change"), + ], + filters=["status != 'cancelled'", "amount_change > 0"], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + # Only March (non-cancelled=325, change=275) + assert response.row_count == 1 + assert response.data[0]["orders.amount_change"] == pytest.approx(275.0) + + +def test_inline_transform_filter(integration_env): + """Transform expressions can be used directly in filters (auto-extracted as hidden fields).""" + engine = integration_env + + # 3 months: Jan(300), Feb(125), Mar(325) + # change: Jan=NULL, Feb=-175, Mar=200 + # Filter: change(total_amount) < 0 → only February + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + )], + fields=[Field(formula="total_amount")], + filters=["change(total_amount) < 0"], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 1 + assert response.data[0]["orders.total_amount"] == pytest.approx(125.0) + + +def test_inline_last_change_filter(integration_env): + """last(change(x)) in filter: keep rows only if the most recent period's change matches.""" + engine = integration_env + + # 3 months: Jan(300), Feb(125), Mar(325) + # change: Jan=NULL, Feb=-175, Mar=200 + # last(change) = 200 (March's change, broadcast to all rows) + # Filter: last(change(total_amount)) > 0 → all rows pass (200 > 0) + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + )], + fields=[Field(formula="total_amount")], + filters=["last(change(total_amount)) > 0"], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + # last(change) = 200 > 0, so all 3 rows pass + assert response.row_count == 3 + + # Now filter for < 0 → no rows pass (last change is 200) + query2 = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + )], + fields=[Field(formula="total_amount")], + filters=["last(change(total_amount)) < 0"], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response2 = engine.execute(query2) + assert response2.row_count == 0 + + +def test_arithmetic_transform_filter(integration_env): + """Arithmetic expressions with transforms in filters: change(x) / x > threshold.""" + engine = integration_env + + # 3 months: Jan(300), Feb(125), Mar(325) + # change: Jan=NULL, Feb=-175, Mar=200 + # change / total_amount: Jan=NULL, Feb=-175/125=-1.4, Mar=200/325≈0.615 + # Filter: change(total_amount) / total_amount > 0 → only March + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + )], + fields=[Field(formula="total_amount")], + filters=["change(total_amount) / total_amount > 0"], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + # Only March passes (positive change ratio) + assert response.row_count == 1 + assert response.data[0]["orders.total_amount"] == pytest.approx(325.0) + + +def test_transform_on_filter_rhs(integration_env): + """Transform expressions work on the RHS of filters too.""" + engine = integration_env + + # 3 months: Jan(300), Feb(125), Mar(325) + # time_shift(total_amount, -1): Jan=NULL, Feb=300, Mar=125 + # Filter: total_amount > time_shift(total_amount, -1) → months where value increased + # Jan: 300 > NULL → NULL (filtered out), Feb: 125 > 300 → false, Mar: 325 > 125 → true + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + )], + fields=[Field(formula="total_amount")], + filters=["total_amount > time_shift(total_amount, -1)"], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + # Only March (325 > 125) + assert response.row_count == 1 + assert response.data[0]["orders.total_amount"] == pytest.approx(325.0) + + +def test_last_measure_type(integration_env): + """A measure with type=last should return the most recent time bucket's value.""" + engine = integration_env + + # 3 months: Jan(300), Feb(125), Mar(325) + # latest_amount has type=last, so querying it as a bare measure + # should auto-wrap with last() and return Mar's value (325) for all rows + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), + granularity=TimeGranularity.MONTH, + )], + fields=[ + Field(formula="total_amount"), + Field(formula="latest_amount"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + response = engine.execute(query) + + assert response.row_count == 3 + # latest_amount should be the same (most recent) value for all rows + # Base agg is MAX(amount), March has max single order = 300 + latest_vals = [r["orders.latest_amount"] for r in response.data] + assert len(set(latest_vals)) == 1 # All rows have the same value + assert latest_vals[0] == pytest.approx(300.0) # March's MAX(amount) diff --git a/tests/test_integration_postgres.py b/tests/test_integration_postgres.py index 53ca929..ce4b5aa 100644 --- a/tests/test_integration_postgres.py +++ b/tests/test_integration_postgres.py @@ -8,9 +8,9 @@ from pytest_postgresql import factories -from slayer.core.enums import DataType +from slayer.core.enums import DataType, TimeGranularity from slayer.core.models import DatasourceConfig, Dimension, Measure, SlayerModel -from slayer.core.query import SlayerQuery +from slayer.core.query import ColumnRef, Field, OrderItem, SlayerQuery, TimeDimension from slayer.engine.ingestion import ingest_datasource from slayer.engine.query_engine import SlayerQueryEngine from slayer.storage.yaml_storage import YAMLStorage @@ -222,6 +222,88 @@ def test_composite_filter(self, pg_env: SlayerQueryEngine) -> None: result = pg_env.execute(query=query) assert result.data[0]["orders.count"] == 5 # 3 completed + 2 pending + def test_time_shift_with_date_range(self, pg_env: SlayerQueryEngine) -> None: + """time_shift with date_range should fetch shifted data from outside the filtered range.""" + # Query only March, ask for previous month (February) + # Seed: Jan(300), Feb(200), Mar(375) + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + date_range=["2024-03-01", "2024-03-31"], + )], + fields=[ + Field(formula="total"), + Field(formula="time_shift(total, -1, 'month')", name="prev_month"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + result = pg_env.execute(query=query) + assert result.row_count == 1 + assert float(result.data[0]["orders.total"]) == pytest.approx(375.0) + # Previous month (Feb) fetched from DB, not NULL + assert float(result.data[0]["orders.prev_month"]) == pytest.approx(200.0) + + def test_change_with_date_range(self, pg_env: SlayerQueryEngine) -> None: + """change() with date_range should fetch previous period from outside the filtered range.""" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + date_range=["2024-03-01", "2024-03-31"], + )], + fields=[ + Field(formula="total"), + Field(formula="change(total)", name="amount_change"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + result = pg_env.execute(query=query) + assert result.row_count == 1 + # Mar(375) - Feb(200) = 175 + assert float(result.data[0]["orders.amount_change"]) == pytest.approx(175.0) + + def test_change_pct_with_date_range(self, pg_env: SlayerQueryEngine) -> None: + """change_pct() with date_range should compute correct percentage from shifted data.""" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + date_range=["2024-03-01", "2024-03-31"], + )], + fields=[ + Field(formula="total"), + Field(formula="change_pct(total)", name="pct"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + result = pg_env.execute(query=query) + assert result.row_count == 1 + # (375 - 200) / 200 = 0.875 + assert float(result.data[0]["orders.pct"]) == pytest.approx(0.875) + + def test_multiple_date_range_shifts(self, pg_env: SlayerQueryEngine) -> None: + """Multiple self-join transforms with different offsets should each get correct data.""" + # Query Feb only, ask for both previous (Jan) and next (Mar) + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + date_range=["2024-02-01", "2024-02-29"], + )], + fields=[ + Field(formula="total"), + Field(formula="time_shift(total, -1, 'month')", name="prev"), + Field(formula="time_shift(total, 1, 'month')", name="next"), + ], + order=[OrderItem(column=ColumnRef(name="created_at"), direction="asc")], + ) + result = pg_env.execute(query=query) + assert result.row_count == 1 + assert float(result.data[0]["orders.total"]) == pytest.approx(200.0) + assert float(result.data[0]["orders.prev"]) == pytest.approx(300.0) # Jan + assert float(result.data[0]["orders.next"]) == pytest.approx(375.0) # Mar + @pytest.fixture def pg_ingest_env(postgresql): diff --git a/tests/test_sql_generator.py b/tests/test_sql_generator.py index de0d40e..5929792 100644 --- a/tests/test_sql_generator.py +++ b/tests/test_sql_generator.py @@ -161,7 +161,7 @@ def test_contains_filter(self, generator: SQLGenerator, orders_model: SlayerMode query = SlayerQuery( model="orders", fields=[Field(formula="count")], - filters=["contains(status, 'act')"], + filters=["status like '%act%'"], ) sql = _generate(generator, query, orders_model) assert "LIKE" in sql @@ -199,10 +199,11 @@ def test_date_range_filter(self, generator: SQLGenerator, orders_model: SlayerMo query = SlayerQuery( model="orders", fields=[Field(formula="count")], - filters=["between(created_at, '2024-01-01', '2024-06-30')"], + filters=["created_at >= '2024-01-01' and created_at <= '2024-06-30'"], ) sql = _generate(generator, query, orders_model) - assert "BETWEEN" in sql + assert ">=" in sql + assert "<=" in sql class TestMeasureTypes: @@ -432,6 +433,22 @@ def test_last(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: assert "FIRST_VALUE(" in sql assert "DESC" in sql + def test_last_measure_type(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """A measure with type=last should auto-wrap with last() transform.""" + orders_model.default_time_dimension = "created_at" + orders_model.measures.append(Measure(name="balance", sql="balance", type=DataType.LAST)) + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension(dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH)], + fields=[Field(formula="balance")], + ) + sql = _generate(generator, query, orders_model) + # Should auto-generate FIRST_VALUE (last() transform) + assert "FIRST_VALUE(" in sql + assert "DESC" in sql + # Base aggregation should use MAX + assert "MAX(" in sql + def test_time_shift(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: orders_model.default_time_dimension = "created_at" query = SlayerQuery( @@ -444,6 +461,160 @@ def test_time_shift(self, generator: SQLGenerator, orders_model: SlayerModel) -> assert "LEFT JOIN" in sql assert "INTERVAL" in sql + def test_time_shift_shifted_date_range(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """Calendar time_shift with date_range should shift the filter in the shifted CTE.""" + orders_model.default_time_dimension = "created_at" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + date_range=["2024-03-01", "2024-03-31"], + )], + fields=[Field(formula="revenue"), Field(formula="time_shift(revenue, -1, 'month')", name="rev_prev")], + ) + sql = _generate(generator, query, orders_model) + # Base CTE should have original date range + assert "2024-03-01" in sql + assert "2024-03-31" in sql + # Shifted CTE should have date range shifted back by 1 month + assert "2024-02-01" in sql + assert "2024-02-29" in sql + + def test_time_shift_yoy_shifted_date_range(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """Year-over-year time_shift should shift the date range by 1 year.""" + orders_model.default_time_dimension = "created_at" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + date_range=["2024-03-01", "2024-03-31"], + )], + fields=[Field(formula="revenue"), Field(formula="time_shift(revenue, -1, 'year')", name="rev_yoy")], + ) + sql = _generate(generator, query, orders_model) + # Shifted CTE should query March 2023 + assert "2023-03-01" in sql + assert "2023-03-31" in sql + + def test_change_shifted_date_range(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """Row-based change with date_range should shift the filter using query's time granularity.""" + orders_model.default_time_dimension = "created_at" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + date_range=["2024-03-01", "2024-03-31"], + )], + fields=[Field(formula="revenue"), Field(formula="change(revenue)", name="rev_change")], + ) + sql = _generate(generator, query, orders_model) + # change looks back 1 period — shifted CTE should query February + assert "2024-02-01" in sql + assert "2024-02-29" in sql + + def test_no_date_range_no_shift(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """Without a date_range, shifted CTE should still be a valid base query (no date filter).""" + orders_model.default_time_dimension = "created_at" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension(dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH)], + fields=[Field(formula="revenue"), Field(formula="time_shift(revenue, -1, 'month')", name="rev_prev")], + ) + sql = _generate(generator, query, orders_model) + # Both base and shifted CTEs should query the source table without date filters + assert "shifted_base_" in sql + assert "BETWEEN" not in sql + + def test_forward_time_shift_with_date_range(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """Forward time_shift(x, 1, 'month') with date_range should shift the filter forward.""" + orders_model.default_time_dimension = "created_at" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH, + date_range=["2024-03-01", "2024-03-31"], + )], + fields=[Field(formula="revenue"), Field(formula="time_shift(revenue, 1, 'month')", name="rev_next")], + ) + sql = _generate(generator, query, orders_model) + # Shifted CTE should query April (1 month forward) + assert "2024-04-01" in sql + assert "2024-04-30" in sql + + def test_quarter_date_shift(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """time_shift with quarter granularity should shift the date range by 3 months.""" + orders_model.default_time_dimension = "created_at" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension( + dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.QUARTER, + date_range=["2024-07-01", "2024-09-30"], + )], + fields=[Field(formula="revenue"), Field(formula="time_shift(revenue, -1, 'quarter')", name="prev_q")], + ) + sql = _generate(generator, query, orders_model) + # Q3 2024 shifted back 1 quarter = Q2 2024 + assert "2024-04-01" in sql + assert "2024-06-30" in sql + + def test_nested_self_join_raises(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """Nesting self-join transforms (e.g., change(time_shift(x))) should raise.""" + orders_model.default_time_dimension = "created_at" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension(dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH)], + fields=[Field(formula="revenue"), Field(formula="change(time_shift(revenue, -1, 'year'))", name="x")], + ) + with pytest.raises(ValueError, match="Nesting.*not supported"): + _generate(generator, query, orders_model) + + def test_post_filter_on_computed_column(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """Filters on computed columns should be applied as post-filter wrapper.""" + orders_model.default_time_dimension = "created_at" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension(dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH)], + fields=[Field(formula="revenue"), Field(formula="change(revenue)", name="rev_change")], + filters=["rev_change < 0"], + ) + sql = _generate(generator, query, orders_model) + # Should wrap in a post-filter SELECT + assert "_filtered" in sql + assert '"orders.rev_change" < 0' in sql + + def test_inline_transform_filter(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """Transform expressions in filters should be auto-extracted as hidden fields.""" + orders_model.default_time_dimension = "created_at" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension(dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH)], + fields=[Field(formula="revenue")], + filters=["last(change(revenue)) < 0"], + ) + sql = _generate(generator, query, orders_model) + # Should have the hidden transform columns + assert "FIRST_VALUE" in sql # last() + assert "shifted_" in sql # change() via self-join + # Should have post-filter wrapper + assert "_filtered" in sql + assert "< 0" in sql + + def test_mixed_base_and_post_filters(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: + """Base filters and post-filters should coexist correctly.""" + orders_model.default_time_dimension = "created_at" + query = SlayerQuery( + model="orders", + time_dimensions=[TimeDimension(dimension=ColumnRef(name="created_at"), granularity=TimeGranularity.MONTH)], + fields=[Field(formula="revenue"), Field(formula="change(revenue)", name="rev_change")], + filters=["status == 'completed'", "rev_change > 0"], + ) + sql = _generate(generator, query, orders_model) + # Base filter should be in the inner WHERE + assert "'completed'" in sql + # Post-filter should be in the outer wrapper + assert '"orders.rev_change" > 0' in sql + assert "_filtered" in sql + def test_transform_without_time_raises(self, generator: SQLGenerator, orders_model: SlayerModel) -> None: """Transforms requiring time should fail if no time dimension available.""" query = SlayerQuery( @@ -618,7 +789,7 @@ def test_date_trunc(self, dialect: str, orders_model: SlayerModel) -> None: @pytest.mark.parametrize("dialect", ALL_DIALECTS) def test_calendar_time_shift(self, dialect: str, orders_model: SlayerModel) -> None: - """Calendar-based time_shift should produce dialect-appropriate date arithmetic.""" + """Calendar-based time_shift should produce dialect-appropriate date arithmetic in shifted CTE.""" gen = SQLGenerator(dialect=dialect) query = SlayerQuery( model="orders", @@ -628,16 +799,10 @@ def test_calendar_time_shift(self, dialect: str, orders_model: SlayerModel) -> N sql = _generate(gen, query, orders_model) assert "shifted_" in sql assert "LEFT JOIN" in sql - # Dialect-specific date arithmetic + # Join should be simple equality (timestamp shift is inside the shifted CTE) + # Dialect-specific date arithmetic should appear in the shifted CTE's SELECT/GROUP BY sql_upper = sql.upper() if dialect == "sqlite": assert "DATE(" in sql_upper - elif dialect in ("bigquery", "clickhouse", "databricks", "spark", "tsql"): - assert "DATE_ADD(" in sql_upper or "DATEADD(" in sql_upper - elif dialect in ("snowflake", "redshift"): - assert "DATEADD(" in sql_upper - elif dialect in ("trino", "presto"): - assert "DATE_ADD(" in sql_upper else: - # Postgres, MySQL, DuckDB — INTERVAL syntax assert "INTERVAL" in sql_upper