3232from typing import (
3333 Any ,
3434 AsyncIterator ,
35- Awaitable ,
3635 Callable ,
3736 Coroutine ,
3837 Dict ,
4443 Sequence ,
4544 Set ,
4645 TypedDict ,
46+ TypeAlias ,
4747 Union ,
4848)
4949
7878 _STDLIB_MODULE_NAMES : frozenset [str ] = frozenset () # type: ignore[no-redef]
7979
8080
81- try :
82- from google .cloud import storage
81+ if typing . TYPE_CHECKING :
82+ from google .cloud import storage # type: ignore[attr-defined]
8383
84- _StorageBucket : type [Any ] = storage .Bucket
85- except (ImportError , AttributeError ):
86- _StorageBucket : type [Any ] = Any # type: ignore[no-redef]
84+ _StorageBucket : typing .TypeAlias = storage .Bucket
85+ else :
86+ try :
87+ from google .cloud import storage # type: ignore[attr-defined]
8788
89+ _StorageBucket : type [Any ] = storage .Bucket
90+ except (ImportError , AttributeError ):
91+ _StorageBucket : type [Any ] = Any # type: ignore[no-redef]
8892
89- try :
93+
94+ if typing .TYPE_CHECKING :
9095 import packaging
9196
92- _SpecifierSet : type [Any ] = packaging .specifiers .SpecifierSet
93- except (ImportError , AttributeError ):
94- _SpecifierSet : type [Any ] = Any # type: ignore[no-redef]
97+ _SpecifierSet = packaging .specifiers .SpecifierSet
98+ else :
99+ try :
100+ import packaging
101+
102+ _SpecifierSet : type [Any ] = packaging .specifiers .SpecifierSet
103+ except (ImportError , AttributeError ):
104+ _SpecifierSet : type [Any ] = Any # type: ignore[no-redef]
95105
96106
97107try :
@@ -258,16 +268,22 @@ class OperationRegistrable(Protocol):
258268 """Protocol for agents that have registered operations."""
259269
260270 @abc .abstractmethod
261- def register_operations (self , ** kwargs ) -> Dict [str , Sequence [str ]]: # type: ignore[no-untyped-def]
271+ def register_operations (self , ** kwargs : Any ) -> dict [str , list [str ]]:
262272 """Register the user provided operations (modes and methods)."""
273+ pass
263274
264275
265- try :
276+ if typing . TYPE_CHECKING :
266277 from google .adk .agents import BaseAgent
267278
268- ADKAgent : type [Any ] = BaseAgent
269- except (ImportError , AttributeError ):
270- ADKAgent : type [Any ] = Any # type: ignore[no-redef]
279+ ADKAgent : TypeAlias = BaseAgent
280+ else :
281+ try :
282+ from google .adk .agents import BaseAgent
283+
284+ ADKAgent : Optional [TypeAlias ] = BaseAgent
285+ except (ImportError , AttributeError ):
286+ ADKAgent : Optional [TypeAlias ] = None # type: ignore[no-redef]
271287
272288_AgentEngineInterface = Union [
273289 ADKAgent ,
@@ -283,8 +299,9 @@ def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ig
283299class _ModuleAgentAttributes (TypedDict , total = False ):
284300 module_name : str
285301 agent_name : str
286- register_operations : Dict [str , Sequence [str ]]
302+ register_operations : Dict [str , list [str ]]
287303 sys_paths : Optional [Sequence [str ]]
304+ agent : _AgentEngineInterface
288305
289306
290307class ModuleAgent (Cloneable , OperationRegistrable ):
@@ -300,7 +317,7 @@ def __init__(
300317 * ,
301318 module_name : str ,
302319 agent_name : str ,
303- register_operations : Dict [str , Sequence [str ]],
320+ register_operations : Dict [str , list [str ]],
304321 sys_paths : Optional [Sequence [str ]] = None ,
305322 ):
306323 """Initializes a module-based agent.
@@ -310,7 +327,7 @@ def __init__(
310327 Required. The name of the module to import.
311328 agent_name (str):
312329 Required. The name of the agent in the module to instantiate.
313- register_operations (Dict[str, Sequence [str]]):
330+ register_operations (Dict[str, list [str]]):
314331 Required. A dictionary of API modes to a list of method names.
315332 sys_paths (Sequence[str]):
316333 Optional. The system paths to search for the module. It should
@@ -336,8 +353,11 @@ def clone(self) -> "ModuleAgent":
336353 sys_paths = self ._tmpl_attrs .get ("sys_paths" ),
337354 )
338355
339- def register_operations (self ) -> Dict [str , Sequence [str ]]:
340- self ._tmpl_attrs .get ("register_operations" )
356+ def register_operations (self , ** kwargs : Any ) -> dict [str , list [str ]]:
357+ reg_operations = self ._tmpl_attrs .get ("register_operations" )
358+ if reg_operations is None :
359+ raise ValueError ("Register operations is not set." )
360+ return reg_operations
341361
342362 def set_up (self ) -> None :
343363 """Sets up the agent for execution of queries at runtime.
@@ -411,7 +431,7 @@ def __call__(
411431class GetAsyncOperationFunction (Protocol ):
412432 async def __call__ (
413433 self , * , operation_name : str , ** kwargs : Any
414- ) -> Awaitable [ AgentEngineOperationUnion ] :
434+ ) -> AgentEngineOperationUnion :
415435 pass
416436
417437
@@ -507,7 +527,7 @@ def _await_operation(
507527def _compare_requirements (
508528 * ,
509529 requirements : Mapping [str , str ],
510- constraints : Union [Sequence [str ], Mapping [str , "_SpecifierSet" ]],
530+ constraints : Union [Sequence [str ], Mapping [str , Optional [ "_SpecifierSet" ] ]],
511531 required_packages : Optional [Iterator [str ]] = None ,
512532) -> _RequirementsValidationResult :
513533 """Compares the requirements with the constraints.
@@ -536,7 +556,7 @@ def _compare_requirements(
536556 """
537557 packaging_version = _import_packaging_version_or_raise ()
538558 if required_packages is None :
539- required_packages = _DEFAULT_REQUIRED_PACKAGES
559+ required_packages = _DEFAULT_REQUIRED_PACKAGES # type: ignore[assignment]
540560 result = _RequirementsValidationResult (
541561 warnings = _RequirementsValidationWarnings (missing = set (), incompatible = set ()),
542562 actions = _RequirementsValidationActions (append = set ()),
@@ -583,7 +603,7 @@ def _generate_class_methods_spec_or_raise(
583603 if isinstance (agent , ModuleAgent ):
584604 # We do a dry-run of setting up the agent engine to have the operations
585605 # needed for registration.
586- agent : ModuleAgent = agent .clone ()
606+ agent : ModuleAgent = agent .clone () # type: ignore[no-redef]
587607 try :
588608 agent .set_up ()
589609 except Exception as e :
@@ -819,13 +839,13 @@ def _get_gcs_bucket(
819839 new_bucket = storage_client .bucket (staging_bucket )
820840 gcs_bucket = storage_client .create_bucket (new_bucket , location = location )
821841 logger .info (f"Creating bucket { staging_bucket } in { location = } " )
822- return gcs_bucket # type: ignore[no-any-return]
842+ return gcs_bucket
823843
824844
825845def _get_registered_operations (
826846 * ,
827847 agent : _AgentEngineInterface ,
828- ) -> Dict [str , List [str ]]:
848+ ) -> dict [str , list [str ]]:
829849 """Retrieves registered operations for a AgentEngine."""
830850 if isinstance (agent , OperationRegistrable ):
831851 return agent .register_operations ()
@@ -859,13 +879,13 @@ def _import_cloudpickle_or_raise() -> types.ModuleType:
859879def _import_cloud_storage_or_raise () -> types .ModuleType :
860880 """Tries to import the Cloud Storage module."""
861881 try :
862- from google .cloud import storage
882+ from google .cloud import storage # type: ignore[attr-defined]
863883 except ImportError as e :
864884 raise ImportError (
865885 "Cloud Storage is not installed. Please call "
866886 "'pip install google-cloud-aiplatform[agent_engines]'."
867887 ) from e
868- return storage
888+ return storage # type: ignore[no-any-return]
869889
870890
871891def _import_packaging_requirements_or_raise () -> types .ModuleType :
@@ -1202,7 +1222,7 @@ def _upload_agent_engine(
12021222) -> None :
12031223 """Uploads the agent engine to GCS."""
12041224 cloudpickle = _import_cloudpickle_or_raise ()
1205- blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _BLOB_FILENAME } " ) # type: ignore[attr-defined]
1225+ blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _BLOB_FILENAME } " )
12061226 with blob .open ("wb" ) as f :
12071227 try :
12081228 cloudpickle .dump (agent , f )
@@ -1216,7 +1236,7 @@ def _upload_agent_engine(
12161236 _ = cloudpickle .load (f )
12171237 except Exception as e :
12181238 raise TypeError ("Agent engine serialized to an invalid format" ) from e
1219- dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } " # type: ignore[attr-defined]
1239+ dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } "
12201240 logger .info (f"Wrote to { dir_name } /{ _BLOB_FILENAME } " )
12211241
12221242
@@ -1227,9 +1247,9 @@ def _upload_requirements(
12271247 gcs_dir_name : str ,
12281248) -> None :
12291249 """Uploads the requirements file to GCS."""
1230- blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _REQUIREMENTS_FILE } " ) # type: ignore[attr-defined]
1250+ blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _REQUIREMENTS_FILE } " )
12311251 blob .upload_from_string ("\n " .join (requirements ))
1232- dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } " # type: ignore[attr-defined]
1252+ dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } "
12331253 logger .info (f"Writing to { dir_name } /{ _REQUIREMENTS_FILE } " )
12341254
12351255
@@ -1246,9 +1266,9 @@ def _upload_extra_packages(
12461266 for file in extra_packages :
12471267 tar .add (file )
12481268 tar_fileobj .seek (0 )
1249- blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _EXTRA_PACKAGES_FILE } " ) # type: ignore[attr-defined]
1269+ blob = gcs_bucket .blob (f"{ gcs_dir_name } /{ _EXTRA_PACKAGES_FILE } " )
12501270 blob .upload_from_string (tar_fileobj .read ())
1251- dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } " # type: ignore[attr-defined]
1271+ dir_name = f"gs://{ gcs_bucket .name } /{ gcs_dir_name } "
12521272 logger .info (f"Writing to { dir_name } /{ _EXTRA_PACKAGES_FILE } " )
12531273
12541274
@@ -1369,7 +1389,7 @@ def _validate_requirements_or_warn(
13691389 * ,
13701390 obj : Any ,
13711391 requirements : List [str ],
1372- ) -> Mapping [ str , str ]:
1392+ ) -> List [ str ]:
13731393 """Compiles the requirements into a list of requirements."""
13741394 requirements = requirements .copy ()
13751395 try :
@@ -1380,16 +1400,14 @@ def _validate_requirements_or_warn(
13801400 requirements = current_requirements ,
13811401 constraints = constraints ,
13821402 )
1383- for warning_type , warnings in missing_requirements .get (
1384- _WARNINGS_KEY , {}
1385- ).items ():
1403+ for warning_type , warnings in missing_requirements ["warnings" ].items ():
13861404 if warnings :
13871405 logger .warning (
13881406 f"The following requirements are { warning_type } : { warnings } "
13891407 )
1390- for action_type , actions in missing_requirements . get ( _ACTIONS_KEY , {}) .items ():
1408+ for action_type , actions in missing_requirements [ "actions" ] .items ():
13911409 if actions and action_type == _ACTION_APPEND :
1392- for action in actions :
1410+ for action in actions : # type: ignore[attr-defined]
13931411 requirements .append (action )
13941412 logger .info (f"The following requirements are appended: { actions } " )
13951413 except Exception as e :
@@ -1413,7 +1431,7 @@ def _validate_requirements_or_raise(
14131431 logger .info (f"Read the following lines: { requirements } " )
14141432 except IOError as err :
14151433 raise IOError (f"Failed to read requirements from { requirements = } " ) from err
1416- requirements = _validate_requirements_or_warn ( # type: ignore[assignment]
1434+ requirements = _validate_requirements_or_warn (
14171435 obj = agent ,
14181436 requirements = requirements ,
14191437 )
@@ -1560,19 +1578,6 @@ def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
15601578 return _method
15611579
15621580
1563- AgentEngineOperationUnion = Union [
1564- genai_types .AgentEngineOperation ,
1565- genai_types .AgentEngineMemoryOperation ,
1566- genai_types .AgentEngineGenerateMemoriesOperation ,
1567- ]
1568-
1569-
1570- class GetOperationFunction (Protocol ):
1571- def __call__ ( # noqa: E704
1572- self , * , operation_name : str , ** kwargs : Any
1573- ) -> AgentEngineOperationUnion : ...
1574-
1575-
15761581def _wrap_query_operation (* , method_name : str ) -> Callable [..., Any ]:
15771582 """Wraps an Agent Engine method, creating a callable for `query` API.
15781583
@@ -1835,7 +1840,7 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
18351840
18361841 return response
18371842
1838- return _method
1843+ return _method # type: ignore[return-value]
18391844
18401845
18411846def _yield_parsed_json (http_response : google_genai_types .HttpResponse ) -> Iterator [Any ]:
0 commit comments