Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Changelog
## Pedantic 2.2.2
- fix `GenericMixin`

## Pedantic 2.2.1
- fixed `setuptools` deprecation warnings
- delete unused scripts
Expand All @@ -7,7 +10,7 @@
- migrated from `setup.py` to `pyproject.toml`

## Pedantic 2.1.11
- improve `GenericMixin` such that it also find bound type variables in parent classes
- improve `GenericMixin` such that it also finds bound type variables in parent classes

## Pedantic 2.1.10
- added type check support for `functools.partial`
Expand Down
67 changes: 43 additions & 24 deletions pedantic/mixins/generic_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,44 +68,64 @@ def _get_resolved_typevars(self) -> Dict[TypeVar, Type]:

mapping: dict[TypeVar, type] = {}

non_generic_error = AssertionError(
f'{self.class_name} is not a generic class. To make it generic, declare it like: '
f'class {self.class_name}(Generic[T], GenericMixin):...')

if not hasattr(self, '__orig_bases__'):
raise non_generic_error
raise AssertionError(
f'{self.class_name} is not a generic class. To make it generic, declare it like: '
f'class {self.class_name}(Generic[T], GenericMixin):...'
)

def collect(base, substitutions: dict[TypeVar, type]) -> None:
origin = get_origin(base)
def collect(base, substitutions: dict[TypeVar, type]):
"""Recursively collect type var mappings from a generic base."""
origin = get_origin(base) or base
args = get_args(base)

if origin is None:
return

params = getattr(origin, '__parameters__', ())
resolved = {}
# copy substitutions so each recursion has its own view
resolved = substitutions.copy()

for param, arg in zip(params, args):
resolved_arg = substitutions.get(arg, arg) if isinstance(arg, TypeVar) else arg
mapping[param] = resolved_arg
resolved[param] = resolved_arg
if isinstance(arg, TypeVar):
arg = substitutions.get(arg, arg)
mapping[param] = arg
resolved[param] = arg

# Recurse into base classes, applying current substitutions
for super_base in getattr(origin, '__orig_bases__', []):
collect(super_base, resolved)

# Prefer __orig_class__ if available
super_origin = get_origin(super_base) or super_base
super_args = get_args(super_base)

if super_args:
# Substitute any TypeVars in the super_base's args using resolved
substituted_args = tuple(
resolved.get(a, a) if isinstance(a, TypeVar) else a
for a in super_args
)
# Build a new parametrized base so get_args() inside collect sees substituted_args
try:
substituted_base = super_origin[substituted_args] # type: ignore[index]
except TypeError:
# Some origins won't accept subscription; fall back to passing the origin and trusting resolved
substituted_base = super_base
collect(base=substituted_base, substitutions=resolved)
else:
collect(base=super_base, substitutions=resolved)

# Start from __orig_class__ if present, else walk the declared MRO bases
cls = getattr(self, '__orig_class__', None)
if cls is not None:
collect(base=cls, substitutions={})
else:
for base in getattr(self.__class__, '__orig_bases__', []):
collect(base=base, substitutions={})
# Walk the full MRO to catch indirect generic ancestors
for c in self.__class__.__mro__:
for base in getattr(c, '__orig_bases__', []):
collect(base=base, substitutions=mapping)

# Extra safety: ensure all declared typevars are resolved
# Ensure no unresolved TypeVars remain
all_params = set()
for cls in self.__class__.__mro__:
all_params.update(getattr(cls, '__parameters__', ()))
for c in self.__class__.__mro__:
all_params.update(getattr(c, '__parameters__', ()))

unresolved = {param for param in all_params if param not in mapping or isinstance(mapping[param], TypeVar)}
unresolved = {p for p in all_params if p not in mapping or isinstance(mapping[p], TypeVar)}
if unresolved:
raise AssertionError(
f'You need to instantiate this class with type parameters! Example: {self.class_name}[int]()\n'
Expand All @@ -115,7 +135,6 @@ def collect(base, substitutions: dict[TypeVar, type]) -> None:
)

return mapping

@property
def class_name(self) -> str:
""" Get the name of the class of this instance. """
Expand Down
44 changes: 44 additions & 0 deletions pedantic/tests/test_generic_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from pedantic import GenericMixin

A = TypeVar('A')
E = TypeVar('E')
S = TypeVar('S')
T = TypeVar('T')
U = TypeVar('U')

Expand Down Expand Up @@ -119,3 +122,44 @@ class Bar(Foo[int], Generic[U]): ...

bar = Bar[str]()
assert bar.type_vars == {T: int, U: str}

def test_very_complex_inheritance(self):
class Foo(Generic[E], GenericMixin): ...
class Bar(Foo[int], Generic[S]): ...
class Baz(Foo[int]): ...
class Deep(Baz): ...
class Deeper(Baz, Generic[T]): ...

foo = Foo[str]()
actual = foo.type_vars
assert actual == {E: str}

bar = Bar[str]()
actual = bar.type_vars
assert actual == {E: int, S: str}

baz = Baz()
actual = baz.type_vars
assert actual == {E: int}

deep = Deep()
actual = deep.type_vars
assert actual == {E: int}

deeper = Deeper[bool]()
actual = deeper.type_vars
assert actual == {E: int, T: bool}

with self.assertRaises(expected_exception=AssertionError) as err:
Foo().type_vars

assert 'You need to instantiate this class with type parameters! Example: Foo[int]()' in err.exception.args[0]

def test_substitution_lookup_hits(self):
class Base(Generic[A], GenericMixin): ...
class Mid(Base[A], Generic[A]): ...
class Final(Mid[int]): ...

obj = Final()
actual = obj.type_vars
assert actual == {A: int}
Loading