Skip to content

Commit dfcccbd

Browse files
sararobcopybara-github
authored andcommitted
chore: Resolve remaining agent engines mypy errors
PiperOrigin-RevId: 862709652
1 parent d685d81 commit dfcccbd

6 files changed

Lines changed: 129 additions & 105 deletions

File tree

.github/workflows/mypy.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
runs-on: ubuntu-latest
1717
strategy:
1818
matrix:
19-
python-version: ['3.9', '3.10', '3.11', '3.12', '3.13', '3.14']
19+
python-version: ['3.10', '3.11', '3.12', '3.13', '3.14']
2020

2121
steps:
2222
- name: Checkout code

vertexai/_genai/_agent_engines_utils.py

Lines changed: 60 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
from typing import (
3333
Any,
3434
AsyncIterator,
35-
Awaitable,
3635
Callable,
3736
Coroutine,
3837
Dict,
@@ -44,6 +43,7 @@
4443
Sequence,
4544
Set,
4645
TypedDict,
46+
TypeAlias,
4747
Union,
4848
)
4949

@@ -78,20 +78,30 @@
7878
_STDLIB_MODULE_NAMES: frozenset[str] = frozenset() # type: ignore[no-redef]
7979

8080

81-
try:
82-
from google.cloud import storage
81+
if typing.TYPE_CHECKING:
82+
from google.cloud import storage # type: ignore[attr-defined]
8383

84-
_StorageBucket: type[Any] = storage.Bucket
85-
except (ImportError, AttributeError):
86-
_StorageBucket: type[Any] = Any # type: ignore[no-redef]
84+
_StorageBucket: typing.TypeAlias = storage.Bucket
85+
else:
86+
try:
87+
from google.cloud import storage # type: ignore[attr-defined]
8788

89+
_StorageBucket: type[Any] = storage.Bucket
90+
except (ImportError, AttributeError):
91+
_StorageBucket: type[Any] = Any # type: ignore[no-redef]
8892

89-
try:
93+
94+
if typing.TYPE_CHECKING:
9095
import packaging
9196

92-
_SpecifierSet: type[Any] = packaging.specifiers.SpecifierSet
93-
except (ImportError, AttributeError):
94-
_SpecifierSet: type[Any] = Any # type: ignore[no-redef]
97+
_SpecifierSet = packaging.specifiers.SpecifierSet
98+
else:
99+
try:
100+
import packaging
101+
102+
_SpecifierSet: type[Any] = packaging.specifiers.SpecifierSet
103+
except (ImportError, AttributeError):
104+
_SpecifierSet: type[Any] = Any # type: ignore[no-redef]
95105

96106

97107
try:
@@ -258,16 +268,22 @@ class OperationRegistrable(Protocol):
258268
"""Protocol for agents that have registered operations."""
259269

260270
@abc.abstractmethod
261-
def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ignore[no-untyped-def]
271+
def register_operations(self, **kwargs: Any) -> dict[str, list[str]]:
262272
"""Register the user provided operations (modes and methods)."""
273+
pass
263274

264275

265-
try:
276+
if typing.TYPE_CHECKING:
266277
from google.adk.agents import BaseAgent
267278

268-
ADKAgent: type[Any] = BaseAgent
269-
except (ImportError, AttributeError):
270-
ADKAgent: type[Any] = Any # type: ignore[no-redef]
279+
ADKAgent: TypeAlias = BaseAgent
280+
else:
281+
try:
282+
from google.adk.agents import BaseAgent
283+
284+
ADKAgent: Optional[TypeAlias] = BaseAgent
285+
except (ImportError, AttributeError):
286+
ADKAgent: Optional[TypeAlias] = None # type: ignore[no-redef]
271287

272288
_AgentEngineInterface = Union[
273289
ADKAgent,
@@ -283,8 +299,9 @@ def register_operations(self, **kwargs) -> Dict[str, Sequence[str]]: # type: ig
283299
class _ModuleAgentAttributes(TypedDict, total=False):
284300
module_name: str
285301
agent_name: str
286-
register_operations: Dict[str, Sequence[str]]
302+
register_operations: Dict[str, list[str]]
287303
sys_paths: Optional[Sequence[str]]
304+
agent: _AgentEngineInterface
288305

289306

290307
class ModuleAgent(Cloneable, OperationRegistrable):
@@ -300,7 +317,7 @@ def __init__(
300317
*,
301318
module_name: str,
302319
agent_name: str,
303-
register_operations: Dict[str, Sequence[str]],
320+
register_operations: Dict[str, list[str]],
304321
sys_paths: Optional[Sequence[str]] = None,
305322
):
306323
"""Initializes a module-based agent.
@@ -310,7 +327,7 @@ def __init__(
310327
Required. The name of the module to import.
311328
agent_name (str):
312329
Required. The name of the agent in the module to instantiate.
313-
register_operations (Dict[str, Sequence[str]]):
330+
register_operations (Dict[str, list[str]]):
314331
Required. A dictionary of API modes to a list of method names.
315332
sys_paths (Sequence[str]):
316333
Optional. The system paths to search for the module. It should
@@ -336,8 +353,11 @@ def clone(self) -> "ModuleAgent":
336353
sys_paths=self._tmpl_attrs.get("sys_paths"),
337354
)
338355

339-
def register_operations(self) -> Dict[str, Sequence[str]]:
340-
self._tmpl_attrs.get("register_operations")
356+
def register_operations(self, **kwargs: Any) -> dict[str, list[str]]:
357+
reg_operations = self._tmpl_attrs.get("register_operations")
358+
if reg_operations is None:
359+
raise ValueError("Register operations is not set.")
360+
return reg_operations
341361

342362
def set_up(self) -> None:
343363
"""Sets up the agent for execution of queries at runtime.
@@ -411,7 +431,7 @@ def __call__(
411431
class GetAsyncOperationFunction(Protocol):
412432
async def __call__(
413433
self, *, operation_name: str, **kwargs: Any
414-
) -> Awaitable[AgentEngineOperationUnion]:
434+
) -> AgentEngineOperationUnion:
415435
pass
416436

417437

@@ -507,7 +527,7 @@ def _await_operation(
507527
def _compare_requirements(
508528
*,
509529
requirements: Mapping[str, str],
510-
constraints: Union[Sequence[str], Mapping[str, "_SpecifierSet"]],
530+
constraints: Union[Sequence[str], Mapping[str, Optional["_SpecifierSet"]]],
511531
required_packages: Optional[Iterator[str]] = None,
512532
) -> _RequirementsValidationResult:
513533
"""Compares the requirements with the constraints.
@@ -536,7 +556,7 @@ def _compare_requirements(
536556
"""
537557
packaging_version = _import_packaging_version_or_raise()
538558
if required_packages is None:
539-
required_packages = _DEFAULT_REQUIRED_PACKAGES
559+
required_packages = _DEFAULT_REQUIRED_PACKAGES # type: ignore[assignment]
540560
result = _RequirementsValidationResult(
541561
warnings=_RequirementsValidationWarnings(missing=set(), incompatible=set()),
542562
actions=_RequirementsValidationActions(append=set()),
@@ -583,7 +603,7 @@ def _generate_class_methods_spec_or_raise(
583603
if isinstance(agent, ModuleAgent):
584604
# We do a dry-run of setting up the agent engine to have the operations
585605
# needed for registration.
586-
agent: ModuleAgent = agent.clone()
606+
agent: ModuleAgent = agent.clone() # type: ignore[no-redef]
587607
try:
588608
agent.set_up()
589609
except Exception as e:
@@ -819,13 +839,13 @@ def _get_gcs_bucket(
819839
new_bucket = storage_client.bucket(staging_bucket)
820840
gcs_bucket = storage_client.create_bucket(new_bucket, location=location)
821841
logger.info(f"Creating bucket {staging_bucket} in {location=}")
822-
return gcs_bucket # type: ignore[no-any-return]
842+
return gcs_bucket
823843

824844

825845
def _get_registered_operations(
826846
*,
827847
agent: _AgentEngineInterface,
828-
) -> Dict[str, List[str]]:
848+
) -> dict[str, list[str]]:
829849
"""Retrieves registered operations for a AgentEngine."""
830850
if isinstance(agent, OperationRegistrable):
831851
return agent.register_operations()
@@ -859,13 +879,13 @@ def _import_cloudpickle_or_raise() -> types.ModuleType:
859879
def _import_cloud_storage_or_raise() -> types.ModuleType:
860880
"""Tries to import the Cloud Storage module."""
861881
try:
862-
from google.cloud import storage
882+
from google.cloud import storage # type: ignore[attr-defined]
863883
except ImportError as e:
864884
raise ImportError(
865885
"Cloud Storage is not installed. Please call "
866886
"'pip install google-cloud-aiplatform[agent_engines]'."
867887
) from e
868-
return storage
888+
return storage # type: ignore[no-any-return]
869889

870890

871891
def _import_packaging_requirements_or_raise() -> types.ModuleType:
@@ -1202,7 +1222,7 @@ def _upload_agent_engine(
12021222
) -> None:
12031223
"""Uploads the agent engine to GCS."""
12041224
cloudpickle = _import_cloudpickle_or_raise()
1205-
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") # type: ignore[attr-defined]
1225+
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}")
12061226
with blob.open("wb") as f:
12071227
try:
12081228
cloudpickle.dump(agent, f)
@@ -1216,7 +1236,7 @@ def _upload_agent_engine(
12161236
_ = cloudpickle.load(f)
12171237
except Exception as e:
12181238
raise TypeError("Agent engine serialized to an invalid format") from e
1219-
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
1239+
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
12201240
logger.info(f"Wrote to {dir_name}/{_BLOB_FILENAME}")
12211241

12221242

@@ -1227,9 +1247,9 @@ def _upload_requirements(
12271247
gcs_dir_name: str,
12281248
) -> None:
12291249
"""Uploads the requirements file to GCS."""
1230-
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_REQUIREMENTS_FILE}") # type: ignore[attr-defined]
1250+
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_REQUIREMENTS_FILE}")
12311251
blob.upload_from_string("\n".join(requirements))
1232-
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
1252+
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
12331253
logger.info(f"Writing to {dir_name}/{_REQUIREMENTS_FILE}")
12341254

12351255

@@ -1246,9 +1266,9 @@ def _upload_extra_packages(
12461266
for file in extra_packages:
12471267
tar.add(file)
12481268
tar_fileobj.seek(0)
1249-
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_EXTRA_PACKAGES_FILE}") # type: ignore[attr-defined]
1269+
blob = gcs_bucket.blob(f"{gcs_dir_name}/{_EXTRA_PACKAGES_FILE}")
12501270
blob.upload_from_string(tar_fileobj.read())
1251-
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}" # type: ignore[attr-defined]
1271+
dir_name = f"gs://{gcs_bucket.name}/{gcs_dir_name}"
12521272
logger.info(f"Writing to {dir_name}/{_EXTRA_PACKAGES_FILE}")
12531273

12541274

@@ -1369,7 +1389,7 @@ def _validate_requirements_or_warn(
13691389
*,
13701390
obj: Any,
13711391
requirements: List[str],
1372-
) -> Mapping[str, str]:
1392+
) -> List[str]:
13731393
"""Compiles the requirements into a list of requirements."""
13741394
requirements = requirements.copy()
13751395
try:
@@ -1380,16 +1400,14 @@ def _validate_requirements_or_warn(
13801400
requirements=current_requirements,
13811401
constraints=constraints,
13821402
)
1383-
for warning_type, warnings in missing_requirements.get(
1384-
_WARNINGS_KEY, {}
1385-
).items():
1403+
for warning_type, warnings in missing_requirements["warnings"].items():
13861404
if warnings:
13871405
logger.warning(
13881406
f"The following requirements are {warning_type}: {warnings}"
13891407
)
1390-
for action_type, actions in missing_requirements.get(_ACTIONS_KEY, {}).items():
1408+
for action_type, actions in missing_requirements["actions"].items():
13911409
if actions and action_type == _ACTION_APPEND:
1392-
for action in actions:
1410+
for action in actions: # type: ignore[attr-defined]
13931411
requirements.append(action)
13941412
logger.info(f"The following requirements are appended: {actions}")
13951413
except Exception as e:
@@ -1413,7 +1431,7 @@ def _validate_requirements_or_raise(
14131431
logger.info(f"Read the following lines: {requirements}")
14141432
except IOError as err:
14151433
raise IOError(f"Failed to read requirements from {requirements=}") from err
1416-
requirements = _validate_requirements_or_warn( # type: ignore[assignment]
1434+
requirements = _validate_requirements_or_warn(
14171435
obj=agent,
14181436
requirements=requirements,
14191437
)
@@ -1560,19 +1578,6 @@ def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
15601578
return _method
15611579

15621580

1563-
AgentEngineOperationUnion = Union[
1564-
genai_types.AgentEngineOperation,
1565-
genai_types.AgentEngineMemoryOperation,
1566-
genai_types.AgentEngineGenerateMemoriesOperation,
1567-
]
1568-
1569-
1570-
class GetOperationFunction(Protocol):
1571-
def __call__( # noqa: E704
1572-
self, *, operation_name: str, **kwargs: Any
1573-
) -> AgentEngineOperationUnion: ...
1574-
1575-
15761581
def _wrap_query_operation(*, method_name: str) -> Callable[..., Any]:
15771582
"""Wraps an Agent Engine method, creating a callable for `query` API.
15781583
@@ -1835,7 +1840,7 @@ async def _method(self, **kwargs) -> Any: # type: ignore[no-untyped-def]
18351840

18361841
return response
18371842

1838-
return _method
1843+
return _method # type: ignore[return-value]
18391844

18401845

18411846
def _yield_parsed_json(http_response: google_genai_types.HttpResponse) -> Iterator[Any]:

0 commit comments

Comments
 (0)