From f779ad631493398643695baad781d397f35f6dce Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Fri, 16 Jan 2026 22:57:23 -0500 Subject: [PATCH] Enhance Gradio integration to process and display images from tool output --- src/utils/gradio/messages.py | 73 +++++++++++++++++++++++++++-- src/utils/tools/code_interpreter.py | 9 +++- 2 files changed, 77 insertions(+), 5 deletions(-) diff --git a/src/utils/gradio/messages.py b/src/utils/gradio/messages.py index 932929f..096217f 100644 --- a/src/utils/gradio/messages.py +++ b/src/utils/gradio/messages.py @@ -1,13 +1,18 @@ """Tools for integrating with the Gradio chatbot UI.""" +import base64 +import json +from io import BytesIO from typing import TYPE_CHECKING +import gradio as gr from agents import StreamEvent, stream_events from agents.items import MessageOutputItem, RunItem, ToolCallItem, ToolCallOutputItem from gradio.components.chatbot import ChatMessage, MetadataDict from openai.types.responses import ResponseFunctionToolCall, ResponseOutputText from openai.types.responses.response_completed_event import ResponseCompletedEvent from openai.types.responses.response_output_message import ResponseOutputMessage +from PIL import Image if TYPE_CHECKING: @@ -31,6 +36,32 @@ def gradio_messages_to_oai_chat( return output +def _process_tool_output_for_images(output_str: str) -> tuple[str, list[Image.Image]]: + """Extract images from tool output JSON if present. + + Returns tuple of (text_content, list_of_images). + """ + images = [] + try: + # Try to parse as JSON to extract images + output_data = json.loads(output_str) + + # Check if results contain PNG data + if isinstance(output_data, dict) and "results" in output_data: + for result in output_data["results"]: + if isinstance(result, dict) and "png" in result: + # Decode base64 PNG + png_data = result["png"] + img_bytes = base64.b64decode(png_data) + img = Image.open(BytesIO(img_bytes)) + images.append(img) + + return output_str, images + except (json.JSONDecodeError, Exception): + # If not JSON or error parsing, return as-is + return output_str, images + + def _oai_response_output_item_to_gradio( item: RunItem, is_final_output: bool ) -> list[ChatMessage] | None: @@ -57,10 +88,11 @@ def _oai_response_output_item_to_gradio( call_id = item.raw_item.get("call_id", None) if isinstance(function_output, str): - return [ + text_content, images = _process_tool_output_for_images(function_output) + messages = [ ChatMessage( role="assistant", - content=f"> {function_output}\n\n`{call_id}`", + content=f"> {text_content}\n\n`{call_id}`", metadata={ "title": "*Tool call output*", "status": "done", # This makes it collapsed by default @@ -68,6 +100,22 @@ def _oai_response_output_item_to_gradio( ) ] + # Add images as separate messages + for img in images: + messages.append( + ChatMessage( + role="assistant", + content=gr.Image( + img, + format="png", + container=False, + interactive=True, + buttons=["download"], + ), + ) + ) + return messages + if isinstance(item, MessageOutputItem): message_content = item.raw_item @@ -121,6 +169,7 @@ def oai_agent_stream_to_gradio_messages( if isinstance(stream_event, stream_events.RawResponsesStreamEvent): data = stream_event.data if isinstance(data, ResponseCompletedEvent): + print(stream_event) # The completed event may contain multiple output messages, # including tool calls and final outputs. # If there is at least one tool call, we mark the response as a thought. @@ -161,10 +210,13 @@ def oai_agent_stream_to_gradio_messages( item = stream_event.item if name == "tool_output" and isinstance(item, ToolCallOutputItem): + print(stream_event) + text_content, images = _process_tool_output_for_images(item.output) + output.append( ChatMessage( role="assistant", - content=f"```\n{item.output}\n```", + content=f"```\n{text_content}\n```", metadata={ "title": "*Tool call output*", "status": "done", # This makes it collapsed by default @@ -172,4 +224,19 @@ def oai_agent_stream_to_gradio_messages( ) ) + # Add images as separate messages + for img in images: + output.append( + ChatMessage( + role="assistant", + content=gr.Image( + img, + format="png", + container=False, + interactive=True, + buttons=["download"], + ), + ) + ) + return output diff --git a/src/utils/tools/code_interpreter.py b/src/utils/tools/code_interpreter.py index 3e05080..6057e54 100644 --- a/src/utils/tools/code_interpreter.py +++ b/src/utils/tools/code_interpreter.py @@ -5,6 +5,7 @@ from typing import Sequence from e2b_code_interpreter import AsyncSandbox +from e2b_code_interpreter.models import serialize_results from pydantic import BaseModel from ..async_utils import gather_with_progress @@ -23,9 +24,10 @@ class CodeInterpreterOutput(BaseModel): stdout: list[str] stderr: list[str] + results: list[dict[str, str]] | None = None error: _CodeInterpreterOutputError | None = None - def __init__(self, stdout: list[str], stderr: list[str], **kwargs): + def __init__(self, stdout: list[str], stderr: list[str], **kwargs) -> None: """Split lines in stdout and stderr.""" stdout_processed = [] for _line in stdout: @@ -109,7 +111,7 @@ def __init__( local_files: "Sequence[Path | str]| None" = None, timeout_seconds: int = 30, template_name: str | None = None, - ): + ) -> None: """Configure your Code Interpreter session. Note that the sandbox is not persistent, and each run_code will @@ -161,6 +163,9 @@ async def run_code(self, code: str) -> str: error.to_json() ) + if result.results: + response.results = serialize_results(result.results) + return response.model_dump_json() finally: await sbx.kill()