|
| 1 | +""" |
| 2 | +Smart chunking module for code-aware text splitting. |
| 3 | +Respects code structure (functions, classes, methods) for better semantic search. |
| 4 | +""" |
| 5 | +import re |
| 6 | +from typing import List, Tuple, Optional |
| 7 | +from pathlib import Path |
| 8 | + |
| 9 | + |
| 10 | +class SmartChunker: |
| 11 | + """ |
| 12 | + Code-aware chunker that splits text based on language structure. |
| 13 | + Falls back to simple chunking for non-code or unknown languages. |
| 14 | + """ |
| 15 | + |
| 16 | + def __init__(self, chunk_size: int = 800, overlap: int = 100): |
| 17 | + self.chunk_size = chunk_size |
| 18 | + self.overlap = overlap |
| 19 | + |
| 20 | + def chunk(self, text: str, language: str = "text") -> List[str]: |
| 21 | + """ |
| 22 | + Chunk text based on language-specific rules. |
| 23 | + |
| 24 | + Args: |
| 25 | + text: Text content to chunk |
| 26 | + language: Programming language identifier |
| 27 | + |
| 28 | + Returns: |
| 29 | + List of text chunks |
| 30 | + """ |
| 31 | + if language in ["python", "javascript", "typescript", "java", "go", "rust", "c", "cpp"]: |
| 32 | + return self._chunk_code(text, language) |
| 33 | + else: |
| 34 | + return self._chunk_simple(text) |
| 35 | + |
| 36 | + def _chunk_code(self, text: str, language: str) -> List[str]: |
| 37 | + """ |
| 38 | + Smart chunking for code that respects structure. |
| 39 | + """ |
| 40 | + # Split into logical units (functions, classes, etc.) |
| 41 | + units = self._split_into_units(text, language) |
| 42 | + |
| 43 | + if not units: |
| 44 | + # Fallback to simple chunking if structure detection fails |
| 45 | + return self._chunk_simple(text) |
| 46 | + |
| 47 | + chunks = [] |
| 48 | + current_chunk = [] |
| 49 | + current_size = 0 |
| 50 | + |
| 51 | + for unit_text, unit_type in units: |
| 52 | + unit_size = len(unit_text) |
| 53 | + |
| 54 | + # If single unit is larger than chunk_size, split it |
| 55 | + if unit_size > self.chunk_size: |
| 56 | + # Save current chunk if it has content |
| 57 | + if current_chunk: |
| 58 | + chunks.append("\n".join(current_chunk)) |
| 59 | + current_chunk = [] |
| 60 | + current_size = 0 |
| 61 | + |
| 62 | + # Split large unit with simple chunking |
| 63 | + sub_chunks = self._chunk_simple(unit_text) |
| 64 | + chunks.extend(sub_chunks) |
| 65 | + continue |
| 66 | + |
| 67 | + # Check if adding this unit would exceed chunk_size |
| 68 | + if current_size + unit_size > self.chunk_size and current_chunk: |
| 69 | + # Save current chunk |
| 70 | + chunks.append("\n".join(current_chunk)) |
| 71 | + |
| 72 | + # Start new chunk with overlap |
| 73 | + # Keep last unit for context |
| 74 | + if len(current_chunk) > 1: |
| 75 | + last_unit = current_chunk[-1] |
| 76 | + current_chunk = [last_unit, unit_text] |
| 77 | + current_size = len(last_unit) + unit_size |
| 78 | + else: |
| 79 | + current_chunk = [unit_text] |
| 80 | + current_size = unit_size |
| 81 | + else: |
| 82 | + # Add to current chunk |
| 83 | + current_chunk.append(unit_text) |
| 84 | + current_size += unit_size |
| 85 | + |
| 86 | + # Add remaining chunk |
| 87 | + if current_chunk: |
| 88 | + chunks.append("\n".join(current_chunk)) |
| 89 | + |
| 90 | + return chunks if chunks else [text] |
| 91 | + |
| 92 | + def _split_into_units(self, text: str, language: str) -> List[Tuple[str, str]]: |
| 93 | + """ |
| 94 | + Split code into logical units (functions, classes, etc.). |
| 95 | + Returns list of (text, unit_type) tuples. |
| 96 | + """ |
| 97 | + if language == "python": |
| 98 | + return self._split_python(text) |
| 99 | + elif language in ["javascript", "typescript"]: |
| 100 | + return self._split_javascript(text) |
| 101 | + elif language == "java": |
| 102 | + return self._split_java(text) |
| 103 | + elif language in ["go", "rust", "c", "cpp"]: |
| 104 | + return self._split_c_style(text) |
| 105 | + else: |
| 106 | + return [] |
| 107 | + |
| 108 | + def _split_python(self, text: str) -> List[Tuple[str, str]]: |
| 109 | + """Split Python code into classes and functions.""" |
| 110 | + units = [] |
| 111 | + lines = text.split("\n") |
| 112 | + current_unit = [] |
| 113 | + current_type = None |
| 114 | + indent_stack = [] |
| 115 | + |
| 116 | + for i, line in enumerate(lines): |
| 117 | + stripped = line.lstrip() |
| 118 | + indent = len(line) - len(stripped) |
| 119 | + |
| 120 | + # Detect class or function definition |
| 121 | + if stripped.startswith("class ") or stripped.startswith("def "): |
| 122 | + # Save previous unit if exists |
| 123 | + if current_unit: |
| 124 | + units.append(("\n".join(current_unit), current_type or "code")) |
| 125 | + current_unit = [] |
| 126 | + |
| 127 | + current_type = "class" if stripped.startswith("class ") else "function" |
| 128 | + current_unit = [line] |
| 129 | + indent_stack = [indent] |
| 130 | + elif current_unit: |
| 131 | + # Continue current unit |
| 132 | + current_unit.append(line) |
| 133 | + |
| 134 | + # Check if we're back to base indent (end of function/class) |
| 135 | + if stripped and not stripped.startswith("#") and indent <= indent_stack[0]: |
| 136 | + if i < len(lines) - 1: # Not last line |
| 137 | + # Check next line to see if it's a new definition |
| 138 | + next_stripped = lines[i + 1].lstrip() |
| 139 | + if next_stripped.startswith("class ") or next_stripped.startswith("def "): |
| 140 | + # End current unit |
| 141 | + units.append(("\n".join(current_unit[:-1]), current_type)) |
| 142 | + current_unit = [line] # Start module-level code |
| 143 | + current_type = "module" |
| 144 | + else: |
| 145 | + # Module-level code |
| 146 | + if not current_unit: |
| 147 | + current_type = "module" |
| 148 | + current_unit.append(line) |
| 149 | + |
| 150 | + # Add remaining unit |
| 151 | + if current_unit: |
| 152 | + units.append(("\n".join(current_unit), current_type or "code")) |
| 153 | + |
| 154 | + return units |
| 155 | + |
| 156 | + def _split_javascript(self, text: str) -> List[Tuple[str, str]]: |
| 157 | + """Split JavaScript/TypeScript code into functions and classes.""" |
| 158 | + units = [] |
| 159 | + |
| 160 | + # Regex patterns for JS/TS |
| 161 | + # Match function declarations, arrow functions, class declarations |
| 162 | + patterns = [ |
| 163 | + r'((?:export\s+)?(?:async\s+)?function\s+\w+\s*\([^)]*\)\s*{[\s\S]*?})', |
| 164 | + r'((?:export\s+)?const\s+\w+\s*=\s*(?:async\s*)?\([^)]*\)\s*=>\s*{[\s\S]*?})', |
| 165 | + r'((?:export\s+)?class\s+\w+(?:\s+extends\s+\w+)?\s*{[\s\S]*?})', |
| 166 | + ] |
| 167 | + |
| 168 | + # Try to match and extract units |
| 169 | + for pattern in patterns: |
| 170 | + matches = re.finditer(pattern, text) |
| 171 | + for match in matches: |
| 172 | + unit_text = match.group(1) |
| 173 | + unit_type = "function" if "function" in unit_text or "=>" in unit_text else "class" |
| 174 | + units.append((unit_text, unit_type)) |
| 175 | + |
| 176 | + # If no matches, fall back to brace-based splitting |
| 177 | + if not units: |
| 178 | + units = self._split_by_braces(text) |
| 179 | + |
| 180 | + return units |
| 181 | + |
| 182 | + def _split_java(self, text: str) -> List[Tuple[str, str]]: |
| 183 | + """Split Java code into classes and methods.""" |
| 184 | + # Similar to JavaScript but with Java-specific patterns |
| 185 | + patterns = [ |
| 186 | + r'((?:public|private|protected)?\s*(?:static)?\s*(?:class|interface|enum)\s+\w+[\s\S]*?{[\s\S]*?})', |
| 187 | + r'((?:public|private|protected)?\s*(?:static)?\s*(?:\w+\s+)?\w+\s*\([^)]*\)\s*(?:throws\s+\w+(?:,\s*\w+)*)?\s*{[\s\S]*?})', |
| 188 | + ] |
| 189 | + |
| 190 | + units = [] |
| 191 | + for pattern in patterns: |
| 192 | + matches = re.finditer(pattern, text) |
| 193 | + for match in matches: |
| 194 | + unit_text = match.group(1) |
| 195 | + unit_type = "class" if any(kw in unit_text for kw in ["class", "interface", "enum"]) else "method" |
| 196 | + units.append((unit_text, unit_type)) |
| 197 | + |
| 198 | + if not units: |
| 199 | + units = self._split_by_braces(text) |
| 200 | + |
| 201 | + return units |
| 202 | + |
| 203 | + def _split_c_style(self, text: str) -> List[Tuple[str, str]]: |
| 204 | + """Split C-style languages (Go, Rust, C, C++) into functions.""" |
| 205 | + units = self._split_by_braces(text) |
| 206 | + return units if units else [] |
| 207 | + |
| 208 | + def _split_by_braces(self, text: str) -> List[Tuple[str, str]]: |
| 209 | + """ |
| 210 | + Generic brace-based splitting for C-style languages. |
| 211 | + Finds balanced brace blocks. |
| 212 | + """ |
| 213 | + units = [] |
| 214 | + lines = text.split("\n") |
| 215 | + current_unit = [] |
| 216 | + brace_count = 0 |
| 217 | + in_block = False |
| 218 | + |
| 219 | + for line in lines: |
| 220 | + current_unit.append(line) |
| 221 | + |
| 222 | + # Count braces (simple heuristic, doesn't handle strings/comments perfectly) |
| 223 | + brace_count += line.count("{") - line.count("}") |
| 224 | + |
| 225 | + if "{" in line and not in_block: |
| 226 | + in_block = True |
| 227 | + |
| 228 | + if in_block and brace_count == 0: |
| 229 | + # Block closed |
| 230 | + units.append(("\n".join(current_unit), "function")) |
| 231 | + current_unit = [] |
| 232 | + in_block = False |
| 233 | + |
| 234 | + # Add remaining lines |
| 235 | + if current_unit: |
| 236 | + units.append(("\n".join(current_unit), "code")) |
| 237 | + |
| 238 | + return units |
| 239 | + |
| 240 | + def _chunk_simple(self, text: str) -> List[str]: |
| 241 | + """ |
| 242 | + Simple character-based chunking with overlap. |
| 243 | + Used as fallback or for non-code content. |
| 244 | + """ |
| 245 | + if not text: |
| 246 | + return [] |
| 247 | + |
| 248 | + if len(text) <= self.chunk_size: |
| 249 | + return [text] |
| 250 | + |
| 251 | + chunks = [] |
| 252 | + step = max(1, self.chunk_size - self.overlap) |
| 253 | + start = 0 |
| 254 | + |
| 255 | + while start < len(text): |
| 256 | + end = min(start + self.chunk_size, len(text)) |
| 257 | + chunks.append(text[start:end]) |
| 258 | + start += step |
| 259 | + |
| 260 | + return chunks |
| 261 | + |
| 262 | + |
| 263 | +# Global instance for convenience |
| 264 | +_default_chunker = SmartChunker() |
| 265 | + |
| 266 | + |
| 267 | +def smart_chunk(text: str, language: str = "text", chunk_size: int = 800, overlap: int = 100) -> List[str]: |
| 268 | + """ |
| 269 | + Convenience function for smart chunking. |
| 270 | + |
| 271 | + Args: |
| 272 | + text: Text to chunk |
| 273 | + language: Programming language |
| 274 | + chunk_size: Maximum chunk size in characters |
| 275 | + overlap: Overlap between chunks in characters |
| 276 | + |
| 277 | + Returns: |
| 278 | + List of text chunks |
| 279 | + """ |
| 280 | + chunker = SmartChunker(chunk_size=chunk_size, overlap=overlap) |
| 281 | + return chunker.chunk(text, language) |
0 commit comments