Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions src/utils/gradio/messages.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
"""Tools for integrating with the Gradio chatbot UI."""

import base64
import json
from io import BytesIO
from typing import TYPE_CHECKING

import gradio as gr
from agents import StreamEvent, stream_events
from agents.items import MessageOutputItem, RunItem, ToolCallItem, ToolCallOutputItem
from gradio.components.chatbot import ChatMessage, MetadataDict
from openai.types.responses import ResponseFunctionToolCall, ResponseOutputText
from openai.types.responses.response_completed_event import ResponseCompletedEvent
from openai.types.responses.response_output_message import ResponseOutputMessage
from PIL import Image


if TYPE_CHECKING:
Expand All @@ -31,6 +36,32 @@ def gradio_messages_to_oai_chat(
return output


def _process_tool_output_for_images(output_str: str) -> tuple[str, list[Image.Image]]:
"""Extract images from tool output JSON if present.

Returns tuple of (text_content, list_of_images).
"""
images = []
try:
# Try to parse as JSON to extract images
output_data = json.loads(output_str)

# Check if results contain PNG data
if isinstance(output_data, dict) and "results" in output_data:
for result in output_data["results"]:
if isinstance(result, dict) and "png" in result:
# Decode base64 PNG
png_data = result["png"]
img_bytes = base64.b64decode(png_data)
img = Image.open(BytesIO(img_bytes))
images.append(img)

return output_str, images
except (json.JSONDecodeError, Exception):
# If not JSON or error parsing, return as-is
return output_str, images


def _oai_response_output_item_to_gradio(
item: RunItem, is_final_output: bool
) -> list[ChatMessage] | None:
Expand All @@ -57,17 +88,34 @@ def _oai_response_output_item_to_gradio(
call_id = item.raw_item.get("call_id", None)

if isinstance(function_output, str):
return [
text_content, images = _process_tool_output_for_images(function_output)
messages = [
ChatMessage(
role="assistant",
content=f"> {function_output}\n\n`{call_id}`",
content=f"> {text_content}\n\n`{call_id}`",
metadata={
"title": "*Tool call output*",
"status": "done", # This makes it collapsed by default
},
)
]

# Add images as separate messages
for img in images:
messages.append(
ChatMessage(
role="assistant",
content=gr.Image(
img,
format="png",
container=False,
interactive=True,
buttons=["download"],
),
)
)
return messages

if isinstance(item, MessageOutputItem):
message_content = item.raw_item

Expand Down Expand Up @@ -121,6 +169,7 @@ def oai_agent_stream_to_gradio_messages(
if isinstance(stream_event, stream_events.RawResponsesStreamEvent):
data = stream_event.data
if isinstance(data, ResponseCompletedEvent):
print(stream_event)
# The completed event may contain multiple output messages,
# including tool calls and final outputs.
# If there is at least one tool call, we mark the response as a thought.
Expand Down Expand Up @@ -161,15 +210,33 @@ def oai_agent_stream_to_gradio_messages(
item = stream_event.item

if name == "tool_output" and isinstance(item, ToolCallOutputItem):
print(stream_event)
text_content, images = _process_tool_output_for_images(item.output)

output.append(
ChatMessage(
role="assistant",
content=f"```\n{item.output}\n```",
content=f"```\n{text_content}\n```",
metadata={
"title": "*Tool call output*",
"status": "done", # This makes it collapsed by default
},
)
)

# Add images as separate messages
for img in images:
output.append(
ChatMessage(
role="assistant",
content=gr.Image(
img,
format="png",
container=False,
interactive=True,
buttons=["download"],
),
)
)

return output
9 changes: 7 additions & 2 deletions src/utils/tools/code_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Sequence

from e2b_code_interpreter import AsyncSandbox
from e2b_code_interpreter.models import serialize_results
from pydantic import BaseModel

from ..async_utils import gather_with_progress
Expand All @@ -23,9 +24,10 @@ class CodeInterpreterOutput(BaseModel):

stdout: list[str]
stderr: list[str]
results: list[dict[str, str]] | None = None
error: _CodeInterpreterOutputError | None = None

def __init__(self, stdout: list[str], stderr: list[str], **kwargs):
def __init__(self, stdout: list[str], stderr: list[str], **kwargs) -> None:
"""Split lines in stdout and stderr."""
stdout_processed = []
for _line in stdout:
Expand Down Expand Up @@ -109,7 +111,7 @@ def __init__(
local_files: "Sequence[Path | str]| None" = None,
timeout_seconds: int = 30,
template_name: str | None = None,
):
) -> None:
"""Configure your Code Interpreter session.

Note that the sandbox is not persistent, and each run_code will
Expand Down Expand Up @@ -161,6 +163,9 @@ async def run_code(self, code: str) -> str:
error.to_json()
)

if result.results:
response.results = serialize_results(result.results)

return response.model_dump_json()
finally:
await sbx.kill()