Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ async def _call_tool_in_thread_pool(
# For sync FunctionTool, call the underlying function directly
def run_sync_tool():
if isinstance(tool, FunctionTool):
args_to_call = tool._preprocess_args(args)
args_to_call, validation_errors = tool._preprocess_args(args)
if validation_errors:
return tool._build_validation_error_response(validation_errors)
signature = inspect.signature(tool.func)
valid_params = {param for param in signature.parameters}
if tool._context_param_name in valid_params:
Expand Down
155 changes: 143 additions & 12 deletions src/google/adk/tools/crewai_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,149 @@

from __future__ import annotations

import warnings
import inspect
from typing import Any
from typing import Callable

from google.adk.integrations.crewai import CrewaiTool
from google.adk.integrations.crewai import CrewaiToolConfig
from google.genai import types
from typing_extensions import override

warnings.warn(
"google.adk.tools.crewai_tool is moved to google.adk.integrations.crewai",
DeprecationWarning,
stacklevel=2,
)
from . import _automatic_function_calling_util
from .function_tool import FunctionTool
from .tool_configs import BaseToolConfig
from .tool_configs import ToolArgsConfig
from .tool_context import ToolContext

__all__ = [
"CrewaiTool",
"CrewaiToolConfig",
]
try:
from crewai.tools import BaseTool as CrewaiBaseTool
except ImportError as e:
raise ImportError(
"Crewai Tools require pip install 'google-adk[extensions]'."
) from e


class CrewaiTool(FunctionTool):
"""Use this class to wrap a CrewAI tool.

If the original tool name and description are not suitable, you can override
them in the constructor.
"""

tool: CrewaiBaseTool
"""The wrapped CrewAI tool."""

def __init__(self, tool: CrewaiBaseTool, *, name: str, description: str):
super().__init__(tool.run)
self.tool = tool
if name:
self.name = name
elif tool.name:
# Right now, CrewAI tool name contains white spaces. White spaces are
# not supported in our framework. So we replace them with "_".
self.name = tool.name.replace(' ', '_').lower()
if description:
self.description = description
elif tool.description:
self.description = tool.description

@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
"""Override run_async to handle CrewAI-specific parameter filtering.

CrewAI tools use **kwargs pattern, so we need special parameter filtering
logic that allows all parameters to pass through while removing only
reserved parameters like 'self' and 'tool_context'.

Note: 'tool_context' is removed from the initial args dictionary to prevent
duplicates, but is re-added if the function signature explicitly requires it
as a parameter.
"""
# Preprocess arguments (includes Pydantic model conversion and type
# validation)
args_to_call, validation_errors = self._preprocess_args(args)

if validation_errors:
return self._build_validation_error_response(validation_errors)

signature = inspect.signature(self.func)
valid_params = {param for param in signature.parameters}

# Check if function accepts **kwargs
has_kwargs = any(
param.kind == inspect.Parameter.VAR_KEYWORD
for param in signature.parameters.values()
)

if has_kwargs:
# For functions with **kwargs, we pass all arguments. We defensively
# remove arguments like `self` that are managed by the framework and not
# intended to be passed through **kwargs.
args_to_call.pop('self', None)
# We also remove context param that might have been passed in `args`,
# as it will be explicitly injected later if it's a valid parameter.
args_to_call.pop(self._context_param_name, None)
else:
# For functions without **kwargs, use the original filtering.
args_to_call = {
k: v for k, v in args_to_call.items() if k in valid_params
}

# Inject context if it's an explicit parameter. This will add it
# or overwrite any value that might have been passed in `args`.
if self._context_param_name in valid_params:
args_to_call[self._context_param_name] = tool_context

# Check for missing mandatory arguments
mandatory_args = self._get_mandatory_args()
missing_mandatory_args = [
arg for arg in mandatory_args if arg not in args_to_call
]

if missing_mandatory_args:
missing_mandatory_args_str = '\n'.join(missing_mandatory_args)
error_str = f"""Invoking `{self.name}()` failed as the following mandatory input parameters are not present:
{missing_mandatory_args_str}
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
return {'error': error_str}

return await self._invoke_callable(self.func, args_to_call)

@override
def _get_declaration(self) -> types.FunctionDeclaration:
"""Build the function declaration for the tool."""
function_declaration = _automatic_function_calling_util.build_function_declaration_for_params_for_crewai(
False,
self.name,
self.description,
self.func,
self.tool.args_schema.model_json_schema(),
)
return function_declaration

@override
@classmethod
def from_config(
cls: type[CrewaiTool], config: ToolArgsConfig, config_abs_path: str
) -> CrewaiTool:
from ..agents import config_agent_utils

crewai_tool_config = CrewaiToolConfig.model_validate(config.model_dump())
tool = config_agent_utils.resolve_fully_qualified_name(
crewai_tool_config.tool
)
name = crewai_tool_config.name
description = crewai_tool_config.description
return cls(tool, name=name, description=description)


class CrewaiToolConfig(BaseToolConfig):
tool: str
"""The fully qualified path of the CrewAI tool instance."""

name: str = ''
"""The name of the tool."""

description: str = ''
"""The description of the tool."""
118 changes: 71 additions & 47 deletions src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import logging
from typing import Any
from typing import Callable
from typing import get_args
from typing import get_origin
from typing import Optional
from typing import Union

Expand Down Expand Up @@ -85,6 +83,7 @@ def __init__(
self._context_param_name = find_context_parameter(func) or 'tool_context'
self._ignore_params = [self._context_param_name, 'input_stream']
self._require_confirmation = require_confirmation
self._type_adapter_cache: dict[Any, pydantic.TypeAdapter] = {}

@override
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
Expand All @@ -100,68 +99,93 @@ def _get_declaration(self) -> Optional[types.FunctionDeclaration]:

return function_decl

def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
"""Preprocess and convert function arguments before invocation.
def _preprocess_args(
self, args: dict[str, Any]
) -> tuple[dict[str, Any], list[str]]:
"""Preprocess, validate, and convert function arguments before invocation.

Currently handles:
Handles:
- Converting JSON dictionaries to Pydantic model instances where expected

Future extensions could include:
- Type coercion for other complex types
- Validation and sanitization
- Custom conversion logic
- Validating and coercing primitive types (int, float, str, bool)
- Validating enum values
- Validating container types (list[int], dict[str, float], etc.)

Args:
args: Raw arguments from the LLM tool call

Returns:
Processed arguments ready for function invocation
A tuple of (processed_args, validation_errors). If validation_errors is
non-empty, the caller should return the errors to the LLM instead of
invoking the function.
"""
signature = inspect.signature(self.func)
converted_args = args.copy()
validation_errors = []

for param_name, param in signature.parameters.items():
if param_name in args and param.annotation != inspect.Parameter.empty:
target_type = param.annotation

# Handle Optional[PydanticModel] types
if get_origin(param.annotation) is Union:
union_args = get_args(param.annotation)
# Find the non-None type in Optional[T] (which is Union[T, None])
non_none_types = [arg for arg in union_args if arg is not type(None)]
if len(non_none_types) == 1:
target_type = non_none_types[0]

# Check if the target type is a Pydantic model
if inspect.isclass(target_type) and issubclass(
target_type, pydantic.BaseModel
):
# Skip conversion if the value is None and the parameter is Optional
if args[param_name] is None:
continue

# Convert to Pydantic model if it's not already the correct type
if not isinstance(args[param_name], target_type):
try:
converted_args[param_name] = target_type.model_validate(
args[param_name]
)
except Exception as e:
logger.warning(
f"Failed to convert argument '{param_name}' to Pydantic model"
f' {target_type.__name__}: {e}'
)
# Keep the original value if conversion fails
pass

return converted_args
if (
param_name not in args
or param.annotation is inspect.Parameter.empty
or param_name in self._ignore_params
):
continue

target_type = param.annotation

# Validate and coerce using TypeAdapter. Handles primitives, enums,
# Pydantic models, Optional[T], T | None, and container types natively.
try:
try:
adapter = self._type_adapter_cache[target_type]
except TypeError:
adapter = pydantic.TypeAdapter(target_type)
except KeyError:
adapter = pydantic.TypeAdapter(target_type)
self._type_adapter_cache[target_type] = adapter
converted_args[param_name] = adapter.validate_python(args[param_name])
except pydantic.ValidationError as e:
validation_errors.append(
f"Parameter '{param_name}': expected type"
f" '{getattr(target_type, '__name__', target_type)}', validation"
f' error: {e}'
)
except (TypeError, NameError) as e:
# TypeAdapter could not handle this annotation (e.g. a forward
# reference string). Skip validation but log a warning.
logger.warning(
"Skipping validation for parameter '%s' due to unhandled"
" annotation type '%s': %s",
param_name,
target_type,
e,
)

return converted_args, validation_errors

def _build_validation_error_response(
self, validation_errors: list[str]
) -> dict[str, str]:
"""Formats validation errors into an error dict for the LLM."""
validation_errors_str = '\n'.join(validation_errors)
return {
'error': (
f'Invoking `{self.name}()` failed due to argument validation'
f' errors:\n{validation_errors_str}\nYou could retry calling'
' this tool with corrected argument types.'
)
}

@override
async def run_async(
self, *, args: dict[str, Any], tool_context: ToolContext
) -> Any:
# Preprocess arguments (includes Pydantic model conversion)
args_to_call = self._preprocess_args(args)
# Preprocess arguments (includes Pydantic model conversion and type
# validation). Validation errors are returned to the LLM so it can
# self-correct and retry with proper argument types.
args_to_call, validation_errors = self._preprocess_args(args)

if validation_errors:
return self._build_validation_error_response(validation_errors)

signature = inspect.signature(self.func)
valid_params = {param for param in signature.parameters}
Expand Down
Loading