diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 877b59f05..921084e77 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -1,5 +1,6 @@ """Langchain tools for workspace operations.""" +import os from typing import Callable, ClassVar, Literal, Optional from langchain_core.tools.base import BaseTool @@ -668,13 +669,27 @@ class LinearGetIssueTool(BaseTool): name: ClassVar[str] = "linear_get_issue" description: ClassVar[str] = "Get details of a Linear issue by its ID" args_schema: ClassVar[type[BaseModel]] = LinearGetIssueInput - client: LinearClient = Field(exclude=True) + codebase: Codebase = Field(exclude=True) + client: LinearClient | None = Field(default=None, exclude=True) + + def __init__(self, codebase: Codebase) -> None: + # Initialize with codebase and create LinearClient on first use + super().__init__(codebase=codebase) - def __init__(self, client: LinearClient) -> None: - super().__init__(client=client) + def _get_client(self) -> LinearClient: + """Get or create a LinearClient instance.""" + if self.client is None: + # Create a new LinearClient instance + access_token = os.getenv("LINEAR_ACCESS_TOKEN") + if not access_token: + msg = "LINEAR_ACCESS_TOKEN environment variable not set" + raise ValueError(msg) + self.client = LinearClient(access_token) + return self.client def _run(self, issue_id: str) -> str: - result = linear_get_issue_tool(self.client, issue_id) + client = self._get_client() + result = linear_get_issue_tool(client, issue_id) return result.render() @@ -690,13 +705,26 @@ class LinearGetIssueCommentsTool(BaseTool): name: ClassVar[str] = "linear_get_issue_comments" description: ClassVar[str] = "Get all comments on a Linear issue" args_schema: ClassVar[type[BaseModel]] = LinearGetIssueCommentsInput - client: LinearClient = Field(exclude=True) + codebase: Codebase = Field(exclude=True) + client: LinearClient | None = Field(default=None, exclude=True) + + def __init__(self, codebase: Codebase) -> None: + super().__init__(codebase=codebase) - def __init__(self, client: LinearClient) -> None: - super().__init__(client=client) + def _get_client(self) -> LinearClient: + """Get or create a LinearClient instance.""" + if self.client is None: + # Create a new LinearClient instance + access_token = os.getenv("LINEAR_ACCESS_TOKEN") + if not access_token: + msg = "LINEAR_ACCESS_TOKEN environment variable not set" + raise ValueError(msg) + self.client = LinearClient(access_token) + return self.client def _run(self, issue_id: str) -> str: - result = linear_get_issue_comments_tool(self.client, issue_id) + client = self._get_client() + result = linear_get_issue_comments_tool(client, issue_id) return result.render() @@ -713,13 +741,26 @@ class LinearCommentOnIssueTool(BaseTool): name: ClassVar[str] = "linear_comment_on_issue" description: ClassVar[str] = "Add a comment to a Linear issue" args_schema: ClassVar[type[BaseModel]] = LinearCommentOnIssueInput - client: LinearClient = Field(exclude=True) + codebase: Codebase = Field(exclude=True) + client: LinearClient | None = Field(default=None, exclude=True) + + def __init__(self, codebase: Codebase) -> None: + super().__init__(codebase=codebase) - def __init__(self, client: LinearClient) -> None: - super().__init__(client=client) + def _get_client(self) -> LinearClient: + """Get or create a LinearClient instance.""" + if self.client is None: + # Create a new LinearClient instance + access_token = os.getenv("LINEAR_ACCESS_TOKEN") + if not access_token: + msg = "LINEAR_ACCESS_TOKEN environment variable not set" + raise ValueError(msg) + self.client = LinearClient(access_token) + return self.client def _run(self, issue_id: str, body: str) -> str: - result = linear_comment_on_issue_tool(self.client, issue_id, body) + client = self._get_client() + result = linear_comment_on_issue_tool(client, issue_id, body) return result.render() @@ -734,15 +775,28 @@ class LinearSearchIssuesTool(BaseTool): """Tool for searching Linear issues.""" name: ClassVar[str] = "linear_search_issues" - description: ClassVar[str] = "Search for Linear issues using a query string" + description: ClassVar[str] = "Search for Linear issues using a search string" args_schema: ClassVar[type[BaseModel]] = LinearSearchIssuesInput - client: LinearClient = Field(exclude=True) + codebase: Codebase = Field(exclude=True) + client: LinearClient | None = Field(default=None, exclude=True) + + def __init__(self, codebase: Codebase) -> None: + super().__init__(codebase=codebase) - def __init__(self, client: LinearClient) -> None: - super().__init__(client=client) + def _get_client(self) -> LinearClient: + """Get or create a LinearClient instance.""" + if self.client is None: + # Create a new LinearClient instance + access_token = os.getenv("LINEAR_ACCESS_TOKEN") + if not access_token: + msg = "LINEAR_ACCESS_TOKEN environment variable not set" + raise ValueError(msg) + self.client = LinearClient(access_token) + return self.client def _run(self, query: str, limit: int = 10) -> str: - result = linear_search_issues_tool(self.client, query, limit) + client = self._get_client() + result = linear_search_issues_tool(client, query, limit) return result.render() @@ -760,13 +814,27 @@ class LinearCreateIssueTool(BaseTool): name: ClassVar[str] = "linear_create_issue" description: ClassVar[str] = "Create a new Linear issue" args_schema: ClassVar[type[BaseModel]] = LinearCreateIssueInput - client: LinearClient = Field(exclude=True) + codebase: Codebase = Field(exclude=True) + client: LinearClient | None = Field(default=None, exclude=True) + + def __init__(self, codebase: Codebase) -> None: + super().__init__(codebase=codebase) - def __init__(self, client: LinearClient) -> None: - super().__init__(client=client) + def _get_client(self) -> LinearClient: + """Get or create a LinearClient instance.""" + if self.client is None: + # Create a new LinearClient instance + access_token = os.getenv("LINEAR_ACCESS_TOKEN") + if not access_token: + msg = "LINEAR_ACCESS_TOKEN environment variable not set" + raise ValueError(msg) + # Initialize without a default team_id to allow explicit team selection + self.client = LinearClient(access_token) + return self.client def _run(self, title: str, description: str | None = None, team_id: str | None = None) -> str: - result = linear_create_issue_tool(self.client, title, description, team_id) + client = self._get_client() + result = linear_create_issue_tool(client, title, description, team_id) return result.render() @@ -775,13 +843,26 @@ class LinearGetTeamsTool(BaseTool): name: ClassVar[str] = "linear_get_teams" description: ClassVar[str] = "Get all Linear teams the authenticated user has access to" - client: LinearClient = Field(exclude=True) + codebase: Codebase = Field(exclude=True) + client: LinearClient | None = Field(default=None, exclude=True) + + def __init__(self, codebase: Codebase) -> None: + super().__init__(codebase=codebase) - def __init__(self, client: LinearClient) -> None: - super().__init__(client=client) + def _get_client(self) -> LinearClient: + """Get or create a LinearClient instance.""" + if self.client is None: + # Create a new LinearClient instance + access_token = os.getenv("LINEAR_ACCESS_TOKEN") + if not access_token: + msg = "LINEAR_ACCESS_TOKEN environment variable not set" + raise ValueError(msg) + self.client = LinearClient(access_token) + return self.client def _run(self) -> str: - result = linear_get_teams_tool(self.client) + client = self._get_client() + result = linear_get_teams_tool(client) return result.render() diff --git a/tests/integration/extension/test_linear_team_id.py b/tests/integration/extension/test_linear_team_id.py new file mode 100644 index 000000000..a866aa69c --- /dev/null +++ b/tests/integration/extension/test_linear_team_id.py @@ -0,0 +1,59 @@ +"""Tests for Linear tools with team_id parameter.""" + +import os + +import pytest + +from codegen.extensions.linear.linear_client import LinearClient +from codegen.extensions.tools.linear.linear import ( + linear_create_issue_tool, + linear_get_teams_tool, +) + + +@pytest.fixture +def client() -> LinearClient: + """Create a Linear client for testing.""" + token = os.getenv("LINEAR_ACCESS_TOKEN") + if not token: + pytest.skip("LINEAR_ACCESS_TOKEN environment variable not set") + # Note: We're not setting team_id here to test explicit team_id passing + return LinearClient(token) + + +def test_create_issue_with_explicit_team_id(client: LinearClient) -> None: + """Test creating an issue with an explicit team_id.""" + # First, get available teams + teams_result = linear_get_teams_tool(client) + assert teams_result.status == "success" + assert len(teams_result.teams) > 0 + + # Use the first team's ID for our test + team_id = teams_result.teams[0]["id"] + team_name = teams_result.teams[0]["name"] + + # Create an issue with explicit team_id + title = f"Test Issue in {team_name} - Explicit Team ID" + description = f"This is a test issue created in team {team_name} with explicit team_id" + + result = linear_create_issue_tool(client, title, description, team_id) + assert result.status == "success" + assert result.title == title + assert result.team_id == team_id + assert result.issue_data["title"] == title + assert result.issue_data["description"] == description + + # If there are multiple teams, test with a different team + if len(teams_result.teams) > 1: + second_team_id = teams_result.teams[1]["id"] + second_team_name = teams_result.teams[1]["name"] + + title2 = f"Test Issue in {second_team_name} - Explicit Team ID" + description2 = f"This is a test issue created in team {second_team_name} with explicit team_id" + + result2 = linear_create_issue_tool(client, title2, description2, second_team_id) + assert result2.status == "success" + assert result2.title == title2 + assert result2.team_id == second_team_id + assert result2.issue_data["title"] == title2 + assert result2.issue_data["description"] == description2