Skip to content

Commit 5fa6e1c

Browse files
fix GenericMixin
1 parent 4fed7d8 commit 5fa6e1c

File tree

5 files changed

+378
-26
lines changed

5 files changed

+378
-26
lines changed

CHANGELOG.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
# Changelog
2+
## Pedantic 2.2.2
3+
- fix `GenericMixin`
4+
25
## Pedantic 2.2.1
36
- fixed `setuptools` deprecation warnings
47
- delete unused scripts
@@ -7,7 +10,7 @@
710
- migrated from `setup.py` to `pyproject.toml`
811

912
## Pedantic 2.1.11
10-
- improve `GenericMixin` such that it also find bound type variables in parent classes
13+
- improve `GenericMixin` such that it also finds bound type variables in parent classes
1114

1215
## Pedantic 2.1.10
1316
- added type check support for `functools.partial`

pedantic/mixins/generic_mixin.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,44 +68,64 @@ def _get_resolved_typevars(self) -> Dict[TypeVar, Type]:
6868

6969
mapping: dict[TypeVar, type] = {}
7070

71-
non_generic_error = AssertionError(
72-
f'{self.class_name} is not a generic class. To make it generic, declare it like: '
73-
f'class {self.class_name}(Generic[T], GenericMixin):...')
74-
7571
if not hasattr(self, '__orig_bases__'):
76-
raise non_generic_error
72+
raise AssertionError(
73+
f'{self.class_name} is not a generic class. To make it generic, declare it like: '
74+
f'class {self.class_name}(Generic[T], GenericMixin):...'
75+
)
7776

78-
def collect(base, substitutions: dict[TypeVar, type]) -> None:
79-
origin = get_origin(base)
77+
def collect(base, substitutions: dict[TypeVar, type]):
78+
"""Recursively collect type var mappings from a generic base."""
79+
origin = get_origin(base) or base
8080
args = get_args(base)
8181

82-
if origin is None:
83-
return
84-
8582
params = getattr(origin, '__parameters__', ())
86-
resolved = {}
83+
# copy substitutions so each recursion has its own view
84+
resolved = substitutions.copy()
85+
8786
for param, arg in zip(params, args):
88-
resolved_arg = substitutions.get(arg, arg) if isinstance(arg, TypeVar) else arg
89-
mapping[param] = resolved_arg
90-
resolved[param] = resolved_arg
87+
if isinstance(arg, TypeVar):
88+
arg = substitutions.get(arg, arg)
89+
mapping[param] = arg
90+
resolved[param] = arg
9191

92+
# Recurse into base classes, applying current substitutions
9293
for super_base in getattr(origin, '__orig_bases__', []):
93-
collect(super_base, resolved)
94-
95-
# Prefer __orig_class__ if available
94+
super_origin = get_origin(super_base) or super_base
95+
super_args = get_args(super_base)
96+
97+
if super_args:
98+
# Substitute any TypeVars in the super_base's args using resolved
99+
substituted_args = tuple(
100+
resolved.get(a, a) if isinstance(a, TypeVar) else a
101+
for a in super_args
102+
)
103+
# Build a new parametrized base so get_args() inside collect sees substituted_args
104+
try:
105+
substituted_base = super_origin[substituted_args] # type: ignore[index]
106+
except TypeError:
107+
# Some origins won't accept subscription; fall back to passing the origin and trusting resolved
108+
substituted_base = super_base
109+
collect(base=substituted_base, substitutions=resolved)
110+
else:
111+
collect(base=super_base, substitutions=resolved)
112+
113+
# Start from __orig_class__ if present, else walk the declared MRO bases
96114
cls = getattr(self, '__orig_class__', None)
97115
if cls is not None:
98116
collect(base=cls, substitutions={})
99117
else:
100-
for base in getattr(self.__class__, '__orig_bases__', []):
101-
collect(base=base, substitutions={})
118+
# Walk the full MRO to catch indirect generic ancestors
119+
for c in self.__class__.__mro__:
120+
for base in getattr(c, '__orig_bases__', []):
121+
collect(base=base, substitutions=mapping)
102122

103-
# Extra safety: ensure all declared typevars are resolved
123+
# Ensure no unresolved TypeVars remain
104124
all_params = set()
105-
for cls in self.__class__.__mro__:
106-
all_params.update(getattr(cls, '__parameters__', ()))
125+
for c in self.__class__.__mro__:
126+
all_params.update(getattr(c, '__parameters__', ()))
107127

108-
unresolved = {param for param in all_params if param not in mapping or isinstance(mapping[param], TypeVar)}
128+
unresolved = {p for p in all_params if p not in mapping or isinstance(mapping[p], TypeVar)}
109129
if unresolved:
110130
raise AssertionError(
111131
f'You need to instantiate this class with type parameters! Example: {self.class_name}[int]()\n'
@@ -115,7 +135,6 @@ def collect(base, substitutions: dict[TypeVar, type]) -> None:
115135
)
116136

117137
return mapping
118-
119138
@property
120139
def class_name(self) -> str:
121140
""" Get the name of the class of this instance. """

pedantic/tests/test_generic_mixin.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from pedantic import GenericMixin
55

6+
A = TypeVar('A')
7+
E = TypeVar('E')
8+
S = TypeVar('S')
69
T = TypeVar('T')
710
U = TypeVar('U')
811

@@ -119,3 +122,44 @@ class Bar(Foo[int], Generic[U]): ...
119122

120123
bar = Bar[str]()
121124
assert bar.type_vars == {T: int, U: str}
125+
126+
def test_very_complex_inheritance(self):
127+
class Foo(Generic[E], GenericMixin): ...
128+
class Bar(Foo[int], Generic[S]): ...
129+
class Baz(Foo[int]): ...
130+
class Deep(Baz): ...
131+
class Deeper(Baz, Generic[T]): ...
132+
133+
foo = Foo[str]()
134+
actual = foo.type_vars
135+
assert actual == {E: str}
136+
137+
bar = Bar[str]()
138+
actual = bar.type_vars
139+
assert actual == {E: int, S: str}
140+
141+
baz = Baz()
142+
actual = baz.type_vars
143+
assert actual == {E: int}
144+
145+
deep = Deep()
146+
actual = deep.type_vars
147+
assert actual == {E: int}
148+
149+
deeper = Deeper[bool]()
150+
actual = deeper.type_vars
151+
assert actual == {E: int, T: bool}
152+
153+
with self.assertRaises(expected_exception=AssertionError) as err:
154+
Foo().type_vars
155+
156+
assert 'You need to instantiate this class with type parameters! Example: Foo[int]()' in err.exception.args[0]
157+
158+
def test_substitution_lookup_hits(self):
159+
class Base(Generic[A], GenericMixin): ...
160+
class Mid(Base[A], Generic[A]): ...
161+
class Final(Mid[int]): ...
162+
163+
obj = Final()
164+
actual = obj.type_vars
165+
assert actual == {A: int}

0 commit comments

Comments
 (0)