Skip to content

Commit c69b446

Browse files
committed
Add aggregation exception for left join constraint
Aggregation with keep_all_rows=True uses a left join internally but has the opposite requirement (B → A) compared to direct left joins (A → B). This is valid because the GROUP BY clause resets the primary key to PK(A), ensuring non-NULL primary key values. Changes: - Add _aggregation parameter to heading.join() and expression.join() - Aggregation.create() passes _aggregation=True to bypass validation - Document aggregation exception in spec - Add tests for aggregation with keep_all_rows=True Co-authored-by: dimitri-yatsenko<dimitri@datajoint.com>
1 parent 496e014 commit c69b446

File tree

4 files changed

+102
-9
lines changed

4 files changed

+102
-9
lines changed

docs/src/design/semantic-matching-spec.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,32 @@ The following attributes from the right operand's primary key are not determined
341341
the left operand: ['z']. Use an inner join or restructure the query.
342342
```
343343

344+
### Aggregation Exception
345+
346+
`A.aggr(B, keep_all_rows=True)` uses a left join internally but has the **opposite requirement**: **B → A** (the group expression B must have all of A's primary key attributes).
347+
348+
This apparent contradiction is resolved by the `GROUP BY` clause:
349+
350+
1. Aggregation requires B → A so that B can be grouped by A's primary key
351+
2. The intermediate left join `A LEFT JOIN B` would have an invalid PK under the normal left join rules (B → A case gives PK(B))
352+
3. However, aggregation's `GROUP BY PK(A)` clause **resets** the primary key to PK(A)
353+
4. The final result has PK(A), which consists entirely of non-NULL values from A
354+
355+
**Example:**
356+
```
357+
Session: session_id*, date
358+
Trial: session_id*, trial_num*, response_time (references Session)
359+
360+
# Aggregation with keep_all_rows=True
361+
Session.aggr(Trial, keep_all_rows=True, avg_rt='avg(response_time)')
362+
363+
# Internally: Session LEFT JOIN Trial (B → A, would normally be invalid)
364+
# But GROUP BY session_id resets PK to {session_id}
365+
# Result: All sessions, with avg_rt=NULL for sessions without trials
366+
```
367+
368+
The left join constraint validation is bypassed internally for aggregation because the `GROUP BY` clause guarantees a valid primary key in the final result.
369+
344370
## Universal Set `dj.U`
345371

346372
`dj.U()` or `dj.U('attr1', 'attr2', ...)` represents the universal set of all possible values and lineages.

src/datajoint/expression.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def __matmul__(self, other):
282282
"The @ operator has been removed in DataJoint 2.0. " "Use .join(other, semantic_check=False) for permissive joins."
283283
)
284284

285-
def join(self, other, semantic_check=True, left=False):
285+
def join(self, other, semantic_check=True, left=False, _aggregation=False):
286286
"""
287287
Create the joined QueryExpression.
288288
@@ -293,6 +293,7 @@ def join(self, other, semantic_check=True, left=False):
293293
:param semantic_check: If True (default), raise error on non-homologous namesakes.
294294
If False, bypass semantic check (use for legacy compatibility).
295295
:param left: If True, perform a left join retaining all rows from self.
296+
:param _aggregation: Internal flag to bypass left join validation for aggregation.
296297
297298
Examples:
298299
a * b is short for a.join(b)
@@ -336,10 +337,10 @@ def join(self, other, semantic_check=True, left=False):
336337
result._connection = self.connection
337338
result._support = self.support + other.support
338339
result._left = self._left + [left] + other._left
339-
result._heading = self.heading.join(other.heading, left=left)
340+
result._heading = self.heading.join(other.heading, left=left, _aggregation=_aggregation)
340341
result._restriction = AndList(self.restriction)
341342
result._restriction.append(other.restriction)
342-
result._original_heading = self.original_heading.join(other.original_heading, left=left)
343+
result._original_heading = self.original_heading.join(other.original_heading, left=left, _aggregation=_aggregation)
343344
assert len(result.support) == len(result._left) + 1
344345
return result
345346

@@ -683,7 +684,8 @@ def create(cls, arg, group, keep_all_rows=False):
683684

684685
if keep_all_rows and len(group.support) > 1 or group.heading.new_attributes:
685686
group = group.make_subquery() # subquery if left joining a join
686-
join = arg.join(group, left=keep_all_rows) # reuse the join logic
687+
# Pass _aggregation=True to bypass left join validation (aggregation resets PK via GROUP BY)
688+
join = arg.join(group, left=keep_all_rows, _aggregation=True)
687689
result = cls()
688690
result._connection = join.connection
689691
result._heading = join.heading.set_primary_key(arg.primary_key) # use left operand's primary key

src/datajoint/heading.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def select(self, select_list, rename_map=None, compute_map=None):
468468
)
469469
return Heading(chain(copy_attrs, compute_attrs))
470470

471-
def join(self, other, left=False):
471+
def join(self, other, left=False, _aggregation=False):
472472
"""
473473
Join two headings into a new one.
474474
@@ -486,11 +486,16 @@ def join(self, other, left=False):
486486
- If B → A or Neither, the PK would include B's attributes, which could be NULL
487487
- Only when A → B does PK(A) uniquely identify all result rows
488488
489+
Exception: Aggregation (A.aggr(B, keep_all_rows=True)) uses a left join internally
490+
but requires B → A instead. This is valid because the GROUP BY clause resets the
491+
primary key to PK(A), which consists of non-NULL values from the left operand.
492+
489493
It assumes that self and other are headings that share no common dependent attributes.
490494
491495
:param other: The other heading to join with
492-
:param left: If True, this is a left join (requires A → B)
493-
:raises DataJointError: If left=True and A does not determine B
496+
:param left: If True, this is a left join (requires A → B unless _aggregation=True)
497+
:param _aggregation: If True, skip left join validation (used by Aggregation.create)
498+
:raises DataJointError: If left=True and A does not determine B (unless _aggregation)
494499
"""
495500
from .errors import DataJointError
496501

@@ -502,8 +507,8 @@ def join(self, other, left=False):
502507
name in other.primary_key or name in other.secondary_attributes for name in self.primary_key
503508
)
504509

505-
# For left joins, require A → B
506-
if left and not self_determines_other:
510+
# For left joins, require A → B (unless this is an aggregation context)
511+
if left and not _aggregation and not self_determines_other:
507512
missing = [
508513
name for name in other.primary_key if name not in self.primary_key and name not in self.secondary_attributes
509514
]

tests/test_semantic_matching.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,3 +754,63 @@ def test_inner_join_still_works_when_b_determines_a(self, schema_pk_rules):
754754

755755
# PK should be {x, z} (B's PK)
756756
assert set(result.primary_key) == {"x", "z"}
757+
758+
759+
class TestAggregationWithKeepAllRows:
760+
"""
761+
Test that aggregation with keep_all_rows=True works correctly.
762+
763+
Aggregation uses a left join internally but has the opposite requirement (B → A)
764+
compared to direct left joins (which require A → B). This is valid because the
765+
GROUP BY clause resets the PK to PK(A).
766+
"""
767+
768+
def test_aggregation_keep_all_rows_works_with_b_determines_a(self, schema_pk_rules):
769+
"""
770+
Aggregation with keep_all_rows=True should work when B → A.
771+
772+
A: x* PK(A) = {x}
773+
B: x*, y* PK(B) = {x, y}
774+
775+
B → A? x in PK(B) → Yes (aggregation requirement met)
776+
777+
The internal left join would normally fail (B → A, not A → B), but
778+
aggregation bypasses this because GROUP BY resets PK to {x}.
779+
"""
780+
TableX = schema_pk_rules["TableX"]
781+
TableXY = schema_pk_rules["TableXY"]
782+
783+
# This should work - aggregation with keep_all_rows=True
784+
result = TableX.aggr(TableXY, keep_all_rows=True, count="count(*)")
785+
786+
# PK should be PK(A) = {x} (reset by GROUP BY)
787+
assert set(result.primary_key) == {"x"}
788+
789+
def test_aggregation_keep_all_rows_produces_correct_pk(self, schema_pk_rules):
790+
"""
791+
Aggregation result should always have PK(A), regardless of functional dependencies.
792+
"""
793+
TableXY = schema_pk_rules["TableXY"]
794+
TableXZwithY = schema_pk_rules["TableXZwithY"]
795+
796+
# TableXY (A): PK = {x, y}
797+
# TableXZwithY (B): PK = {x, z}, y is secondary
798+
# B → A (y secondary in B), so left join would use PK(B) = {x, z}
799+
# But aggregation resets to PK(A) = {x, y}
800+
result = TableXY.aggr(TableXZwithY, keep_all_rows=True, count="count(*)")
801+
802+
# PK should be PK(A) = {x, y}
803+
assert set(result.primary_key) == {"x", "y"}
804+
805+
def test_aggregation_without_keep_all_rows_also_works(self, schema_pk_rules):
806+
"""
807+
Normal aggregation (keep_all_rows=False) should continue to work.
808+
"""
809+
TableX = schema_pk_rules["TableX"]
810+
TableXY = schema_pk_rules["TableXY"]
811+
812+
# Normal aggregation (inner join behavior)
813+
result = TableX.aggr(TableXY, count="count(*)")
814+
815+
# PK should be PK(A) = {x}
816+
assert set(result.primary_key) == {"x"}

0 commit comments

Comments
 (0)