Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/.env.example
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
# GEMINI_API_KEY=
# GEMINI_API_KEY=
# TAVILY_API_KEY= # Required when search_provider=tavily
1 change: 1 addition & 0 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies = [
"langgraph-api",
"fastapi",
"google-genai",
"tavily-python",
]


Expand Down
7 changes: 7 additions & 0 deletions backend/src/agent/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class Configuration(BaseModel):
},
)

search_provider: str = Field(
default="google",
metadata={
"description": "The search provider to use for web research. Options: 'google' (default) or 'tavily'."
},
)
Comment on lines +32 to +37
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make the Tavily integration more flexible, consider making max_results and search_depth configurable instead of hardcoding them in graph.py. This allows adjusting the search behavior without changing the code.

Suggested change
search_provider: str = Field(
default="google",
metadata={
"description": "The search provider to use for web research. Options: 'google' (default) or 'tavily'."
},
)
search_provider: str = Field(
default="google",
metadata={
"description": "The search provider to use for web research. Options: 'google' (default) or 'tavily'."
},
)
tavily_max_results: int = Field(
default=5,
metadata={"description": "The maximum number of results to return from Tavily search."},
)
tavily_search_depth: str = Field(
default="advanced",
metadata={"description": "The depth of search for Tavily. Options: 'basic' or 'advanced'."},
)


number_of_initial_queries: int = Field(
default=3,
metadata={"description": "The number of initial search queries to generate."},
Expand Down
52 changes: 49 additions & 3 deletions backend/src/agent/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from langgraph.graph import START, END
from langchain_core.runnables import RunnableConfig
from google.genai import Client
from tavily import TavilyClient

from agent.state import (
OverallState,
Expand Down Expand Up @@ -93,9 +94,10 @@ def continue_to_web_research(state: QueryGenerationState):


def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
"""LangGraph node that performs web research using the native Google Search API tool.
"""LangGraph node that performs web research using Google Search or Tavily.

Executes a web search using the native Google Search API tool in combination with Gemini 2.0 Flash.
Executes a web search using the configured search provider. When search_provider
is 'tavily', uses the Tavily API; otherwise defaults to Google Search grounding.

Args:
state: Current graph state containing the search query and research loop count
Expand All @@ -104,8 +106,16 @@ def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
Returns:
Dictionary with state update, including sources_gathered, research_loop_count, and web_research_results
"""
# Configure
configurable = Configuration.from_runnable_config(config)

if configurable.search_provider == "tavily":
return _web_research_tavily(state, configurable)
else:
return _web_research_google(state, configurable)


def _web_research_google(state: WebSearchState, configurable: Configuration) -> OverallState:
"""Perform web research using Google Search grounding via Gemini."""
formatted_prompt = web_searcher_instructions.format(
current_date=get_current_date(),
research_topic=state["search_query"],
Expand Down Expand Up @@ -136,6 +146,42 @@ def web_research(state: WebSearchState, config: RunnableConfig) -> OverallState:
}


def _web_research_tavily(state: WebSearchState, configurable: Configuration) -> OverallState:
"""Perform web research using the Tavily search API."""
tavily_client = TavilyClient()
response = tavily_client.search(
query=state["search_query"],
max_results=5,
search_depth="advanced",
)
Comment on lines +151 to +156
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are a couple of improvements that can be made here:

  1. Client Caching: To improve performance, the TavilyClient should be instantiated only once. Creating a new client on every call is inefficient, especially since this function can be called multiple times. You can cache the client instance, for example, as a function attribute.
  2. Configurability: The max_results and search_depth are hardcoded. It would be more flexible to make these configurable via the Configuration model, similar to other settings.

This suggestion addresses both points and assumes you add tavily_max_results and tavily_search_depth to configuration.py.

Suggested change
tavily_client = TavilyClient()
response = tavily_client.search(
query=state["search_query"],
max_results=5,
search_depth="advanced",
)
if not hasattr(_web_research_tavily, "client"):
_web_research_tavily.client = TavilyClient()
response = _web_research_tavily.client.search(
query=state["search_query"],
max_results=configurable.tavily_max_results,
search_depth=configurable.tavily_search_depth,
)


# Build sources and research text matching the existing schema
sources_gathered = []
research_parts = []
for idx, result in enumerate(response.get("results", [])):
url = result.get("url", "")
title = result.get("title", "")
content = result.get("content", "")
short_url = f"https://tavily.com/id/{state['id']}-{idx}"

# label: use the domain or title
label = title.split(" - ")[0] if title else url
sources_gathered.append({
"label": label,
"short_url": short_url,
"value": url,
})
research_parts.append(f"{content} [{label}]({short_url})")

modified_text = "\n\n".join(research_parts) if research_parts else ""

return {
"sources_gathered": sources_gathered,
"search_query": [state["search_query"]],
"web_research_result": [modified_text],
}


def reflection(state: OverallState, config: RunnableConfig) -> ReflectionState:
"""LangGraph node that identifies knowledge gaps and generates potential follow-up queries.

Expand Down