diff --git a/sentience/agent.py b/sentience/agent.py index 9cf4367..1e87b25 100644 --- a/sentience/agent.py +++ b/sentience/agent.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Optional from .actions import click, click_async, press, press_async, type_text, type_text_async +from .agent_config import AgentConfig from .base_agent import BaseAgent, BaseAgentAsync from .browser import AsyncSentienceBrowser, SentienceBrowser from .llm_provider import LLMProvider, LLMResponse @@ -25,7 +26,6 @@ from .snapshot import snapshot, snapshot_async if TYPE_CHECKING: - from .agent_config import AgentConfig from .tracing import Tracer @@ -78,8 +78,9 @@ def __init__( self.default_snapshot_limit = default_snapshot_limit self.verbose = verbose self.tracer = tracer - self.config = config + self.config = config or AgentConfig() + # Screenshot sequence counter # Execution history self.history: list[dict[str, Any]] = [] @@ -150,6 +151,21 @@ def act( # noqa: C901 if snap_opts.goal is None: snap_opts.goal = goal + # Apply AgentConfig screenshot settings if not overridden by snapshot_options + if snapshot_options is None and self.config: + if self.config.capture_screenshots: + # Create ScreenshotConfig from AgentConfig + snap_opts.screenshot = ScreenshotConfig( + format=self.config.screenshot_format, + quality=( + self.config.screenshot_quality + if self.config.screenshot_format == "jpeg" + else None + ), + ) + else: + snap_opts.screenshot = False + # Call snapshot with options object (matches TypeScript API) snap = snapshot(self.browser, snap_opts) @@ -178,14 +194,36 @@ def act( # noqa: C901 for el in filtered_elements[:50] # Limit to first 50 for performance ] + # Build snapshot event data + snapshot_data = { + "url": snap.url, + "element_count": len(snap.elements), + "timestamp": snap.timestamp, + "elements": elements_data, # Add element data for overlay + } + + # Always include screenshot in trace event for studio viewer compatibility + # CloudTraceSink will extract and upload screenshots separately, then remove + # screenshot_base64 from events before uploading the trace file. + if snap.screenshot: + # Extract base64 string from data URL if needed + if snap.screenshot.startswith("data:image"): + # Format: "data:image/jpeg;base64,{base64_string}" + screenshot_base64 = ( + snap.screenshot.split(",", 1)[1] + if "," in snap.screenshot + else snap.screenshot + ) + else: + screenshot_base64 = snap.screenshot + + snapshot_data["screenshot_base64"] = screenshot_base64 + if snap.screenshot_format: + snapshot_data["screenshot_format"] = snap.screenshot_format + self.tracer.emit( "snapshot", - { - "url": snap.url, - "element_count": len(snap.elements), - "timestamp": snap.timestamp, - "elements": elements_data, # Add element data for overlay - }, + snapshot_data, step_id=step_id, ) @@ -721,8 +759,9 @@ def __init__( self.default_snapshot_limit = default_snapshot_limit self.verbose = verbose self.tracer = tracer - self.config = config + self.config = config or AgentConfig() + # Screenshot sequence counter # Execution history self.history: list[dict[str, Any]] = [] @@ -790,6 +829,23 @@ async def act( # noqa: C901 if snap_opts.goal is None: snap_opts.goal = goal + # Apply AgentConfig screenshot settings if not overridden by snapshot_options + # Only apply if snapshot_options wasn't provided OR if screenshot wasn't explicitly set + # (snapshot_options.screenshot defaults to False, so we check if it's still False) + if self.config and (snapshot_options is None or snap_opts.screenshot is False): + if self.config.capture_screenshots: + # Create ScreenshotConfig from AgentConfig + snap_opts.screenshot = ScreenshotConfig( + format=self.config.screenshot_format, + quality=( + self.config.screenshot_quality + if self.config.screenshot_format == "jpeg" + else None + ), + ) + else: + snap_opts.screenshot = False + # Call snapshot with options object (matches TypeScript API) snap = await snapshot_async(self.browser, snap_opts) @@ -818,14 +874,36 @@ async def act( # noqa: C901 for el in filtered_elements[:50] # Limit to first 50 for performance ] + # Build snapshot event data + snapshot_data = { + "url": snap.url, + "element_count": len(snap.elements), + "timestamp": snap.timestamp, + "elements": elements_data, # Add element data for overlay + } + + # Always include screenshot in trace event for studio viewer compatibility + # CloudTraceSink will extract and upload screenshots separately, then remove + # screenshot_base64 from events before uploading the trace file. + if snap.screenshot: + # Extract base64 string from data URL if needed + if snap.screenshot.startswith("data:image"): + # Format: "data:image/jpeg;base64,{base64_string}" + screenshot_base64 = ( + snap.screenshot.split(",", 1)[1] + if "," in snap.screenshot + else snap.screenshot + ) + else: + screenshot_base64 = snap.screenshot + + snapshot_data["screenshot_base64"] = screenshot_base64 + if snap.screenshot_format: + snapshot_data["screenshot_format"] = snap.screenshot_format + self.tracer.emit( "snapshot", - { - "url": snap.url, - "element_count": len(snap.elements), - "timestamp": snap.timestamp, - "elements": elements_data, # Add element data for overlay - }, + snapshot_data, step_id=step_id, ) diff --git a/sentience/cloud_tracing.py b/sentience/cloud_tracing.py index ff8c0a0..5d1d9e0 100644 --- a/sentience/cloud_tracing.py +++ b/sentience/cloud_tracing.py @@ -4,11 +4,13 @@ Implements "Local Write, Batch Upload" pattern for enterprise cloud tracing. """ +import base64 import gzip import json import os import threading from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path from typing import Any, Protocol @@ -103,9 +105,10 @@ def __init__( self._closed = False self._upload_successful = False - # File size tracking (NEW) + # File size tracking self.trace_file_size_bytes = 0 self.screenshot_total_size_bytes = 0 + self.screenshot_count = 0 # Track number of screenshots extracted def emit(self, event: dict[str, Any]) -> None: """ @@ -165,21 +168,36 @@ def _do_upload(self, on_progress: Callable[[int, int], None] | None = None) -> N """ Internal upload method with progress tracking. + Extracts screenshots from trace events, uploads them separately, + then removes screenshot_base64 from events before uploading trace. + Args: on_progress: Optional callback(uploaded_bytes, total_bytes) for progress updates """ try: - # Read and compress - with open(self._path, "rb") as f: + # Step 1: Extract screenshots from trace events + screenshots = self._extract_screenshots_from_trace() + self.screenshot_count = len(screenshots) + + # Step 2: Upload screenshots separately + if screenshots: + self._upload_screenshots(screenshots, on_progress) + + # Step 3: Create cleaned trace file (without screenshot_base64) + cleaned_trace_path = self._path.with_suffix(".cleaned.jsonl") + self._create_cleaned_trace(cleaned_trace_path) + + # Step 4: Read and compress cleaned trace + with open(cleaned_trace_path, "rb") as f: trace_data = f.read() compressed_data = gzip.compress(trace_data) compressed_size = len(compressed_data) - # Measure trace file size (NEW) + # Measure trace file size self.trace_file_size_bytes = compressed_size - # Log file sizes if logger is provided (NEW) + # Log file sizes if logger is provided if self.logger: self.logger.info( f"Trace file size: {self.trace_file_size_bytes / 1024 / 1024:.2f} MB" @@ -192,8 +210,9 @@ def _do_upload(self, on_progress: Callable[[int, int], None] | None = None) -> N if on_progress: on_progress(0, compressed_size) - # Upload to DigitalOcean Spaces via pre-signed URL - print(f"📤 [Sentience] Uploading trace to cloud ({compressed_size} bytes)...") + # Step 5: Upload cleaned trace to cloud + if self.logger: + self.logger.info(f"Uploading trace to cloud ({compressed_size} bytes)") response = requests.put( self.upload_url, @@ -208,6 +227,8 @@ def _do_upload(self, on_progress: Callable[[int, int], None] | None = None) -> N if response.status_code == 200: self._upload_successful = True print("✅ [Sentience] Trace uploaded successfully") + if self.logger: + self.logger.info("Trace uploaded successfully") # Report progress: complete if on_progress: @@ -219,22 +240,28 @@ def _do_upload(self, on_progress: Callable[[int, int], None] | None = None) -> N # Call /v1/traces/complete to report file sizes self._complete_trace() - # Delete file only on successful upload - if os.path.exists(self._path): - try: - os.remove(self._path) - except Exception: - pass # Ignore cleanup errors + # Delete files only on successful upload + self._cleanup_files() + + # Clean up temporary cleaned trace file + if cleaned_trace_path.exists(): + cleaned_trace_path.unlink() else: self._upload_successful = False print(f"❌ [Sentience] Upload failed: HTTP {response.status_code}") - print(f" Response: {response.text}") + print(f" Response: {response.text[:200]}") print(f" Local trace preserved at: {self._path}") + if self.logger: + self.logger.error( + f"Upload failed: HTTP {response.status_code}, Response: {response.text[:200]}" + ) except Exception as e: self._upload_successful = False print(f"❌ [Sentience] Error uploading trace: {e}") print(f" Local trace preserved at: {self._path}") + if self.logger: + self.logger.error(f"Error uploading trace: {e}") # Don't raise - preserve trace locally even if upload fails def _generate_index(self) -> None: @@ -246,6 +273,8 @@ def _generate_index(self) -> None: except Exception as e: # Non-fatal: log but don't crash print(f"⚠️ Failed to generate trace index: {e}") + if self.logger: + self.logger.warning(f"Failed to generate trace index: {e}") def _upload_index(self) -> None: """ @@ -301,8 +330,7 @@ def _upload_index(self) -> None: if self.logger: self.logger.info(f"Index file size: {index_size / 1024:.2f} KB") - - print(f"📤 [Sentience] Uploading trace index ({index_size} bytes)...") + self.logger.info(f"Uploading trace index ({index_size} bytes)") # Upload index to cloud storage index_response = requests.put( @@ -316,7 +344,8 @@ def _upload_index(self) -> None: ) if index_response.status_code == 200: - print("✅ [Sentience] Trace index uploaded successfully") + if self.logger: + self.logger.info("Trace index uploaded successfully") # Delete local index file after successful upload try: @@ -326,13 +355,11 @@ def _upload_index(self) -> None: else: if self.logger: self.logger.warning(f"Index upload failed: HTTP {index_response.status_code}") - print(f"⚠️ [Sentience] Index upload failed: HTTP {index_response.status_code}") except Exception as e: # Non-fatal: log but don't crash if self.logger: self.logger.warning(f"Error uploading trace index: {e}") - print(f"⚠️ [Sentience] Error uploading trace index: {e}") def _complete_trace(self) -> None: """ @@ -353,6 +380,7 @@ def _complete_trace(self) -> None: "stats": { "trace_file_size_bytes": self.trace_file_size_bytes, "screenshot_total_size_bytes": self.screenshot_total_size_bytes, + "screenshot_count": self.screenshot_count, }, }, timeout=10, @@ -372,6 +400,266 @@ def _complete_trace(self) -> None: if self.logger: self.logger.warning(f"Error reporting trace completion: {e}") + def _extract_screenshots_from_trace(self) -> dict[int, dict[str, Any]]: + """ + Extract screenshots from trace events. + + Returns: + dict mapping sequence number to screenshot data: + {seq: {"base64": str, "format": str, "step_id": str}} + """ + screenshots: dict[int, dict[str, Any]] = {} + sequence = 0 + + try: + with open(self._path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + + try: + event = json.loads(line) + # Check if this is a snapshot event with screenshot + if event.get("type") == "snapshot": + data = event.get("data", {}) + screenshot_base64 = data.get("screenshot_base64") + + if screenshot_base64: + sequence += 1 + screenshots[sequence] = { + "base64": screenshot_base64, + "format": data.get("screenshot_format", "jpeg"), + "step_id": event.get("step_id"), + } + except json.JSONDecodeError: + continue + except Exception as e: + if self.logger: + self.logger.error(f"Error extracting screenshots: {e}") + + return screenshots + + def _create_cleaned_trace(self, output_path: Path) -> None: + """ + Create trace file without screenshot_base64 fields. + + Args: + output_path: Path to write cleaned trace file + """ + try: + with ( + open(self._path, encoding="utf-8") as infile, + open(output_path, "w", encoding="utf-8") as outfile, + ): + for line in infile: + line = line.strip() + if not line: + continue + + try: + event = json.loads(line) + # Remove screenshot_base64 from snapshot events + if event.get("type") == "snapshot": + data = event.get("data", {}) + if "screenshot_base64" in data: + # Create copy without screenshot fields + cleaned_data = { + k: v + for k, v in data.items() + if k not in ("screenshot_base64", "screenshot_format") + } + event["data"] = cleaned_data + + # Write cleaned event + outfile.write(json.dumps(event, ensure_ascii=False) + "\n") + except json.JSONDecodeError: + # Skip invalid lines + continue + except Exception as e: + if self.logger: + self.logger.error(f"Error creating cleaned trace: {e}") + raise + + def _request_screenshot_urls(self, sequences: list[int]) -> dict[int, str]: + """ + Request pre-signed upload URLs for screenshots from gateway. + + Args: + sequences: List of screenshot sequence numbers + + Returns: + dict mapping sequence number to upload URL + """ + if not self.api_key or not sequences: + return {} + + try: + response = requests.post( + f"{self.api_url}/v1/screenshots/init", + headers={"Authorization": f"Bearer {self.api_key}"}, + json={ + "run_id": self.run_id, + "sequences": sequences, + }, + timeout=10, + ) + + if response.status_code == 200: + data = response.json() + # Gateway returns sequences as strings in JSON, convert to int keys + upload_urls = data.get("upload_urls", {}) + result = {int(k): v for k, v in upload_urls.items()} + if self.logger: + self.logger.info(f"Received {len(result)} screenshot upload URLs") + return result + else: + error_msg = f"Failed to get screenshot URLs: HTTP {response.status_code}" + if self.logger: + # Try to get error details + try: + error_data = response.json() + error_detail = error_data.get("error") or error_data.get("message", "") + if error_detail: + self.logger.warning(f"{error_msg}: {error_detail}") + else: + self.logger.warning(f"{error_msg}: {response.text[:200]}") + except Exception: + self.logger.warning(f"{error_msg}: {response.text[:200]}") + return {} + except Exception as e: + error_msg = f"Error requesting screenshot URLs: {e}" + if self.logger: + self.logger.warning(error_msg) + return {} + + def _upload_screenshots( + self, + screenshots: dict[int, dict[str, Any]], + on_progress: Callable[[int, int], None] | None = None, + ) -> None: + """ + Upload screenshots extracted from trace events. + + Steps: + 1. Request pre-signed URLs from gateway (/v1/screenshots/init) + 2. Decode base64 to image bytes + 3. Upload screenshots in parallel (10 concurrent workers) + 4. Track upload progress + + Args: + screenshots: dict mapping sequence to screenshot data + on_progress: Optional callback(uploaded_count, total_count) + """ + if not screenshots: + return + + # 1. Request pre-signed URLs from gateway + sequences = sorted(screenshots.keys()) + if self.logger: + self.logger.info(f"Requesting upload URLs for {len(sequences)} screenshot(s)") + upload_urls = self._request_screenshot_urls(sequences) + + if not upload_urls: + if self.logger: + self.logger.warning( + "No screenshot upload URLs received, skipping upload. " + "This may indicate API key permission issue, gateway error, or network problem." + ) + return + + # 2. Upload screenshots in parallel + uploaded_count = 0 + total_count = len(upload_urls) + failed_sequences: list[int] = [] + + def upload_one(seq: int, url: str) -> bool: + """Upload a single screenshot. Returns True if successful.""" + try: + screenshot_data = screenshots[seq] + base64_str = screenshot_data["base64"] + format_str = screenshot_data.get("format", "jpeg") + + # Decode base64 to image bytes + image_bytes = base64.b64decode(base64_str) + image_size = len(image_bytes) + + # Update total size + self.screenshot_total_size_bytes += image_size + + # Upload to pre-signed URL + response = requests.put( + url, + data=image_bytes, # Binary image data + headers={ + "Content-Type": f"image/{format_str}", + }, + timeout=30, # 30 second timeout per screenshot + ) + + if response.status_code == 200: + if self.logger: + self.logger.info( + f"Screenshot {seq} uploaded successfully ({image_size / 1024:.1f} KB)" + ) + return True + else: + error_msg = f"Screenshot {seq} upload failed: HTTP {response.status_code}" + if self.logger: + try: + error_detail = response.text[:200] + if error_detail: + self.logger.warning(f"{error_msg}: {error_detail}") + else: + self.logger.warning(error_msg) + except Exception: + self.logger.warning(error_msg) + return False + except Exception as e: + error_msg = f"Screenshot {seq} upload error: {e}" + if self.logger: + self.logger.warning(error_msg) + return False + + # Upload in parallel (max 10 concurrent) + with ThreadPoolExecutor(max_workers=10) as executor: + futures = { + executor.submit(upload_one, seq, url): seq for seq, url in upload_urls.items() + } + + for future in as_completed(futures): + seq = futures[future] + if future.result(): + uploaded_count += 1 + if on_progress: + on_progress(uploaded_count, total_count) + else: + failed_sequences.append(seq) + + # 3. Report results + if uploaded_count == total_count: + total_size_mb = self.screenshot_total_size_bytes / 1024 / 1024 + if self.logger: + self.logger.info( + f"All {total_count} screenshots uploaded successfully " + f"(total size: {total_size_mb:.2f} MB)" + ) + else: + if self.logger: + self.logger.warning( + f"Uploaded {uploaded_count}/{total_count} screenshots. " + f"Failed sequences: {failed_sequences if failed_sequences else 'none'}" + ) + + def _cleanup_files(self) -> None: + """Delete local files after successful upload.""" + # Delete trace file + if os.path.exists(self._path): + try: + os.remove(self._path) + except Exception: + pass # Ignore cleanup errors + def __enter__(self): """Context manager support.""" return self diff --git a/sentience/extension/background.js b/sentience/extension/background.js index 811303f..f359ba6 100644 --- a/sentience/extension/background.js +++ b/sentience/extension/background.js @@ -144,13 +144,13 @@ async function handleScreenshotCapture(_tabId, options = {}) { async function handleSnapshotProcessing(rawData, options = {}) { const MAX_ELEMENTS = 10000; // Safety limit to prevent hangs const startTime = performance.now(); - + try { // Safety check: limit element count to prevent hangs if (!Array.isArray(rawData)) { throw new Error('rawData must be an array'); } - + if (rawData.length > MAX_ELEMENTS) { console.warn(`[Sentience Background] ⚠️ Large dataset: ${rawData.length} elements. Limiting to ${MAX_ELEMENTS} to prevent hangs.`); rawData = rawData.slice(0, MAX_ELEMENTS); @@ -186,7 +186,7 @@ async function handleSnapshotProcessing(rawData, options = {}) { // Add timeout protection (18 seconds - less than content.js timeout) analyzedElements = await Promise.race([ wasmPromise, - new Promise((_, reject) => + new Promise((_, reject) => setTimeout(() => reject(new Error('WASM processing timeout (>18s)')), 18000) ) ]); diff --git a/sentience/extension/content.js b/sentience/extension/content.js index 62ae408..8d3b0d4 100644 --- a/sentience/extension/content.js +++ b/sentience/extension/content.js @@ -92,7 +92,7 @@ function handleSnapshotRequest(data) { if (responded) return; // Already responded via timeout responded = true; clearTimeout(timeoutId); - + const duration = performance.now() - startTime; // Handle Chrome extension errors (e.g., background script crashed) diff --git a/sentience/extension/injected_api.js b/sentience/extension/injected_api.js index 45c4337..e81c9be 100644 --- a/sentience/extension/injected_api.js +++ b/sentience/extension/injected_api.js @@ -66,10 +66,10 @@ // --- HELPER: Safe Class Name Extractor (Handles SVGAnimatedString) --- function getClassName(el) { if (!el || !el.className) return ''; - + // Handle string (HTML elements) if (typeof el.className === 'string') return el.className; - + // Handle SVGAnimatedString (SVG elements) if (typeof el.className === 'object') { if ('baseVal' in el.className && typeof el.className.baseVal === 'string') { @@ -85,17 +85,17 @@ return ''; } } - + return ''; } // --- HELPER: Paranoid String Converter (Handles SVGAnimatedString) --- function toSafeString(value) { if (value === null || value === undefined) return null; - + // 1. If it's already a primitive string, return it if (typeof value === 'string') return value; - + // 2. Handle SVG objects (SVGAnimatedString, SVGAnimatedNumber, etc.) if (typeof value === 'object') { // Try extracting baseVal (standard SVG property) @@ -114,7 +114,7 @@ return null; } } - + // 3. Last resort cast for primitives try { return String(value); @@ -127,9 +127,9 @@ // For SVG elements, get the fill or stroke color (SVGs use fill/stroke, not backgroundColor) function getSVGColor(el) { if (!el || el.tagName !== 'SVG') return null; - + const style = window.getComputedStyle(el); - + // Try fill first (most common for SVG icons) const fill = style.fill; if (fill && fill !== 'none' && fill !== 'transparent' && fill !== 'rgba(0, 0, 0, 0)') { @@ -144,7 +144,7 @@ return fill; } } - + // Fallback to stroke if fill is not available const stroke = style.stroke; if (stroke && stroke !== 'none' && stroke !== 'transparent' && stroke !== 'rgba(0, 0, 0, 0)') { @@ -158,7 +158,7 @@ return stroke; } } - + return null; } @@ -168,28 +168,28 @@ // This handles rgba(0,0,0,0) and transparent values that browsers commonly return function getEffectiveBackgroundColor(el) { if (!el) return null; - + // For SVG elements, use fill/stroke instead of backgroundColor if (el.tagName === 'SVG') { const svgColor = getSVGColor(el); if (svgColor) return svgColor; } - + let current = el; const maxDepth = 10; // Prevent infinite loops let depth = 0; - + while (current && depth < maxDepth) { const style = window.getComputedStyle(current); - + // For SVG elements in the tree, also check fill/stroke if (current.tagName === 'SVG') { const svgColor = getSVGColor(current); if (svgColor) return svgColor; } - + const bgColor = style.backgroundColor; - + if (bgColor && bgColor !== 'transparent' && bgColor !== 'rgba(0, 0, 0, 0)') { // Check if it's rgba with alpha < 1 (semi-transparent) const rgbaMatch = bgColor.match(/rgba?\((\d+),\s*(\d+),\s*(\d+)(?:,\s*([\d.]+))?\)/); @@ -209,12 +209,12 @@ return bgColor; } } - + // Move up the DOM tree current = current.parentElement; depth++; } - + // Fallback: return null if nothing found return null; } @@ -235,7 +235,7 @@ // Only check for elements that are likely to be occluded (overlays, modals, tooltips) const zIndex = parseInt(style.zIndex, 10); const position = style.position; - + // Skip occlusion check for normal flow elements (vast majority) // Only check for positioned elements or high z-index (likely overlays) if (position === 'static' && (isNaN(zIndex) || zIndex <= 10)) { @@ -308,7 +308,7 @@ }; window.addEventListener('message', listener); - + try { window.postMessage({ type: 'SENTIENCE_SNAPSHOT_REQUEST', @@ -514,7 +514,7 @@ function extractRawElementData(el) { const style = window.getComputedStyle(el); const rect = el.getBoundingClientRect(); - + return { tag: el.tagName, rect: { @@ -548,12 +548,12 @@ // --- HELPER: Generate Unique CSS Selector (for Golden Set) --- function getUniqueSelector(el) { if (!el || !el.tagName) return ''; - + // If element has a unique ID, use it if (el.id) { return `#${el.id}`; } - + // Try data attributes or aria-label for uniqueness for (const attr of el.attributes) { if (attr.name.startsWith('data-') || attr.name === 'aria-label') { @@ -561,21 +561,21 @@ return `${el.tagName.toLowerCase()}[${attr.name}="${value}"]`; } } - + // Build path with classes and nth-child for uniqueness const path = []; let current = el; - + while (current && current !== document.body && current !== document.documentElement) { let selector = current.tagName.toLowerCase(); - + // If current element has ID, use it and stop if (current.id) { selector = `#${current.id}`; path.unshift(selector); break; } - + // Add class if available if (current.className && typeof current.className === 'string') { const classes = current.className.trim().split(/\s+/).filter(c => c); @@ -584,7 +584,7 @@ selector += `.${classes[0]}`; } } - + // Add nth-of-type if needed for uniqueness if (current.parentElement) { const siblings = Array.from(current.parentElement.children); @@ -594,11 +594,11 @@ selector += `:nth-of-type(${index + 1})`; } } - + path.unshift(selector); current = current.parentElement; } - + return path.join(' > ') || el.tagName.toLowerCase(); } @@ -613,7 +613,7 @@ } = options; const startTime = Date.now(); - + return new Promise((resolve) => { // Check if DOM already has enough nodes const nodeCount = document.querySelectorAll('*').length; @@ -623,17 +623,17 @@ const observer = new MutationObserver(() => { lastChange = Date.now(); }); - + observer.observe(document.body, { childList: true, subtree: true, attributes: false }); - + const checkStable = () => { const timeSinceLastChange = Date.now() - lastChange; const totalWait = Date.now() - startTime; - + if (timeSinceLastChange >= quietPeriod) { observer.disconnect(); resolve(); @@ -645,14 +645,14 @@ setTimeout(checkStable, 50); } }; - + checkStable(); } else { // DOM doesn't have enough nodes yet, wait for them const observer = new MutationObserver(() => { const currentCount = document.querySelectorAll('*').length; const totalWait = Date.now() - startTime; - + if (currentCount >= minNodeCount) { observer.disconnect(); // Now wait for quiet period @@ -660,17 +660,17 @@ const quietObserver = new MutationObserver(() => { lastChange = Date.now(); }); - + quietObserver.observe(document.body, { childList: true, subtree: true, attributes: false }); - + const checkQuiet = () => { const timeSinceLastChange = Date.now() - lastChange; const totalWait = Date.now() - startTime; - + if (timeSinceLastChange >= quietPeriod) { quietObserver.disconnect(); resolve(); @@ -682,7 +682,7 @@ setTimeout(checkQuiet, 50); } }; - + checkQuiet(); } else if (totalWait >= maxWait) { observer.disconnect(); @@ -690,13 +690,13 @@ resolve(); } }); - + observer.observe(document.body, { childList: true, subtree: true, attributes: false }); - + // Timeout fallback setTimeout(() => { observer.disconnect(); @@ -710,21 +710,21 @@ // --- HELPER: Collect Iframe Snapshots (Frame Stitching) --- // Recursively collects snapshot data from all child iframes // This enables detection of elements inside iframes (e.g., Stripe forms) - // + // // NOTE: Cross-origin iframes cannot be accessed due to browser security (Same-Origin Policy). // Only same-origin iframes will return snapshot data. Cross-origin iframes will be skipped // with a warning. For cross-origin iframes, users must manually switch frames using // Playwright's page.frame() API. async function collectIframeSnapshots(options = {}) { const iframeData = new Map(); // Map of iframe element -> snapshot data - + // Find all iframe elements in current document const iframes = Array.from(document.querySelectorAll('iframe')); - + if (iframes.length === 0) { return iframeData; } - + console.log(`[SentienceAPI] Found ${iframes.length} iframe(s), requesting snapshots...`); // Request snapshot from each iframe const iframePromises = iframes.map((iframe, idx) => { @@ -737,13 +737,13 @@ return new Promise((resolve) => { const requestId = `iframe-${idx}-${Date.now()}`; - + // 1. EXTENDED TIMEOUT (Handle slow children) const timeout = setTimeout(() => { console.warn(`[SentienceAPI] ⚠️ Iframe ${idx} snapshot TIMEOUT (id: ${requestId})`); resolve(null); }, 5000); // Increased to 5s to handle slow processing - + // 2. ROBUST LISTENER with debugging const listener = (event) => { // Debug: Log all SENTIENCE_IFRAME_SNAPSHOT_RESPONSE messages to see what's happening @@ -753,14 +753,14 @@ // console.log(`[SentienceAPI] Received response for different request: ${event.data.requestId} (expected: ${requestId})`); } } - + // Check if this is the response we're waiting for - if (event.data?.type === 'SENTIENCE_IFRAME_SNAPSHOT_RESPONSE' && + if (event.data?.type === 'SENTIENCE_IFRAME_SNAPSHOT_RESPONSE' && event.data?.requestId === requestId) { - + clearTimeout(timeout); window.removeEventListener('message', listener); - + if (event.data.error) { console.warn(`[SentienceAPI] Iframe ${idx} returned error:`, event.data.error); resolve(null); @@ -775,9 +775,9 @@ } } }; - + window.addEventListener('message', listener); - + // 3. SEND REQUEST with error handling try { if (iframe.contentWindow) { @@ -785,8 +785,8 @@ iframe.contentWindow.postMessage({ type: 'SENTIENCE_IFRAME_SNAPSHOT_REQUEST', requestId: requestId, - options: { - ...options, + options: { + ...options, collectIframes: true // Enable recursion for nested iframes } }, '*'); // Use '*' for cross-origin, but browser will enforce same-origin policy @@ -804,10 +804,10 @@ } }); }); - + // Wait for all iframe responses const results = await Promise.all(iframePromises); - + // Store iframe data results.forEach((result, idx) => { if (result && result.data && !result.error) { @@ -819,7 +819,7 @@ console.warn(`[SentienceAPI] Iframe ${idx} returned no data (timeout or error)`); } }); - + return iframeData; } @@ -832,7 +832,7 @@ // Security: only respond to snapshot requests from parent frames if (event.data?.type === 'SENTIENCE_IFRAME_SNAPSHOT_REQUEST') { const { requestId, options } = event.data; - + try { // Generate snapshot for this iframe's content // Allow recursive collection - querySelectorAll('iframe') only finds direct children, @@ -840,7 +840,7 @@ // waitForStability: false makes performance better - i.e. don't wait for children frames const snapshotOptions = { ...options, collectIframes: true, waitForStability: options.waitForStability === false ? false : false }; const snapshot = await window.sentience.snapshot(snapshotOptions); - + // Send response back to parent if (event.source && event.source.postMessage) { event.source.postMessage({ @@ -864,7 +864,7 @@ } }); } - + // Setup iframe handler when script loads (only once) if (!window.sentience_iframe_handler_setup) { setupIframeSnapshotHandler(); @@ -880,7 +880,7 @@ if (options.waitForStability !== false) { await waitForStability(options.waitForStability || {}); } - + // Step 1: Collect raw DOM data (Main World - CSP can't block this!) const rawData = []; window.sentience_registry = []; @@ -896,17 +896,17 @@ const textVal = getText(el); const inView = isInViewport(rect); - + // Get computed style once (needed for both occlusion check and data collection) const style = window.getComputedStyle(el); - + // Only check occlusion for elements likely to be occluded (optimized) // This avoids layout thrashing for the vast majority of elements const occluded = inView ? isOccluded(el, rect, style) : false; - + // Get effective background color (traverses DOM to find non-transparent color) const effectiveBgColor = getEffectiveBackgroundColor(el); - + rawData.push({ id: idx, tag: el.tagName.toLowerCase(), @@ -946,26 +946,26 @@ // This allows WASM to process all elements uniformly (no recursion needed) let allRawElements = [...rawData]; // Start with main frame elements let totalIframeElements = 0; - + if (options.collectIframes !== false) { try { console.log(`[SentienceAPI] Starting iframe collection...`); const iframeSnapshots = await collectIframeSnapshots(options); console.log(`[SentienceAPI] Iframe collection complete. Received ${iframeSnapshots.size} snapshot(s)`); - + if (iframeSnapshots.size > 0) { // FLATTEN IMMEDIATELY: Don't nest them. Just append them with coordinate translation. iframeSnapshots.forEach((iframeSnapshot, iframeEl) => { // Debug: Log structure to verify data is correct // console.log(`[SentienceAPI] Processing iframe snapshot:`, iframeSnapshot); - + if (iframeSnapshot && iframeSnapshot.raw_elements) { const rawElementsCount = iframeSnapshot.raw_elements.length; console.log(`[SentienceAPI] Processing ${rawElementsCount} elements from iframe (src: ${iframeEl.src || 'unknown'})`); // Get iframe's bounding rect (offset for coordinate translation) const iframeRect = iframeEl.getBoundingClientRect(); const offset = { x: iframeRect.x, y: iframeRect.y }; - + // Get iframe context for frame switching (Playwright needs this) const iframeSrc = iframeEl.src || iframeEl.getAttribute('src') || ''; let isSameOrigin = false; @@ -975,11 +975,11 @@ } catch (e) { isSameOrigin = false; } - + // Adjust coordinates and add iframe context to each element const adjustedElements = iframeSnapshot.raw_elements.map(el => { const adjusted = { ...el }; - + // Adjust rect coordinates to parent viewport if (adjusted.rect) { adjusted.rect = { @@ -988,22 +988,22 @@ y: adjusted.rect.y + offset.y }; } - + // Add iframe context so agents can switch frames in Playwright adjusted.iframe_context = { src: iframeSrc, is_same_origin: isSameOrigin }; - + return adjusted; }); - + // Append flattened iframe elements to main array allRawElements.push(...adjustedElements); totalIframeElements += adjustedElements.length; } }); - + // console.log(`[SentienceAPI] Merged ${iframeSnapshots.size} iframe(s). Total elements: ${allRawElements.length} (${rawData.length} main + ${totalIframeElements} iframe)`); } } catch (error) { @@ -1016,7 +1016,7 @@ // No recursion needed - everything is already flat console.log(`[SentienceAPI] Sending ${allRawElements.length} total elements to WASM (${rawData.length} main + ${totalIframeElements} iframe)`); const processed = await processSnapshotInBackground(allRawElements, options); - + if (!processed || !processed.elements) { throw new Error('WASM processing returned invalid result'); } @@ -1032,10 +1032,10 @@ const cleanedRawElements = cleanElement(processed.raw_elements); // FIXED: Removed undefined 'totalIframeRawElements' - // FIXED: Logic updated for "Flatten Early" architecture. + // FIXED: Logic updated for "Flatten Early" architecture. // processed.elements ALREADY contains the merged iframe elements, // so we simply use .length. No addition needed. - + const totalCount = cleanedElements.length; const totalRaw = cleanedRawElements.length; const iframeCount = totalIframeElements || 0; @@ -1253,23 +1253,23 @@ autoDisableTimeout = 30 * 60 * 1000, // 30 minutes default keyboardShortcut = 'Ctrl+Shift+I' } = options; - + console.log("🔴 [Sentience] Recording Mode STARTED. Click an element to copy its Ground Truth JSON."); console.log(` Press ${keyboardShortcut} or call stopRecording() to stop.`); - + // Validate registry is populated if (!window.sentience_registry || window.sentience_registry.length === 0) { console.warn("⚠️ Registry empty. Call `await window.sentience.snapshot()` first to populate registry."); alert("Registry empty. Run `await window.sentience.snapshot()` first!"); return () => {}; // Return no-op cleanup function } - + // Create reverse mapping for O(1) lookup (fixes registry lookup bug) window.sentience_registry_map = new Map(); window.sentience_registry.forEach((el, idx) => { if (el) window.sentience_registry_map.set(el, idx); }); - + // Create highlight box overlay let highlightBox = document.getElementById('sentience-highlight-box'); if (!highlightBox) { @@ -1287,7 +1287,7 @@ `; document.body.appendChild(highlightBox); } - + // Create visual indicator (red border on page when recording) let recordingIndicator = document.getElementById('sentience-recording-indicator'); if (!recordingIndicator) { @@ -1306,12 +1306,12 @@ document.body.appendChild(recordingIndicator); } recordingIndicator.style.display = 'block'; - + // Hover handler (visual feedback) const mouseOverHandler = (e) => { const el = e.target; if (!el || el === highlightBox || el === recordingIndicator) return; - + const rect = el.getBoundingClientRect(); highlightBox.style.display = 'block'; highlightBox.style.top = (rect.top + window.scrollY) + 'px'; @@ -1319,15 +1319,15 @@ highlightBox.style.width = rect.width + 'px'; highlightBox.style.height = rect.height + 'px'; }; - + // Click handler (capture ground truth data) const clickHandler = (e) => { e.preventDefault(); e.stopPropagation(); - + const el = e.target; if (!el || el === highlightBox || el === recordingIndicator) return; - + // Use Map for reliable O(1) lookup const sentienceId = window.sentience_registry_map.get(el); if (sentienceId === undefined) { @@ -1335,13 +1335,13 @@ alert("Element not in registry. Run `await window.sentience.snapshot()` first!"); return; } - + // Extract raw data (ground truth + raw signals, NOT model outputs) const rawData = extractRawElementData(el); const selector = getUniqueSelector(el); const role = el.getAttribute('role') || el.tagName.toLowerCase(); const text = getText(el); - + // Build golden set JSON (ground truth + raw signals only) const snippet = { task: `Interact with ${text.substring(0, 20)}${text.length > 20 ? '...' : ''}`, @@ -1355,12 +1355,12 @@ }, debug_snapshot: rawData }; - + // Copy to clipboard const jsonString = JSON.stringify(snippet, null, 2); navigator.clipboard.writeText(jsonString).then(() => { console.log("✅ Copied Ground Truth to clipboard:", snippet); - + // Flash green to indicate success highlightBox.style.border = `2px solid ${successColor}`; highlightBox.style.background = 'rgba(0, 255, 0, 0.2)'; @@ -1373,42 +1373,42 @@ alert("Failed to copy to clipboard. Check console for JSON."); }); }; - + // Auto-disable timeout let timeoutId = null; - + // Cleanup function to stop recording (defined before use) const stopRecording = () => { document.removeEventListener('mouseover', mouseOverHandler, true); document.removeEventListener('click', clickHandler, true); document.removeEventListener('keydown', keyboardHandler, true); - + if (timeoutId) { clearTimeout(timeoutId); timeoutId = null; } - + if (highlightBox) { highlightBox.style.display = 'none'; } - + if (recordingIndicator) { recordingIndicator.style.display = 'none'; } - + // Clean up registry map (optional, but good practice) if (window.sentience_registry_map) { window.sentience_registry_map.clear(); } - + // Remove global reference if (window.sentience_stopRecording === stopRecording) { delete window.sentience_stopRecording; } - + console.log("⚪ [Sentience] Recording Mode STOPPED."); }; - + // Keyboard shortcut handler (defined after stopRecording) const keyboardHandler = (e) => { // Ctrl+Shift+I or Cmd+Shift+I @@ -1417,12 +1417,12 @@ stopRecording(); } }; - + // Attach event listeners (use capture phase to intercept early) document.addEventListener('mouseover', mouseOverHandler, true); document.addEventListener('click', clickHandler, true); document.addEventListener('keydown', keyboardHandler, true); - + // Set up auto-disable timeout if (autoDisableTimeout > 0) { timeoutId = setTimeout(() => { @@ -1430,10 +1430,10 @@ stopRecording(); }, autoDisableTimeout); } - + // Store stop function globally for keyboard shortcut access window.sentience_stopRecording = stopRecording; - + return stopRecording; } }; diff --git a/sentience/models.py b/sentience/models.py index d8fe61c..a16b035 100644 --- a/sentience/models.py +++ b/sentience/models.py @@ -2,6 +2,7 @@ Pydantic models for Sentience SDK - matches spec/snapshot.schema.json """ +from dataclasses import dataclass from typing import Literal, Optional from pydantic import BaseModel, Field @@ -410,3 +411,19 @@ class TextRectSearchResult(BaseModel): ) viewport: Viewport | None = Field(None, description="Current viewport dimensions") error: str | None = Field(None, description="Error message if status is 'error'") + + +@dataclass +class ScreenshotMetadata: + """ + Metadata for a stored screenshot. + + Used by CloudTraceSink to track screenshots before upload. + All fields are required for type safety. + """ + + sequence: int + format: Literal["png", "jpeg"] + size_bytes: int + step_id: str | None + filepath: str diff --git a/sentience/snapshot.py b/sentience/snapshot.py index 4f74bb6..786161f 100644 --- a/sentience/snapshot.py +++ b/sentience/snapshot.py @@ -120,7 +120,11 @@ def _snapshot_via_extension( # Build options dict for extension API (exclude save_trace/trace_path) ext_options: dict[str, Any] = {} if options.screenshot is not False: - ext_options["screenshot"] = options.screenshot + # Serialize ScreenshotConfig to dict if it's a Pydantic model + if hasattr(options.screenshot, "model_dump"): + ext_options["screenshot"] = options.screenshot.model_dump() + else: + ext_options["screenshot"] = options.screenshot if options.limit != 50: ext_options["limit"] = options.limit if options.filter is not None: @@ -355,7 +359,11 @@ async def _snapshot_via_extension_async( # Build options dict for extension API ext_options: dict[str, Any] = {} if options.screenshot is not False: - ext_options["screenshot"] = options.screenshot + # Serialize ScreenshotConfig to dict if it's a Pydantic model + if hasattr(options.screenshot, "model_dump"): + ext_options["screenshot"] = options.screenshot.model_dump() + else: + ext_options["screenshot"] = options.screenshot if options.limit != 50: ext_options["limit"] = options.limit if options.filter is not None: @@ -372,6 +380,8 @@ async def _snapshot_via_extension_async( """, ext_options, ) + if result.get("error"): + print(f" Snapshot error: {result.get('error')}") # Save trace if requested if options.save_trace: @@ -392,6 +402,15 @@ async def _snapshot_via_extension_async( raw_elements, ) + # Extract screenshot_format from data URL if not provided by extension + if result.get("screenshot") and not result.get("screenshot_format"): + screenshot_data_url = result.get("screenshot", "") + if screenshot_data_url.startswith("data:image/"): + # Extract format from "data:image/jpeg;base64,..." or "data:image/png;base64,..." + format_match = screenshot_data_url.split(";")[0].split("/")[-1] + if format_match in ["jpeg", "jpg", "png"]: + result["screenshot_format"] = "jpeg" if format_match in ["jpeg", "jpg"] else "png" + # Validate and parse with Pydantic snapshot_obj = Snapshot(**result) return snapshot_obj @@ -421,10 +440,16 @@ async def _snapshot_via_api_async( "Sentience extension failed to inject. Cannot collect raw data for API processing." ) from e - # Step 1: Get raw data from local extension + # Step 1: Get raw data from local extension (including screenshot) raw_options: dict[str, Any] = {} + screenshot_requested = False if options.screenshot is not False: - raw_options["screenshot"] = options.screenshot + screenshot_requested = True + # Serialize ScreenshotConfig to dict if it's a Pydantic model + if hasattr(options.screenshot, "model_dump"): + raw_options["screenshot"] = options.screenshot.model_dump() + else: + raw_options["screenshot"] = options.screenshot raw_result = await browser.page.evaluate( """ @@ -435,6 +460,16 @@ async def _snapshot_via_api_async( raw_options, ) + # Extract screenshot from raw result (extension captures it, but API doesn't return it) + screenshot_data_url = raw_result.get("screenshot") + screenshot_format = None + if screenshot_data_url: + # Extract format from data URL + if screenshot_data_url.startswith("data:image/"): + format_match = screenshot_data_url.split(";")[0].split("/")[-1] + if format_match in ["jpeg", "jpg", "png"]: + screenshot_format = "jpeg" if format_match in ["jpeg", "jpg"] else "png" + # Save trace if requested if options.save_trace: _save_trace_to_file(raw_result.get("raw_elements", []), options.trace_path) @@ -479,6 +514,13 @@ async def _snapshot_via_api_async( response.raise_for_status() api_result = response.json() + # Extract screenshot format from data URL if not provided + if screenshot_data_url and not screenshot_format: + if screenshot_data_url.startswith("data:image/"): + format_match = screenshot_data_url.split(";")[0].split("/")[-1] + if format_match in ["jpeg", "jpg", "png"]: + screenshot_format = "jpeg" if format_match in ["jpeg", "jpg"] else "png" + # Merge API result with local data snapshot_data = { "status": api_result.get("status", "success"), @@ -486,8 +528,8 @@ async def _snapshot_via_api_async( "url": api_result.get("url", raw_result.get("url", "")), "viewport": api_result.get("viewport", raw_result.get("viewport")), "elements": api_result.get("elements", []), - "screenshot": raw_result.get("screenshot"), - "screenshot_format": raw_result.get("screenshot_format"), + "screenshot": screenshot_data_url, # Use the extracted screenshot + "screenshot_format": screenshot_format, # Use the extracted format "error": api_result.get("error"), } diff --git a/sentience/tracer_factory.py b/sentience/tracer_factory.py index 57b61a5..86c3b01 100644 --- a/sentience/tracer_factory.py +++ b/sentience/tracer_factory.py @@ -99,13 +99,42 @@ def create_tracer( ) else: print("⚠️ [Sentience] Cloud init response missing upload_url") + print(f" Response data: {data}") print(" Falling back to local-only tracing") elif response.status_code == 403: print("⚠️ [Sentience] Cloud tracing requires Pro tier") + try: + error_data = response.json() + error_msg = error_data.get("error") or error_data.get("message", "") + if error_msg: + print(f" API Error: {error_msg}") + except Exception: + pass + print(" Falling back to local-only tracing") + elif response.status_code == 401: + print("⚠️ [Sentience] Cloud init failed: HTTP 401 Unauthorized") + print(" API key is invalid or expired") + try: + error_data = response.json() + error_msg = error_data.get("error") or error_data.get("message", "") + if error_msg: + print(f" API Error: {error_msg}") + except Exception: + pass print(" Falling back to local-only tracing") else: print(f"⚠️ [Sentience] Cloud init failed: HTTP {response.status_code}") + try: + error_data = response.json() + error_msg = error_data.get("error") or error_data.get( + "message", "Unknown error" + ) + print(f" Error: {error_msg}") + if "tier" in error_msg.lower() or "subscription" in error_msg.lower(): + print(f" 💡 This may be a tier/subscription issue") + except Exception: + print(f" Response: {response.text[:200]}") print(" Falling back to local-only tracing") except requests.exceptions.Timeout: @@ -149,10 +178,23 @@ def _recover_orphaned_traces(api_key: str, api_url: str = SENTIENCE_API_URL) -> if not orphaned: return - print(f"⚠️ [Sentience] Found {len(orphaned)} un-uploaded trace(s) from previous runs") + # Filter out test files (run_ids that start with "test-" or are clearly test data) + # These are likely from local testing and shouldn't be uploaded + test_patterns = ["test-", "test_", "test."] + valid_orphaned = [ + f + for f in orphaned + if not any(f.stem.startswith(pattern) for pattern in test_patterns) + and not f.stem.startswith("test") + ] + + if not valid_orphaned: + return + + print(f"⚠️ [Sentience] Found {len(valid_orphaned)} un-uploaded trace(s) from previous runs") print(" Attempting to upload now...") - for trace_file in orphaned: + for trace_file in valid_orphaned: try: # Extract run_id from filename (format: {run_id}.jsonl) run_id = trace_file.stem @@ -166,6 +208,11 @@ def _recover_orphaned_traces(api_key: str, api_url: str = SENTIENCE_API_URL) -> ) if response.status_code != 200: + # HTTP 422 typically means invalid run_id (e.g., test files) + # Skip silently for 422, but log other errors + if response.status_code == 422: + # Likely a test file or invalid run_id, skip silently + continue print(f"❌ Failed to get upload URL for {run_id}: HTTP {response.status_code}") continue diff --git a/tests/test_cloud_tracing.py b/tests/test_cloud_tracing.py index 1979a81..88dfd63 100644 --- a/tests/test_cloud_tracing.py +++ b/tests/test_cloud_tracing.py @@ -1,5 +1,6 @@ """Tests for sentience.cloud_tracing module""" +import base64 import gzip import json import os @@ -231,6 +232,131 @@ def progress_callback(uploaded: int, total: int): # Last call should have uploaded == total assert progress_calls[-1][0] == progress_calls[-1][1], "Final progress should be 100%" + def test_cloud_trace_sink_uploads_screenshots_after_trace(self): + """Test that CloudTraceSink uploads screenshots after trace upload succeeds.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-screenshot-integration-1" + api_key = "sk_test_123" + + # Create test screenshot + test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + sink = CloudTraceSink(upload_url, run_id=run_id, api_key=api_key) + + # Emit trace event with screenshot embedded + sink.emit( + { + "v": 1, + "type": "snapshot", + "ts": "2026-01-01T00:00:00.000Z", + "run_id": run_id, + "seq": 1, + "step_id": "step-1", + "data": { + "url": "https://example.com", + "element_count": 10, + "screenshot_base64": test_image_base64, + "screenshot_format": "png", + }, + } + ) + + # Mock all HTTP calls + mock_upload_urls = { + "1": "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/screenshots/step_0001.png?signature=...", + } + + with ( + patch("sentience.cloud_tracing.requests.put") as mock_put, + patch("sentience.cloud_tracing.requests.post") as mock_post, + ): + # Mock trace upload (first PUT) + mock_trace_response = Mock() + mock_trace_response.status_code = 200 + mock_put.return_value = mock_trace_response + + # Mock screenshot init (first POST) + mock_init_response = Mock() + mock_init_response.status_code = 200 + mock_init_response.json.return_value = {"upload_urls": mock_upload_urls} + + # Mock screenshot upload (second PUT) + mock_screenshot_response = Mock() + mock_screenshot_response.status_code = 200 + + # Mock complete (second POST) + mock_complete_response = Mock() + mock_complete_response.status_code = 200 + + # Setup mock to return different responses for different calls + def put_side_effect(*args, **kwargs): + url = args[0] if args else kwargs.get("url", "") + if "screenshots" in url: + return mock_screenshot_response + return mock_trace_response + + def post_side_effect(*args, **kwargs): + url = args[0] if args else kwargs.get("url", "") + if "screenshots/init" in url: + return mock_init_response + return mock_complete_response + + mock_put.side_effect = put_side_effect + mock_post.side_effect = post_side_effect + + # Close triggers upload (which extracts screenshots and uploads them) + sink.close() + + # Verify trace was uploaded + assert mock_put.call_count >= 1 + + # Verify screenshot init was called + post_calls = [call[0][0] for call in mock_post.call_args_list] + assert any("screenshots/init" in url for url in post_calls) + + # Verify screenshot was uploaded (second PUT call) + put_urls = [call[0][0] for call in mock_put.call_args_list] + assert any("screenshots" in url for url in put_urls) + + # Verify uploaded trace data does NOT contain screenshot_base64 + trace_upload_call = None + for call in mock_put.call_args_list: + headers = call[1].get("headers", {}) + if headers.get("Content-Type") == "application/x-gzip": + trace_upload_call = call + break + + assert trace_upload_call is not None, "Trace upload should have been called" + + # Decompress and verify screenshot_base64 is removed + compressed_data = trace_upload_call[1]["data"] + decompressed_data = gzip.decompress(compressed_data) + trace_content = decompressed_data.decode("utf-8") + events = [ + json.loads(line) for line in trace_content.strip().split("\n") if line.strip() + ] + + snapshot_events = [e for e in events if e.get("type") == "snapshot"] + assert len(snapshot_events) > 0, "Should have snapshot event" + + for event in snapshot_events: + data = event.get("data", {}) + assert ( + "screenshot_base64" not in data + ), "screenshot_base64 should be removed from uploaded trace" + assert ( + "screenshot_format" not in data + ), "screenshot_format should be removed from uploaded trace" + + # Cleanup + cache_dir = Path.home() / ".sentience" / "traces" / "pending" + trace_path = cache_dir / f"{run_id}.jsonl" + cleaned_trace_path = cache_dir / f"{run_id}.cleaned.jsonl" + if trace_path.exists(): + os.remove(trace_path) + if cleaned_trace_path.exists(): + os.remove(cleaned_trace_path) + class TestTracerFactory: """Test create_tracer factory function.""" diff --git a/tests/test_screenshot_storage.py b/tests/test_screenshot_storage.py new file mode 100644 index 0000000..a95a63d --- /dev/null +++ b/tests/test_screenshot_storage.py @@ -0,0 +1,517 @@ +"""Tests for screenshot extraction and upload in CloudTraceSink""" + +import base64 +import gzip +import json +import os +from pathlib import Path +from unittest.mock import Mock, patch + +import pytest + +from sentience.cloud_tracing import CloudTraceSink + + +class TestScreenshotExtraction: + """Test screenshot extraction functionality in CloudTraceSink.""" + + def test_extract_screenshots_from_trace(self): + """Test that _extract_screenshots_from_trace extracts screenshots from events.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-screenshot-extraction-1" + + test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + sink = CloudTraceSink(upload_url, run_id=run_id) + + # Emit a snapshot event with screenshot + sink.emit( + { + "v": 1, + "type": "snapshot", + "ts": "2026-01-01T00:00:00.000Z", + "run_id": run_id, + "seq": 1, + "step_id": "step-1", + "data": { + "url": "https://example.com", + "element_count": 10, + "screenshot_base64": test_image_base64, + "screenshot_format": "png", + }, + } + ) + + # Close to write file + sink.close(blocking=False) + + # Wait a bit for file to be written + import time + + time.sleep(0.1) + + # Extract screenshots + screenshots = sink._extract_screenshots_from_trace() + + assert len(screenshots) == 1 + assert 1 in screenshots + assert screenshots[1]["base64"] == test_image_base64 + assert screenshots[1]["format"] == "png" + assert screenshots[1]["step_id"] == "step-1" + + # Cleanup + cache_dir = Path.home() / ".sentience" / "traces" / "pending" + trace_path = cache_dir / f"{run_id}.jsonl" + if trace_path.exists(): + trace_path.unlink() + + def test_extract_screenshots_handles_multiple(self): + """Test that _extract_screenshots_from_trace handles multiple screenshots.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-screenshot-extraction-2" + + test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + sink = CloudTraceSink(upload_url, run_id=run_id) + + # Emit multiple snapshot events with screenshots + for i in range(1, 4): + sink.emit( + { + "v": 1, + "type": "snapshot", + "ts": "2026-01-01T00:00:00.000Z", + "run_id": run_id, + "seq": i, + "step_id": f"step-{i}", + "data": { + "url": "https://example.com", + "element_count": 10, + "screenshot_base64": test_image_base64, + "screenshot_format": "png", + }, + } + ) + + sink.close(blocking=False) + import time + + time.sleep(0.1) + + screenshots = sink._extract_screenshots_from_trace() + assert len(screenshots) == 3 + + # Cleanup + cache_dir = Path.home() / ".sentience" / "traces" / "pending" + trace_path = cache_dir / f"{run_id}.jsonl" + if trace_path.exists(): + trace_path.unlink() + + def test_extract_screenshots_skips_events_without_screenshots(self): + """Test that _extract_screenshots_from_trace skips events without screenshots.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-screenshot-extraction-3" + + sink = CloudTraceSink(upload_url, run_id=run_id) + + # Emit snapshot without screenshot + sink.emit( + { + "v": 1, + "type": "snapshot", + "ts": "2026-01-01T00:00:00.000Z", + "run_id": run_id, + "seq": 1, + "data": { + "url": "https://example.com", + "element_count": 10, + # No screenshot_base64 + }, + } + ) + + sink.close(blocking=False) + import time + + time.sleep(0.1) + + screenshots = sink._extract_screenshots_from_trace() + assert len(screenshots) == 0 + + # Cleanup + cache_dir = Path.home() / ".sentience" / "traces" / "pending" + trace_path = cache_dir / f"{run_id}.jsonl" + if trace_path.exists(): + trace_path.unlink() + + +class TestCleanedTrace: + """Test cleaned trace creation functionality.""" + + def test_create_cleaned_trace_removes_screenshot_fields(self): + """Test that _create_cleaned_trace removes screenshot_base64 from events.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-cleaned-trace-1" + + test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + sink = CloudTraceSink(upload_url, run_id=run_id) + + # Emit snapshot event with screenshot + sink.emit( + { + "v": 1, + "type": "snapshot", + "ts": "2026-01-01T00:00:00.000Z", + "run_id": run_id, + "seq": 1, + "data": { + "url": "https://example.com", + "element_count": 10, + "screenshot_base64": test_image_base64, + "screenshot_format": "png", + }, + } + ) + + sink.close(blocking=False) + import time + + time.sleep(0.1) + + # Create cleaned trace + cache_dir = Path.home() / ".sentience" / "traces" / "pending" + cleaned_trace_path = cache_dir / f"{run_id}.cleaned.jsonl" + sink._create_cleaned_trace(cleaned_trace_path) + + # Read cleaned trace + with open(cleaned_trace_path) as f: + cleaned_event = json.loads(f.readline()) + + # Verify screenshot fields are removed + assert "screenshot_base64" not in cleaned_event["data"] + assert "screenshot_format" not in cleaned_event["data"] + assert cleaned_event["data"]["url"] == "https://example.com" + assert cleaned_event["data"]["element_count"] == 10 + + # Cleanup + trace_path = cache_dir / f"{run_id}.jsonl" + if trace_path.exists(): + trace_path.unlink() + if cleaned_trace_path.exists(): + cleaned_trace_path.unlink() + + def test_create_cleaned_trace_preserves_other_events(self): + """Test that _create_cleaned_trace preserves non-snapshot events unchanged.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-cleaned-trace-2" + + sink = CloudTraceSink(upload_url, run_id=run_id) + + # Emit non-snapshot event + sink.emit( + { + "v": 1, + "type": "action", + "ts": "2026-01-01T00:00:00.000Z", + "run_id": run_id, + "seq": 1, + "data": { + "action": "click", + "element_id": 123, + }, + } + ) + + sink.close(blocking=False) + import time + + time.sleep(0.1) + + # Create cleaned trace + cache_dir = Path.home() / ".sentience" / "traces" / "pending" + cleaned_trace_path = cache_dir / f"{run_id}.cleaned.jsonl" + sink._create_cleaned_trace(cleaned_trace_path) + + # Read cleaned trace + with open(cleaned_trace_path) as f: + cleaned_event = json.loads(f.readline()) + + # Verify action event is unchanged + assert cleaned_event["type"] == "action" + assert cleaned_event["data"]["action"] == "click" + assert cleaned_event["data"]["element_id"] == 123 + + # Cleanup + trace_path = cache_dir / f"{run_id}.jsonl" + if trace_path.exists(): + trace_path.unlink() + if cleaned_trace_path.exists(): + cleaned_trace_path.unlink() + + +class TestScreenshotUpload: + """Test screenshot upload functionality in CloudTraceSink.""" + + def test_request_screenshot_urls_success(self): + """Test that _request_screenshot_urls requests URLs from gateway.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-screenshot-upload-1" + api_key = "sk_test_123" + + sink = CloudTraceSink(upload_url, run_id=run_id, api_key=api_key) + + # Mock gateway response + mock_urls = { + "1": "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/screenshots/step_0001.png?signature=...", + "2": "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/screenshots/step_0002.png?signature=...", + } + + with patch("sentience.cloud_tracing.requests.post") as mock_post: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"upload_urls": mock_urls} + mock_post.return_value = mock_response + + # Request URLs + result = sink._request_screenshot_urls([1, 2]) + + # Verify request was made + assert mock_post.called + call_args = mock_post.call_args + assert "v1/screenshots/init" in call_args[0][0] + assert call_args[1]["headers"]["Authorization"] == f"Bearer {api_key}" + assert call_args[1]["json"]["run_id"] == run_id + assert call_args[1]["json"]["sequences"] == [1, 2] + + # Verify result (keys converted to int) + assert result == {1: mock_urls["1"], 2: mock_urls["2"]} + + sink.close(blocking=False) + + def test_request_screenshot_urls_handles_failure(self): + """Test that _request_screenshot_urls handles gateway failure.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-screenshot-upload-2" + api_key = "sk_test_123" + + sink = CloudTraceSink(upload_url, run_id=run_id, api_key=api_key) + + with patch("sentience.cloud_tracing.requests.post") as mock_post: + mock_response = Mock() + mock_response.status_code = 500 + mock_post.return_value = mock_response + + # Request URLs (should return empty dict on failure) + result = sink._request_screenshot_urls([1, 2]) + assert result == {} + + sink.close(blocking=False) + + def test_upload_screenshots_uploads_in_parallel(self): + """Test that _upload_screenshots uploads screenshots in parallel.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-screenshot-upload-3" + api_key = "sk_test_123" + + # Create test screenshots data + test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + screenshots = { + 1: {"base64": test_image_base64, "format": "png", "step_id": "step-1"}, + 2: {"base64": test_image_base64, "format": "png", "step_id": "step-2"}, + } + + sink = CloudTraceSink(upload_url, run_id=run_id, api_key=api_key) + + # Mock gateway and upload responses + mock_upload_urls = { + "1": "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/screenshots/step_0001.png?signature=...", + "2": "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/screenshots/step_0002.png?signature=...", + } + + with ( + patch("sentience.cloud_tracing.requests.post") as mock_post, + patch("sentience.cloud_tracing.requests.put") as mock_put, + ): + # Mock gateway response + mock_gateway_response = Mock() + mock_gateway_response.status_code = 200 + mock_gateway_response.json.return_value = {"upload_urls": mock_upload_urls} + mock_post.return_value = mock_gateway_response + + # Mock upload responses + mock_upload_response = Mock() + mock_upload_response.status_code = 200 + mock_put.return_value = mock_upload_response + + # Upload screenshots + sink._upload_screenshots(screenshots) + + # Verify gateway was called + assert mock_post.called + + # Verify uploads were called (2 screenshots) + # Filter PUT calls to only screenshot uploads (exclude trace file uploads) + put_calls = mock_put.call_args_list + screenshot_uploads = [ + call for call in put_calls if "screenshots" in str(call[0][0] if call[0] else "") + ] + assert len(screenshot_uploads) == 2 + + # Verify upload URLs and content + upload_urls = [call[0][0] for call in screenshot_uploads] + assert mock_upload_urls["1"] in upload_urls + assert mock_upload_urls["2"] in upload_urls + + sink.close(blocking=False) + + def test_upload_screenshots_skips_when_no_screenshots(self, capsys): + """Test that _upload_screenshots skips when no screenshots provided.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-screenshot-upload-4" + + sink = CloudTraceSink(upload_url, run_id=run_id) + + # Call upload with no screenshots (should do nothing) + sink._upload_screenshots({}) + + # Verify no errors + captured = capsys.readouterr() + assert "Uploading" not in captured.out + + sink.close(blocking=False) + + def test_complete_trace_includes_screenshot_count(self): + """Test that _complete_trace includes screenshot_count in stats.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-screenshot-complete-1" + api_key = "sk_test_123" + + sink = CloudTraceSink(upload_url, run_id=run_id, api_key=api_key) + # Set screenshot count (normally set during extraction) + sink.screenshot_count = 2 + + with patch("sentience.cloud_tracing.requests.post") as mock_post: + mock_response = Mock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + # Call complete + sink._complete_trace() + + # Verify request included screenshot_count + assert mock_post.called + call_args = mock_post.call_args + stats = call_args[1]["json"]["stats"] + assert "screenshot_count" in stats + assert stats["screenshot_count"] == 2 + + sink.close(blocking=False) + + def test_upload_removes_screenshot_base64_from_trace(self): + """Test that uploaded trace data does not contain screenshot_base64.""" + upload_url = "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/trace.jsonl.gz" + run_id = "test-screenshot-upload-clean-1" + api_key = "sk_test_123" + + test_image_base64 = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==" + + sink = CloudTraceSink(upload_url, run_id=run_id, api_key=api_key) + + # Emit snapshot event with screenshot + sink.emit( + { + "v": 1, + "type": "snapshot", + "ts": "2026-01-01T00:00:00.000Z", + "run_id": run_id, + "seq": 1, + "step_id": "step-1", + "data": { + "url": "https://example.com", + "element_count": 10, + "screenshot_base64": test_image_base64, + "screenshot_format": "png", + }, + } + ) + + sink.close(blocking=False) + import time + + time.sleep(0.1) + + # Mock gateway and upload responses + mock_upload_urls = { + "1": "https://sentience.nyc3.digitaloceanspaces.com/user123/run456/screenshots/step_0001.png?signature=...", + } + + with ( + patch("sentience.cloud_tracing.requests.post") as mock_post, + patch("sentience.cloud_tracing.requests.put") as mock_put, + ): + # Mock gateway response for screenshot URLs + mock_gateway_response = Mock() + mock_gateway_response.status_code = 200 + mock_gateway_response.json.return_value = {"upload_urls": mock_upload_urls} + mock_post.return_value = mock_gateway_response + + # Mock screenshot upload response + mock_screenshot_upload = Mock() + mock_screenshot_upload.status_code = 200 + mock_put.return_value = mock_screenshot_upload + + # Call _do_upload to simulate the full upload process + sink._do_upload() + + # Verify trace was uploaded (PUT was called) + assert mock_put.called + + # Find the trace upload call (not screenshot upload) + # Screenshot uploads happen first, then trace upload + put_calls = mock_put.call_args_list + trace_upload_call = None + for call in put_calls: + # Trace upload has Content-Type: application/x-gzip + headers = call[1].get("headers", {}) + if headers.get("Content-Type") == "application/x-gzip": + trace_upload_call = call + break + + assert trace_upload_call is not None, "Trace upload should have been called" + + # Decompress and verify the uploaded trace data + compressed_data = trace_upload_call[1]["data"] + decompressed_data = gzip.decompress(compressed_data) + trace_content = decompressed_data.decode("utf-8") + + # Parse the trace events + events = [ + json.loads(line) for line in trace_content.strip().split("\n") if line.strip() + ] + + # Find snapshot event + snapshot_events = [e for e in events if e.get("type") == "snapshot"] + assert len(snapshot_events) > 0, "Should have at least one snapshot event" + + # Verify screenshot_base64 is NOT in the uploaded trace + for event in snapshot_events: + data = event.get("data", {}) + assert ( + "screenshot_base64" not in data + ), "screenshot_base64 should be removed from uploaded trace" + assert ( + "screenshot_format" not in data + ), "screenshot_format should be removed from uploaded trace" + # Verify other fields are preserved + assert "url" in data + assert "element_count" in data + + # Cleanup + cache_dir = Path.home() / ".sentience" / "traces" / "pending" + trace_path = cache_dir / f"{run_id}.jsonl" + cleaned_trace_path = cache_dir / f"{run_id}.cleaned.jsonl" + if trace_path.exists(): + trace_path.unlink() + if cleaned_trace_path.exists(): + cleaned_trace_path.unlink()