From e846d4643eddb8dc8508dbf161c02c5d2e86ba35 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 6 Jan 2026 15:35:20 +0100 Subject: [PATCH 1/6] Add extension --- python/restate/ext/adk/plugin.py | 74 +++++++++++++++++++------------- python/restate/server_context.py | 51 +++++++++++++++++++++- 2 files changed, 94 insertions(+), 31 deletions(-) diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index 65bdbc2..2ac1d83 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -13,6 +13,7 @@ """ import asyncio +from copyreg import clear_extension_cache import restate from datetime import timedelta @@ -36,6 +37,7 @@ from restate.extensions import current_context from restate.ext.turnstile import Turnstile +from restate.server_context import clear_extension_data, get_extension_data, set_extension_data def _create_turnstile(s: LlmResponse) -> Turnstile: @@ -43,17 +45,23 @@ def _create_turnstile(s: LlmResponse) -> Turnstile: turnstile = Turnstile(ids) return turnstile +class PluginState: + def __init__(self, model: BaseLlm): + self.model = model + self.turnstiles: Turnstile = Turnstile([]) + + def __close__(self): + """Clean up resources.""" + self.turnstiles.cancel_all() + + class RestatePlugin(BasePlugin): """A plugin to integrate Restate with the ADK framework.""" - _models: dict[str, BaseLlm] - _turnstiles: dict[str, Turnstile | None] def __init__(self, *, max_model_call_retries: int = 10): super().__init__(name="restate_plugin") - self._models = {} - self._turnstiles = {} self._max_model_call_retries = max_model_call_retries async def before_agent_callback( @@ -68,31 +76,24 @@ async def before_agent_callback( Ensure that the agent is invoked within a restate handler and, using a ```with restate_overrides(ctx):``` block. around your agent use.""" ) + model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model) - self._models[callback_context.invocation_id] = model - self._turnstiles[callback_context.invocation_id] = None - - id = callback_context.invocation_id - event = ctx.request().attempt_finished_event - - async def release_task(): - """make sure to release resources when the agent finishes""" - try: - await event.wait() - finally: - self._models.pop(id, None) - maybe_turnstile = self._turnstiles.pop(id, None) - if maybe_turnstile is not None: - maybe_turnstile.cancel_all() - - _ = asyncio.create_task(release_task()) + set_extension_data(ctx, "adk_" + callback_context.invocation_id, PluginState(model)) return None async def after_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> Optional[types.Content]: - self._models.pop(callback_context.invocation_id, None) - self._turnstiles.pop(callback_context.invocation_id, None) + ctx = current_context() + if ctx is None: + raise RuntimeError( + "No Restate context found, the restate plugin must be used from within a restate handler." + ) + state = get_extension_data(ctx, "adk_" + callback_context.invocation_id) + if state is not None: + state.__close__() + clear_extension_data(ctx, "adk_" + callback_context.invocation_id) + return None async def after_run_callback(self, *, invocation_context: InvocationContext) -> None: @@ -103,15 +104,22 @@ async def after_run_callback(self, *, invocation_context: InvocationContext) -> async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest ) -> Optional[LlmResponse]: - model = self._models[callback_context.invocation_id] ctx = current_context() if ctx is None: raise RuntimeError( "No Restate context found, the restate plugin must be used from within a restate handler." ) + state = get_extension_data( + ctx, "adk_" + callback_context.invocation_id + ) + if state is None: + raise RuntimeError( + "No RestatePlugin state found, the restate plugin must be used from within a restate handler." + ) + model = state.model response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request) turnstile = _create_turnstile(response) - self._turnstiles[callback_context.invocation_id] = turnstile + state.turnstiles = turnstile return response async def before_tool_callback( @@ -121,17 +129,23 @@ async def before_tool_callback( tool_args: dict[str, Any], tool_context: ToolContext, ) -> Optional[dict]: - turnstile = self._turnstiles[tool_context.invocation_id] + ctx = current_context() + if ctx is None: + raise RuntimeError( + "No Restate context found, the restate plugin must be used from within a restate handler." + ) + state: PluginState | None = get_extension_data(ctx, "adk_" + tool_context.invocation_id) + if state is None: + raise RuntimeError( + "No RestatePlugin state found, the restate plugin must be used from within a restate handler." + ) + turnstile = state.turnstiles assert turnstile is not None, "Turnstile not found for tool invocation." id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." await turnstile.wait_for(id) - - ctx = current_context() - tool_context.session.state["restate_context"] = ctx - return None async def after_tool_callback( diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 31ae4be..1ddb9f3 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -17,7 +17,7 @@ """This module contains the restate context implementation based on the server""" import asyncio -from contextlib import AsyncExitStack +from contextlib import AsyncExitStack, asynccontextmanager import contextvars import copy from random import Random @@ -314,6 +314,44 @@ def current_context() -> Context | None: """Get the current context.""" return _restate_context_var.get() +def set_extension_data(ctx: Context, key: str, value: T) -> None: + """Set extension data in the current context.""" + if not isinstance(ctx, ServerInvocationContext): + raise RuntimeError("Current context is not a ServerInvocationContext") + ctx.extension_data[key] = value + +def get_extension_data(ctx: Context, key: str) -> Any: + """Get extension data from the current context.""" + if not isinstance(ctx, ServerInvocationContext): + raise RuntimeError("Current context is not a ServerInvocationContext") + return ctx.extension_data.get(key, None) + +def clear_extension_data(ctx: Context, key: str) -> None: + """Clear extension data from the current context.""" + if not isinstance(ctx, ServerInvocationContext): + raise RuntimeError("Current context is not a ServerInvocationContext") + if key in ctx.extension_data: + del ctx.extension_data[key] + +@asynccontextmanager +def auto_close_extension_data(data: Dict[str, Any]): + """Context manager to auto close extension data.""" + try: + yield + finally: + for value in data.values(): + if hasattr(value, "__close__") and callable(getattr(value, "__close__")): + try: + close_method = getattr(value, "close") + if inspect.iscoroutinefunction(close_method): + await close_method() + else: + close_method() + except Exception as e: + # extension data close failure should not block further processing + # TODO: add logging here + pass + data.clear() # pylint: disable=R0902 class ServerInvocationContext(ObjectContext): @@ -339,6 +377,7 @@ def __init__( self.run_coros_to_execute: dict[int, Callable[[], Awaitable[None]]] = {} self.request_finished_event = asyncio.Event() self.tasks = Tasks() + self.extension_data: Dict[str, Any] = {} async def enter(self): """Invoke the user code.""" @@ -349,6 +388,8 @@ async def enter(self): async with AsyncExitStack() as stack: for manager in self.handler.context_managers or []: await stack.enter_async_context(manager()) + stack.enter_async_context(auto_close_extension_data(self.extension_data)) + out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer) restate_context_is_replaying.set(False) self.vm.sys_write_output_success(bytes(out_buffer)) @@ -971,3 +1012,11 @@ def attach_invocation( handle = self.vm.attach_invocation(invocation_id) update_restate_context_is_replaying(self.vm) return self.create_future(handle, serde) + + def get_extension_data(self, key: str) -> Any: + """Get extension data by key.""" + return self.extension_data.get(key) + + def set_extension_data(self, key: str, value: Any) -> None: + """Set extension data by key.""" + self.extension_data[key] = value \ No newline at end of file From 78bf8c13cbcf0e9f290e507f569b0c4ce05736da Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 6 Jan 2026 16:23:34 +0100 Subject: [PATCH 2/6] [extensions] Avoid global state in the ADK plugin --- python/restate/ext/adk/plugin.py | 55 +++++++++++++------------------- python/restate/server_context.py | 27 +++++++++------- 2 files changed, 38 insertions(+), 44 deletions(-) diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index 2ac1d83..b2a74b3 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -12,8 +12,6 @@ ADK plugin implementation for restate. """ -import asyncio -from copyreg import clear_extension_cache import restate from datetime import timedelta @@ -45,21 +43,34 @@ def _create_turnstile(s: LlmResponse) -> Turnstile: turnstile = Turnstile(ids) return turnstile + +def _turnstile_from_context(invocation_id: str) -> Turnstile: + ctx = current_context() + if ctx is None: + raise RuntimeError("No Restate context found, the restate plugin must be used from within a restate handler.") + state = get_extension_data(ctx, "adk_" + invocation_id) + if state is None: + raise RuntimeError( + "No RestatePlugin state found, the restate plugin must be used from within a restate handler." + ) + turnstile = state.turnstiles + assert turnstile is not None, "Turnstile not found for invocation." + return turnstile + + class PluginState: def __init__(self, model: BaseLlm): self.model = model - self.turnstiles: Turnstile = Turnstile([]) + self.turnstiles: Turnstile = Turnstile([]) def __close__(self): """Clean up resources.""" self.turnstiles.cancel_all() - class RestatePlugin(BasePlugin): """A plugin to integrate Restate with the ADK framework.""" - def __init__(self, *, max_model_call_retries: int = 10): super().__init__(name="restate_plugin") self._max_model_call_retries = max_model_call_retries @@ -76,7 +87,7 @@ async def before_agent_callback( Ensure that the agent is invoked within a restate handler and, using a ```with restate_overrides(ctx):``` block. around your agent use.""" ) - + model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model) set_extension_data(ctx, "adk_" + callback_context.invocation_id, PluginState(model)) return None @@ -109,9 +120,7 @@ async def before_model_callback( raise RuntimeError( "No Restate context found, the restate plugin must be used from within a restate handler." ) - state = get_extension_data( - ctx, "adk_" + callback_context.invocation_id - ) + state = get_extension_data(ctx, "adk_" + callback_context.invocation_id) if state is None: raise RuntimeError( "No RestatePlugin state found, the restate plugin must be used from within a restate handler." @@ -129,22 +138,9 @@ async def before_tool_callback( tool_args: dict[str, Any], tool_context: ToolContext, ) -> Optional[dict]: - ctx = current_context() - if ctx is None: - raise RuntimeError( - "No Restate context found, the restate plugin must be used from within a restate handler." - ) - state: PluginState | None = get_extension_data(ctx, "adk_" + tool_context.invocation_id) - if state is None: - raise RuntimeError( - "No RestatePlugin state found, the restate plugin must be used from within a restate handler." - ) - turnstile = state.turnstiles - assert turnstile is not None, "Turnstile not found for tool invocation." - + turnstile = _turnstile_from_context(tool_context.invocation_id) id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." - await turnstile.wait_for(id) return None @@ -156,12 +152,11 @@ async def after_tool_callback( tool_context: ToolContext, result: dict, ) -> Optional[dict]: - tool_context.session.state.pop("restate_context", None) - turnstile = self._turnstiles[tool_context.invocation_id] - assert turnstile is not None, "Turnstile not found for tool invocation." + turnstile = _turnstile_from_context(tool_context.invocation_id) id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." turnstile.allow_next_after(id) + return None async def on_tool_error_callback( @@ -172,18 +167,12 @@ async def on_tool_error_callback( tool_context: ToolContext, error: Exception, ) -> Optional[dict]: - tool_context.session.state.pop("restate_context", None) - turnstile = self._turnstiles[tool_context.invocation_id] - assert turnstile is not None, "Turnstile not found for tool invocation." + turnstile = _turnstile_from_context(tool_context.invocation_id) id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." turnstile.cancel_all_after(id) return None - async def close(self): - self._models.clear() - self._turnstiles.clear() - def _get_function_call_ids(s: LlmResponse) -> list[str]: ids = [] diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 1ddb9f3..7ed4f5b 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -314,18 +314,21 @@ def current_context() -> Context | None: """Get the current context.""" return _restate_context_var.get() -def set_extension_data(ctx: Context, key: str, value: T) -> None: + +def set_extension_data(ctx: Context, key: str, value: Any) -> None: """Set extension data in the current context.""" if not isinstance(ctx, ServerInvocationContext): raise RuntimeError("Current context is not a ServerInvocationContext") ctx.extension_data[key] = value - + + def get_extension_data(ctx: Context, key: str) -> Any: """Get extension data from the current context.""" if not isinstance(ctx, ServerInvocationContext): raise RuntimeError("Current context is not a ServerInvocationContext") return ctx.extension_data.get(key, None) + def clear_extension_data(ctx: Context, key: str) -> None: """Clear extension data from the current context.""" if not isinstance(ctx, ServerInvocationContext): @@ -333,8 +336,9 @@ def clear_extension_data(ctx: Context, key: str) -> None: if key in ctx.extension_data: del ctx.extension_data[key] + @asynccontextmanager -def auto_close_extension_data(data: Dict[str, Any]): +async def auto_close_extension_data(data: Dict[str, Any]): """Context manager to auto close extension data.""" try: yield @@ -342,17 +346,18 @@ def auto_close_extension_data(data: Dict[str, Any]): for value in data.values(): if hasattr(value, "__close__") and callable(getattr(value, "__close__")): try: - close_method = getattr(value, "close") + close_method = getattr(value, "__close__") if inspect.iscoroutinefunction(close_method): await close_method() else: close_method() - except Exception as e: + except Exception: # extension data close failure should not block further processing # TODO: add logging here - pass + pass data.clear() + # pylint: disable=R0902 class ServerInvocationContext(ObjectContext): """This class implements the context for the restate framework based on the server.""" @@ -388,7 +393,7 @@ async def enter(self): async with AsyncExitStack() as stack: for manager in self.handler.context_managers or []: await stack.enter_async_context(manager()) - stack.enter_async_context(auto_close_extension_data(self.extension_data)) + await stack.enter_async_context(auto_close_extension_data(self.extension_data)) out_buffer = await invoke_handler(handler=self.handler, ctx=self, in_buffer=in_buffer) restate_context_is_replaying.set(False) @@ -1012,11 +1017,11 @@ def attach_invocation( handle = self.vm.attach_invocation(invocation_id) update_restate_context_is_replaying(self.vm) return self.create_future(handle, serde) - + def get_extension_data(self, key: str) -> Any: """Get extension data by key.""" return self.extension_data.get(key) - + def set_extension_data(self, key: str, value: Any) -> None: - """Set extension data by key.""" - self.extension_data[key] = value \ No newline at end of file + """Set extension data by key.""" + self.extension_data[key] = value From 8a29a3ea9ed7e2621837d7afcd8cf36416293ebb Mon Sep 17 00:00:00 2001 From: Igal Shilman Date: Tue, 6 Jan 2026 16:31:39 +0100 Subject: [PATCH 3/6] Update python/restate/ext/adk/plugin.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- python/restate/ext/adk/plugin.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index b2a74b3..2add606 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -54,7 +54,6 @@ def _turnstile_from_context(invocation_id: str) -> Turnstile: "No RestatePlugin state found, the restate plugin must be used from within a restate handler." ) turnstile = state.turnstiles - assert turnstile is not None, "Turnstile not found for invocation." return turnstile From 338410aa3711e4fb66d1deb9073c107b2a916e98 Mon Sep 17 00:00:00 2001 From: Igal Shilman Date: Tue, 6 Jan 2026 16:32:04 +0100 Subject: [PATCH 4/6] Update python/restate/ext/adk/plugin.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- python/restate/ext/adk/plugin.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index 2add606..b420f53 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -44,11 +44,15 @@ def _create_turnstile(s: LlmResponse) -> Turnstile: return turnstile +def _invocation_extension_key(invocation_id: str) -> str: + return "adk_" + invocation_id + + def _turnstile_from_context(invocation_id: str) -> Turnstile: ctx = current_context() if ctx is None: raise RuntimeError("No Restate context found, the restate plugin must be used from within a restate handler.") - state = get_extension_data(ctx, "adk_" + invocation_id) + state = get_extension_data(ctx, _invocation_extension_key(invocation_id)) if state is None: raise RuntimeError( "No RestatePlugin state found, the restate plugin must be used from within a restate handler." From 021fd74174d254f494e1732d6e6c9d7aeaa2ed44 Mon Sep 17 00:00:00 2001 From: Igal Shilman Date: Tue, 6 Jan 2026 16:32:30 +0100 Subject: [PATCH 5/6] Update python/restate/server_context.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- python/restate/server_context.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/restate/server_context.py b/python/restate/server_context.py index 7ed4f5b..d10c032 100644 --- a/python/restate/server_context.py +++ b/python/restate/server_context.py @@ -344,9 +344,9 @@ async def auto_close_extension_data(data: Dict[str, Any]): yield finally: for value in data.values(): - if hasattr(value, "__close__") and callable(getattr(value, "__close__")): + close_method = getattr(value, "__close__", None) + if callable(close_method): try: - close_method = getattr(value, "__close__") if inspect.iscoroutinefunction(close_method): await close_method() else: From ed202fe7131ef152183e9000781061466522e715 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Tue, 6 Jan 2026 16:53:15 +0100 Subject: [PATCH 6/6] Extract plugin state management --- python/restate/ext/adk/plugin.py | 85 ++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 38 deletions(-) diff --git a/python/restate/ext/adk/plugin.py b/python/restate/ext/adk/plugin.py index b420f53..8fda68d 100644 --- a/python/restate/ext/adk/plugin.py +++ b/python/restate/ext/adk/plugin.py @@ -48,27 +48,46 @@ def _invocation_extension_key(invocation_id: str) -> str: return "adk_" + invocation_id -def _turnstile_from_context(invocation_id: str) -> Turnstile: - ctx = current_context() - if ctx is None: - raise RuntimeError("No Restate context found, the restate plugin must be used from within a restate handler.") - state = get_extension_data(ctx, _invocation_extension_key(invocation_id)) - if state is None: - raise RuntimeError( - "No RestatePlugin state found, the restate plugin must be used from within a restate handler." - ) - turnstile = state.turnstiles - return turnstile - - class PluginState: def __init__(self, model: BaseLlm): self.model = model - self.turnstiles: Turnstile = Turnstile([]) + self.turnstile: Turnstile = Turnstile([]) def __close__(self): """Clean up resources.""" - self.turnstiles.cancel_all() + self.turnstile.cancel_all() + + @staticmethod + def create(invocation_id: str, model: BaseLlm, ctx: restate.Context) -> None: + extension_key = _invocation_extension_key(invocation_id) + state = PluginState(model) + set_extension_data(ctx, extension_key, state) + + @staticmethod + def clear(invocation_id: str) -> None: + ctx = current_context() + if ctx is None: + raise RuntimeError( + "No Restate context found, the restate plugin must be used from within a restate handler." + ) + extension_key = _invocation_extension_key(invocation_id) + clear_extension_data(ctx, extension_key) + + @staticmethod + def from_context(invocation_id: str, ctx: restate.Context | None = None) -> "PluginState": + if ctx is None: + ctx = current_context() + if ctx is None: + raise RuntimeError( + "No Restate context found, the restate plugin must be used from within a restate handler." + ) + extension_key = _invocation_extension_key(invocation_id) + state: PluginState | None = get_extension_data(ctx, extension_key) + if state is None: + raise RuntimeError( + "No RestatePlugin state found, the restate plugin must be used from within a restate handler." + ) + return state class RestatePlugin(BasePlugin): @@ -92,22 +111,13 @@ async def before_agent_callback( ) model = agent.model if isinstance(agent.model, BaseLlm) else LLMRegistry.new_llm(agent.model) - set_extension_data(ctx, "adk_" + callback_context.invocation_id, PluginState(model)) + PluginState.create(callback_context.invocation_id, model, ctx) return None async def after_agent_callback( self, *, agent: BaseAgent, callback_context: CallbackContext ) -> Optional[types.Content]: - ctx = current_context() - if ctx is None: - raise RuntimeError( - "No Restate context found, the restate plugin must be used from within a restate handler." - ) - state = get_extension_data(ctx, "adk_" + callback_context.invocation_id) - if state is not None: - state.__close__() - clear_extension_data(ctx, "adk_" + callback_context.invocation_id) - + PluginState.clear(callback_context.invocation_id) return None async def after_run_callback(self, *, invocation_context: InvocationContext) -> None: @@ -118,20 +128,18 @@ async def after_run_callback(self, *, invocation_context: InvocationContext) -> async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest ) -> Optional[LlmResponse]: - ctx = current_context() + ctx = current_context() # Ensure we have a Restate context if ctx is None: - raise RuntimeError( - "No Restate context found, the restate plugin must be used from within a restate handler." - ) - state = get_extension_data(ctx, "adk_" + callback_context.invocation_id) - if state is None: - raise RuntimeError( - "No RestatePlugin state found, the restate plugin must be used from within a restate handler." + raise restate.TerminalError( + """No Restate context found for RestatePlugin. + Ensure that the agent is invoked within a restate handler and, + using a ```with restate_overrides(ctx):``` block. around your agent use.""" ) + state = PluginState.from_context(callback_context.invocation_id, ctx) model = state.model response = await _generate_content_async(ctx, self._max_model_call_retries, model, llm_request) turnstile = _create_turnstile(response) - state.turnstiles = turnstile + state.turnstile = turnstile return response async def before_tool_callback( @@ -141,9 +149,10 @@ async def before_tool_callback( tool_args: dict[str, Any], tool_context: ToolContext, ) -> Optional[dict]: - turnstile = _turnstile_from_context(tool_context.invocation_id) id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." + + turnstile = PluginState.from_context(tool_context.invocation_id).turnstile await turnstile.wait_for(id) return None @@ -155,9 +164,9 @@ async def after_tool_callback( tool_context: ToolContext, result: dict, ) -> Optional[dict]: - turnstile = _turnstile_from_context(tool_context.invocation_id) id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." + turnstile = PluginState.from_context(tool_context.invocation_id).turnstile turnstile.allow_next_after(id) return None @@ -170,9 +179,9 @@ async def on_tool_error_callback( tool_context: ToolContext, error: Exception, ) -> Optional[dict]: - turnstile = _turnstile_from_context(tool_context.invocation_id) id = tool_context.function_call_id assert id is not None, "Function call ID is required for tool invocation." + turnstile = PluginState.from_context(tool_context.invocation_id).turnstile turnstile.cancel_all_after(id) return None