@@ -68,44 +68,64 @@ def _get_resolved_typevars(self) -> Dict[TypeVar, Type]:
6868
6969 mapping : dict [TypeVar , type ] = {}
7070
71- non_generic_error = AssertionError (
72- f'{ self .class_name } is not a generic class. To make it generic, declare it like: '
73- f'class { self .class_name } (Generic[T], GenericMixin):...' )
74-
7571 if not hasattr (self , '__orig_bases__' ):
76- raise non_generic_error
72+ raise AssertionError (
73+ f'{ self .class_name } is not a generic class. To make it generic, declare it like: '
74+ f'class { self .class_name } (Generic[T], GenericMixin):...'
75+ )
7776
78- def collect (base , substitutions : dict [TypeVar , type ]) -> None :
79- origin = get_origin (base )
77+ def collect (base , substitutions : dict [TypeVar , type ]):
78+ """Recursively collect type var mappings from a generic base."""
79+ origin = get_origin (base ) or base
8080 args = get_args (base )
8181
82- if origin is None :
83- return
84-
8582 params = getattr (origin , '__parameters__' , ())
86- resolved = {}
83+ # copy substitutions so each recursion has its own view
84+ resolved = substitutions .copy ()
85+
8786 for param , arg in zip (params , args ):
88- resolved_arg = substitutions .get (arg , arg ) if isinstance (arg , TypeVar ) else arg
89- mapping [param ] = resolved_arg
90- resolved [param ] = resolved_arg
87+ if isinstance (arg , TypeVar ):
88+ arg = substitutions .get (arg , arg )
89+ mapping [param ] = arg
90+ resolved [param ] = arg
9191
92+ # Recurse into base classes, applying current substitutions
9293 for super_base in getattr (origin , '__orig_bases__' , []):
93- collect (super_base , resolved )
94-
95- # Prefer __orig_class__ if available
94+ super_origin = get_origin (super_base ) or super_base
95+ super_args = get_args (super_base )
96+
97+ if super_args :
98+ # Substitute any TypeVars in the super_base's args using resolved
99+ substituted_args = tuple (
100+ resolved .get (a , a ) if isinstance (a , TypeVar ) else a
101+ for a in super_args
102+ )
103+ # Build a new parametrized base so get_args() inside collect sees substituted_args
104+ try :
105+ substituted_base = super_origin [substituted_args ] # type: ignore[index]
106+ except TypeError :
107+ # Some origins won't accept subscription; fall back to passing the origin and trusting resolved
108+ substituted_base = super_base
109+ collect (base = substituted_base , substitutions = resolved )
110+ else :
111+ collect (base = super_base , substitutions = resolved )
112+
113+ # Start from __orig_class__ if present, else walk the declared MRO bases
96114 cls = getattr (self , '__orig_class__' , None )
97115 if cls is not None :
98116 collect (base = cls , substitutions = {})
99117 else :
100- for base in getattr (self .__class__ , '__orig_bases__' , []):
101- collect (base = base , substitutions = {})
118+ # Walk the full MRO to catch indirect generic ancestors
119+ for c in self .__class__ .__mro__ :
120+ for base in getattr (c , '__orig_bases__' , []):
121+ collect (base = base , substitutions = mapping )
102122
103- # Extra safety: ensure all declared typevars are resolved
123+ # Ensure no unresolved TypeVars remain
104124 all_params = set ()
105- for cls in self .__class__ .__mro__ :
106- all_params .update (getattr (cls , '__parameters__' , ()))
125+ for c in self .__class__ .__mro__ :
126+ all_params .update (getattr (c , '__parameters__' , ()))
107127
108- unresolved = {param for param in all_params if param not in mapping or isinstance (mapping [param ], TypeVar )}
128+ unresolved = {p for p in all_params if p not in mapping or isinstance (mapping [p ], TypeVar )}
109129 if unresolved :
110130 raise AssertionError (
111131 f'You need to instantiate this class with type parameters! Example: { self .class_name } [int]()\n '
@@ -115,7 +135,6 @@ def collect(base, substitutions: dict[TypeVar, type]) -> None:
115135 )
116136
117137 return mapping
118-
119138 @property
120139 def class_name (self ) -> str :
121140 """ Get the name of the class of this instance. """
0 commit comments