diff --git a/sentience/backends/snapshot.py b/sentience/backends/snapshot.py index 18036fe..624a03f 100644 --- a/sentience/backends/snapshot.py +++ b/sentience/backends/snapshot.py @@ -21,6 +21,7 @@ cache.invalidate() # Force refresh on next get() """ +import asyncio import time from typing import TYPE_CHECKING, Any @@ -37,6 +38,57 @@ from .protocol import BrowserBackend +def _is_execution_context_destroyed_error(e: Exception) -> bool: + """ + Playwright (and other browser backends) can throw while a navigation is in-flight. + + Common symptoms: + - "Execution context was destroyed, most likely because of a navigation" + - "Cannot find context with specified id" + """ + msg = str(e).lower() + return ( + "execution context was destroyed" in msg + or "most likely because of a navigation" in msg + or "cannot find context with specified id" in msg + ) + + +async def _eval_with_navigation_retry( + backend: "BrowserBackend", + expression: str, + *, + retries: int = 10, + settle_state: str = "interactive", + settle_timeout_ms: int = 10000, +) -> Any: + """ + Evaluate JS, retrying once/ twice if the page is mid-navigation. + + This makes snapshots resilient to cases like: + - press Enter (navigation) → snapshot immediately → context destroyed + """ + last_err: Exception | None = None + for attempt in range(retries + 1): + try: + return await backend.eval(expression) + except Exception as e: + last_err = e + if not _is_execution_context_destroyed_error(e) or attempt >= retries: + raise + # Navigation is in-flight; wait for new document context then retry. + try: + await backend.wait_ready_state(state=settle_state, timeout_ms=settle_timeout_ms) # type: ignore[arg-type] + except Exception: + # If readyState polling also fails mid-nav, still retry after a short backoff. + pass + # Exponential-ish backoff (caps quickly), tuned for real navigations. + await asyncio.sleep(min(0.25 * (attempt + 1), 1.5)) + + # Unreachable in practice, but keeps type-checkers happy. + raise last_err if last_err else RuntimeError("eval failed") + + class CachedSnapshot: """ Snapshot cache with staleness detection. @@ -289,13 +341,14 @@ async def _snapshot_via_extension( ext_options = _build_extension_options(options) # Call extension's snapshot function - result = await backend.eval( + result = await _eval_with_navigation_retry( + backend, f""" (() => {{ const options = {_json_serialize(ext_options)}; return window.sentience.snapshot(options); }})() - """ + """, ) if result is None: @@ -310,14 +363,15 @@ async def _snapshot_via_extension( if options.show_overlay: raw_elements = result.get("raw_elements", []) if raw_elements: - await backend.eval( + await _eval_with_navigation_retry( + backend, f""" (() => {{ if (window.sentience && window.sentience.showOverlay) {{ window.sentience.showOverlay({_json_serialize(raw_elements)}, null); }} }})() - """ + """, ) # Build and return Snapshot @@ -341,13 +395,14 @@ async def _snapshot_via_api( raw_options["screenshot"] = options.screenshot # Call extension to get raw elements - raw_result = await backend.eval( + raw_result = await _eval_with_navigation_retry( + backend, f""" (() => {{ const options = {_json_serialize(raw_options)}; return window.sentience.snapshot(options); }})() - """ + """, ) if raw_result is None: @@ -372,14 +427,15 @@ async def _snapshot_via_api( if options.show_overlay: elements = api_result.get("elements", []) if elements: - await backend.eval( + await _eval_with_navigation_retry( + backend, f""" (() => {{ if (window.sentience && window.sentience.showOverlay) {{ window.sentience.showOverlay({_json_serialize(elements)}, null); }} }})() - """ + """, ) return Snapshot(**snapshot_data) diff --git a/sentience/llm_provider.py b/sentience/llm_provider.py index 7779f6c..8939b8d 100644 --- a/sentience/llm_provider.py +++ b/sentience/llm_provider.py @@ -5,6 +5,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import Any from .llm_provider_utils import get_api_key_from_env, handle_provider_error, require_package from .llm_response_builder import LLMResponseBuilder @@ -777,21 +778,44 @@ def __init__( elif load_in_8bit: quantization_config = BitsAndBytesConfig(load_in_8bit=True) + device = (device or "auto").strip().lower() + # Determine torch dtype if torch_dtype == "auto": - dtype = torch.float16 if device != "cpu" else torch.float32 + dtype = torch.float16 if device not in {"cpu"} else torch.float32 else: dtype = getattr(torch, torch_dtype) - # Load model - self.model = AutoModelForCausalLM.from_pretrained( - model_name, - quantization_config=quantization_config, - torch_dtype=dtype if quantization_config is None else None, - device_map=device, - trust_remote_code=True, - low_cpu_mem_usage=True, - ) + # device_map is a Transformers concept (not a literal "cpu/mps/cuda" device string). + # - "auto" enables Accelerate device mapping. + # - Otherwise, we load normally and then move the model to the requested device. + device_map: str | None = "auto" if device == "auto" else None + + def _load(*, device_map_override: str | None) -> Any: + return AutoModelForCausalLM.from_pretrained( + model_name, + quantization_config=quantization_config, + torch_dtype=dtype if quantization_config is None else None, + device_map=device_map_override, + trust_remote_code=True, + low_cpu_mem_usage=True, + ) + + try: + self.model = _load(device_map_override=device_map) + except KeyError as e: + # Some envs / accelerate versions can crash on auto mapping (e.g. KeyError: 'cpu'). + # Keep demo ergonomics: default stays "auto", but we gracefully fall back. + if device == "auto" and ("cpu" in str(e).lower()): + device = "cpu" + dtype = torch.float32 + self.model = _load(device_map_override=None) + else: + raise + + # If we didn't use device_map, move model explicitly (only safe for non-quantized loads). + if device_map is None and quantization_config is None and device in {"cpu", "cuda", "mps"}: + self.model = self.model.to(device) self.model.eval() def generate( diff --git a/sentience/snapshot.py b/sentience/snapshot.py index 96717c8..8cf75cf 100644 --- a/sentience/snapshot.py +++ b/sentience/snapshot.py @@ -20,6 +20,78 @@ MAX_PAYLOAD_BYTES = 10 * 1024 * 1024 +def _is_execution_context_destroyed_error(e: Exception) -> bool: + """ + Playwright can throw while a navigation is in-flight, invalidating the JS execution context. + + Common symptoms: + - "Execution context was destroyed, most likely because of a navigation" + - "Cannot find context with specified id" + """ + msg = str(e).lower() + return ( + "execution context was destroyed" in msg + or "most likely because of a navigation" in msg + or "cannot find context with specified id" in msg + ) + + +async def _page_evaluate_with_nav_retry( + page: Any, + expression: str, + arg: Any = None, + *, + retries: int = 2, + settle_timeout_ms: int = 10000, +) -> Any: + """ + Evaluate JS with a small retry loop if the page is mid-navigation. + + This prevents flaky crashes when callers snapshot right after triggering a navigation + (e.g., pressing Enter on Google). + """ + last_err: Exception | None = None + for attempt in range(retries + 1): + try: + if arg is None: + return await page.evaluate(expression) + return await page.evaluate(expression, arg) + except Exception as e: + last_err = e + if not _is_execution_context_destroyed_error(e) or attempt >= retries: + raise + try: + await page.wait_for_load_state("domcontentloaded", timeout=settle_timeout_ms) + except Exception: + pass + await asyncio.sleep(0.25) + raise last_err if last_err else RuntimeError("Page.evaluate failed") + + +async def _wait_for_function_with_nav_retry( + page: Any, + expression: str, + *, + timeout_ms: int, + retries: int = 2, +) -> None: + last_err: Exception | None = None + for attempt in range(retries + 1): + try: + await page.wait_for_function(expression, timeout=timeout_ms) + return + except Exception as e: + last_err = e + if not _is_execution_context_destroyed_error(e) or attempt >= retries: + raise + try: + await page.wait_for_load_state("domcontentloaded", timeout=timeout_ms) + except Exception: + pass + await asyncio.sleep(0.25) + raise last_err if last_err else RuntimeError("wait_for_function failed") + + def _build_snapshot_payload( raw_result: dict[str, Any], options: SnapshotOptions, @@ -265,8 +337,10 @@ def _snapshot_via_extension( # Show visual overlay if requested if options.show_overlay: - raw_elements = result.get("raw_elements", []) - if raw_elements: + # Prefer processed semantic elements for overlay (have bbox/importance/visual_cues). + # raw_elements may not match the overlay renderer's expected shape. + elements_for_overlay = result.get("elements") or result.get("raw_elements") or [] + if elements_for_overlay: browser.page.evaluate( """ (elements) => { @@ -275,7 +349,7 @@ def _snapshot_via_extension( } } """, - raw_elements, + elements_for_overlay, ) # Show grid overlay if requested @@ -455,18 +529,20 @@ async def _snapshot_via_extension_async( # Wait for extension injection to complete try: - await browser.page.wait_for_function( + await _wait_for_function_with_nav_retry( + browser.page, "typeof window.sentience !== 'undefined'", - timeout=5000, + timeout_ms=5000, ) except Exception as e: try: - diag = await browser.page.evaluate( + diag = await _page_evaluate_with_nav_retry( + browser.page, """() => ({ sentience_defined: typeof window.sentience !== 'undefined', extension_id: document.documentElement.dataset.sentienceExtensionId || 'not set', url: window.location.href - })""" + })""", ) except Exception: diag = {"error": "Could not gather diagnostics"} @@ -492,7 +568,8 @@ async def _snapshot_via_extension_async( ) # Call extension API - result = await browser.page.evaluate( + result = await _page_evaluate_with_nav_retry( + browser.page, """ (options) => { return window.sentience.snapshot(options); @@ -521,9 +598,12 @@ async def _snapshot_via_extension_async( # Show visual overlay if requested if options.show_overlay: - raw_elements = result.get("raw_elements", []) - if raw_elements: - await browser.page.evaluate( + # Prefer processed semantic elements for overlay (have bbox/importance/visual_cues). + # raw_elements may not match the overlay renderer's expected shape. + elements_for_overlay = result.get("elements") or result.get("raw_elements") or [] + if elements_for_overlay: + await _page_evaluate_with_nav_retry( + browser.page, """ (elements) => { if (window.sentience && window.sentience.showOverlay) { @@ -531,7 +611,7 @@ async def _snapshot_via_extension_async( } } """, - raw_elements, + elements_for_overlay, ) # Show grid overlay if requested @@ -542,9 +622,11 @@ async def _snapshot_via_extension_async( grid_dicts = [grid.model_dump() for grid in grids] # Pass grid_id as targetGridId to highlight it in red target_grid_id = options.grid_id if options.grid_id is not None else None - await browser.page.evaluate( + await _page_evaluate_with_nav_retry( + browser.page, """ - (grids, targetGridId) => { + (args) => { + const [grids, targetGridId] = args; if (window.sentience && window.sentience.showGrid) { window.sentience.showGrid(grids, targetGridId); } else { @@ -552,8 +634,7 @@ async def _snapshot_via_extension_async( } } """, - grid_dicts, - target_grid_id, + [grid_dicts, target_grid_id], ) return snapshot_obj @@ -573,8 +654,10 @@ async def _snapshot_via_api_async( # Wait for extension injection try: - await browser.page.wait_for_function( - "typeof window.sentience !== 'undefined'", timeout=5000 + await _wait_for_function_with_nav_retry( + browser.page, + "typeof window.sentience !== 'undefined'", + timeout_ms=5000, ) except Exception as e: raise RuntimeError( @@ -600,7 +683,8 @@ async def _snapshot_via_api_async( options.filter.model_dump() if hasattr(options.filter, "model_dump") else options.filter ) - raw_result = await browser.page.evaluate( + raw_result = await _page_evaluate_with_nav_retry( + browser.page, """ (options) => { return window.sentience.snapshot(options); @@ -689,7 +773,8 @@ async def _snapshot_via_api_async( if options.show_overlay: elements = api_result.get("elements", []) if elements: - await browser.page.evaluate( + await _page_evaluate_with_nav_retry( + browser.page, """ (elements) => { if (window.sentience && window.sentience.showOverlay) { @@ -708,9 +793,11 @@ async def _snapshot_via_api_async( grid_dicts = [grid.model_dump() for grid in grids] # Pass grid_id as targetGridId to highlight it in red target_grid_id = options.grid_id if options.grid_id is not None else None - await browser.page.evaluate( + await _page_evaluate_with_nav_retry( + browser.page, """ - (grids, targetGridId) => { + (args) => { + const [grids, targetGridId] = args; if (window.sentience && window.sentience.showGrid) { window.sentience.showGrid(grids, targetGridId); } else { @@ -718,8 +805,7 @@ async def _snapshot_via_api_async( } } """, - grid_dicts, - target_grid_id, + [grid_dicts, target_grid_id], ) return snapshot_obj