Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions src/codegen/extensions/tools/github/create_pr.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tool for creating pull requests."""

import re
import uuid
from typing import ClassVar

Expand All @@ -23,8 +24,66 @@ class CreatePRObservation(Observation):
title: str = Field(
description="Title of the PR",
)
changes_summary: str = Field(
description="Summary of changes included in the PR",
default="",
)

str_template: ClassVar[str] = "Created PR #{number}: {title}\n\nChanges Summary:\n{changes_summary}"


def generate_changes_summary(diff_text: str) -> str:
"""Generate a human-readable summary of changes from a git diff.

str_template: ClassVar[str] = "Created PR #{number}: {title}"
Args:
diff_text: The git diff text

Returns:
A formatted summary of the changes
"""
if not diff_text:
return "No changes detected."

# Parse the diff to extract file information
file_pattern = re.compile(r"diff --git a/(.*?) b/(.*?)\n")
file_matches = file_pattern.findall(diff_text)

# Count additions and deletions
addition_pattern = re.compile(r"^\+[^+]", re.MULTILINE)
deletion_pattern = re.compile(r"^-[^-]", re.MULTILINE)

additions = len(addition_pattern.findall(diff_text))
deletions = len(deletion_pattern.findall(diff_text))

# Get unique files changed
files_changed = set()
for match in file_matches:
# Use the second part of the match (b/file) as it represents the new file
files_changed.add(match[1])

# Group files by extension
file_extensions: dict[str, list[str]] = {}
for file in files_changed:
ext = file.split(".")[-1] if "." in file else "other"
if ext not in file_extensions:
file_extensions[ext] = []
file_extensions[ext].append(file)

# Build the summary
summary = []
summary.append(f"**Files Changed:** {len(files_changed)}")
summary.append(f"**Lines Added:** {additions}")
summary.append(f"**Lines Deleted:** {deletions}")

# Add file details grouped by extension
if file_extensions:
summary.append("\n**Modified Files:**")
for ext, files in file_extensions.items():
summary.append(f"\n*{ext.upper()} Files:*")
for file in sorted(files):
summary.append(f"- {file}")

return "\n".join(summary)


def create_pr(codebase: Codebase, title: str, body: str) -> CreatePRObservation:
Expand All @@ -37,15 +96,20 @@ def create_pr(codebase: Codebase, title: str, body: str) -> CreatePRObservation:
"""
try:
# Check for uncommitted changes and commit them
if len(codebase.get_diff()) == 0:
diff_text = codebase.get_diff()
if len(diff_text) == 0:
return CreatePRObservation(
status="error",
error="No changes to create a PR.",
url="",
number=0,
title=title,
changes_summary="",
)

# Generate a summary of changes
changes_summary = generate_changes_summary(diff_text)

# TODO: this is very jank. We should ideally check out the branch before
# making the changes, but it looks like `codebase.checkout` blows away
# all of your changes
Expand All @@ -65,13 +129,15 @@ def create_pr(codebase: Codebase, title: str, body: str) -> CreatePRObservation:
url="",
number=0,
title=title,
changes_summary="",
)

return CreatePRObservation(
status="success",
url=pr.html_url,
number=pr.number,
title=pr.title,
changes_summary=changes_summary,
)

except Exception as e:
Expand All @@ -81,4 +147,5 @@ def create_pr(codebase: Codebase, title: str, body: str) -> CreatePRObservation:
url="",
number=0,
title=title,
changes_summary="",
)