11import inspect
22from dataclasses import dataclass
33from typing import Annotated , Literal , get_args , get_origin , Callable , Any
4+ from enum import Enum
45import types
56
67from pydantic import TypeAdapter
@@ -67,6 +68,10 @@ class ParamInfo:
6768 - Contains Field constraints for the list container
6869 - None if no list-level constraints
6970 Example: Field(min_items=2, max_items=5)
71+ enum_type: The original Enum type if parameter was an Enum.
72+ - None for non-Enum parameters
73+ - Stored to convert string back to Enum in validation
74+ - Example: For `color: Color`, stores the Color Enum class
7075 """
7176 type : type
7277 default : Any = None
@@ -76,6 +81,7 @@ class ParamInfo:
7681 optional_enabled : bool = False
7782 is_list : bool = False
7883 list_field_info : Any = None
84+ enum_type : None = None
7985
8086def analyze (func : Callable [..., Any ]) -> dict [str , ParamInfo ]:
8187 """Analyze a function's signature and extract parameter metadata.
@@ -110,6 +116,7 @@ def analyze(func: Callable[..., Any]) -> dict[str, ParamInfo]:
110116 is_optional = False
111117 optional_default_enabled = None # None = auto, True = enabled, False = disabled
112118 is_list = False
119+ enum_type = None
113120
114121 # 1. Extract base type from Annotated (OUTER level)
115122 # This could be constraints for the list itself
@@ -239,6 +246,27 @@ def analyze(func: Callable[..., Any]) -> dict[str, ParamInfo]:
239246 else :
240247 t = type (None )
241248
249+ # 5b. Handle Enum types
250+ elif isinstance (t , type ) and issubclass (t , Enum ):
251+ opts = tuple (e .value for e in t )
252+
253+ if not opts :
254+ raise ValueError (f"'{ name } ': Enum must have at least one value" )
255+
256+ types_set = {type (v ) for v in opts }
257+ if len (types_set ) > 1 :
258+ raise TypeError (f"'{ name } ': Enum values must be same type" )
259+
260+ if default is not None :
261+ if not isinstance (default , t ):
262+ raise TypeError (f"'{ name } ': default must be { t .__name__ } instance" )
263+ default = default .value
264+
265+ enum_type = t
266+
267+ f = Literal [opts ]
268+ t = types_set .pop ()
269+
242270 # 6. Validate base type
243271 if t not in VALID :
244272 raise TypeError (f"'{ name } ': { t } not supported" )
@@ -289,6 +317,6 @@ def analyze(func: Callable[..., Any]) -> dict[str, ParamInfo]:
289317 # No default, start disabled
290318 final_optional_enabled = False
291319
292- result [name ] = ParamInfo (t , default , f , dynamic_func , is_optional , final_optional_enabled , is_list , list_f )
320+ result [name ] = ParamInfo (t , default , f , dynamic_func , is_optional , final_optional_enabled , is_list , list_f , enum_type )
293321
294322 return result
0 commit comments