Skip to content

Commit 1b3fa91

Browse files
Ubuntuclaude
andcommitted
feat: Java testgen class name fix, remove per-test @timeout, and wire language_version
- Add class_name and qualified_name to /testgen API payload so the backend has explicit access to computed FunctionToOptimize properties - Add client-side _fix_java_test_class_name() to correct wrong class name references in LLM-generated Java test code - Remove per-test @timeout annotation from Java instrumentation (causes timing instability on CI runners; Maven Surefire handles timeouts) - Remove redundant default_language_version, use language_version as canonical Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent dc7caa3 commit 1b3fa91

11 files changed

Lines changed: 183 additions & 60 deletions

File tree

.claude/rules/architecture.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Core protocol in `languages/base.py`. Each language (`PythonSupport`, `JavaScrip
6464
|----------|----------------|---------|
6565
| Identity | `language`, `file_extensions`, `default_file_extension` | Language identification |
6666
| Identity | `comment_prefix`, `dir_excludes` | Language conventions |
67-
| AI service | `default_language_version` | Language version for API payloads (`None` for Python, `"ES2022"` for JS) |
67+
| AI service | `language_version` | Detected language version for API payloads (e.g., `"3.11.0"` for Python, `"17"` for Java) |
6868
| AI service | `valid_test_frameworks` | Allowed test frameworks for validation |
6969
| Discovery | `discover_functions`, `discover_tests` | Find optimizable functions and their tests |
7070
| Discovery | `adjust_test_config_for_discovery` | Pre-discovery config adjustment (no-op default) |

codeflash/api/aiservice.py

Lines changed: 26 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ def optimize_code(
129129
experiment_metadata: ExperimentMetadata | None = None,
130130
*,
131131
language: str = "python",
132-
language_version: str
133-
| None = None, # TODO:{claude} add language version to the language support and it should be cached
132+
language_version: str | None = None,
134133
module_system: str | None = None,
135134
is_async: bool = False,
136135
n_candidates: int = 5,
@@ -177,16 +176,12 @@ def optimize_code(
177176
"is_numerical_code": is_numerical_code,
178177
}
179178

180-
# Add language-specific version fields
181-
# Always include python_version for backward compatibility with older backend
182-
payload["python_version"] = platform.python_version()
183-
if is_python():
184-
pass # python_version already set
185-
elif is_java():
186-
payload["language_version"] = language_version or "17" # Default Java version
187-
else:
188-
payload["language_version"] = language_version or "ES2022"
189-
# Add module system for JavaScript/TypeScript (esm or commonjs)
179+
# Add language version (canonical for all languages)
180+
payload["language_version"] = language_version
181+
# Backward compat: backend still expects python_version
182+
payload["python_version"] = language_version if is_python() else platform.python_version()
183+
184+
if not is_python():
190185
if module_system:
191186
payload["module_system"] = module_system
192187

@@ -262,7 +257,8 @@ def get_jit_rewritten_code(self, source_code: str, trace_id: str) -> list[Optimi
262257
"source_code": source_code,
263258
"trace_id": trace_id,
264259
"dependency_code": "", # dummy value to please the api endpoint
265-
"python_version": "3.12.1", # dummy value to please the api endpoint
260+
"language_version": platform.python_version(),
261+
"python_version": platform.python_version(), # backward compat
266262
"current_username": get_last_commit_author_if_pr_exists(None),
267263
"repo_owner": git_repo_owner,
268264
"repo_name": git_repo_name,
@@ -329,18 +325,15 @@ def optimize_python_code_line_profiler(
329325
logger.info("Generating optimized candidates with line profiler…")
330326
console.rule()
331327

332-
# Set python_version for backward compatibility with Python, or use language_version
333-
python_version = language_version if language_version else platform.python_version()
334-
335328
payload = {
336329
"source_code": source_code,
337330
"dependency_code": dependency_code,
338331
"n_candidates": n_candidates,
339332
"line_profiler_results": line_profiler_results,
340333
"trace_id": trace_id,
341-
"python_version": python_version,
342334
"language": language,
343335
"language_version": language_version,
336+
"python_version": language_version if is_python() else platform.python_version(), # backward compat
344337
"experiment_metadata": experiment_metadata,
345338
"codeflash_version": codeflash_version,
346339
"call_sequence": self.get_next_sequence(),
@@ -434,14 +427,10 @@ def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> li
434427
"language": opt.language,
435428
}
436429

437-
# Add language version - always include python_version for backward compatibility
438-
item["python_version"] = platform.python_version()
439-
if is_python():
440-
pass # python_version already set
441-
elif opt.language_version:
442-
item["language_version"] = opt.language_version
443-
else:
444-
item["language_version"] = "ES2022" # Default for JS/TS
430+
# Add language version (canonical for all languages)
431+
item["language_version"] = opt.language_version
432+
# Backward compat: backend still expects python_version
433+
item["python_version"] = opt.language_version if is_python() else platform.python_version()
445434

446435
# Add multi-file context if provided
447436
if opt.additional_context_files:
@@ -649,7 +638,8 @@ def generate_ranking(
649638
"diffs": diffs,
650639
"speedups": speedups,
651640
"optimization_ids": optimization_ids,
652-
"python_version": platform.python_version(),
641+
"language_version": platform.python_version(),
642+
"python_version": platform.python_version(), # backward compat
653643
"function_references": function_references,
654644
}
655645
logger.info("loading|Generating ranking")
@@ -785,18 +775,16 @@ def generate_regression_tests(
785775
"is_async": function_to_optimize.is_async,
786776
"call_sequence": self.get_next_sequence(),
787777
"is_numerical_code": is_numerical_code,
778+
"class_name": function_to_optimize.class_name,
779+
"qualified_name": function_to_optimize.qualified_name,
788780
}
789781

790-
# Add language-specific version fields
791-
# Always include python_version for backward compatibility with older backend
792-
payload["python_version"] = platform.python_version()
793-
if is_python():
794-
pass # python_version already set
795-
elif is_java():
796-
payload["language_version"] = language_version or "17" # Default Java version
797-
else:
798-
payload["language_version"] = language_version or "ES2022"
799-
# Add module system for JavaScript/TypeScript (esm or commonjs)
782+
# Add language version (canonical for all languages)
783+
payload["language_version"] = language_version
784+
# Backward compat: backend still expects python_version
785+
payload["python_version"] = language_version if is_python() else platform.python_version()
786+
787+
if not is_python():
800788
if module_system:
801789
payload["module_system"] = module_system
802790

@@ -884,7 +872,8 @@ def get_optimization_review(
884872
"codeflash_version": codeflash_version,
885873
"calling_fn_details": calling_fn_details,
886874
"language": language,
887-
"python_version": platform.python_version() if is_python() else None,
875+
"language_version": platform.python_version() if is_python() else None,
876+
"python_version": platform.python_version() if is_python() else None, # backward compat
888877
"call_sequence": self.get_next_sequence(),
889878
}
890879
console.rule()

codeflash/api/schemas.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,15 @@ def to_payload(self) -> dict[str, Any]:
137137
"is_numerical_code": self.is_numerical_code,
138138
}
139139

140-
# Add language-specific fields
141-
if self.language_info.version:
142-
payload["language_version"] = self.language_info.version
140+
# Add language version (canonical for all languages)
141+
payload["language_version"] = self.language_info.version
143142

144-
# Backward compat: always include python_version
143+
# Backward compat: backend still expects python_version
145144
import platform
146145

147-
payload["python_version"] = platform.python_version()
146+
payload["python_version"] = (
147+
self.language_info.version if self.language_info.name == "python" else platform.python_version()
148+
)
148149

149150
# Module system for JS/TS
150151
if self.language_info.module_system != ModuleSystem.UNKNOWN:
@@ -205,14 +206,15 @@ def to_payload(self) -> dict[str, Any]:
205206
"is_numerical_code": self.is_numerical_code,
206207
}
207208

208-
# Add language version
209-
if self.language_info.version:
210-
payload["language_version"] = self.language_info.version
209+
# Add language version (canonical for all languages)
210+
payload["language_version"] = self.language_info.version
211211

212-
# Backward compat: always include python_version
212+
# Backward compat: backend still expects python_version
213213
import platform
214214

215-
payload["python_version"] = platform.python_version()
215+
payload["python_version"] = (
216+
self.language_info.version if self.language_info.name == "python" else platform.python_version()
217+
)
216218

217219
# Module system for JS/TS
218220
if self.language_info.module_system != ModuleSystem.UNKNOWN:

codeflash/languages/base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -325,13 +325,9 @@ def dir_excludes(self) -> frozenset[str]:
325325
...
326326

327327
@property
328-
def default_language_version(self) -> str | None:
329-
"""Default language version string sent to AI service.
330-
331-
Returns None for languages where the runtime version is auto-detected (e.g. Python).
332-
Returns a version string (e.g. "ES2022") for languages that need an explicit default.
333-
"""
334-
return None
328+
def language_version(self) -> str | None:
329+
"""The detected language version (e.g., "17" for Java, "ES2022" for JS)."""
330+
...
335331

336332
@property
337333
def valid_test_frameworks(self) -> tuple[str, ...]:

codeflash/languages/java/instrumentation.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def wrap_target_calls_with_treesitter(
258258
precise_call_timing: bool = False,
259259
class_name: str = "",
260260
test_method_name: str = "",
261+
target_return_type: str = "",
261262
) -> tuple[list[str], int]:
262263
"""Replace target method calls in body_lines with capture + serialize using tree-sitter.
263264
@@ -327,6 +328,8 @@ def wrap_target_calls_with_treesitter(
327328
call_counter += 1
328329
var_name = f"_cf_result{iter_id}_{call_counter}"
329330
cast_type = _infer_array_cast_type(body_line)
331+
if not cast_type and target_return_type and target_return_type != "void":
332+
cast_type = target_return_type
330333
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name
331334

332335
# Use per-call unique variables (with call_counter suffix) for behavior mode
@@ -524,6 +527,26 @@ def _infer_array_cast_type(line: str) -> str | None:
524527
return None
525528

526529

530+
def _extract_return_type(function_to_optimize: Any) -> str:
531+
"""Extract the return type of a Java function from its source file using tree-sitter."""
532+
file_path = getattr(function_to_optimize, "file_path", None)
533+
func_name = _get_function_name(function_to_optimize)
534+
if not file_path or not file_path.exists():
535+
return ""
536+
try:
537+
from codeflash.languages.java.parser import get_java_analyzer
538+
539+
analyzer = get_java_analyzer()
540+
source_text = file_path.read_text(encoding="utf-8")
541+
methods = analyzer.find_methods(source_text)
542+
for method in methods:
543+
if method.name == func_name and method.return_type:
544+
return method.return_type
545+
except Exception:
546+
logger.debug("Could not extract return type for %s", func_name)
547+
return ""
548+
549+
527550
def _get_qualified_name(func: Any) -> str:
528551
"""Get the qualified name from FunctionToOptimize."""
529552
if hasattr(func, "qualified_name"):
@@ -617,6 +640,7 @@ def instrument_existing_test(
617640
"""
618641
source = test_string
619642
func_name = _get_function_name(function_to_optimize)
643+
target_return_type = _extract_return_type(function_to_optimize)
620644

621645
# Get the original class name from the file name
622646
if test_path:
@@ -648,14 +672,16 @@ def instrument_existing_test(
648672
)
649673
else:
650674
# Behavior mode: add timing instrumentation that also writes to SQLite
651-
modified_source = _add_behavior_instrumentation(modified_source, original_class_name, func_name)
675+
modified_source = _add_behavior_instrumentation(
676+
modified_source, original_class_name, func_name, target_return_type
677+
)
652678

653679
logger.debug("Java %s testing for %s: renamed class %s -> %s", mode, func_name, original_class_name, new_class_name)
654680
# Why return True here?
655681
return True, modified_source
656682

657683

658-
def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) -> str:
684+
def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, target_return_type: str = "") -> str:
659685
"""Add behavior instrumentation to test methods.
660686
661687
For behavior mode, this adds:
@@ -796,6 +822,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
796822
precise_call_timing=True,
797823
class_name=class_name,
798824
test_method_name=test_method_name,
825+
target_return_type=target_return_type,
799826
)
800827

801828
# Add behavior instrumentation setup code (shared variables for all calls in the method)

codeflash/languages/java/support.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(self) -> None:
6464
self._analyzer = get_java_analyzer()
6565
self.line_profiler_agent_arg: str | None = None
6666
self.line_profiler_warmup_iterations: int = 0
67+
self._language_version: str | None = None
6768

6869
@property
6970
def language(self) -> Language:
@@ -93,6 +94,10 @@ def default_file_extension(self) -> str:
9394
def dir_excludes(self) -> frozenset[str]:
9495
return frozenset({"target", "build", ".gradle", ".mvn", ".idea", "apidocs", "javadoc"})
9596

97+
@property
98+
def language_version(self) -> str | None:
99+
return self._language_version
100+
96101
def postprocess_generated_tests(
97102
self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path
98103
) -> GeneratedTestsList:
@@ -364,10 +369,36 @@ def ensure_runtime_environment(self, project_root: Path) -> bool:
364369
if config is None:
365370
return False
366371

372+
self._language_version = config.java_version
373+
if self._language_version is None:
374+
self._detect_java_version()
375+
367376
# For now, assume the runtime is available
368377
# A full implementation would check/install the JAR
369378
return True
370379

380+
def _detect_java_version(self) -> None:
381+
"""Detect and cache the Java runtime version."""
382+
import subprocess
383+
384+
try:
385+
result = subprocess.run(["java", "-version"], check=False, capture_output=True, text=True, timeout=10)
386+
# java -version outputs to stderr, e.g. 'openjdk version "17.0.2"'
387+
output = result.stderr or result.stdout
388+
for line in output.splitlines():
389+
if "version" in line:
390+
# Extract version between quotes: "17.0.2" -> "17"
391+
start = line.find('"')
392+
end = line.find('"', start + 1)
393+
if start != -1 and end != -1:
394+
full_version = line[start + 1 : end]
395+
# Use major version only: "17.0.2" -> "17", "1.8.0_292" -> "8"
396+
major = full_version.split(".")[0]
397+
self._language_version = "8" if major == "1" else major
398+
return
399+
except Exception:
400+
pass
401+
371402
def instrument_existing_test(
372403
self,
373404
test_string: str,

codeflash/languages/javascript/support.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class JavaScriptSupport:
3636
using tree-sitter for code analysis and Jest for test execution.
3737
"""
3838

39+
def __init__(self) -> None:
40+
self._language_version: str | None = None
41+
3942
# === Properties ===
4043

4144
@property
@@ -68,6 +71,10 @@ def comment_prefix(self) -> str:
6871
def dir_excludes(self) -> frozenset[str]:
6972
return frozenset({"node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache", ".turbo", ".vercel"})
7073

74+
@property
75+
def language_version(self) -> str | None:
76+
return self._language_version
77+
7178
# === Discovery ===
7279

7380
def discover_functions(
@@ -2077,6 +2084,15 @@ def verify_requirements(self, project_root: Path, test_framework: str = "jest")
20772084

20782085
return len(errors) == 0, errors
20792086

2087+
def _detect_node_version(self) -> None:
2088+
"""Detect and cache the Node.js runtime version."""
2089+
try:
2090+
result = subprocess.run(["node", "--version"], check=False, capture_output=True, text=True, timeout=10)
2091+
if result.returncode == 0 and result.stdout.strip():
2092+
self._language_version = result.stdout.strip().lstrip("v")
2093+
except Exception:
2094+
pass
2095+
20802096
def ensure_runtime_environment(self, project_root: Path) -> bool:
20812097
"""Ensure codeflash npm package is installed.
20822098
@@ -2091,6 +2107,8 @@ def ensure_runtime_environment(self, project_root: Path) -> bool:
20912107
"""
20922108
from codeflash.cli_cmds.console import logger
20932109

2110+
self._detect_node_version()
2111+
20942112
node_modules_pkg = project_root / "node_modules" / "codeflash"
20952113
if node_modules_pkg.exists():
20962114
logger.debug("codeflash already installed")

codeflash/languages/python/support.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6+
import platform
67
from pathlib import Path
78
from typing import TYPE_CHECKING, Any
89

@@ -107,6 +108,10 @@ def dir_excludes(self) -> frozenset[str]:
107108
}
108109
)
109110

111+
@property
112+
def language_version(self) -> str | None:
113+
return platform.python_version()
114+
110115
# === Discovery ===
111116

112117
def discover_functions(

0 commit comments

Comments
 (0)