diff --git a/codeflash/languages/java/context.py b/codeflash/languages/java/context.py index a5597351c..ea1778908 100644 --- a/codeflash/languages/java/context.py +++ b/codeflash/languages/java/context.py @@ -19,6 +19,12 @@ if TYPE_CHECKING: from tree_sitter import Node +_TYPE_DECLARATIONS = { + "class_declaration": "class", + "interface_declaration": "interface", + "enum_declaration": "enum", +} + logger = logging.getLogger(__name__) @@ -253,18 +259,14 @@ def _find_type_node(node: Node, type_name: str, source_bytes: bytes) -> tuple[No Tuple of (node, type_kind) where type_kind is "class", "interface", or "enum". """ - type_declarations = { - "class_declaration": "class", - "interface_declaration": "interface", - "enum_declaration": "enum", - } - - if node.type in type_declarations: + node_type = node.type + if node_type in _TYPE_DECLARATIONS: name_node = node.child_by_field_name("name") if name_node: node_name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8") if node_name == type_name: - return node, type_declarations[node.type] + return node, _TYPE_DECLARATIONS[node_type] + for child in node.children: result, kind = _find_type_node(child, type_name, source_bytes)