diff --git a/src/codegen/extensions/tools/github/view_pr.py b/src/codegen/extensions/tools/github/view_pr.py index 00c20f7bb..445f5e989 100644 --- a/src/codegen/extensions/tools/github/view_pr.py +++ b/src/codegen/extensions/tools/github/view_pr.py @@ -24,8 +24,16 @@ class ViewPRObservation(Observation): modified_symbols: list[str] = Field( description="Names of modified symbols in the PR", ) + pr_url: str = Field( + description="URL of the PR", + default="", + ) + repo_name: str = Field( + description="Name of the repository the PR belongs to", + default="", + ) - str_template: ClassVar[str] = "PR #{pr_id}" + str_template: ClassVar[str] = "PR #{pr_id} in {repo_name}: {pr_url}" def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: @@ -38,12 +46,19 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: try: patch, file_commit_sha, moddified_symbols = codebase.get_modified_symbols_in_pr(pr_id) + # Get the PR object to extract URL and repo name + pr = codebase._op.get_pull_request(pr_id) + pr_url = pr.html_url if pr else "" + repo_name = codebase._op.repo_config.full_name if codebase._op.repo_config.full_name else "" + return ViewPRObservation( status="success", pr_id=pr_id, patch=patch, file_commit_sha=file_commit_sha, modified_symbols=moddified_symbols, + pr_url=pr_url, + repo_name=repo_name, ) except Exception as e: @@ -54,4 +69,6 @@ def view_pr(codebase: Codebase, pr_id: int) -> ViewPRObservation: patch="", file_commit_sha={}, modified_symbols=[], + pr_url="", + repo_name="", )