Skip to content

Commit d68de16

Browse files
Add functional dependency check for aggregation operator
In A.aggr(B, ...), ensures every entry in B matches exactly one entry in A: - B must have all of A's primary key attributes - Primary key attributes must be homologous (same lineage) - Clear error messages for missing attributes or non-homologous lineage Updated docstrings for: - Aggregation.create() - QueryExpression.aggr() - U.aggr() Updated spec document with: - Functional dependency requirements - Error message examples - Additional test cases Co-authored-by: dimitri-yatsenko <dimitri@datajoint.com>
1 parent e64e7a0 commit d68de16

File tree

2 files changed

+105
-14
lines changed

2 files changed

+105
-14
lines changed

docs/src/design/semantic-matching-spec.md

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ Semantic matching applies to all binary operations that match attributes:
136136
| `A * B` | Join | Matches on homologous namesakes |
137137
| `A & B` | Restriction | Matches on homologous namesakes |
138138
| `A - B` | Anti-restriction | Matches on homologous namesakes |
139-
| `A.aggr(B, ...)` | Aggregation | Matches on homologous namesakes |
139+
| `A.aggr(B, ...)` | Aggregation | Requires functional dependency (see below) |
140140

141141
### The `.join()` Method
142142

@@ -438,8 +438,27 @@ def proj(self, *attributes, **named_attributes):
438438

439439
### Aggregation (`aggr`)
440440

441-
Aggregation creates a new expression with:
442-
- Group attributes retain their lineage from the group operand
441+
In `A.aggr(B, ...)`, entries from B are grouped by A's primary key and aggregate functions are computed.
442+
443+
**Functional Dependency Requirement**: Every entry in B must match exactly one entry in A. This requires:
444+
445+
1. **B must have all of A's primary key attributes**: If A's primary key is `(a, b)`, then B must contain attributes named `a` and `b`.
446+
447+
2. **Primary key attributes must be homologous**: The namesake attributes in B must have the same lineage as in A. This ensures they represent the same entity.
448+
449+
```python
450+
# Valid: Session.aggr(Trial, ...) where Trial has session_id from Session
451+
Session.aggr(Trial, n='count(*)') # OK - Trial.session_id traces to Session.session_id
452+
453+
# Invalid: Missing primary key attribute
454+
Session.aggr(Stimulus, n='count(*)') # Error if Stimulus lacks session_id
455+
456+
# Invalid: Non-homologous primary key
457+
TableA.aggr(TableB, n='count(*)') # Error if TableB.id has different lineage than TableA.id
458+
```
459+
460+
**Result lineage**:
461+
- Group attributes retain their lineage from the grouping expression (A)
443462
- Aggregated attributes have `lineage=None` (they are computations)
444463

445464
### Union (`+`)
@@ -470,6 +489,22 @@ DataJointError: dj.U(...) * table is deprecated in DataJoint 2.0.
470489
Use dj.U(...) & table instead.
471490
```
472491

492+
### Aggregation Missing Primary Key
493+
494+
```
495+
DataJointError: Aggregation requires functional dependency: `group` must have all primary key
496+
attributes of the grouping expression. Missing: {'session_id'}.
497+
Use .proj() to add the missing attributes or verify the schema design.
498+
```
499+
500+
### Aggregation Non-Homologous Primary Key
501+
502+
```
503+
DataJointError: Aggregation requires homologous primary key attributes.
504+
Attribute `id` has different lineages: university.student.id (grouping) vs university.course.id (group).
505+
Use .proj() to rename one of the attributes or .join(semantic_check=False) in a manual aggregation.
506+
```
507+
473508
## Testing Strategy
474509

475510
### Unit Tests
@@ -496,6 +531,12 @@ Use dj.U(...) & table instead.
496531
- `dj.U - table` raises error
497532
- `dj.U * table` raises deprecation error
498533

534+
5. **Aggregation functional dependency**:
535+
- `A.aggr(B)` works when B has all of A's PK attributes with same lineage
536+
- `A.aggr(B)` raises error when B is missing PK attributes
537+
- `A.aggr(B)` raises error when PK attributes have different lineage
538+
- `dj.U('a', 'b').aggr(B)` works when B has `a` and `b` attributes
539+
499540
### Integration Tests
500541

501542
1. **Schema migration**: Existing schema gets `~lineage` table populated correctly

src/datajoint/expression.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -464,13 +464,20 @@ def proj(self, *attributes, **named_attributes):
464464

465465
def aggr(self, group, *attributes, keep_all_rows=False, **named_attributes):
466466
"""
467-
Aggregation of the type U('attr1','attr2').aggr(group, computation="QueryExpression")
468-
has the primary key ('attr1','attr2') and performs aggregation computations for all matching elements of `group`.
467+
Aggregate `group` over the primary key of `self`.
469468
470-
:param group: The query expression to be aggregated.
471-
:param keep_all_rows: True=keep all the rows from self. False=keep only rows that match entries in group.
472-
:param named_attributes: computations of the form new_attribute="sql expression on attributes of group"
473-
:return: The derived query expression
469+
In A.aggr(B, ...), groups entries from B by the primary key of A and computes
470+
aggregate functions. Requires functional dependency: every entry in B must match
471+
exactly one entry in A. This means B must have all of A's primary key attributes
472+
as homologous namesakes (same name AND same lineage).
473+
474+
:param group: the query expression to aggregate (B in A.aggr(B))
475+
:param attributes: attributes from self to include in the result
476+
:param keep_all_rows: True=keep all rows from self (left join). False=keep only matching rows.
477+
:param named_attributes: aggregation computations, e.g., count='count(*)', avg_val='avg(value)'
478+
:return: query expression with self's primary key and the computed aggregations
479+
:raises DataJointError: if group is missing primary key attributes from self,
480+
or if namesake primary key attributes have different lineages
474481
"""
475482
if Ellipsis in attributes:
476483
# expand ellipsis to include only attributes from the left table
@@ -631,9 +638,47 @@ class Aggregation(QueryExpression):
631638

632639
@classmethod
633640
def create(cls, arg, group, keep_all_rows=False):
641+
"""
642+
Create an aggregation expression.
643+
644+
For A.aggr(B, ...), ensures functional dependency: every entry in B must match
645+
exactly one entry in A. This requires B to have all of A's primary key attributes
646+
as homologous namesakes (same name AND same lineage).
647+
648+
:param arg: the grouping expression (A in A.aggr(B))
649+
:param group: the expression to aggregate (B in A.aggr(B))
650+
:param keep_all_rows: if True, keep all rows from arg (left join behavior)
651+
:raises DataJointError: if group is missing any primary key attributes from arg,
652+
or if namesake attributes have different lineages
653+
"""
634654
if inspect.isclass(group) and issubclass(group, QueryExpression):
635655
group = group() # instantiate if a class
636656
assert isinstance(group, QueryExpression)
657+
658+
# Check functional dependency: group must have all of arg's primary key attributes
659+
missing_pk = set(arg.primary_key) - set(group.heading.names)
660+
if missing_pk:
661+
raise DataJointError(
662+
f"Aggregation requires functional dependency: `group` must have all primary key "
663+
f"attributes of the grouping expression. Missing: {missing_pk}. "
664+
f"Use .proj() to add the missing attributes or verify the schema design."
665+
)
666+
667+
# Check that primary key attributes are homologous (same lineage)
668+
# This is done for QueryExpression args; U is always compatible
669+
if not isinstance(arg, U):
670+
for attr_name in arg.primary_key:
671+
arg_lineage = arg.heading[attr_name].lineage
672+
group_lineage = group.heading[attr_name].lineage
673+
if arg_lineage != group_lineage:
674+
raise DataJointError(
675+
f"Aggregation requires homologous primary key attributes. "
676+
f"Attribute `{attr_name}` has different lineages: "
677+
f"{arg_lineage} (grouping) vs {group_lineage} (group). "
678+
f"Use .proj() to rename one of the attributes or "
679+
f".join(semantic_check=False) in a manual aggregation."
680+
)
681+
637682
if keep_all_rows and len(group.support) > 1 or group.heading.new_attributes:
638683
group = group.make_subquery() # subquery if left joining a join
639684
join = arg.join(group, left=keep_all_rows) # reuse the join logic
@@ -853,12 +898,17 @@ def __sub__(self, other):
853898

854899
def aggr(self, group, **named_attributes):
855900
"""
856-
Aggregation of the type U('attr1','attr2').aggr(group, computation="QueryExpression")
857-
has the primary key ('attr1','attr2') and performs aggregation computations for all matching elements of `group`.
901+
Aggregate `group` over the attributes of this universal set.
902+
903+
In dj.U('attr1', 'attr2').aggr(B, ...), groups entries from B by attr1 and attr2
904+
and computes aggregate functions. Requires B to have all specified attributes.
905+
Since dj.U is homologous to any namesake attribute, lineage compatibility is
906+
always satisfied.
858907
859-
:param group: The query expression to be aggregated.
860-
:param named_attributes: computations of the form new_attribute="sql expression on attributes of group"
861-
:return: The derived query expression
908+
:param group: the query expression to aggregate
909+
:param named_attributes: aggregation computations, e.g., count='count(*)', avg_val='avg(value)'
910+
:return: query expression with U's attributes as primary key and the computed aggregations
911+
:raises DataJointError: if group is missing any of U's primary key attributes
862912
"""
863913
if named_attributes.get("keep_all_rows", False):
864914
raise DataJointError("Cannot set keep_all_rows=True when aggregating on a universal set.")

0 commit comments

Comments
 (0)