diff --git a/datajunction-server/datajunction_server/construction/build_v3/builder.py b/datajunction-server/datajunction_server/construction/build_v3/builder.py index 8d69ed750..77ef3dd3e 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/builder.py +++ b/datajunction-server/datajunction_server/construction/build_v3/builder.py @@ -54,6 +54,16 @@ logger = logging.getLogger(__name__) +# Engine tier preference for dialect auto-detection (fastest → slowest). +# When no dialect is specified, build_metrics_sql probes tiers in order: +# DRUID – served from a materialized cube (single Druid datasource scan) +# TRINO – served from pre-agg or source tables via Trino +# SPARK – served from pre-agg or source tables via Spark (default fallback) +# Full tier resolution (including Trino catalog engine lookup) is handled by +# resolve_dialect_and_engine_for_metrics in cube_matcher.py; the auto-detect +# logic here covers the common DRUID-vs-SPARK split. +_ENGINE_TIER_PREFERENCE = [Dialect.DRUID, Dialect.TRINO, Dialect.SPARK] + def _normalize_query_param_value(param: str, value: ast.Value | Any) -> ast.Value: """Normalize a Python value to an AST value node for query parameter substitution.""" @@ -471,9 +481,32 @@ async def build_metrics_sql( Layer 3: Derived Metrics Computes derived metrics that reference other metrics. """ - # Default to SPARK dialect if not specified + # Auto-detect dialect when none specified: probe fastest available engine tier. + # See _ENGINE_TIER_PREFERENCE for priority ordering (DRUID > TRINO > SPARK). + # Trino resolution requires a catalog engine lookup; that is handled by + # resolve_dialect_and_engine_for_metrics. Here we cover the DRUID-vs-SPARK split. if dialect is None: - dialect = Dialect.SPARK + if use_materialized: + # Probe Druid tier: look for a matching materialized cube. + probe_cube = ( + matched_cube + if matched_cube is not None + else await find_matching_cube( + session, + metrics, + dimensions, + require_availability=True, + ) + ) + if probe_cube: + dialect = Dialect.DRUID + matched_cube = probe_cube # reuse below, avoids second DB round-trip + else: + dialect = ( + Dialect.SPARK + ) # no cube; Trino tier needs catalog engine lookup + else: + dialect = Dialect.SPARK # Setup context (loads nodes, decomposes metrics, adds dimensions from expressions) ctx = await setup_build_context( @@ -485,9 +518,9 @@ async def build_metrics_sql( use_materialized=use_materialized, ) - # Use materialized cube if available. + # Use materialized cube when dialect is DRUID (explicit or auto-detected above). # Use pre-resolved cube if available (avoids duplicate find_matching_cube call). - if use_materialized: + if use_materialized and dialect == Dialect.DRUID: cube = ( matched_cube if matched_cube is not None diff --git a/datajunction-server/datajunction_server/construction/build_v3/cte.py b/datajunction-server/datajunction_server/construction/build_v3/cte.py index f566414fb..b7e45791a 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/cte.py +++ b/datajunction-server/datajunction_server/construction/build_v3/cte.py @@ -7,6 +7,7 @@ from copy import deepcopy from typing import Optional +from datajunction_server.construction.build_v3.filters import extract_subscript_role from datajunction_server.construction.build_v3.materialization import ( get_table_reference_parts_with_materialization, should_use_materialized_table, @@ -185,7 +186,7 @@ def replace_component_refs_in_ast( component_aliases: Mapping from component name to (table_alias, column_name) e.g., {"unit_price_sum_abc123": ("gg0", "sum_unit_price")} """ - for col in expr_ast.find_all(ast.Column): + for col in list(expr_ast.find_all(ast.Column)): # Get the column name (might be in name.name or just name) col_name = col.name.name if col.name else None if not col_name: # pragma: no cover @@ -194,7 +195,6 @@ def replace_component_refs_in_ast( # Check if this column name matches a component if col_name in component_aliases: # pragma: no branch table_alias, actual_col = component_aliases[col_name] - # Replace with qualified column reference col.name = ast.Name(actual_col) # Only set table if alias is non-empty (empty = no CTE prefix) col._table = ast.Table(ast.Name(table_alias)) if table_alias else None @@ -261,15 +261,7 @@ def replace_dimension_refs_in_ast( if not base_col_name: # pragma: no cover continue - # Get the role from the index (e.g., "order") - role = None - if isinstance(subscript.index, ast.Column): - role = subscript.index.name.name if subscript.index.name else None - elif isinstance(subscript.index, ast.Name): # pragma: no cover - role = subscript.index.name # pragma: no cover - elif hasattr(subscript.index, "name"): # pragma: no cover - role = str(subscript.index.name) # type: ignore - + role = extract_subscript_role(subscript) if not role: # pragma: no cover continue @@ -1019,15 +1011,14 @@ def process_metric_combiner_expression( """ Process a metric combiner expression for final output. - This function applies the same transformations used in generate_metrics_sql - (specifically build_derived_metric_expr) to ensure consistency between - SQL generation and stored metric expressions. + Transforms a raw combiner AST into the final SQL expression by replacing + component, metric, and dimension references with qualified column refs. Used by: - build_derived_metric_expr in generate_metrics_sql - cube materialization for storing metric_expression in config - Transformations applied (in order, matching build_derived_metric_expr): + Transformations applied (in order): 1. Replace metric references (e.g., "v3.total_revenue" -> column ref) 2. Replace component references (e.g., "revenue_sum_abc123" -> column ref) 3. Replace dimension references (e.g., "v3.date.dateint" -> column ref) @@ -1056,7 +1047,6 @@ def process_metric_combiner_expression( expr_ast = deepcopy(combiner_ast) # Replace metric references (for derived metrics referencing other metrics) - # This must happen first, matching build_derived_metric_expr order if metric_refs: replace_metric_refs_in_ast(expr_ast, metric_refs) diff --git a/datajunction-server/datajunction_server/construction/build_v3/cube_matcher.py b/datajunction-server/datajunction_server/construction/build_v3/cube_matcher.py index fcb53ab6c..3757bc05b 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/cube_matcher.py +++ b/datajunction-server/datajunction_server/construction/build_v3/cube_matcher.py @@ -391,14 +391,55 @@ async def build_sql_from_cube( return build_sql_from_cube_impl(ctx, cube, ctx.decomposed_metrics) +def _build_mat_col_lookup(cube: NodeRevision) -> dict[str, str]: + """ + Build a mapping from short column name -> physical column name by reading + the cube's materialization config columns. + + Example entry in config["columns"]: + { + "name": "common_DOT_dimensions_DOT_time_DOT_date_DOT_dateint", # physical Druid col + "column": "dateint", # short col name + "semantic_entity": "common.dimensions.time.date.dateint", + "semantic_type": "dimension", + ... + } + + We key on ``column`` (the short name) because that is what + parse_dimension_ref().column_name returns, and it is stable across + different namespace / path representations. + + Returns {} when no materialization config is available (e.g. in tests that + set availability directly without going through the materialization pipeline), + in which case callers fall back to the short column name unchanged. + """ + lookup: dict[str, str] = {} + for mat in cube.materializations or []: + for combiner in (mat.config or {}).get("combiners") or []: + for col_data in (combiner or {}).get("columns") or []: + short_name = col_data.get("column") + physical_name = col_data.get("name") + if short_name and physical_name: + lookup[short_name] = physical_name + return lookup + + def build_synthetic_grain_group( ctx: BuildContext, decomposed_metrics: dict[str, DecomposedMetricInfo], cube: NodeRevision, ) -> GrainGroupSQL: """ - Collect components from base metrics only (not derived). - V3 cube column naming always uses component.name (the hashed name) for consistency. + Build a synthetic GrainGroupSQL that reads from the cube's materialized Druid table. + + Physical column names are resolved from the cube's materialization config + (``materialization.config["columns"]``). Each entry there carries a + ``column`` key (the short column name, e.g. ``dateint``) and a ``name`` key + (the physical column name as it exists in the Druid table, e.g. + ``common_DOT_dimensions_DOT_time_DOT_date_DOT_dateint``). We key on the + short column name because that is what parse_dimension_ref().column_name + returns. When a match is found the physical name is used; otherwise we fall + back to the short name (which is correct for new-style materializations). """ all_components = [] component_aliases: dict[str, str] = {} @@ -406,8 +447,18 @@ def build_synthetic_grain_group( avail = cube.availability if not avail: # pragma: no cover raise ValueError(f"Cube {cube.name} has no availability") - table_parts = [p for p in [avail.catalog, avail.schema_, avail.table] if p] - table_name = ".".join(table_parts) + # Druid tables are referenced by the table name only (schema/catalog are not part of the ref). + # For other engines (e.g. Iceberg/Spark) we use the full catalog.schema.table path. + if ctx.dialect == Dialect.DRUID: + table_name = avail.table + else: + table_name = ".".join( + p for p in [avail.catalog, avail.schema_, avail.table] if p + ) + + # short_col_name -> physical column name from the materialization config. + # Empty when no materialization config is present (tests / direct calls). + mat_col_lookup = _build_mat_col_lookup(cube) for metric_name, decomposed in decomposed_metrics.items(): # Only process BASE metrics for component alias mapping @@ -418,40 +469,46 @@ def build_synthetic_grain_group( for comp in decomposed.components: if comp.name not in component_aliases: # pragma: no branch - # Always use component.name for consistency - no special case for single-component - cube_col_name = comp.name - + cube_col_name = mat_col_lookup.get(comp.name, comp.name) component_aliases[comp.name] = cube_col_name all_components.append(comp) # Build column metadata for the synthetic grain group grain_group_columns: list[ColumnMetadata] = [] - # Build mapping from dimension ref to short column name for filter resolution + # Build mapping from dimension ref to physical column name for filter resolution. # ctx.dimensions includes both requested dimensions AND filter-only dimensions # (filter-only dimensions were added by add_dimensions_from_filters() in setup_build_context) dimension_aliases: dict[str, str] = {} # Add all dimensions (requested + filter-only). We need all dimensions - # in the cube SELECT for proper filter resolution + # in the cube SELECT for proper filter resolution. + # dim_short_names holds the alias (short name) used everywhere outside the CTE. + # dim_physical_names holds the actual column name in the Druid table (may differ). dim_short_names = [] + dim_physical_names = [] for dim_ref in ctx.dimensions: parsed_dim = parse_dimension_ref(dim_ref) - col_name = parsed_dim.column_name + short_name = parsed_dim.column_name if parsed_dim.role: - col_name = f"{col_name}_{parsed_dim.role}" - dim_short_names.append(col_name) - dimension_aliases[dim_ref] = col_name + short_name = f"{short_name}_{parsed_dim.role}" + physical_name = mat_col_lookup.get(parsed_dim.column_name, short_name) + dim_short_names.append(short_name) + dim_physical_names.append(physical_name) + # Use the physical column name for WHERE clause resolution: + # the WHERE is applied directly on the cube table, so we must reference + # the physical column (e.g. common_DOT_..._DOT_dateint) not the alias. + dimension_aliases[dim_ref] = physical_name grain_group_columns.append( ColumnMetadata( - name=col_name, + name=short_name, semantic_name=dim_ref, type="string", # Will be refined by generate_metrics_sql semantic_type="dimension", ), ) - # Add component columns (using cube column names from component_aliases) + # Add component columns (always use the short/hash name as both physical and alias) for comp in all_components: cube_col_name = component_aliases[comp.name] grain_group_columns.append( @@ -466,9 +523,13 @@ def build_synthetic_grain_group( # Build the synthetic query: SELECT dims, components FROM cube_table WHERE filters projection: list[ast.Column] = [] - # Add all dimension columns (requested + filter-only) - for dim_col in dim_short_names: - projection.append(ast.Column(name=ast.Name(dim_col))) + # Add all dimension columns. When the physical name differs from the short alias, + # emit "physical_name AS short_name" so the rest of the query can use the short name. + for short_name, physical_name in zip(dim_short_names, dim_physical_names): + col = ast.Column(name=ast.Name(physical_name)) + if physical_name != short_name: + col = col.set_alias(ast.Name(short_name)) # type: ignore[assignment] + projection.append(col) # Add component columns (using cube column names) for comp in all_components: diff --git a/datajunction-server/datajunction_server/construction/build_v3/filters.py b/datajunction-server/datajunction_server/construction/build_v3/filters.py index 51517b120..7fd11b87b 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/filters.py +++ b/datajunction-server/datajunction_server/construction/build_v3/filters.py @@ -40,6 +40,29 @@ def parse_filter(filter_str: str) -> ast.Expression: return query.select.where +def extract_subscript_role(subscript: ast.Subscript) -> str | None: + """ + Extract the role string from a subscript index node. + + Handles the three forms that can appear as a subscript index: + - ast.Column: simple role like "order" (e.g., "v3.date.year[order]") + - ast.Name: simple role like "order" (fallback if parser produces Name instead of Column) + - ast.Lambda: multi-hop role (e.g., "v3.user[customer->home]") + + Returns the role string, or None if the index is not a recognised form. + """ + # simple role like "dim.attr[order]" + if isinstance(subscript.index, ast.Column): + return subscript.index.name.name if subscript.index.name else None + # simple role like "dim.attr[order]" + if isinstance(subscript.index, ast.Name): # pragma: no cover + return subscript.index.name + # multi-hop role like "dim.attr[customer->home]" + if isinstance(subscript.index, ast.Lambda): + return str(subscript.index) + return None # pragma: no cover + + def resolve_filter_references( filter_ast: ast.Expression, column_aliases: dict[str, str], @@ -81,17 +104,7 @@ def resolve_filter_references( if not base_col_ref: continue # pragma: no cover - # Extract the role from the subscript index - role = None - if isinstance(subscript.index, ast.Column): - role = subscript.index.name.name if subscript.index.name else None - elif isinstance(subscript.index, ast.Name): # pragma: no cover - role = subscript.index.name - elif isinstance(subscript.index, ast.Lambda): - # Multi-hop role notation like "customer->home" is parsed as a Lambda node. - # Lambda.__str__ returns the canonical role string (e.g., "customer->home"). - role = str(subscript.index) - + role = extract_subscript_role(subscript) if not role: continue # pragma: no cover diff --git a/datajunction-server/datajunction_server/construction/build_v3/metrics.py b/datajunction-server/datajunction_server/construction/build_v3/metrics.py index 2809f9fed..7918347eb 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/metrics.py +++ b/datajunction-server/datajunction_server/construction/build_v3/metrics.py @@ -195,19 +195,29 @@ def get_comp_aggregability(comp_name: str) -> Aggregability: return gg.component_aggregabilities.get(comp_name, Aggregability.FULL) return decomposed.aggregability - # Handle LIMITED aggregability (COUNT DISTINCT) specially - # This can't be pre-aggregated, so we need COUNT(DISTINCT grain_col) + # Handle LIMITED aggregability (COUNT DISTINCT). + # If the grain group was pre-aggregated (is_pre_aggregated=True), the wrapper CTE + # already computed COUNT(DISTINCT grain_key) and stored it as a named column. + # Emit SUM(pre_agg_col) — a no-op re-aggregation since the wrapper produces + # exactly 1 row per dimension combination. + # Otherwise fall through to COUNT(DISTINCT grain_col) against the raw CTE. if len(decomposed.components) == 1: comp = decomposed.components[0] orig_agg = get_comp_aggregability(comp.name) if orig_agg == Aggregability.LIMITED: _, col_name = comp_mappings[comp.name] - distinct_col = make_column_ref(col_name, cte_alias) + col_ref = make_column_ref(col_name, cte_alias) + if gg.is_pre_aggregated: + # Wrapper CTE already computed COUNT(DISTINCT) as a named integer column. + # Use MAX() as a no-op passthrough — there is exactly 1 row per dimension + # group so MAX = the value itself. MAX is semantically cleaner than SUM + # (which implies addition) and more widely supported than ANY_VALUE. + return ast.Function(ast.Name("MAX"), args=[col_ref]) agg_name = comp.aggregation or "COUNT" return ast.Function( ast.Name(agg_name), - args=[distinct_col], + args=[col_ref], quantifier=ast.SetQuantifier.Distinct, ) @@ -220,11 +230,97 @@ def get_comp_aggregability(comp_name: str) -> Aggregability: return expr_ast +def _build_pre_agg_wrapper_cte( + alias: str, + gg: GrainGroupSQL, +) -> tuple[ast.Query, str]: + """ + Build a pre-aggregation wrapper CTE for a LIMITED grain group. + + A LIMITED grain group CTE outputs N rows per dimension combination (one per distinct + grain key, e.g. customer_id). When FULL OUTER JOINed with other CTEs that have 1 + row per dimension combination, those rows fan out 1:N, causing SUM() to overcount. + + This wrapper collapses the N rows into 1 by applying COUNT(DISTINCT grain_key) inside + the CTE instead of in the outer SELECT. The outer SELECT can then use SUM() on the + already-computed count, which is a no-op when there's exactly 1 row per group. + + Args: + alias: The raw grain group CTE alias (e.g., "page_views_enriched_0") + gg: The LIMITED grain group + + Returns: + (wrapper_cte_ast, wrapper_alias) where wrapper_alias is e.g. + "page_views_enriched_0_agg" + """ + wrapper_alias = f"{alias}_agg" + + # Dimension columns for the GROUP BY: all grain columns except the LIMITED grain keys. + # gg.grain = [dim_col_aliases..., grain_key, ...] + # We want only the user-requested dimension columns (e.g., "category"), not the + # extra grain keys (e.g., "customer_id") that are being collapsed by COUNT DISTINCT. + limited_grain_keys = { + comp.rule.level[0] + for comp in gg.components + if comp.rule and comp.rule.type == Aggregability.LIMITED and comp.rule.level + } + dim_col_names = [col for col in gg.grain if col not in limited_grain_keys] + + # Build SELECT projection: dim cols + COUNT(DISTINCT grain_key) per component + projection: list[Any] = [ + ast.Column(name=ast.Name(col_name)) for col_name in dim_col_names + ] + for comp in gg.components: + if comp.rule and comp.rule.type == Aggregability.LIMITED: # pragma: no branch + grain_col = comp.rule.level[0] if comp.rule.level else None + if not grain_col: + continue # pragma: no cover + grain_col_ref = ast.Column(name=ast.Name(grain_col)) + count_expr = ast.Function( + ast.Name("COUNT"), + args=[grain_col_ref], + quantifier=ast.SetQuantifier.Distinct, + ) + projection.append(ast.Alias(child=count_expr, alias=ast.Name(comp.name))) + + # GROUP BY the dimension columns only (not the grain key) + group_by: list[ast.Expression] = [ + ast.Column(name=ast.Name(col_name)) for col_name in dim_col_names + ] + + from_clause = ast.From( + relations=[ + ast.Relation(primary=ast.Table(name=ast.Name(alias))), + ], + ) + + wrapper_query = ast.Query( + select=ast.Select( + projection=projection, + from_=from_clause, + group_by=group_by if group_by else [], + ), + ) + wrapper_query.to_cte(ast.Name(wrapper_alias), None) + return wrapper_query, wrapper_alias + + def collect_and_build_ctes( grain_groups: list[GrainGroupSQL], + skip_pre_agg: bool = False, ) -> tuple[list[ast.Query], list[str]]: """ Collect shared CTEs and convert grain groups to CTEs. + + For LIMITED grain groups (COUNT DISTINCT), also emits a pre-aggregation wrapper + CTE that collapses the N-rows-per-dimension output to 1 row per dimension by + computing COUNT(DISTINCT grain_key) inside the CTE. This prevents fan-out when + FULL OUTER JOINing with FULL grain groups. + + When skip_pre_agg=True (window function case), the wrapper CTEs are skipped. + The base_metrics CTE will use COUNT(DISTINCT ...) in its own GROUP BY, which + correctly handles fan-out from the FULL OUTER JOIN without the extra wrappers. + Returns (all_cte_asts, cte_aliases). """ # Collect all inner CTEs, dedupe by original name @@ -258,7 +354,6 @@ def collect_and_build_ctes( idx = parent_index_counter.get(parent_short, 0) parent_index_counter[parent_short] = idx + 1 alias = f"{parent_short}_{idx}" - cte_aliases.append(alias) # gg.query is already an AST - no need to parse! gg_query = gg.query @@ -272,6 +367,47 @@ def collect_and_build_ctes( gg_main.to_cte(ast.Name(alias), None) all_cte_asts.append(gg_main) + # For non-merged LIMITED grain groups, add a pre-aggregation wrapper CTE. + # This collapses N rows per dimension (one per distinct grain key) into 1 row + # by computing COUNT(DISTINCT grain_key) inside the CTE, preventing fan-out + # in the FULL OUTER JOIN step. + # Skip when skip_pre_agg=True (window function case): the base_metrics CTE + # uses GROUP BY + COUNT(DISTINCT ...) directly, handling fan-out correctly. + # Also skip when there's only one grain group: no FULL OUTER JOIN means no + # fan-out, so COUNT(DISTINCT ...) in the final GROUP BY is sufficient. + # Also skip when the grain key IS the dimension (dim_col_names would be empty): + # the grain group already has 1 row per dimension so there's no fan-out. + # A wrapper with no GROUP BY would produce a scalar aggregate, which is wrong. + limited_grain_keys = { + comp.rule.level[0] + for comp in gg.components + if comp.rule and comp.rule.type == Aggregability.LIMITED and comp.rule.level + } + has_non_grain_dims = any(col not in limited_grain_keys for col in gg.grain) + needs_pre_agg = ( + not skip_pre_agg + and len(grain_groups) > 1 + and not gg.is_merged + and gg.aggregability == Aggregability.LIMITED + and gg.components + and has_non_grain_dims + ) + if needs_pre_agg: + wrapper_cte, wrapper_alias = _build_pre_agg_wrapper_cte(alias, gg) + all_cte_asts.append(wrapper_cte) + # Record the pre-aggregated column name for each LIMITED component so that + # _build_metric_aggregation() can reference it by name instead of re-applying + # COUNT(DISTINCT). + for comp in gg.components: + if ( + comp.rule and comp.rule.type == Aggregability.LIMITED + ): # pragma: no branch + gg.component_aliases[comp.name] = comp.name + gg.is_pre_aggregated = True + cte_aliases.append(wrapper_alias) + else: + cte_aliases.append(alias) + return all_cte_asts, cte_aliases @@ -902,6 +1038,7 @@ def process_derived_metrics( resolver: ColumnResolver, partition_columns: list[str], window_cte_alias: str | None, + base_metric_exprs: dict[str, MetricExprInfo], intermediate_metric_names: set[str] | None = None, alias_to_dimension_node: dict[str, str] | None = None, ) -> dict[str, MetricExprInfo]: @@ -910,7 +1047,7 @@ def process_derived_metrics( Derived metrics are computed from base metrics and may include: - Window function metrics (LAG, LEAD, etc.) that reference base_metrics CTE - - Non-window derived metrics that reference other metric columns + - Non-window derived metrics that inline base metric expressions directly Args: ctx: Build context with metrics list and nodes @@ -919,6 +1056,8 @@ def process_derived_metrics( resolver: ColumnResolver with metric, component, and dimension refs partition_columns: Column names for PARTITION BY injection window_cte_alias: Alias of base_metrics CTE ("base_metrics" or None) + base_metric_exprs: Pre-built expressions for base metrics, used to inline + into derived metric formulas (correctly handles pre-aggregated cases) intermediate_metric_names: Optional set of intermediate derived metric names Returns: @@ -947,6 +1086,29 @@ def process_derived_metrics( if not decomposed: # pragma: no cover continue + # Check if this derived metric references any NONE-aggregability parent metrics. + # NONE-aggregability metrics (e.g. MAX_BY) cannot be safely combined into + # derived metrics — their raw-grain expressions reference columns that are + # not in scope at the derived metric's aggregation level. + parent_names = ctx.parent_map.get(metric_name, []) + none_parents = [ + p + for p in parent_names + if ctx.nodes.get(p) + and ctx.nodes.get(p).type == NodeType.METRIC # type: ignore + and decomposed_metrics.get(p) + and decomposed_metrics[p].aggregability == Aggregability.NONE + ] + if none_parents: + raise DJInvalidInputException( + f"Cannot compute derived metric '{metric_name}' because it references " + f"non-decomposable metric(s) with Aggregability.NONE: " + f"{none_parents}. " + f"Non-decomposable metrics (e.g. MAX_BY) cannot be combined into " + f"derived metrics — their expressions require raw-grain access that " + f"is not available at the derived metric's aggregation level.", + ) + short_name = get_short_name(metric_name) # Handle window function metrics specially @@ -963,11 +1125,21 @@ def process_derived_metrics( ) derived_cte_alias = window_cte_alias else: - expr_ast = build_derived_metric_expr( - decomposed, - resolver, - partition_columns, - alias_to_dimension_node, + # Use build_intermediate_metric_expr to inline pre-built base metric + # expressions into the derived metric formula. This correctly handles + # COUNT DISTINCT base metrics from pre-aggregated (_agg) CTEs, where + # the combiner_ast would produce COUNT(DISTINCT ...) but the pre-built + # expression already uses MAX(...) as the correct passthrough aggregation. + # Fall back to build_derived_metric_expr when some base metrics are missing + # from base_metric_exprs (e.g., NONE-aggregability metrics not in any grain group). + expr_ast = ( # type: ignore[assignment] + build_intermediate_metric_expr(ctx, metric_name, base_metric_exprs) + or build_derived_metric_expr( + decomposed, + resolver, + partition_columns, + alias_to_dimension_node, + ) ) derived_cte_alias = default_cte_alias @@ -1558,8 +1730,30 @@ def generate_metrics_sql( base_grain_groups = [gg for gg in grain_groups if not gg.is_window_grain_group] window_grain_groups = [gg for gg in grain_groups if gg.is_window_grain_group] + # Pre-detect whether there will be a base_metrics CTE (any window function metrics). + # When a base_metrics CTE is built, it uses GROUP BY + COUNT(DISTINCT ...) which + # correctly handles fan-out from FULL OUTER JOIN — so _agg wrapper CTEs are unnecessary. + # _agg wrappers are only needed when cross-grain CTEs are joined directly in the + # final SELECT without a GROUP BY that can apply COUNT(DISTINCT). + all_base_metrics_precheck: set[str] = { + m for gg in base_grain_groups for m in gg.metrics + } + will_have_base_metrics_cte = ( + bool(window_grain_groups) + or bool(measures_result.window_metric_grains) + or any( + decomposed_metrics.get(m) + and has_window_function(decomposed_metrics[m].combiner_ast) + for m in ctx.metrics + if m not in all_base_metrics_precheck + ) + ) + # Convert base grain groups to CTEs (window grain groups handled separately) - all_cte_asts, cte_aliases = collect_and_build_ctes(base_grain_groups) + all_cte_asts, cte_aliases = collect_and_build_ctes( + base_grain_groups, + skip_pre_agg=will_have_base_metrics_cte, + ) # Build dimension info and projection # Filter out filter-only dimensions (they're needed for WHERE but not output) @@ -1862,6 +2056,7 @@ def collect_derived_dependencies(metric_name: str, visited: set[str]) -> None: resolver, all_dim_aliases, window_metrics_cte_alias, + base_metrics_result.metric_exprs, set(), # No intermediate derived metrics in new architecture alias_to_dimension_node, ) @@ -1915,13 +2110,29 @@ def collect_derived_dependencies(metric_name: str, visited: set[str]) -> None: applicable_dimension_filters = [] for f in dimension_filters_raw: filter_ast = parse_filter(f) - # Check if any column ref in this filter is a filter-only dimension + # Check if any column ref in this filter is a filter-only dimension. + # Must handle both plain column refs and role-qualified subscript refs + # (e.g., "v3.location.country[customer->home]"), because find_all(ast.Column) + # returns only the base Column inside the Subscript, not the full role string. refs_filter_only = False - for col in filter_ast.find_all(ast.Column): - full_name = get_column_full_name(col) - if full_name and full_name in ctx.filter_dimensions: - refs_filter_only = True + for subscript in filter_ast.find_all(ast.Subscript): + if not isinstance(subscript.expr, ast.Column): + continue # pragma: no cover + base_ref = get_column_full_name(subscript.expr) + if base_ref: # pragma: no branch + for fd in ctx.filter_dimensions: + fd_base = fd.split("[")[0] if "[" in fd else fd + if fd_base == base_ref: # pragma: no branch + refs_filter_only = True + break + if refs_filter_only: break + if not refs_filter_only: + for col in filter_ast.find_all(ast.Column): + full_name = get_column_full_name(col) + if full_name and full_name in ctx.filter_dimensions: + refs_filter_only = True + break if not refs_filter_only: applicable_dimension_filters.append(f) diff --git a/datajunction-server/datajunction_server/construction/build_v3/types.py b/datajunction-server/datajunction_server/construction/build_v3/types.py index 6796f4918..ad58c4a28 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/types.py +++ b/datajunction-server/datajunction_server/construction/build_v3/types.py @@ -242,6 +242,12 @@ class GrainGroupSQL: # instead of individual grain group CTEs. is_cross_fact_window: bool = False + # Pre-aggregation: True when collect_and_build_ctes() added a wrapper CTE that + # applies COUNT(DISTINCT grain_key) per requested dimension combination. + # When True, _build_metric_aggregation() should emit SUM(pre_agg_col) instead of + # COUNT(DISTINCT raw_grain_col), since the wrapper CTE already did the DISTINCT work. + is_pre_aggregated: bool = False + # Scan estimation: source tables accessed during SQL generation # Populated by collect_node_ctes during CTE building scanned_sources: list[str] = field(default_factory=list) diff --git a/datajunction-server/datajunction_server/construction/build_v3/utils.py b/datajunction-server/datajunction_server/construction/build_v3/utils.py index 36200ade2..7eb9aad3c 100644 --- a/datajunction-server/datajunction_server/construction/build_v3/utils.py +++ b/datajunction-server/datajunction_server/construction/build_v3/utils.py @@ -3,6 +3,10 @@ import logging from typing import TYPE_CHECKING +from datajunction_server.construction.build_v3.filters import ( + extract_subscript_role, + parse_filter, +) from datajunction_server.database.node import Node from datajunction_server.sql.parsing import ast from datajunction_server.utils import SEPARATOR @@ -162,6 +166,37 @@ def collect_required_dimensions( return sorted(required_dims) +def _try_add_dim_to_ctx( + full_name: str, + ctx: "BuildContext", + existing_dims: set[str], + log_source: str, +) -> None: + """ + Add a fully-qualified dimension ref to ctx.dimensions if not already covered. + + Helper used by add_dimensions_from_metric_expressions to deduplicate logic. + """ + # Import here to avoid circular imports + from datajunction_server.construction.build_v3.dimensions import parse_dimension_ref + + if not full_name or SEPARATOR not in full_name or full_name in existing_dims: + return + if full_name in ctx.metrics: # pragma: no cover + return + dim_ref = parse_dimension_ref(full_name) + for existing_dim in ctx.dimensions: + existing_ref = parse_dimension_ref(existing_dim) + if ( + existing_ref.node_name == dim_ref.node_name + and existing_ref.column_name == dim_ref.column_name + ): + return + logger.info("[BuildV3] Auto-adding dimension %s from %s", full_name, log_source) + ctx.dimensions.append(full_name) + existing_dims.add(full_name) + + def add_dimensions_from_metric_expressions( ctx: "BuildContext", decomposed_metrics: dict[str, "DecomposedMetricInfo"], @@ -173,42 +208,44 @@ def add_dimensions_from_metric_expressions( weren't explicitly requested by the user or marked as required_dimensions. We add them so they're included in the grain group SQL. + Also scans the original metric query's window function ORDER BY clauses, because + decomposition strips OVER clauses from derived_ast (the combiner_ast source), so + ORDER BY dimension refs like ``common.dimensions.time.date.dateint`` would otherwise + be invisible to this scan. + Args: ctx: BuildContext with dimensions list to update decomposed_metrics: Dict of metric_name -> DecomposedMetricInfo with combiner ASTs """ # Import here to avoid circular imports from datajunction_server.construction.build_v3.cte import get_column_full_name - from datajunction_server.construction.build_v3.dimensions import parse_dimension_ref existing_dims = set(ctx.dimensions) - for decomposed in decomposed_metrics.values(): + for metric_name, decomposed in decomposed_metrics.items(): combiner_ast = decomposed.combiner_ast for col in combiner_ast.find_all(ast.Column): full_name = get_column_full_name(col) - if full_name and SEPARATOR in full_name and full_name not in existing_dims: - # Skip if this is a metric reference (e.g., in derived metric combiners) - # Metrics should not be added as dimensions - if full_name in ctx.metrics: - continue # pragma: no cover - - # Check if any existing dimension already covers this (node, column) - dim_ref = parse_dimension_ref(full_name) - is_covered = False - for existing_dim in ctx.dimensions: - existing_ref = parse_dimension_ref(existing_dim) - if ( - existing_ref.node_name == dim_ref.node_name - and existing_ref.column_name == dim_ref.column_name - ): - is_covered = True - break - if not is_covered: - logger.info( - f"[BuildV3] Auto-adding dimension {full_name} from metric expression", - ) - ctx.dimensions.append(full_name) - existing_dims.add(full_name) + _try_add_dim_to_ctx(full_name, ctx, existing_dims, "metric expression") + + # Also scan the original metric query for window function ORDER BY dimension refs. + # During decomposition, aggregation functions like SUM(...) OVER (ORDER BY dim) + # are replaced with component names, losing the OVER clause from derived_ast. + # So combiner_ast.find_all(ast.Column) never sees ORDER BY dimension refs. + metric_node = ctx.nodes.get(metric_name) + if metric_node: # pragma: no branch + original_query = ctx.get_parsed_query(metric_node) + for func in original_query.find_all(ast.Function): + if not func.over or not func.over.order_by: + continue + for sort_item in func.over.order_by: + for col in sort_item.find_all(ast.Column): + full_name = get_column_full_name(col) + _try_add_dim_to_ctx( + full_name, + ctx, + existing_dims, + "metric window ORDER BY", + ) def add_dimensions_from_filters(ctx: "BuildContext") -> None: @@ -226,10 +263,9 @@ def add_dimensions_from_filters(ctx: "BuildContext") -> None: Args: ctx: BuildContext with filters and dimensions lists to update """ - # Import here to avoid circular imports + # Import here to avoid circular imports (cte.py imports utils.py) from datajunction_server.construction.build_v3.cte import get_column_full_name from datajunction_server.construction.build_v3.dimensions import parse_dimension_ref - from datajunction_server.construction.build_v3.filters import parse_filter if not ctx.filters: return @@ -240,16 +276,73 @@ def add_dimensions_from_filters(ctx: "BuildContext") -> None: try: filter_ast = parse_filter(filter_str) except Exception: # pragma: no cover - logger.warning(f"[BuildV3] Failed to parse filter: {filter_str}") + logger.warning("[BuildV3] Failed to parse filter: %s", filter_str) continue - # Find all column references in the filter + # Track base column refs handled via role-qualified subscript notation + # (e.g., "v3.location.country" from "v3.location.country[customer->home]") + # so we don't also add the role-less version in the Column pass below. + subscript_handled_refs: set[str] = set() + + # First pass: handle Subscript nodes for role-qualified dimension refs. + # SQL like "v3.location.country[customer->home]" is parsed as + # Subscript(Column(v3.location.country), Lambda(customer->home)). + for subscript in filter_ast.find_all(ast.Subscript): + if not isinstance(subscript.expr, ast.Column): + continue # pragma: no cover + + base_col_ref = get_column_full_name(subscript.expr) + if not base_col_ref or SEPARATOR not in base_col_ref: + continue # pragma: no cover + + role = extract_subscript_role(subscript) + if role: + full_name = f"{base_col_ref}[{role}]" + else: # pragma: no cover + full_name = base_col_ref + + # Mark this base ref as handled so the Column pass skips it + subscript_handled_refs.add(base_col_ref) + + if full_name in existing_dims: + continue + + if full_name in ctx.metrics: # pragma: no cover + continue + + dim_ref = parse_dimension_ref(full_name) + is_covered = False + for existing_dim in ctx.dimensions: + existing_ref = parse_dimension_ref(existing_dim) + if ( + existing_ref.node_name == dim_ref.node_name + and existing_ref.column_name == dim_ref.column_name + and existing_ref.role == dim_ref.role + ): + is_covered = True # pragma: no cover + break # pragma: no cover + + if not is_covered: # pragma: no branch + logger.info( + "[BuildV3] Auto-adding filter-only dimension %s", + full_name, + ) + ctx.dimensions.append(full_name) + ctx.filter_dimensions.add(full_name) + existing_dims.add(full_name) + + # Second pass: handle regular Column references. + # Skip columns that were already added via the subscript pass above. for col in filter_ast.find_all(ast.Column): full_name = get_column_full_name(col) if not full_name or SEPARATOR not in full_name: # Simple column name (e.g., "status") - will be resolved from parent node continue + # Skip if already handled as a role-qualified subscript ref + if full_name in subscript_handled_refs: + continue + if full_name in existing_dims: # Already in dimensions, no need to add continue @@ -269,12 +362,13 @@ def add_dimensions_from_filters(ctx: "BuildContext") -> None: existing_ref.node_name == dim_ref.node_name and existing_ref.column_name == dim_ref.column_name ): - is_covered = True - break + is_covered = True # pragma: no cover + break # pragma: no cover - if not is_covered: + if not is_covered: # pragma: no branch logger.info( - f"[BuildV3] Auto-adding filter-only dimension {full_name}", + "[BuildV3] Auto-adding filter-only dimension %s", + full_name, ) ctx.dimensions.append(full_name) ctx.filter_dimensions.add(full_name) diff --git a/datajunction-server/tests/api/namespaces_test.py b/datajunction-server/tests/api/namespaces_test.py index 1a1b54aae..db314993c 100644 --- a/datajunction-server/tests/api/namespaces_test.py +++ b/datajunction-server/tests/api/namespaces_test.py @@ -79,7 +79,7 @@ async def test_list_all_namespaces( {"namespace": "different.basic.transform", "num_nodes": 1}, {"namespace": "foo.bar", "num_nodes": 26}, {"namespace": "hll", "num_nodes": 4}, - {"namespace": "v3", "num_nodes": 44}, + {"namespace": "v3", "num_nodes": 45}, ] diff --git a/datajunction-server/tests/api/nodes_test.py b/datajunction-server/tests/api/nodes_test.py index a0e315a8c..e709d85c0 100644 --- a/datajunction-server/tests/api/nodes_test.py +++ b/datajunction-server/tests/api/nodes_test.py @@ -343,6 +343,7 @@ async def test_get_nodes_with_details(client_with_examples: AsyncClient): "v3.total_revenue", "v3.total_unit_price", "v3.trailing_7d_revenue", + "v3.trailing_7d_revenue_inferred_dim", "v3.trailing_wow_revenue_change", "v3.visitor_count", "v3.wow_order_growth", diff --git a/datajunction-server/tests/api/preaggregations_test.py b/datajunction-server/tests/api/preaggregations_test.py index 21505d595..987ca4fb0 100644 --- a/datajunction-server/tests/api/preaggregations_test.py +++ b/datajunction-server/tests/api/preaggregations_test.py @@ -727,6 +727,10 @@ async def test_get_preagg_by_id(self, client_with_preaggs): "display_name": "Trailing 7D Revenue", "name": "v3.trailing_7d_revenue", }, + { + "display_name": "Trailing 7D Revenue Inferred Dim", + "name": "v3.trailing_7d_revenue_inferred_dim", + }, { "display_name": "Trailing Wow Revenue Change", "name": "v3.trailing_wow_revenue_change", diff --git a/datajunction-server/tests/construction/build_v3/cte_test.py b/datajunction-server/tests/construction/build_v3/cte_test.py index 9f35a9e4b..a5a73d847 100644 --- a/datajunction-server/tests/construction/build_v3/cte_test.py +++ b/datajunction-server/tests/construction/build_v3/cte_test.py @@ -3,6 +3,7 @@ from datajunction_server.construction.build_v3.cte import ( get_column_full_name, inject_partition_by_into_windows, + process_metric_combiner_expression, replace_metric_refs_in_ast, ) from datajunction_server.sql.parsing import ast @@ -322,3 +323,29 @@ def test_avg_over_not_modified(self): # AVG OVER () should remain empty assert "AVG(value) OVER ()" in result_sql assert "PARTITION BY" not in result_sql + + +class TestProcessMetricCombinerExpression: + """Tests for process_metric_combiner_expression.""" + + def test_no_partition_dimensions_skips_partition_by_injection(self): + """ + When partition_dimensions is None the function skips the PARTITION BY + injection block. + """ + query = parse("SELECT SUM(revenue) / NULLIF(COUNT(DISTINCT order_id), 0)") + combiner_ast = query.select.projection[0] + + dimension_refs = { + "v3.order_details.status": ("order_details_0", "status"), + } + + result = process_metric_combiner_expression( + combiner_ast=combiner_ast, + dimension_refs=dimension_refs, + partition_dimensions=None, + ) + + # Function should return without error; no PARTITION BY injected + result_sql = str(result) + assert "PARTITION BY" not in result_sql diff --git a/datajunction-server/tests/construction/build_v3/cube_matcher_test.py b/datajunction-server/tests/construction/build_v3/cube_matcher_test.py index 5a0aa12cc..5bd378b2f 100644 --- a/datajunction-server/tests/construction/build_v3/cube_matcher_test.py +++ b/datajunction-server/tests/construction/build_v3/cube_matcher_test.py @@ -882,7 +882,7 @@ async def test_builds_sql_from_cube_with_all_v3_order_details_metrics( unit_price_sum_55cff00f, unit_price_max_55cff00f, unit_price_min_55cff00f - FROM default.analytics.cube_all_order_metrics + FROM cube_all_order_metrics ) SELECT test_cube_all_order_metrics_0.category AS category, @@ -1146,6 +1146,107 @@ async def test_builds_sql_from_cube_with_trailing_metrics( assert "trailing_7d_revenue" in column_names assert "trailing_wow_revenue_change" in column_names + @pytest.mark.asyncio + async def test_builds_sql_from_cube_with_inferred_order_by_dimension( + self, + client_with_build_v3, + session, + ): + """Window function ORDER BY dim auto-detected even without required_dimensions. + + v3.trailing_7d_revenue_inferred_dim has no required_dimensions set, but its + metric expression uses ORDER BY v3.date.date_id[order]. The fix in + add_dimensions_from_metric_expressions scans the original metric query's + window function ORDER BY clauses and auto-adds the dimension to ctx.dimensions. + + Without the fix, the ORDER BY would contain the raw unresolved ref + "v3.date.date_id[order]" instead of "base_metrics.date_id_order". + """ + all_metrics = ["v3.total_revenue", "v3.trailing_7d_revenue_inferred_dim"] + + response = await client_with_build_v3.post( + "/nodes/cube/", + json={ + "name": "v3.test_cube_inferred_order_by_dim", + "metrics": all_metrics, + "dimensions": ["v3.date.date_id[order]", "v3.product.category"], + "mode": "published", + "description": "Cube for testing inferred ORDER BY dimension", + }, + ) + assert response.status_code == 201, response.json() + + valid_through_ts = int(time.time() * 1000) + response = await client_with_build_v3.post( + "/data/v3.test_cube_inferred_order_by_dim/availability/", + json={ + "catalog": "default", + "schema_": "analytics", + "table": "cube_inferred_order_by_dim", + "valid_through_ts": valid_through_ts, + }, + ) + assert response.status_code == 200, response.json() + + cube = await find_matching_cube( + session, + metrics=all_metrics, + dimensions=["v3.date.date_id[order]", "v3.product.category"], + ) + assert cube is not None + + result = await build_sql_from_cube( + session=session, + cube=cube, + metrics=all_metrics, + dimensions=["v3.date.date_id[order]", "v3.product.category"], + filters=None, + dialect=Dialect.SPARK, + ) + + assert result is not None + assert result.sql is not None + + # The ORDER BY should use the aliased column name (date_id_order), + # NOT the raw dimension ref (v3.date.date_id[order]). + assert "v3.date.date_id" not in result.sql, ( + "Raw dimension ref found in SQL — ORDER BY dimension was not resolved" + ) + assert "date_id_order" in result.sql, ( + "Aliased column name missing — ORDER BY dimension was not resolved to alias" + ) + + expected_sql = """ + WITH test_cube_inferred_order_by_dim_0 AS ( + SELECT + date_id_order, + category, + line_total_sum_e1f61696 + FROM default.analytics.cube_inferred_order_by_dim + ), + base_metrics AS ( + SELECT + test_cube_inferred_order_by_dim_0.date_id_order AS date_id_order, + test_cube_inferred_order_by_dim_0.category AS category, + SUM(test_cube_inferred_order_by_dim_0.line_total_sum_e1f61696) AS total_revenue + FROM test_cube_inferred_order_by_dim_0 + GROUP BY + test_cube_inferred_order_by_dim_0.date_id_order, + test_cube_inferred_order_by_dim_0.category + ) + SELECT + base_metrics.date_id_order AS date_id_order, + base_metrics.category AS category, + base_metrics.total_revenue AS total_revenue, + SUM(base_metrics.total_revenue) OVER ( PARTITION BY base_metrics.category ORDER BY base_metrics.date_id_order ROWS BETWEEN 6 PRECEDING AND CURRENT ROW) AS trailing_7d_revenue_inferred_dim + FROM base_metrics + """ + assert_sql_equal(result.sql, expected_sql, normalize_aliases=False) + + column_names = [col.name for col in result.columns] + assert "total_revenue" in column_names + assert "trailing_7d_revenue_inferred_dim" in column_names + @pytest.mark.asyncio async def test_builds_sql_with_rollup_dimensions( self, @@ -1282,8 +1383,8 @@ async def test_builds_sql_respects_dialect( dialect=Dialect.DRUID, ) - # Both should produce valid SQL with expected structure - expected_sql = """ + # Spark uses catalog.schema.table; Druid uses table name only + spark_expected_sql = """ WITH test_cube_dialect_0 AS ( SELECT category, @@ -1296,8 +1397,21 @@ async def test_builds_sql_respects_dialect( FROM test_cube_dialect_0 GROUP BY test_cube_dialect_0.category """ - assert_sql_equal(spark_result.sql, expected_sql) - assert_sql_equal(druid_result.sql, expected_sql) + druid_expected_sql = """ + WITH test_cube_dialect_0 AS ( + SELECT + category, + line_total_sum_e1f61696 + FROM cube_dialect + ) + SELECT + test_cube_dialect_0.category AS category, + SUM(test_cube_dialect_0.line_total_sum_e1f61696) AS total_revenue + FROM test_cube_dialect_0 + GROUP BY test_cube_dialect_0.category + """ + assert_sql_equal(spark_result.sql, spark_expected_sql) + assert_sql_equal(druid_result.sql, druid_expected_sql) @pytest.mark.asyncio async def test_builds_sql_from_cube_with_filter( @@ -1958,13 +2072,14 @@ async def test_build_metrics_sql_uses_cube_when_available( ) assert response.status_code == 200, response.json() - # Call build_metrics_sql - should use cube path + # Call build_metrics_sql with DRUID dialect - should use cube path + # (cube path is only taken when dialect is DRUID or unset) result = await build_metrics_sql( session=session, metrics=["v3.total_revenue"], dimensions=["v3.product.category"], filters=None, - dialect=Dialect.SPARK, + dialect=Dialect.DRUID, use_materialized=True, ) @@ -1975,7 +2090,7 @@ async def test_build_metrics_sql_uses_cube_when_available( WITH cube_for_metrics_sql_0 AS ( SELECT category, line_total_sum_e1f61696 - FROM default.analytics.cube_for_metrics_sql + FROM cube_for_metrics_sql ) SELECT cube_for_metrics_sql_0.category AS category, @@ -2163,7 +2278,7 @@ async def test_build_metrics_sql_cube_with_multi_component_metric( metrics=["v3.avg_unit_price"], dimensions=["v3.product.category"], filters=None, - dialect=Dialect.SPARK, + dialect=Dialect.DRUID, use_materialized=True, ) @@ -2175,12 +2290,12 @@ async def test_build_metrics_sql_cube_with_multi_component_metric( WITH cube_avg_metric_0 AS ( SELECT category, unit_price_count_55cff00f, unit_price_sum_55cff00f - FROM default.analytics.cube_avg_metric + FROM cube_avg_metric ) SELECT cube_avg_metric_0.category AS category, - SUM(cube_avg_metric_0.unit_price_sum_55cff00f) - / SUM(cube_avg_metric_0.unit_price_count_55cff00f) AS avg_unit_price + SAFE_DIVIDE(SUM(cube_avg_metric_0.unit_price_sum_55cff00f), + SUM(cube_avg_metric_0.unit_price_count_55cff00f)) AS avg_unit_price FROM cube_avg_metric_0 GROUP BY cube_avg_metric_0.category """, @@ -2234,7 +2349,7 @@ async def test_build_metrics_sql_cube_with_multiple_metrics( metrics=["v3.total_revenue", "v3.total_quantity"], dimensions=["v3.product.category"], filters=None, - dialect=Dialect.SPARK, + dialect=Dialect.DRUID, use_materialized=True, ) @@ -2245,7 +2360,7 @@ async def test_build_metrics_sql_cube_with_multiple_metrics( WITH cube_multi_metrics_0 AS ( SELECT category, line_total_sum_e1f61696, quantity_sum_06b64d2e - FROM default.analytics.cube_multi_metrics + FROM cube_multi_metrics ) SELECT cube_multi_metrics_0.category AS category, @@ -2307,7 +2422,7 @@ async def test_build_metrics_sql_cube_rollup( metrics=["v3.total_revenue"], dimensions=["v3.product.category"], # Subset of cube dims filters=None, - dialect=Dialect.SPARK, + dialect=Dialect.DRUID, use_materialized=True, ) @@ -2319,7 +2434,7 @@ async def test_build_metrics_sql_cube_rollup( WITH cube_rollup_test_0 AS ( SELECT category, line_total_sum_e1f61696 - FROM default.analytics.cube_rollup_test + FROM cube_rollup_test ) SELECT cube_rollup_test_0.category AS category, @@ -2337,6 +2452,239 @@ async def test_build_metrics_sql_cube_rollup( assert "category" in column_names assert "subcategory" not in column_names + @pytest.mark.asyncio + async def test_build_metrics_sql_dialect_none_with_cube_auto_detects_druid( + self, + client_with_build_v3, + session, + ): + """dialect=None + use_materialized=True + matching cube → auto-selects DRUID. + + Covers builder.py lines 502-503: the probe_cube path that sets dialect=DRUID + when no dialect is specified but a materialized cube is available. + """ + from datajunction_server.construction.build_v3 import build_metrics_sql + + response = await client_with_build_v3.post( + "/nodes/cube/", + json={ + "name": "v3.cube_dialect_none_auto", + "metrics": ["v3.total_revenue"], + "dimensions": ["v3.product.category"], + "mode": "published", + "description": "Cube for dialect=None auto-detection test", + }, + ) + assert response.status_code == 201, response.json() + + valid_through_ts = int(time.time() * 1000) + response = await client_with_build_v3.post( + "/data/v3.cube_dialect_none_auto/availability/", + json={ + "catalog": "default", + "schema_": "analytics", + "table": "cube_dialect_none_auto", + "valid_through_ts": valid_through_ts, + }, + ) + assert response.status_code == 200, response.json() + + # dialect=None with use_materialized=True → should auto-detect DRUID + result = await build_metrics_sql( + session=session, + metrics=["v3.total_revenue"], + dimensions=["v3.product.category"], + filters=None, + dialect=None, + use_materialized=True, + ) + + # DRUID path → table-only name (no catalog/schema prefix) + assert result.cube_name == "v3.cube_dialect_none_auto" + assert_sql_equal( + result.sql, + """ + WITH + cube_dialect_none_auto_0 AS ( + SELECT category, line_total_sum_e1f61696 + FROM cube_dialect_none_auto + ) + SELECT + cube_dialect_none_auto_0.category AS category, + SUM(cube_dialect_none_auto_0.line_total_sum_e1f61696) AS total_revenue + FROM cube_dialect_none_auto_0 + GROUP BY cube_dialect_none_auto_0.category + """, + ) + + @pytest.mark.asyncio + async def test_build_metrics_sql_dialect_none_no_materialized_defaults_spark( + self, + client_with_build_v3, + session, + ): + """dialect=None + use_materialized=False → falls through to SPARK source tables. + + Covers builder.py line 509: when use_materialized=False, dialect defaults to + SPARK regardless of whether a cube exists. + """ + from datajunction_server.construction.build_v3 import build_metrics_sql + + result = await build_metrics_sql( + session=session, + metrics=["v3.total_revenue"], + dimensions=["v3.product.category"], + filters=None, + dialect=None, + use_materialized=False, + ) + + # SPARK path -> catalog.schema.table references, no cube used + assert result.cube_name is None + assert "default.v3.orders" in result.sql + assert "default.v3.order_items" in result.sql + + @pytest.mark.asyncio + async def test_build_mat_col_lookup_returns_physical_name_mapping( + self, + client_with_build_v3, + session, + ): + """ + Verifies that _build_mat_col_lookup reads combiners[*].columns to build short name to + physical column name map. + + Creates a cube with a Materialization record whose config mimics the old-style + Druid format where physical column names use amenable_name encoding (e.g., + v3_DOT_product_DOT_category instead of category). + """ + from sqlalchemy import select + + from datajunction_server.construction.build_v3.cube_matcher import ( + _build_mat_col_lookup, + build_sql_from_cube, + find_matching_cube, + ) + from datajunction_server.database.materialization import Materialization + from datajunction_server.database.node import NodeRevision + + # Create a cube + response = await client_with_build_v3.post( + "/nodes/cube/", + json={ + "name": "v3.cube_old_style_druid", + "metrics": ["v3.total_revenue"], + "dimensions": ["v3.product.category"], + "mode": "published", + "description": "Cube for old-style Druid physical column test", + }, + ) + assert response.status_code == 201, response.json() + + valid_through_ts = int(time.time() * 1000) + response = await client_with_build_v3.post( + "/data/v3.cube_old_style_druid/availability/", + json={ + "catalog": "default", + "schema_": "analytics", + "table": "cube_old_style_druid", + "valid_through_ts": valid_through_ts, + }, + ) + assert response.status_code == 200, response.json() + + # Fetch the NodeRevision for the cube and add a materialization with + # old-style Druid column names (physical names use DOT encoding) + result = await session.execute( + select(NodeRevision).where(NodeRevision.name == "v3.cube_old_style_druid"), + ) + node_rev = result.scalars().first() + assert node_rev is not None + + old_style_mat = Materialization( + node_revision_id=node_rev.id, + name="old_style_druid_mat", + strategy=None, + schedule="", + config={ + "combiners": [ + { + "columns": [ + { + "column": "category", + "name": "v3_DOT_product_DOT_category", + }, + { + "column": "line_total_sum_e1f61696", + "name": "line_total_sum_e1f61696", + }, + # Entry with missing 'name' — exercises the + # if short_name and physical_name: False branch + {"column": "orphan_col"}, + ], + }, + ], + }, + job="DruidMaterializationJob", + ) + session.add(old_style_mat) + await session.commit() + + # Expire to force reload from DB + session.expire(node_rev) + + # Re-fetch with materializations eagerly loaded + from sqlalchemy.orm import selectinload + + result2 = await session.execute( + select(NodeRevision) + .where(NodeRevision.name == "v3.cube_old_style_druid") + .options(selectinload(NodeRevision.materializations)), + ) + node_rev2 = result2.scalars().first() + + # Verify _build_mat_col_lookup returns the physical name mapping + lookup = _build_mat_col_lookup(node_rev2) + assert lookup == { + "category": "v3_DOT_product_DOT_category", + "line_total_sum_e1f61696": "line_total_sum_e1f61696", + } + + # Verify build_sql_from_cube uses physical name alias + cube = await find_matching_cube( + session, + metrics=["v3.total_revenue"], + dimensions=["v3.product.category"], + ) + assert cube is not None + + sql_result = await build_sql_from_cube( + session=session, + cube=cube, + metrics=["v3.total_revenue"], + dimensions=["v3.product.category"], + filters=None, + dialect=Dialect.DRUID, + ) + + # The SELECT should alias the physical name to the short name + # (physical_name != short_name triggers the alias branch) + assert_sql_equal( + sql_result.sql, + """ + WITH + cube_old_style_druid_0 AS ( + SELECT v3_DOT_product_DOT_category category, line_total_sum_e1f61696 + FROM cube_old_style_druid + ) + SELECT + cube_old_style_druid_0.category AS category, + SUM(cube_old_style_druid_0.line_total_sum_e1f61696) AS total_revenue + FROM cube_old_style_druid_0 + GROUP BY cube_old_style_druid_0.category + """, + ) + class TestDataEndpointCubePath: """ diff --git a/datajunction-server/tests/construction/build_v3/measures_sql_test.py b/datajunction-server/tests/construction/build_v3/measures_sql_test.py index 07a3bf6e6..f6784c501 100644 --- a/datajunction-server/tests/construction/build_v3/measures_sql_test.py +++ b/datajunction-server/tests/construction/build_v3/measures_sql_test.py @@ -2198,6 +2198,7 @@ async def test_all_additional_metrics_combined(self, client_with_build_v3): class TestMeasuresSQLFilters: + @pytest.mark.asyncio async def test_simple_filter_on_local_column(self, client_with_build_v3): """Test a simple filter on a local (fact) column.""" response = await client_with_build_v3.get( @@ -2210,16 +2211,25 @@ async def test_simple_filter_on_local_column(self, client_with_build_v3): ) assert response.status_code == 200, response.json() - data = response.json() - sql = data["grain_groups"][0]["sql"] - - # Should have WHERE clause with the filter - assert "WHERE" in sql - assert "status" in sql - assert "'completed'" in sql + data = get_first_grain_group(response.json()) + assert_sql_equal( + data["sql"], + """ + WITH v3_order_details AS ( + SELECT o.status, oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ) + SELECT t1.status, SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + WHERE t1.status = 'completed' + GROUP BY t1.status + """, + ) + @pytest.mark.asyncio async def test_filter_on_dimension_column(self, client_with_build_v3): - """Test a filter on a joined dimension column.""" + """Test a filter on a joined dimension column that is also in GROUP BY.""" response = await client_with_build_v3.get( "/sql/measures/v3/", params={ @@ -2230,14 +2240,29 @@ async def test_filter_on_dimension_column(self, client_with_build_v3): ) assert response.status_code == 200, response.json() - data = response.json() - sql = data["grain_groups"][0]["sql"] - - # Should have WHERE clause referencing the dimension column - assert "WHERE" in sql - assert "category" in sql - assert "'Electronics'" in sql + data = get_first_grain_group(response.json()) + assert_sql_equal( + data["sql"], + """ + WITH + v3_order_details AS ( + SELECT oi.product_id, oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + v3_product AS ( + SELECT product_id, category + FROM default.v3.products + ) + SELECT t2.category, SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + LEFT OUTER JOIN v3_product t2 ON t1.product_id = t2.product_id + WHERE t2.category = 'Electronics' + GROUP BY t2.category + """, + ) + @pytest.mark.asyncio async def test_multiple_filters_combined_with_and(self, client_with_build_v3): """Test multiple filters are combined with AND.""" response = await client_with_build_v3.get( @@ -2253,17 +2278,31 @@ async def test_multiple_filters_combined_with_and(self, client_with_build_v3): ) assert response.status_code == 200, response.json() - data = response.json() - sql = data["grain_groups"][0]["sql"] - - # Should have WHERE clause with both filters combined with AND - assert "WHERE" in sql - assert "AND" in sql - assert "'completed'" in sql - assert "'Electronics'" in sql + data = get_first_grain_group(response.json()) + assert_sql_equal( + data["sql"], + """ + WITH + v3_order_details AS ( + SELECT o.status, oi.product_id, oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + v3_product AS ( + SELECT product_id, category + FROM default.v3.products + ) + SELECT t1.status, t2.category, SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + LEFT OUTER JOIN v3_product t2 ON t1.product_id = t2.product_id + WHERE t1.status = 'completed' AND t2.category = 'Electronics' + GROUP BY t1.status, t2.category + """, + ) + @pytest.mark.asyncio async def test_filter_with_comparison_operators(self, client_with_build_v3): - """Test filters with various comparison operators.""" + """Test filters with comparison operators on a role-qualified dimension.""" response = await client_with_build_v3.get( "/sql/measures/v3/", params={ @@ -2274,16 +2313,31 @@ async def test_filter_with_comparison_operators(self, client_with_build_v3): ) assert response.status_code == 200, response.json() - data = response.json() - sql = data["grain_groups"][0]["sql"] - - # Should have filter with >= operator - assert "WHERE" in sql - assert ">=" in sql - assert "2024" in sql + data = get_first_grain_group(response.json()) + assert_sql_equal( + data["sql"], + """ + WITH + v3_date AS ( + SELECT date_id, year + FROM default.v3.dates + ), + v3_order_details AS ( + SELECT o.order_date, oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ) + SELECT t2.year year_order, SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + LEFT OUTER JOIN v3_date t2 ON t1.order_date = t2.date_id + WHERE t2.year >= 2024 + GROUP BY t2.year + """, + ) + @pytest.mark.asyncio async def test_filter_with_in_operator(self, client_with_build_v3): - """Test filter with IN operator.""" + """Test filter with IN operator on a local column.""" response = await client_with_build_v3.get( "/sql/measures/v3/", params={ @@ -2294,14 +2348,21 @@ async def test_filter_with_in_operator(self, client_with_build_v3): ) assert response.status_code == 200, response.json() - data = response.json() - sql = data["grain_groups"][0]["sql"] - - # Should have filter with IN operator - assert "WHERE" in sql - assert "IN" in sql - assert "'completed'" in sql - assert "'pending'" in sql + data = get_first_grain_group(response.json()) + assert_sql_equal( + data["sql"], + """ + WITH v3_order_details AS ( + SELECT o.status, oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ) + SELECT t1.status, SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + WHERE t1.status IN ('completed', 'pending') + GROUP BY t1.status + """, + ) class TestBaseMetricCaching: diff --git a/datajunction-server/tests/construction/build_v3/metrics_sql_test.py b/datajunction-server/tests/construction/build_v3/metrics_sql_test.py index 03cc442d7..1072766d0 100644 --- a/datajunction-server/tests/construction/build_v3/metrics_sql_test.py +++ b/datajunction-server/tests/construction/build_v3/metrics_sql_test.py @@ -247,16 +247,29 @@ async def test_derived_metric_ratio(self, client_with_build_v3): LEFT OUTER JOIN v3_product t2 ON t1.product_id = t2.product_id GROUP BY t2.category, t1.order_id ), + order_details_0_agg AS ( + SELECT + category, + COUNT(DISTINCT order_id) order_id_distinct_f93d50ab + FROM order_details_0 + GROUP BY category + ), page_views_enriched_0 AS ( SELECT t2.category, t1.customer_id FROM v3_page_views_enriched t1 LEFT OUTER JOIN v3_product t2 ON t1.product_id = t2.product_id GROUP BY t2.category, t1.customer_id + ), + page_views_enriched_0_agg AS ( + SELECT category, + COUNT( DISTINCT customer_id) customer_id_distinct_dd4be7a5 + FROM page_views_enriched_0 + GROUP BY category ) - SELECT COALESCE(order_details_0.category, page_views_enriched_0.category) AS category, - CAST(COUNT(DISTINCT order_details_0.order_id) AS DOUBLE) / NULLIF(COUNT(DISTINCT page_views_enriched_0.customer_id), 0) AS conversion_rate - FROM order_details_0 - FULL OUTER JOIN page_views_enriched_0 ON order_details_0.category = page_views_enriched_0.category + SELECT COALESCE(order_details_0_agg.category, page_views_enriched_0_agg.category) AS category, + CAST(MAX(order_details_0_agg.order_id_distinct_f93d50ab) AS DOUBLE) / NULLIF(MAX(page_views_enriched_0_agg.customer_id_distinct_dd4be7a5), 0) AS conversion_rate + FROM order_details_0_agg + FULL OUTER JOIN page_views_enriched_0_agg ON order_details_0_agg.category = page_views_enriched_0_agg.category GROUP BY 1 """, ) @@ -1542,6 +1555,206 @@ async def test_cross_fact_metrics_without_shared_dimensions_raises_error( or "shared dimension" in str(error_detail).lower() ) + @pytest.mark.asyncio + async def test_cross_fact_full_plus_limited_fan_out( + self, + client_with_build_v3, + ): + """ + Test that combining a FULL metric with a LIMITED (COUNT DISTINCT) metric from a + different fact table produces correct values for both metrics. + + collect_and_build_ctes() adds a pre-aggregation wrapper CTE for LIMITED grain groups. + The wrapper collapses N rows per dimension (one per distinct grain key) into 1 row by + computing COUNT(DISTINCT grain_key) inside the CTE. The FULL OUTER JOIN then sees 1:1 + rows from both sides, so SUM() does not overcount. + + Setup: + - v3.total_revenue (FULL, from order_details) + grain group CTE: GROUP BY category → 1 row per category + - v3.visitor_count (LIMITED, COUNT DISTINCT customer_id, from page_views_enriched) + raw grain group CTE: GROUP BY (category, customer_id) -> N rows per category + wrapper CTE (page_views_enriched_0_agg): GROUP BY category -> 1 row per category + + After FULL OUTER JOIN on category both sides have 1 row per category. + SUM(order_details_0.line_total_sum_HASH) = correct revenue. + MAX(page_views_enriched_0_agg.customer_id_distinct_HASH) = correct visitor count. + """ + response = await client_with_build_v3.get( + "/sql/metrics/v3/", + params={ + "metrics": ["v3.total_revenue", "v3.visitor_count"], + "dimensions": ["v3.product.category"], + }, + ) + + assert response.status_code == 200, response.json() + result = response.json() + + assert_sql_equal( + result["sql"], + """ + WITH + v3_order_details AS ( + SELECT oi.product_id, oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + v3_product AS ( + SELECT product_id, category + FROM default.v3.products + ), + v3_page_views_enriched AS ( + SELECT customer_id, product_id + FROM default.v3.page_views + ), + order_details_0 AS ( + SELECT t2.category, SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + LEFT OUTER JOIN v3_product t2 ON t1.product_id = t2.product_id + GROUP BY t2.category + ), + page_views_enriched_0 AS ( + SELECT t2.category, t1.customer_id + FROM v3_page_views_enriched t1 + LEFT OUTER JOIN v3_product t2 ON t1.product_id = t2.product_id + GROUP BY t2.category, t1.customer_id + ), + page_views_enriched_0_agg AS ( + SELECT category, COUNT(DISTINCT customer_id) customer_id_distinct_dd4be7a5 + FROM page_views_enriched_0 + GROUP BY category + ) + SELECT + COALESCE(order_details_0.category, page_views_enriched_0_agg.category) AS category, + SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue, + MAX(page_views_enriched_0_agg.customer_id_distinct_dd4be7a5) AS visitor_count + FROM order_details_0 + FULL OUTER JOIN page_views_enriched_0_agg + ON order_details_0.category = page_views_enriched_0_agg.category + GROUP BY 1 + """, + ) + + +class TestDerivedAndBaseMetricsTogether: + """Test that derived and base metrics can be queried together.""" + + @pytest.mark.asyncio + async def test_derived_metric_with_base_metric_in_same_query( + self, + client_with_build_v3, + ): + """Querying a derived metric alongside one of its base metrics works correctly. + + avg_order_value = total_revenue / order_count. Requesting both avg_order_value + and total_revenue together should return both in the output columns. + """ + response = await client_with_build_v3.get( + "/sql/metrics/v3/", + params={ + "metrics": ["v3.avg_order_value", "v3.total_revenue"], + "dimensions": ["v3.order_details.status"], + }, + ) + + assert response.status_code == 200, response.json() + result = response.json() + column_names = [c["name"] for c in result["columns"]] + assert "avg_order_value" in column_names + assert "total_revenue" in column_names + assert "status" in column_names + assert_sql_equal( + result["sql"], + """ + WITH v3_order_details AS ( + SELECT + o.order_id, + o.status, + oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + order_details_0 AS ( + SELECT + t1.status, + t1.order_id, + SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + GROUP BY t1.status, t1.order_id + ) + SELECT + order_details_0.status AS status, + SUM(order_details_0.line_total_sum_e1f61696) / NULLIF(COUNT( DISTINCT order_details_0.order_id), 0) AS avg_order_value, + SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue + FROM order_details_0 + GROUP BY order_details_0.status + """, + ) + + +class TestFilterOnlyDimensionLoop: + """Tests for filter-only dimension loop iteration in metrics.py.""" + + @pytest.mark.asyncio + async def test_filter_subscript_matches_second_filter_dimension( + self, + client_with_build_v3, + ): + """ + Filter with a role-qualified subscript whose base ref is checked against + multiple filter-only dimensions exercises the ctx.filter_dimensions loop + in metrics.py. + """ + response = await client_with_build_v3.get( + "/sql/metrics/v3/", + params={ + "metrics": ["v3.total_revenue"], + "dimensions": ["v3.order_details.status"], + "filters": [ + "v3.product.subcategory = 'tools'", + "v3.product.category[buyer->home] = 'electronics'", + ], + }, + ) + + # The query may succeed or raise a dimension resolution error; either way + # the filter_dimensions loop is exercised with multiple entries. + assert response.status_code == 200 + assert_sql_equal( + response.json()["sql"], + """ + WITH v3_order_details AS ( + SELECT + o.status, + oi.product_id, + oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + v3_product AS ( + SELECT + product_id, + category, + subcategory + FROM default.v3.products + ), + order_details_0 AS ( + SELECT + t1.status, + SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 LEFT OUTER JOIN v3_product t2 ON t1.product_id = t2.product_id + WHERE t2.subcategory = 'tools' AND t2.category = 'electronics' + GROUP BY t1.status + ) + SELECT + order_details_0.status AS status, + SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue + FROM order_details_0 + GROUP BY order_details_0.status + """, + ) + class TestNonDecomposableMetrics: """Tests for metrics that cannot be decomposed (Aggregability.NONE).""" @@ -3600,6 +3813,56 @@ async def test_multi_hop_filter_only_on_multi_hop_dim( assert response.status_code == 200, response.json() result = response.json() + # customer and location CTEs must be joined to support the WHERE clause, + # even though neither column appears in GROUP BY or the final SELECT. + # The WHERE filter on country appears only inside the grain group CTE — + # the outer SELECT has no access to that column so no outer WHERE is needed. + assert_sql_equal( + result["sql"], + """ + WITH v3_customer AS ( + SELECT + customer_id, + location_id + FROM default.v3.customers + ), + v3_location AS ( + SELECT + location_id, + country + FROM default.v3.locations + ), + v3_order_details AS ( + SELECT + o.customer_id, + oi.product_id, + oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + v3_product AS ( + SELECT + product_id, + category + FROM default.v3.products + ), + order_details_0 AS ( + SELECT + t2.category, + SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 LEFT OUTER JOIN v3_product t2 ON t1.product_id = t2.product_id + LEFT OUTER JOIN v3_customer t3 ON t1.customer_id = t3.customer_id + LEFT OUTER JOIN v3_location t4 ON t3.location_id = t4.location_id + WHERE t4.country = 'US' + GROUP BY t2.category + ) + SELECT + order_details_0.category AS category, + SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue + FROM order_details_0 + GROUP BY order_details_0.category + """, + ) + assert result["columns"] == [ { "name": "category", @@ -4066,7 +4329,8 @@ async def test_two_roles_same_dim_group_by(self, client_with_build_v3): v3.location.city[from] and v3.location.city[to] both point to v3.location but via different FKs (from_location_id and to_location_id). The grain group - CTE must join v3_location twice with distinct aliases. + CTE must join v3_location twice with distinct aliases — one per role — so + neither role silently shadows the other. """ response = await client_with_build_v3.get( "/sql/metrics/v3/", @@ -4079,6 +4343,36 @@ async def test_two_roles_same_dim_group_by(self, client_with_build_v3): assert response.status_code == 200, response.json() result = response.json() + # One shared v3_location CTE, two separate JOINs with distinct aliases + assert_sql_equal( + result["sql"], + """ + WITH v3_location AS ( + SELECT location_id, city + FROM default.v3.locations + ), + v3_order_details AS ( + SELECT o.from_location_id, o.to_location_id, + oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + order_details_0 AS ( + SELECT t2.city city_from, t3.city city_to, + SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + LEFT OUTER JOIN v3_location t2 ON t1.from_location_id = t2.location_id + LEFT OUTER JOIN v3_location t3 ON t1.to_location_id = t3.location_id + GROUP BY t2.city, t3.city + ) + SELECT order_details_0.city_from AS city_from, + order_details_0.city_to AS city_to, + SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue + FROM order_details_0 + GROUP BY order_details_0.city_from, order_details_0.city_to + """, + ) + assert result["columns"] == [ { "name": "city_from", @@ -4204,3 +4498,292 @@ async def test_filter_on_multi_hop_role_dimension(self, client_with_build_v3): "semantic_type": "metric", }, ] + + +class TestMetricsSQLEdgeCases: + """ + Edge case tests for metric SQL generation. + + Covers unusual but valid query patterns: + - Skip-join: filter on dimension PK == fact FK should skip the dimension join + - Scalar aggregate: no dimensions → no GROUP BY anywhere + - Dimension-only column filter: non-PK dim attribute forces JOIN even if not in SELECT + - FULL + LIMITED metrics from same fact: combined at the LIMITED grain + - Derived metric referencing NONE-aggregability metric: clear error or correct fallback + """ + + @pytest.mark.asyncio + async def test_skip_join_filter_on_dimension_pk_as_fact_fk( + self, + client_with_build_v3, + ): + """ + Verify filtering on a dimension PK that is also the FK on the fact table. + + v3.customer.customer_id[customer] is the PK of the customer dimension and + equals v3.order_details.customer_id (the FK on the fact). When no customer + attribute is requested in GROUP BY, the join to v3.customer should be + skipped and the filter applied directly to the FK column on the fact CTE. + """ + response = await client_with_build_v3.get( + "/sql/metrics/v3/", + params={ + "metrics": ["v3.total_revenue"], + "dimensions": ["v3.order_details.status"], + "filters": ["v3.customer.customer_id[customer] = 42"], + }, + ) + + assert response.status_code == 200, response.json() + result = response.json() + + # No join to v3.customer — filter pushes down to the fact FK column directly + assert_sql_equal( + result["sql"], + """ + WITH v3_order_details AS ( + SELECT o.customer_id, o.status, + oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + order_details_0 AS ( + SELECT t1.status, SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + WHERE t1.customer_id = 42 + GROUP BY t1.status + ) + SELECT order_details_0.status AS status, + SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue + FROM order_details_0 + GROUP BY order_details_0.status + """, + ) + + @pytest.mark.asyncio + async def test_scalar_aggregate_no_dimensions_emits_no_group_by( + self, + client_with_build_v3, + ): + """ + Verify metric with no dimensions requested — scalar aggregate. + + The grain group CTE and final SELECT must NOT emit a GROUP BY clause. + The result is a single-row scalar. + """ + response = await client_with_build_v3.get( + "/sql/metrics/v3/", + params={ + "metrics": ["v3.total_revenue"], + "dimensions": [], + }, + ) + + assert response.status_code == 200, response.json() + result = response.json() + + # No GROUP BY anywhere — single-row scalar result + assert_sql_equal( + result["sql"], + """ + WITH v3_order_details AS ( + SELECT oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + order_details_0 AS ( + SELECT SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + ) + SELECT SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue + FROM order_details_0 + """, + ) + + assert result["columns"] == [ + { + "name": "total_revenue", + "type": "double", + "semantic_entity": "v3.total_revenue", + "semantic_type": "metric", + }, + ] + + @pytest.mark.asyncio + async def test_filter_on_non_pk_dim_attribute_not_in_group_by( + self, + client_with_build_v3, + ): + """ + Verify filtering on a non-PK dimension attribute not in GROUP BY. + + v3.customer.email[customer] is only present in the customer dimension node, + not on the fact table. Even though no customer column appears in GROUP BY, + the system must emit a JOIN to v3.customer to evaluate the LIKE filter. + + The WHERE clause on email appears only inside the grain group CTE - + the outer SELECT has no access to email so no outer WHERE is emitted. + """ + response = await client_with_build_v3.get( + "/sql/metrics/v3/", + params={ + "metrics": ["v3.total_revenue"], + "dimensions": ["v3.product.category"], + "filters": ["v3.customer.email[customer] LIKE '%@example.com'"], + }, + ) + + assert response.status_code == 200, response.json() + result = response.json() + + # JOIN to v3.customer is required to evaluate the email filter, + # even though email is not in the SELECT or GROUP BY + assert_sql_equal( + result["sql"], + """ + WITH v3_customer AS ( + SELECT customer_id, email + FROM default.v3.customers + ), + v3_order_details AS ( + SELECT o.customer_id, oi.product_id, + oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + v3_product AS ( + SELECT product_id, category + FROM default.v3.products + ), + order_details_0 AS ( + SELECT t2.category, SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + LEFT OUTER JOIN v3_product t2 ON t1.product_id = t2.product_id + LEFT OUTER JOIN v3_customer t3 ON t1.customer_id = t3.customer_id + WHERE t3.email LIKE '%@example.com' + GROUP BY t2.category + ) + SELECT order_details_0.category AS category, + SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue + FROM order_details_0 + GROUP BY order_details_0.category + """, + ) + + @pytest.mark.asyncio + async def test_full_and_limited_metrics_same_fact_computed_at_limited_grain( + self, + client_with_build_v3, + ): + """ + Verify FULL and LIMITED metrics from the same fact in a single query. + + total_revenue (Aggregability.FULL / SUM) and order_count + (Aggregability.LIMITED / COUNT DISTINCT order_id) share v3.order_details. + + The grain group CTE must pre-aggregate at order_id grain so that COUNT DISTINCT + is semantically correct. The outer SELECT then re-aggregates: SUM for the FULL + component, COUNT DISTINCT for the LIMITED component. + """ + response = await client_with_build_v3.get( + "/sql/metrics/v3/", + params={ + "metrics": ["v3.total_revenue", "v3.order_count"], + "dimensions": ["v3.order_details.status"], + }, + ) + + assert response.status_code == 200, response.json() + result = response.json() + + # Pre-agg at order_id grain so COUNT DISTINCT is correct; + # final SELECT re-aggregates to status grain + assert_sql_equal( + result["sql"], + """ + WITH v3_order_details AS ( + SELECT o.order_id, o.status, + oi.quantity * oi.unit_price AS line_total + FROM default.v3.orders o + JOIN default.v3.order_items oi ON o.order_id = oi.order_id + ), + order_details_0 AS ( + SELECT t1.status, t1.order_id, SUM(t1.line_total) line_total_sum_e1f61696 + FROM v3_order_details t1 + GROUP BY t1.status, t1.order_id + ) + SELECT order_details_0.status AS status, + SUM(order_details_0.line_total_sum_e1f61696) AS total_revenue, + COUNT(DISTINCT order_details_0.order_id) AS order_count + FROM order_details_0 + GROUP BY order_details_0.status + """, + ) + + assert result["columns"] == [ + { + "name": "status", + "type": "string", + "semantic_entity": "v3.order_details.status", + "semantic_type": "dimension", + }, + { + "name": "total_revenue", + "type": "double", + "semantic_entity": "v3.total_revenue", + "semantic_type": "metric", + }, + { + "name": "order_count", + "type": "bigint", + "semantic_entity": "v3.order_count", + "semantic_type": "metric", + }, + ] + + @pytest.mark.asyncio + async def test_derived_metric_referencing_none_aggregability_metric( + self, + client_with_build_v3, + ): + """ + Verify derived metric whose formula references a NONE-aggregability metric. + + v3.top_product_by_revenue uses MAX_BY (Aggregability.NONE) — it cannot be + decomposed into re-aggregatable components. A derived metric that incorporates + it must not silently produce wrong SQL; it should either: + a) Raise a clear error about non-decomposable components, or + b) Fall back to Aggregability.NONE for the entire derived metric (raw-row access) + + This test pins the actual behavior so any future regression is caught. + """ + create_response = await client_with_build_v3.post( + "/nodes/metric/", + json={ + "name": "v3.orders_plus_top_product", + "description": "Derived referencing NONE-aggregability metric", + "query": ( + "SELECT v3.order_count + CAST(v3.top_product_by_revenue AS BIGINT)" + ), + "mode": "published", + }, + ) + assert create_response.status_code in (200, 201), create_response.json() + + response = await client_with_build_v3.get( + "/sql/metrics/v3/", + params={ + "metrics": ["v3.orders_plus_top_product"], + "dimensions": ["v3.order_details.status"], + }, + ) + + assert response.status_code == 422, response.text + error_text = response.json().get("message", "") + assert error_text.lower() == ( + "cannot compute derived metric 'v3.orders_plus_top_product' because it references" + " non-decomposable metric(s) with aggregability.none: ['v3.top_product_by_revenue']." + " non-decomposable metrics (e.g. max_by) cannot be combined into derived metrics — " + "their expressions require raw-grain access that is not available at the derived " + "metric's aggregation level." + ), f"Expected clear error about non-decomposable metric, got: {response.text}" diff --git a/datajunction-server/tests/examples.py b/datajunction-server/tests/examples.py index e470cac4a..029fa8fd3 100644 --- a/datajunction-server/tests/examples.py +++ b/datajunction-server/tests/examples.py @@ -3576,6 +3576,24 @@ "required_dimensions": ["v3.date.date_id[order]"], }, ), + ( + "/nodes/metric/", + { + "name": "v3.trailing_7d_revenue_inferred_dim", + "description": ( + "Trailing 7-day revenue without required_dimensions set. " + "Tests that the ORDER BY dimension is auto-detected from the metric expression." + ), + "query": """ + SELECT + SUM(v3.total_revenue) OVER ( + ORDER BY v3.date.date_id[order] + ROWS BETWEEN 6 PRECEDING AND CURRENT ROW + ) + """, + "mode": "published", + }, + ), ( "/nodes/metric/", {