Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 106 additions & 25 deletions src/codegen/extensions/langchain/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Langchain tools for workspace operations."""

import os
from typing import Callable, ClassVar, Literal, Optional

from langchain_core.tools.base import BaseTool
Expand Down Expand Up @@ -668,13 +669,27 @@ class LinearGetIssueTool(BaseTool):
name: ClassVar[str] = "linear_get_issue"
description: ClassVar[str] = "Get details of a Linear issue by its ID"
args_schema: ClassVar[type[BaseModel]] = LinearGetIssueInput
client: LinearClient = Field(exclude=True)
codebase: Codebase = Field(exclude=True)
client: LinearClient | None = Field(default=None, exclude=True)

def __init__(self, codebase: Codebase) -> None:
# Initialize with codebase and create LinearClient on first use
super().__init__(codebase=codebase)

def __init__(self, client: LinearClient) -> None:
super().__init__(client=client)
def _get_client(self) -> LinearClient:
"""Get or create a LinearClient instance."""
if self.client is None:
# Create a new LinearClient instance
access_token = os.getenv("LINEAR_ACCESS_TOKEN")
if not access_token:
msg = "LINEAR_ACCESS_TOKEN environment variable not set"
raise ValueError(msg)
self.client = LinearClient(access_token)
return self.client

def _run(self, issue_id: str) -> str:
result = linear_get_issue_tool(self.client, issue_id)
client = self._get_client()
result = linear_get_issue_tool(client, issue_id)
return result.render()


Expand All @@ -690,13 +705,26 @@ class LinearGetIssueCommentsTool(BaseTool):
name: ClassVar[str] = "linear_get_issue_comments"
description: ClassVar[str] = "Get all comments on a Linear issue"
args_schema: ClassVar[type[BaseModel]] = LinearGetIssueCommentsInput
client: LinearClient = Field(exclude=True)
codebase: Codebase = Field(exclude=True)
client: LinearClient | None = Field(default=None, exclude=True)

def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def __init__(self, client: LinearClient) -> None:
super().__init__(client=client)
def _get_client(self) -> LinearClient:
"""Get or create a LinearClient instance."""
if self.client is None:
# Create a new LinearClient instance
access_token = os.getenv("LINEAR_ACCESS_TOKEN")
if not access_token:
msg = "LINEAR_ACCESS_TOKEN environment variable not set"
raise ValueError(msg)
self.client = LinearClient(access_token)
return self.client

def _run(self, issue_id: str) -> str:
result = linear_get_issue_comments_tool(self.client, issue_id)
client = self._get_client()
result = linear_get_issue_comments_tool(client, issue_id)
return result.render()


Expand All @@ -713,13 +741,26 @@ class LinearCommentOnIssueTool(BaseTool):
name: ClassVar[str] = "linear_comment_on_issue"
description: ClassVar[str] = "Add a comment to a Linear issue"
args_schema: ClassVar[type[BaseModel]] = LinearCommentOnIssueInput
client: LinearClient = Field(exclude=True)
codebase: Codebase = Field(exclude=True)
client: LinearClient | None = Field(default=None, exclude=True)

def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def __init__(self, client: LinearClient) -> None:
super().__init__(client=client)
def _get_client(self) -> LinearClient:
"""Get or create a LinearClient instance."""
if self.client is None:
# Create a new LinearClient instance
access_token = os.getenv("LINEAR_ACCESS_TOKEN")
if not access_token:
msg = "LINEAR_ACCESS_TOKEN environment variable not set"
raise ValueError(msg)
self.client = LinearClient(access_token)
return self.client

def _run(self, issue_id: str, body: str) -> str:
result = linear_comment_on_issue_tool(self.client, issue_id, body)
client = self._get_client()
result = linear_comment_on_issue_tool(client, issue_id, body)
return result.render()


Expand All @@ -734,15 +775,28 @@ class LinearSearchIssuesTool(BaseTool):
"""Tool for searching Linear issues."""

name: ClassVar[str] = "linear_search_issues"
description: ClassVar[str] = "Search for Linear issues using a query string"
description: ClassVar[str] = "Search for Linear issues using a search string"
args_schema: ClassVar[type[BaseModel]] = LinearSearchIssuesInput
client: LinearClient = Field(exclude=True)
codebase: Codebase = Field(exclude=True)
client: LinearClient | None = Field(default=None, exclude=True)

def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def __init__(self, client: LinearClient) -> None:
super().__init__(client=client)
def _get_client(self) -> LinearClient:
"""Get or create a LinearClient instance."""
if self.client is None:
# Create a new LinearClient instance
access_token = os.getenv("LINEAR_ACCESS_TOKEN")
if not access_token:
msg = "LINEAR_ACCESS_TOKEN environment variable not set"
raise ValueError(msg)
self.client = LinearClient(access_token)
return self.client

def _run(self, query: str, limit: int = 10) -> str:
result = linear_search_issues_tool(self.client, query, limit)
client = self._get_client()
result = linear_search_issues_tool(client, query, limit)
return result.render()


Expand All @@ -760,13 +814,27 @@ class LinearCreateIssueTool(BaseTool):
name: ClassVar[str] = "linear_create_issue"
description: ClassVar[str] = "Create a new Linear issue"
args_schema: ClassVar[type[BaseModel]] = LinearCreateIssueInput
client: LinearClient = Field(exclude=True)
codebase: Codebase = Field(exclude=True)
client: LinearClient | None = Field(default=None, exclude=True)

def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def __init__(self, client: LinearClient) -> None:
super().__init__(client=client)
def _get_client(self) -> LinearClient:
"""Get or create a LinearClient instance."""
if self.client is None:
# Create a new LinearClient instance
access_token = os.getenv("LINEAR_ACCESS_TOKEN")
if not access_token:
msg = "LINEAR_ACCESS_TOKEN environment variable not set"
raise ValueError(msg)
# Initialize without a default team_id to allow explicit team selection
self.client = LinearClient(access_token)
return self.client

def _run(self, title: str, description: str | None = None, team_id: str | None = None) -> str:
result = linear_create_issue_tool(self.client, title, description, team_id)
client = self._get_client()
result = linear_create_issue_tool(client, title, description, team_id)
return result.render()


Expand All @@ -775,13 +843,26 @@ class LinearGetTeamsTool(BaseTool):

name: ClassVar[str] = "linear_get_teams"
description: ClassVar[str] = "Get all Linear teams the authenticated user has access to"
client: LinearClient = Field(exclude=True)
codebase: Codebase = Field(exclude=True)
client: LinearClient | None = Field(default=None, exclude=True)

def __init__(self, codebase: Codebase) -> None:
super().__init__(codebase=codebase)

def __init__(self, client: LinearClient) -> None:
super().__init__(client=client)
def _get_client(self) -> LinearClient:
"""Get or create a LinearClient instance."""
if self.client is None:
# Create a new LinearClient instance
access_token = os.getenv("LINEAR_ACCESS_TOKEN")
if not access_token:
msg = "LINEAR_ACCESS_TOKEN environment variable not set"
raise ValueError(msg)
self.client = LinearClient(access_token)
return self.client

def _run(self) -> str:
result = linear_get_teams_tool(self.client)
client = self._get_client()
result = linear_get_teams_tool(client)
return result.render()


Expand Down
59 changes: 59 additions & 0 deletions tests/integration/extension/test_linear_team_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
"""Tests for Linear tools with team_id parameter."""

import os

import pytest

from codegen.extensions.linear.linear_client import LinearClient
from codegen.extensions.tools.linear.linear import (
linear_create_issue_tool,
linear_get_teams_tool,
)


@pytest.fixture
def client() -> LinearClient:
"""Create a Linear client for testing."""
token = os.getenv("LINEAR_ACCESS_TOKEN")
if not token:
pytest.skip("LINEAR_ACCESS_TOKEN environment variable not set")
# Note: We're not setting team_id here to test explicit team_id passing
return LinearClient(token)


def test_create_issue_with_explicit_team_id(client: LinearClient) -> None:
"""Test creating an issue with an explicit team_id."""
# First, get available teams
teams_result = linear_get_teams_tool(client)
assert teams_result.status == "success"
assert len(teams_result.teams) > 0

# Use the first team's ID for our test
team_id = teams_result.teams[0]["id"]
team_name = teams_result.teams[0]["name"]

# Create an issue with explicit team_id
title = f"Test Issue in {team_name} - Explicit Team ID"
description = f"This is a test issue created in team {team_name} with explicit team_id"

result = linear_create_issue_tool(client, title, description, team_id)
assert result.status == "success"
assert result.title == title
assert result.team_id == team_id
assert result.issue_data["title"] == title
assert result.issue_data["description"] == description

# If there are multiple teams, test with a different team
if len(teams_result.teams) > 1:
second_team_id = teams_result.teams[1]["id"]
second_team_name = teams_result.teams[1]["name"]

title2 = f"Test Issue in {second_team_name} - Explicit Team ID"
description2 = f"This is a test issue created in team {second_team_name} with explicit team_id"

result2 = linear_create_issue_tool(client, title2, description2, second_team_id)
assert result2.status == "success"
assert result2.title == title2
assert result2.team_id == second_team_id
assert result2.issue_data["title"] == title2
assert result2.issue_data["description"] == description2
Loading