Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 21 additions & 15 deletions src/codemodder/codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from codemodder.codetf import CodeTF
from codemodder.context import CodemodExecutionContext
from codemodder.dependency import Dependency
from codemodder.llm import MisconfiguredAIClient
from codemodder.llm import MisconfiguredAIClient, TokenUsage, log_token_usage
from codemodder.logging import configure_logger, log_list, log_section, logger
from codemodder.project_analysis.file_parsers.package_store import PackageStore
from codemodder.project_analysis.python_repo_manager import PythonRepoManager
Expand Down Expand Up @@ -46,7 +46,7 @@ def find_semgrep_results(
return run_semgrep(context, yaml_files, files_to_analyze)


def log_report(context, output, elapsed_ms, files_to_analyze):
def log_report(context, output, elapsed_ms, files_to_analyze, token_usage):
log_section("report")
logger.info("scanned: %s files", len(files_to_analyze))
all_failures = context.get_failed_files()
Expand All @@ -62,6 +62,7 @@ def log_report(context, output, elapsed_ms, files_to_analyze):
len(set(all_changes)),
)
logger.info("report file: %s", output)
log_token_usage("All", token_usage)
logger.info("total elapsed: %s ms", elapsed_ms)
logger.info(" semgrep: %s ms", context.timer.get_time_ms("semgrep"))
logger.info(" parse: %s ms", context.timer.get_time_ms("parse"))
Expand All @@ -72,24 +73,29 @@ def log_report(context, output, elapsed_ms, files_to_analyze):
def apply_codemods(
context: CodemodExecutionContext,
codemods_to_run: Sequence[BaseCodemod],
):
) -> TokenUsage:
log_section("scanning")
token_usage = TokenUsage()

if not context.files_to_analyze:
logger.info("no files to scan")
return
return token_usage

if not codemods_to_run:
logger.info("no codemods to run")
return
return token_usage

# run codemods one at a time making sure to respect the given sequence
for codemod in codemods_to_run:
# NOTE: this may be used as a progress indicator by upstream tools
logger.info("running codemod %s", codemod.id)
codemod.apply(context)
if codemod_token_usage := codemod.apply(context):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replaces multiple expressions involving if operator with 'walrus' operator.

log_token_usage(f"Codemod {codemod.id}", codemod_token_usage)
token_usage += codemod_token_usage

record_dependency_update(context.process_dependencies(codemod.id))
context.log_changes(codemod.id)
return token_usage


def record_dependency_update(dependency_results: dict[Dependency, PackageStore | None]):
Expand Down Expand Up @@ -128,7 +134,7 @@ def run(
codemod_registry: registry.CodemodRegistry | None = None,
sast_only: bool = False,
ai_client: bool = True,
) -> tuple[CodeTF | None, int]:
) -> tuple[CodeTF | None, int, TokenUsage]:
start = datetime.datetime.now()

codemod_registry = codemod_registry or registry.load_registered_codemods()
Expand All @@ -139,6 +145,7 @@ def run(
codemod_exclude = codemod_exclude or []

provider_registry = providers.load_providers()
token_usage = TokenUsage()

log_section("startup")
logger.info("codemodder: python/%s", __version__)
Expand All @@ -148,7 +155,7 @@ def run(
logger.error(
f"FileNotFoundError: [Errno 2] No such file or directory: '{file_name}'"
)
return None, 1
return None, 1, token_usage

repo_manager = PythonRepoManager(Path(directory))

Expand All @@ -168,7 +175,8 @@ def run(
)
except MisconfiguredAIClient as e:
logger.error(e)
return None, 3 # Codemodder instructions conflicted (according to spec)
# Codemodder instructions conflicted (according to spec)
return None, 3, token_usage

context.repo_manager.parse_project()

Expand All @@ -194,10 +202,7 @@ def run(
context.find_and_fix_paths,
)

apply_codemods(
context,
codemods_to_run,
)
token_usage = apply_codemods(context, codemods_to_run)

elapsed = datetime.datetime.now() - start
elapsed_ms = int(elapsed.total_seconds() * 1000)
Expand All @@ -217,8 +222,9 @@ def run(
output,
elapsed_ms,
[] if not codemods_to_run else context.files_to_analyze,
token_usage,
)
return codetf, 0
return codetf, 0, token_usage


def _run_cli(original_args) -> int:
Expand Down Expand Up @@ -258,7 +264,7 @@ def _run_cli(original_args) -> int:
logger.info("command: %s %s", Path(sys.argv[0]).name, " ".join(original_args))
configure_logger(argv.verbose, argv.log_format, argv.project_name)

_, status = run(
_, status, _ = run(
argv.directory,
argv.dry_run,
argv.output,
Expand Down
20 changes: 11 additions & 9 deletions src/codemodder/codemods/base_codemod.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from codemodder.codetf import DetectionTool, Reference
from codemodder.context import CodemodExecutionContext
from codemodder.file_context import FileContext
from codemodder.llm import TokenUsage
from codemodder.logging import logger
from codemodder.result import ResultSet

Expand Down Expand Up @@ -188,15 +189,15 @@ def _apply(
self,
context: CodemodExecutionContext,
rules: list[str],
) -> None:
) -> None | TokenUsage:
if self.provider and (
not (provider := context.providers.get_provider(self.provider))
or not provider.is_available
):
logger.warning(
"provider %s is not available, skipping codemod", self.provider
)
return
return None

if isinstance(self.detector, SemgrepRuleDetector):
if (
Expand All @@ -208,7 +209,7 @@ def _apply(
"no results from semgrep for %s, skipping analysis",
self.id,
)
return
return None

results: ResultSet | None = (
# It seems like semgrep doesn't like our fully-specified id format so pass in short name instead.
Expand All @@ -219,11 +220,11 @@ def _apply(

if results is not None and not results:
logger.debug("No results for %s", self.id)
return
return None

if not (files_to_analyze := self.get_files_to_analyze(context, results)):
logger.debug("No files matched for %s", self.id)
return
return None

process_file = functools.partial(
self._process_file, context=context, results=results, rules=rules
Expand All @@ -240,8 +241,9 @@ def _apply(
executor.shutdown(wait=True)

context.process_results(self.id, contexts)
return None

def apply(self, context: CodemodExecutionContext) -> None:
def apply(self, context: CodemodExecutionContext) -> None | TokenUsage:
"""
Apply the codemod with the given codemod execution context

Expand All @@ -257,7 +259,7 @@ def apply(self, context: CodemodExecutionContext) -> None:

:param context: The codemod execution context
"""
self._apply(context, [self._internal_name])
return self._apply(context, [self._internal_name])

def _process_file(
self,
Expand Down Expand Up @@ -355,8 +357,8 @@ def __init__(
if requested_rules:
self.requested_rules.extend(requested_rules)

def apply(self, context: CodemodExecutionContext) -> None:
self._apply(context, self.requested_rules)
def apply(self, context: CodemodExecutionContext) -> None | TokenUsage:
return self._apply(context, self.requested_rules)

def get_files_to_analyze(
self,
Expand Down
10 changes: 6 additions & 4 deletions src/codemodder/codemods/utils_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ def class_has_method(self, classdef: cst.ClassDef, method_name: str) -> bool:
"""Check if a given class definition implements a method of name `method_name`."""
for node in classdef.body.body:
match node:
case cst.FunctionDef(
name=cst.Name(value=value)
) if value == method_name:
case cst.FunctionDef(name=cst.Name(value=value)) if (
value == method_name
):
return True
return False

Expand All @@ -331,7 +331,9 @@ def is_value_of_assignment(
| cst.Assign(value=value)
| cst.WithItem(item=value)
| cst.NamedExpr(value=value)
) if expr == value: # type: ignore
) if (
expr == value
): # type: ignore
return parent
return None

Expand Down
31 changes: 30 additions & 1 deletion src/codemodder/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from __future__ import annotations

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING

from typing_extensions import Self

try:
from openai import AzureOpenAI, OpenAI
except ImportError:
Expand All @@ -28,6 +31,8 @@
"setup_openai_llm_client",
"setup_azure_llama_llm_client",
"MisconfiguredAIClient",
"TokenUsage",
"log_token_usage",
]

models = ["gpt-4-turbo-2024-04-09", "gpt-4o-2024-05-13", "gpt-35-turbo-0125"]
Expand Down Expand Up @@ -58,8 +63,8 @@ def __getattr__(self, name):
def setup_openai_llm_client() -> OpenAI | None:
"""Configure either the Azure OpenAI LLM client or the OpenAI client, in that order."""
if not AzureOpenAI:
logger.info("Azure OpenAI API client not available")
return None
logger.info("Azure OpenAI API client not available")

azure_openapi_key = os.getenv("CODEMODDER_AZURE_OPENAI_API_KEY")
azure_openapi_endpoint = os.getenv("CODEMODDER_AZURE_OPENAI_ENDPOINT")
Expand Down Expand Up @@ -115,3 +120,27 @@ def setup_azure_llama_llm_client() -> ChatCompletionsClient | None:

class MisconfiguredAIClient(ValueError):
pass


@dataclass
class TokenUsage:
completion_tokens: int = 0
prompt_tokens: int = 0

def __iadd__(self, other: Self) -> Self:
self.completion_tokens += other.completion_tokens
self.prompt_tokens += other.prompt_tokens
return self

@property
def total(self):
return self.completion_tokens + self.prompt_tokens


def log_token_usage(header: str, token_usage: TokenUsage):
logger.info(
"%s token usage\n\tcompletion_tokens = %s\n\tprompt_tokens = %s",
header,
token_usage.completion_tokens,
token_usage.prompt_tokens,
)
2 changes: 1 addition & 1 deletion src/codemodder/utils/format_string_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def expressions_from_replacements(


def dict_to_values_dict(
expr_dict: dict[cst.BaseExpression, cst.BaseExpression]
expr_dict: dict[cst.BaseExpression, cst.BaseExpression],
) -> dict[str | cst.BaseExpression, cst.BaseExpression]:
return {
extract_raw_value(k): v
Expand Down
10 changes: 7 additions & 3 deletions tests/test_codemodder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from codemodder import run
from codemodder.codemodder import _run_cli, find_semgrep_results
from codemodder.diff import create_diff_from_tree
from codemodder.llm import TokenUsage
from codemodder.registry import load_registered_codemods
from codemodder.result import ResultSet
from codemodder.semgrep import run as semgrep_run
Expand All @@ -30,7 +31,9 @@ def disable_codemod_apply(mocker, request):
"test_run_codemod_name_or_id",
):
return
mocker.patch("codemodder.codemods.base_codemod.BaseCodemod.apply")
mocker.patch(
"codemodder.codemods.base_codemod.BaseCodemod.apply", return_value=TokenUsage()
)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -395,7 +398,8 @@ class TestRun:
def test_run_basic_call(self, mock_parse, dir_structure):
code_dir, codetf = dir_structure

codetf_output, status = run(code_dir, dry_run=True)
codetf_output, status, token_usage = run(code_dir, dry_run=True)
assert token_usage.total == 0
assert status == 0
assert codetf_output
assert codetf_output.run.directory == str(code_dir)
Expand All @@ -406,7 +410,7 @@ def test_run_basic_call(self, mock_parse, dir_structure):
def test_run_with_output(self, mock_parse, dir_structure):
code_dir, codetf = dir_structure

codetf_output, status = run(
codetf_output, status, _ = run(
code_dir,
output=codetf,
dry_run=True,
Expand Down
10 changes: 9 additions & 1 deletion tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from codemodder.llm import MODELS, models
from codemodder.llm import MODELS, TokenUsage, models


class TestModels:
Expand All @@ -20,3 +20,11 @@ def test_model_get_name_from_env(self, mocker, model):
},
)
assert getattr(MODELS, attr_name) == name


def test_token_usage():
token_usage = TokenUsage()
token_usage += TokenUsage(10, 5)
assert token_usage.completion_tokens == 10
assert token_usage.prompt_tokens == 5
assert token_usage.total == 15