Skip to content

Optimize COUNT on top of SEMI join to NDISTINCT on the right hand side #491

@knassre-bodo

Description

@knassre-bodo

Consider the question below and two different implementations:

# How many customers have made an urgent order in 1994?

# Method 1: count how many customers have an order with those properties
selected_orders = orders.WHERE((order_priority == '1-URGENT') & (YEAR(order_date) == 1994))
result = TPCH.CALCULATE(
  n=COUNT(customers.WHERE(HAS(selected_orders))
)

# Method 2: count the number of unique customer keys out of the orders with those properties
selected_orders = orders.WHERE((order_priority == '1-URGENT') & (YEAR(order_date) == 1994))
result = TPCH.CALCULATE(
  n=NDISTINCT(selected_orders.customer_key)
)

These two different ways of writing the query are both valid and produce identical answers, but the former does a SEMI join from customers to orders before aggregating, while the latter only scans and aggregates orders (even though it is a COUNT(DISTINCT) which is less performant than COUNT(*)). Almost always, the second method will be more performant since the cost of the extra scan/join will be worse than the less-performant aggregation.

The goal of this optimization is to rewrite the former case into the latter in relational structure. Here are examples of the two relational structures, the former which is to be optimized and the latter that it should be rewritten to:.

AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
 JOIN(condition=t0.LEFT_UNIQUE_KEY == t1.RIGHT_KEY, type=SEMI, columns={})
  LEFT_TREE
  RIGHT_TREE
AGGREGATE(keys={}, aggregations={'n_rows': NDISTINCT(RIGHT_KEY)})
 RIGHT_TREE

The pre-requisites for the transformation:

  • Current node is a no-groupby aggregation where the only aggfunc is COUNT(*)
  • Input to the node is a SEMI join where the condition is a single equality check where the column form the left hand side is a uniqueness key from the left hand side (e.g. every row of LEFT_TREE has a unique value of LEFT_UNIQUE_KEY)
  • The join has a reverse cardinality that always matches (every row from RIGHT_TREE is matched by at least 1 row from LEFT_TREE, e.g. each row of orders has a match in customers).

If this is so, the aggregate should be transformed to get rid of the join & LEFT_TREE, replacing the aggregate input with RIGHT_TREE and replacing COUNT() with NDISTINCT(RIGHT_KEY), where RIGHT_KEY is the column from the right hand side that was part of the join condition being checked for equality against.

The reverse cardinality check is important because of queries such as this one:

# How many customers in the building market segment have made an urgent order in 1994?

selected_orders = orders.WHERE((order_priority == '1-URGENT') & (YEAR(order_date) == 1994))
result = TPCH.CALCULATE(
  n=COUNT(customers.WHERE((market_segment == 'BUILDING') & HAS(selected_orders))
)

The relational structure for this is as follows:

AGGREGATE(keys={}, aggregations={'n_rows': COUNT()})
 JOIN(condition=t0.key == t1.customer_key, type=SEMI, columns={})
  FILTER(market_segment == 'BUILDING')
    SCAN(CUSTOMER)
  FILTER(YEAR(order_date) == 1994)
    SCAN(ORDERS)

Here, the reverse cardinality is SINGULAR_FILTERS, meaning each row from orders in the right hand side matches onto 0 or 1 rows from the left hand side, so the rewrite as-is is not valid since we would be losing the market segment filter. HOWEVER, the following alternative rewrite is possible:

AGGREGATE(keys={}, aggregations={'n_rows': NDISTINCT(key)})
 JOIN(condition=t0.key == t1.customer_key, type=INNER, columns={"key": t0.customer_key})
  FILTER(YEAR(order_date) == 1994)
    SCAN(ORDERS)
  FILTER(market_segment == 'BUILDING')
    SCAN(CUSTOMER)

^ This alternative rewrite is possible and operates on the same principle, just that if the reverse cardinality is not always matching then instead of pruning the join and LHS, we just flip the join inputs/cardinalities, pass the former-rhs key (now lhs) for the NDISTINCT to use, and make the join type INNER. However, this case is NOT as clear-cut of an optimization (since it still does a join, even though INNER is better than SEMI, but has also downgraded from COUNT(*) to NDISTINCT, so the performance being an improvement vs downgrade will depend on the relative sizes of the tables and the number of unique rows.

Ignoring this alternative case, the main optimization can be added as a special case in the redundant aggregation removal optimization in agg_removal.py. Inside this file, the function aggregation_uniqueness_helper is the main recursive protocol responsible for traversing the tree, figuring out which columns (or combinations of columns) are unique, and removing redundant joins if the grouping columns are already unique.

Inside the Aggregate case of this function, if the removal optimization did not trigger, then this new optimization can be inserted as a call to a helper function dealing with the new logic. See below:

        case Aggregate():
            node._input, input_uniqueness = aggregation_uniqueness_helper(node.input)
            agg_keys: frozenset[str] = frozenset(node.keys)
            output_uniqueness: set[frozenset[str]] = {agg_keys}
            for unique_set in input_uniqueness:
                if agg_keys.issuperset(unique_set):
                    node = delete_aggregation(node)
                    # If deleting the aggregation, then the uniqueness of the
                    # input is propagated through the new projection.
                    output_uniqueness = bubble_uniqueness(
                        input_uniqueness, node.columns, None
                    )
                    break
            # NEW CODE (if the optimziation does not fire, just returns `node` and `output_uniqueness` as-is.
            if isinstance(node, Aggregate):
              node, output_uniqueness = rewrite_count_semi(node, input_uniqueness, output_uniqueness)
            return node, output_uniqueness

There are several existing tests that cover this behavior, but new ones will need to be written:

  • cryptbank (masked column sqlite tests): filter_count_15, filter_count_16, general_join_02
  • sf_masked (masked column snowflake tests): patient_claims
  • misc TPC-H tests (all in test_pipeline_tpch_custom): redundant_has_on_plural, redundant_has_on_plural_lineitems, has_cross_correlated_singular

In particular the two example questions shown should be added as tests, as should more tests covering various edge cases:

  • Left hand side has a more complex subtree (e.g. customers.WHERE((nation.region == 'ASIA') & (nation.name != 'CHINA')))
  • All different sorts of combinations of cardinality between left vs right subtree. A few examples:
    • LEFT->RIGHT=PLURAL_FILTER, RIGHT->LEFT=SINGULAR_FILTER (should not rewrite). Example: left = nations in europe, right = customers in those nation in the building market segment (every such nation can have 0+ such customers, every customer has 0 or 1 such nations), aggregation would be counting the number of nations with at least 1 such customer
    • LEFT->RIGHT=PLURAL_FILTER, RIGHT->LEFT=SINGULAR_ACCESS (should rewrite). Example: left = all nations, right = customers in those nation in the automible market segment with an account balance below $-975 (every such nation can have 0+ such customers, every customer has 0 or 1 such nations), aggregation would be counting the number of nations with at least 1 such customer (would be transformed to do NDISTINCT on the nation keys of the customers)
    • LEFT->RIGHT=SINGULAR_FILTER, RIGHT->LEFT=PLURAL_FILTER (should not rewrite)
    • LEFT->RIGHT=SINGULAR_FILTER, RIGHT->LEFT=PLURAL_ACCESS (should rewrite)

Metadata

Metadata

Assignees

No one assigned

    Labels

    effort - mediummid-sized issue with average implementation time/difficultyenhancementNew feature or requestoptimizationImproving the speed/quality of PyDough's outputs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions