diff --git a/src/codegen/extensions/langchain/tools.py b/src/codegen/extensions/langchain/tools.py index 0749384a4..a8786bc02 100644 --- a/src/codegen/extensions/langchain/tools.py +++ b/src/codegen/extensions/langchain/tools.py @@ -546,7 +546,7 @@ class GithubViewPRTool(BaseTool): """Tool for getting PR data.""" name: ClassVar[str] = "view_pr" - description: ClassVar[str] = "View the diff and associated context for a pull request" + description: ClassVar[str] = "View the diff and associated context for a pull request on Github. Returns PR URL and repository name." args_schema: ClassVar[type[BaseModel]] = GithubViewPRInput codebase: Codebase = Field(exclude=True) diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py index 00c20f7bb..5653c8806 100644 --- a/src/codegen/extensions/tools/github/view_pr.py +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -15,6 +15,12 @@ class ViewPRObservation(Observation): pr_id: int = Field( description="ID of the PR", ) + repo_name: str = Field( + description="Name of the repository containing the PR", + ) + pr_url: str = Field( + description="URL of the PR", + ) patch: str = Field( description="The PR's patch/diff content", ) @@ -25,7 +31,7 @@ class ViewPRObservation(Observation): description="Names of modified symbols in the PR", ) - str_template: ClassVar[str] = "PR #{pr_id}" + str_template: ClassVar[str] = "PR #{pr_id} in {repo_name}" def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: @@ -36,14 +42,21 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: pr_id: Number of the PR to get the contents for """ try: - patch, file_commit_sha, moddified_symbols = codebase.get_modified_symbols_in_pr(pr_id) + patch, file_commit_sha, modified_symbols = codebase.get_modified_symbols_in_pr(pr_id) + + # Get the PR object to extract URL and repository name + pr = codebase.op.remote_git_repo.get_pull(pr_id) + repo_name = codebase.op.remote_git_repo.repo.full_name + pr_url = pr.html_url return ViewPRObservation( status="success", pr_id=pr_id, + repo_name=repo_name, + pr_url=pr_url, patch=patch, file_commit_sha=file_commit_sha, - modified_symbols=moddified_symbols, + modified_symbols=modified_symbols, ) except Exception as e: @@ -51,6 +64,8 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: status="error", error=f"Failed to view PR: {e!s}", pr_id=pr_id, + repo_name="", + pr_url="", patch="", file_commit_sha={}, modified_symbols=[],