diff --git a/README.md b/README.md index 4fa9219..d78bb04 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,24 @@ pip install transformers torch # For local LLMs pip install -e . ``` +## 🧭 Manual driver CLI + +Use the interactive CLI to open a page, inspect clickables, and drive actions: + +```bash +sentience driver --url https://example.com +``` + +Commands: +- `open ` +- `state [limit]` +- `click ` +- `type ` +- `press ` +- `screenshot [path]` +- `help` +- `close` + ## Jest for AI Web Agent ### Semantic snapshots and assertions that let agents act, verify, and know when they're done. diff --git a/sentience/__init__.py b/sentience/__init__.py index 4425bf0..f2e4b58 100644 --- a/sentience/__init__.py +++ b/sentience/__init__.py @@ -19,7 +19,11 @@ click_rect, press, scroll_to, + search, + search_async, select_option, + send_keys, + send_keys_async, submit, type_text, uncheck, @@ -51,7 +55,7 @@ # Agent Layer (Phase 1 & 2) from .base_agent import BaseAgent -from .browser import SentienceBrowser +from .browser import AsyncSentienceBrowser, SentienceBrowser from .captcha import CaptchaContext, CaptchaHandlingError, CaptchaOptions, CaptchaResolution from .captcha_strategies import ExternalSolver, HumanHandoffSolver, VisionSolver @@ -86,6 +90,7 @@ Snapshot, SnapshotFilter, SnapshotOptions, + StepHookContext, StorageState, TextContext, TextMatch, @@ -101,13 +106,14 @@ from .ordinal import OrdinalIntent, boost_ordinal_elements, detect_ordinal_intent, select_by_ordinal from .overlay import clear_overlay, show_overlay from .query import find, query -from .read import read +from .read import extract, extract_async, read from .recorder import Recorder, Trace, TraceStep, record from .runtime_agent import RuntimeAgent, RuntimeStep, StepVerification from .screenshot import screenshot from .sentience_methods import AgentAction, SentienceMethod from .snapshot import snapshot from .text_search import find_text_rect +from .tools import BackendCapabilities, ToolContext, ToolRegistry, ToolSpec, register_default_tools from .tracer_factory import SENTIENCE_API_URL, create_tracer from .tracing import JsonlTraceSink, TraceEvent, Tracer, TraceSink @@ -186,6 +192,7 @@ "backend_wait_for_stable", # Core SDK "SentienceBrowser", + "AsyncSentienceBrowser", "Snapshot", "Element", "BBox", diff --git a/sentience/actions.py b/sentience/actions.py index 0be5f48..90f9774 100644 --- a/sentience/actions.py +++ b/sentience/actions.py @@ -5,11 +5,12 @@ import asyncio import time from pathlib import Path +from urllib.parse import quote_plus from .browser import AsyncSentienceBrowser, SentienceBrowser from .browser_evaluator import BrowserEvaluator from .cursor_policy import CursorPolicy, build_human_cursor_path -from .models import ActionResult, BBox, Snapshot +from .models import ActionResult, BBox, Snapshot, SnapshotOptions from .sentience_methods import SentienceMethod from .snapshot import snapshot, snapshot_async @@ -709,6 +710,146 @@ def press(browser: SentienceBrowser, key: str, take_snapshot: bool = False) -> A ) +def _normalize_key_token(token: str) -> str: + lookup = { + "CMD": "Meta", + "COMMAND": "Meta", + "CTRL": "Control", + "CONTROL": "Control", + "ALT": "Alt", + "OPTION": "Alt", + "SHIFT": "Shift", + "ESC": "Escape", + "ESCAPE": "Escape", + "ENTER": "Enter", + "RETURN": "Enter", + "TAB": "Tab", + "SPACE": "Space", + } + upper = token.strip().upper() + return lookup.get(upper, token.strip()) + + +def _parse_key_sequence(sequence: str) -> list[str]: + parts = [] + for raw in sequence.replace(",", " ").split(): + raw = raw.strip() + if not raw: + continue + if raw.startswith("{") and raw.endswith("}"): + raw = raw[1:-1] + if "+" in raw: + combo = "+".join(_normalize_key_token(tok) for tok in raw.split("+") if tok) + parts.append(combo) + else: + parts.append(_normalize_key_token(raw)) + return parts + + +def send_keys( + browser: SentienceBrowser, + sequence: str, + take_snapshot: bool = False, + delay_ms: int = 50, +) -> ActionResult: + """ + Send a sequence of key presses (e.g., "CMD+H", "CTRL+SHIFT+P"). + + Supports sequences separated by commas/spaces, and brace-wrapped tokens + like "{ENTER}" or "{CTRL+L}". + """ + if not browser.page: + raise RuntimeError("Browser not started. Call browser.start() first.") + + start_time = time.time() + url_before = browser.page.url + + keys = _parse_key_sequence(sequence) + if not keys: + raise ValueError("send_keys sequence is empty") + for key in keys: + browser.page.keyboard.press(key) + if delay_ms > 0: + browser.page.wait_for_timeout(delay_ms) + + duration_ms = int((time.time() - start_time) * 1000) + url_after = browser.page.url + url_changed = url_before != url_after + outcome = "navigated" if url_changed else "dom_updated" + + snapshot_after: Snapshot | None = None + if take_snapshot: + snapshot_after = snapshot(browser) + + return ActionResult( + success=True, + duration_ms=duration_ms, + outcome=outcome, + url_changed=url_changed, + snapshot_after=snapshot_after, + ) + + +def _build_search_url(query: str, engine: str) -> str: + q = quote_plus(query) + key = engine.strip().lower() + if key in {"duckduckgo", "ddg"}: + return f"https://duckduckgo.com/?q={q}" + if key in {"google.com", "google"}: + return f"https://www.google.com/search?q={q}" + if key in {"google"}: + return f"https://www.google.com/search?q={q}" + if key in {"bing"}: + return f"https://www.bing.com/search?q={q}" + raise ValueError(f"unsupported search engine: {engine}") + + +def search( + browser: SentienceBrowser, + query: str, + engine: str = "duckduckgo", + take_snapshot: bool = False, + snapshot_options: SnapshotOptions | None = None, +) -> ActionResult: + """ + Navigate to a search results page for the given query. + + Args: + browser: SentienceBrowser instance + query: Search query string + engine: Search engine name (duckduckgo, google, google.com, bing) + take_snapshot: Whether to take snapshot after navigation + snapshot_options: Snapshot options passed to snapshot() when take_snapshot is True. + """ + if not browser.page: + raise RuntimeError("Browser not started. Call browser.start() first.") + if not query.strip(): + raise ValueError("search query is empty") + + start_time = time.time() + url_before = browser.page.url + url = _build_search_url(query, engine) + browser.goto(url) + browser.page.wait_for_load_state("networkidle") + + duration_ms = int((time.time() - start_time) * 1000) + url_after = browser.page.url + url_changed = url_before != url_after + outcome = "navigated" if url_changed else "dom_updated" + + snapshot_after: Snapshot | None = None + if take_snapshot: + snapshot_after = snapshot(browser, snapshot_options) + + return ActionResult( + success=True, + duration_ms=duration_ms, + outcome=outcome, + url_changed=url_changed, + snapshot_after=snapshot_after, + ) + + def scroll_to( browser: SentienceBrowser, element_id: int, @@ -1698,6 +1839,93 @@ async def press_async( ) +async def send_keys_async( + browser: AsyncSentienceBrowser, + sequence: str, + take_snapshot: bool = False, + delay_ms: int = 50, +) -> ActionResult: + """ + Async version of send_keys(). + """ + if not browser.page: + raise RuntimeError("Browser not started. Call await browser.start() first.") + + start_time = time.time() + url_before = browser.page.url + + keys = _parse_key_sequence(sequence) + if not keys: + raise ValueError("send_keys sequence is empty") + for key in keys: + await browser.page.keyboard.press(key) + if delay_ms > 0: + await browser.page.wait_for_timeout(delay_ms) + + duration_ms = int((time.time() - start_time) * 1000) + url_after = browser.page.url + url_changed = url_before != url_after + outcome = "navigated" if url_changed else "dom_updated" + + snapshot_after: Snapshot | None = None + if take_snapshot: + snapshot_after = await snapshot_async(browser) + + return ActionResult( + success=True, + duration_ms=duration_ms, + outcome=outcome, + url_changed=url_changed, + snapshot_after=snapshot_after, + ) + + +async def search_async( + browser: AsyncSentienceBrowser, + query: str, + engine: str = "duckduckgo", + take_snapshot: bool = False, + snapshot_options: SnapshotOptions | None = None, +) -> ActionResult: + """ + Async version of search(). + + Args: + browser: AsyncSentienceBrowser instance + query: Search query string + engine: Search engine name (duckduckgo, google, google.com, bing) + take_snapshot: Whether to take snapshot after navigation + snapshot_options: Snapshot options passed to snapshot_async() when take_snapshot is True. + """ + if not browser.page: + raise RuntimeError("Browser not started. Call await browser.start() first.") + if not query.strip(): + raise ValueError("search query is empty") + + start_time = time.time() + url_before = browser.page.url + url = _build_search_url(query, engine) + await browser.goto(url) + await browser.page.wait_for_load_state("networkidle") + + duration_ms = int((time.time() - start_time) * 1000) + url_after = browser.page.url + url_changed = url_before != url_after + outcome = "navigated" if url_changed else "dom_updated" + + snapshot_after: Snapshot | None = None + if take_snapshot: + snapshot_after = await snapshot_async(browser, snapshot_options) + + return ActionResult( + success=True, + duration_ms=duration_ms, + outcome=outcome, + url_changed=url_changed, + snapshot_after=snapshot_after, + ) + + async def scroll_to_async( browser: AsyncSentienceBrowser, element_id: int, diff --git a/sentience/agent.py b/sentience/agent.py index 2f7f4ac..eb42a13 100644 --- a/sentience/agent.py +++ b/sentience/agent.py @@ -5,7 +5,10 @@ import asyncio import hashlib +import inspect +import logging import time +from collections.abc import Callable from typing import TYPE_CHECKING, Any, Optional, Union from .action_executor import ActionExecutor @@ -23,6 +26,7 @@ ScreenshotConfig, Snapshot, SnapshotOptions, + StepHookContext, TokenStats, ) from .protocols import AsyncBrowserProtocol, BrowserProtocol @@ -65,6 +69,47 @@ def _safe_tracer_call( print(f"āš ļø Tracer error (non-fatal): {tracer_error}") +def _safe_hook_call_sync( + hook: Callable[[StepHookContext], Any] | None, + ctx: StepHookContext, + verbose: bool, +) -> None: + if not hook: + return + try: + result = hook(ctx) + if inspect.isawaitable(result): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + asyncio.run(result) + else: + loop.create_task(result) + except Exception as hook_error: + if verbose: + print(f"āš ļø Hook error (non-fatal): {hook_error}") + else: + logging.getLogger(__name__).warning("Hook error (non-fatal): %s", hook_error) + + +async def _safe_hook_call_async( + hook: Callable[[StepHookContext], Any] | None, + ctx: StepHookContext, + verbose: bool, +) -> None: + if not hook: + return + try: + result = hook(ctx) + if inspect.isawaitable(result): + await result + except Exception as hook_error: + if verbose: + print(f"āš ļø Hook error (non-fatal): {hook_error}") + else: + logging.getLogger(__name__).warning("Hook error (non-fatal): %s", hook_error) + + class SentienceAgent(BaseAgent): """ High-level agent that combines Sentience SDK with any LLM provider. @@ -181,6 +226,8 @@ def act( # noqa: C901 goal: str, max_retries: int = 2, snapshot_options: SnapshotOptions | None = None, + on_step_start: Callable[[StepHookContext], Any] | None = None, + on_step_end: Callable[[StepHookContext], Any] | None = None, ) -> AgentActionResult: """ Execute a high-level goal using observe → think → act loop @@ -210,9 +257,9 @@ def act( # noqa: C901 self._step_count += 1 step_id = f"step-{self._step_count}" + pre_url = self.browser.page.url if self.browser.page else None # Emit step_start trace event if tracer is enabled if self.tracer: - pre_url = self.browser.page.url if self.browser.page else None _safe_tracer_call( self.tracer, "emit_step_start", @@ -224,6 +271,18 @@ def act( # noqa: C901 pre_url=pre_url, ) + _safe_hook_call_sync( + on_step_start, + StepHookContext( + step_id=step_id, + step_index=self._step_count, + goal=goal, + attempt=0, + url=pre_url, + ), + self.verbose, + ) + # Track data collected during step execution for step_end emission on failure _step_snap_with_diff: Snapshot | None = None _step_pre_url: str | None = None @@ -396,8 +455,8 @@ def act( # noqa: C901 _step_duration_ms = duration_ms # Emit action execution trace event if tracer is enabled + post_url = self.browser.page.url if self.browser.page else None if self.tracer: - post_url = self.browser.page.url if self.browser.page else None # Include element data for live overlay visualization elements_data = [ @@ -454,7 +513,6 @@ def act( # noqa: C901 if self.tracer: # Get pre_url from step_start (stored in tracer or use current) pre_url = snap.url - post_url = self.browser.page.url if self.browser.page else None # Compute snapshot digest (simplified - use URL + timestamp) snapshot_digest = f"sha256:{self._compute_hash(f'{pre_url}{snap.timestamp}')}" @@ -561,6 +619,20 @@ def act( # noqa: C901 step_id=step_id, ) + _safe_hook_call_sync( + on_step_end, + StepHookContext( + step_id=step_id, + step_index=self._step_count, + goal=goal, + attempt=attempt, + url=post_url, + success=result.success, + outcome=result.outcome, + error=result.error, + ), + self.verbose, + ) return result except Exception as e: @@ -660,6 +732,20 @@ def act( # noqa: C901 "duration_ms": 0, } ) + _safe_hook_call_sync( + on_step_end, + StepHookContext( + step_id=step_id, + step_index=self._step_count, + goal=goal, + attempt=attempt, + url=_step_pre_url, + success=False, + outcome="exception", + error=str(e), + ), + self.verbose, + ) raise RuntimeError(f"Failed after {max_retries} retries: {e}") def _track_tokens(self, goal: str, llm_response: LLMResponse): @@ -833,6 +919,8 @@ async def act( # noqa: C901 goal: str, max_retries: int = 2, snapshot_options: SnapshotOptions | None = None, + on_step_start: Callable[[StepHookContext], Any] | None = None, + on_step_end: Callable[[StepHookContext], Any] | None = None, ) -> AgentActionResult: """ Execute a high-level goal using observe → think → act loop (async) @@ -859,9 +947,9 @@ async def act( # noqa: C901 self._step_count += 1 step_id = f"step-{self._step_count}" + pre_url = self.browser.page.url if self.browser.page else None # Emit step_start trace event if tracer is enabled if self.tracer: - pre_url = self.browser.page.url if self.browser.page else None _safe_tracer_call( self.tracer, "emit_step_start", @@ -873,6 +961,18 @@ async def act( # noqa: C901 pre_url=pre_url, ) + await _safe_hook_call_async( + on_step_start, + StepHookContext( + step_id=step_id, + step_index=self._step_count, + goal=goal, + attempt=0, + url=pre_url, + ), + self.verbose, + ) + # Track data collected during step execution for step_end emission on failure _step_snap_with_diff: Snapshot | None = None _step_pre_url: str | None = None @@ -1209,6 +1309,21 @@ async def act( # noqa: C901 step_id=step_id, ) + post_url = self.browser.page.url if self.browser.page else None + await _safe_hook_call_async( + on_step_end, + StepHookContext( + step_id=step_id, + step_index=self._step_count, + goal=goal, + attempt=attempt, + url=post_url, + success=result.success, + outcome=result.outcome, + error=result.error, + ), + self.verbose, + ) return result except Exception as e: @@ -1308,6 +1423,20 @@ async def act( # noqa: C901 "duration_ms": 0, } ) + await _safe_hook_call_async( + on_step_end, + StepHookContext( + step_id=step_id, + step_index=self._step_count, + goal=goal, + attempt=attempt, + url=_step_pre_url, + success=False, + outcome="exception", + error=str(e), + ), + self.verbose, + ) raise RuntimeError(f"Failed after {max_retries} retries: {e}") def _track_tokens(self, goal: str, llm_response: LLMResponse): diff --git a/sentience/agent_runtime.py b/sentience/agent_runtime.py index 2852699..bebe547 100644 --- a/sentience/agent_runtime.py +++ b/sentience/agent_runtime.py @@ -72,7 +72,16 @@ from .captcha import CaptchaContext, CaptchaHandlingError, CaptchaOptions, CaptchaResolution from .failure_artifacts import FailureArtifactBuffer, FailureArtifactsOptions -from .models import Snapshot, SnapshotOptions +from .models import ( + EvaluateJsRequest, + EvaluateJsResult, + Snapshot, + SnapshotOptions, + TabInfo, + TabListResult, + TabOperationResult, +) +from .tools import BackendCapabilities, ToolRegistry from .trace_event_builder import TraceEventBuilder from .verification import AssertContext, AssertOutcome, Predicate @@ -110,6 +119,7 @@ def __init__( tracer: Tracer, snapshot_options: SnapshotOptions | None = None, sentience_api_key: str | None = None, + tool_registry: ToolRegistry | None = None, ): """ Initialize agent runtime with any BrowserBackend-compatible browser. @@ -122,9 +132,11 @@ def __init__( tracer: Tracer for emitting verification events snapshot_options: Default options for snapshots sentience_api_key: API key for Pro/Enterprise tier (enables Gateway refinement) + tool_registry: Optional ToolRegistry for LLM-callable tools """ self.backend = backend self.tracer = tracer + self.tool_registry = tool_registry # Build default snapshot options with API key if provided default_opts = snapshot_options or SnapshotOptions() @@ -286,6 +298,124 @@ async def snapshot(self, **kwargs: Any) -> Snapshot: await self._handle_captcha_if_needed(self.last_snapshot, source="gateway") return self.last_snapshot + async def evaluate_js(self, request: EvaluateJsRequest) -> EvaluateJsResult: + """ + Evaluate JavaScript expression in the active backend. + + Args: + request: EvaluateJsRequest with code and output limits. + + Returns: + EvaluateJsResult with normalized text output. + """ + try: + value = await self.backend.eval(request.code) + except Exception as exc: # pragma: no cover - backend-specific errors + return EvaluateJsResult(ok=False, error=str(exc)) + + text = self._stringify_eval_value(value) + truncated = False + if request.truncate and len(text) > request.max_output_chars: + text = text[: request.max_output_chars] + "..." + truncated = True + + return EvaluateJsResult( + ok=True, + value=value, + text=text, + truncated=truncated, + ) + + async def list_tabs(self) -> TabListResult: + backend = self._get_tab_backend() + if backend is None: + return TabListResult(ok=False, error="unsupported_capability") + try: + tabs = await backend.list_tabs() + except Exception as exc: # pragma: no cover - backend specific + return TabListResult(ok=False, error=str(exc)) + return TabListResult(ok=True, tabs=tabs) + + async def open_tab(self, url: str) -> TabOperationResult: + backend = self._get_tab_backend() + if backend is None: + return TabOperationResult(ok=False, error="unsupported_capability") + try: + tab = await backend.open_tab(url) + except Exception as exc: # pragma: no cover - backend specific + return TabOperationResult(ok=False, error=str(exc)) + return TabOperationResult(ok=True, tab=tab) + + async def switch_tab(self, tab_id: str) -> TabOperationResult: + backend = self._get_tab_backend() + if backend is None: + return TabOperationResult(ok=False, error="unsupported_capability") + try: + tab = await backend.switch_tab(tab_id) + except Exception as exc: # pragma: no cover - backend specific + return TabOperationResult(ok=False, error=str(exc)) + return TabOperationResult(ok=True, tab=tab) + + async def close_tab(self, tab_id: str) -> TabOperationResult: + backend = self._get_tab_backend() + if backend is None: + return TabOperationResult(ok=False, error="unsupported_capability") + try: + tab = await backend.close_tab(tab_id) + except Exception as exc: # pragma: no cover - backend specific + return TabOperationResult(ok=False, error=str(exc)) + return TabOperationResult(ok=True, tab=tab) + + def _get_tab_backend(self): + backend = getattr(self, "backend", None) + if backend is None: + return None + if not all( + hasattr(backend, attr) for attr in ("list_tabs", "open_tab", "switch_tab", "close_tab") + ): + return None + return backend + + def capabilities(self) -> BackendCapabilities: + backend = getattr(self, "backend", None) + if backend is None: + return BackendCapabilities() + has_eval = hasattr(backend, "eval") + has_keyboard = hasattr(backend, "type_text") or bool( + getattr(getattr(backend, "_page", None), "keyboard", None) + ) + has_downloads = bool(getattr(backend, "downloads", None)) + has_files = False + if self.tool_registry is not None: + try: + has_files = self.tool_registry.get("read_file") is not None + except Exception: + has_files = False + return BackendCapabilities( + tabs=self._get_tab_backend() is not None, + evaluate_js=bool(has_eval), + downloads=has_downloads, + filesystem_tools=has_files, + keyboard=bool(has_keyboard or has_eval), + ) + + def can(self, capability: str) -> bool: + caps = self.capabilities() + return bool(getattr(caps, capability, False)) + + @staticmethod + def _stringify_eval_value(value: Any) -> str: + if value is None: + return "null" + if isinstance(value, (dict, list)): + try: + import json + + return json.dumps(value, ensure_ascii=False) + except Exception: + return str(value) + return str(value) + def set_captcha_options(self, options: CaptchaOptions) -> None: """ Configure CAPTCHA handling (disabled by default unless set). @@ -450,10 +580,7 @@ def _compute_snapshot_digest(self, snap: Snapshot | None) -> str | None: if snap is None: return None try: - return ( - "sha256:" - + hashlib.sha256(f"{snap.url}{snap.timestamp}".encode("utf-8")).hexdigest() - ) + return "sha256:" + hashlib.sha256(f"{snap.url}{snap.timestamp}".encode()).hexdigest() except Exception: return None @@ -477,10 +604,7 @@ async def emit_step_end( goal = self._step_goal or "" pre_snap = self._step_pre_snapshot or self.last_snapshot pre_url = ( - self._step_pre_url - or (pre_snap.url if pre_snap else None) - or self._cached_url - or "" + self._step_pre_url or (pre_snap.url if pre_snap else None) or self._cached_url or "" ) if post_url is None: @@ -488,8 +612,8 @@ async def emit_step_end( post_url = await self.get_url() except Exception: post_url = ( - (self.last_snapshot.url if self.last_snapshot else None) or self._cached_url - ) + self.last_snapshot.url if self.last_snapshot else None + ) or self._cached_url post_url = post_url or pre_url pre_digest = self._compute_snapshot_digest(pre_snap) @@ -505,13 +629,15 @@ async def emit_step_end( signals["error"] = error passed = ( - bool(verify_passed) - if verify_passed is not None - else self.required_assertions_passed() + bool(verify_passed) if verify_passed is not None else self.required_assertions_passed() ) - exec_success = bool(success) if success is not None else bool( - self._last_action_success if self._last_action_success is not None else passed + exec_success = ( + bool(success) + if success is not None + else bool( + self._last_action_success if self._last_action_success is not None else passed + ) ) exec_data: dict[str, Any] = { @@ -716,7 +842,9 @@ def assert_done( True if task is complete (assertion passed), False otherwise """ # Convenience wrapper for assert_ with required=True + # pylint: disable=deprecated-method ok = self.assert_(predicate, label=label, required=True) + # pylint: enable=deprecated-method if ok: self._task_done = True self._task_done_label = label diff --git a/sentience/backends/playwright_backend.py b/sentience/backends/playwright_backend.py index cfbc808..47b9960 100644 --- a/sentience/backends/playwright_backend.py +++ b/sentience/backends/playwright_backend.py @@ -29,6 +29,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Literal +from ..models import TabInfo from .protocol import BrowserBackend, LayoutMetrics, ViewportInfo if TYPE_CHECKING: @@ -53,6 +54,7 @@ def __init__(self, page: "AsyncPage") -> None: self._page = page self._cached_viewport: ViewportInfo | None = None self._downloads: list[dict[str, Any]] = [] + self._tab_registry: dict[str, "AsyncPage"] = {} # Best-effort download tracking (does not change behavior unless a download occurs). # pylint: disable=broad-exception-caught @@ -108,6 +110,104 @@ def page(self) -> "AsyncPage": """Access the underlying Playwright page.""" return self._page + async def list_tabs(self) -> list[TabInfo]: + self._prune_tabs() + context = self._page.context + tabs: list[TabInfo] = [] + for page in context.pages: + tab_id = self._ensure_tab_id(page) + title = None + try: + title = await page.title() + except Exception: # pylint: disable=broad-exception-caught + title = None + tabs.append( + TabInfo( + tab_id=tab_id, + url=getattr(page, "url", None), + title=title, + is_active=page == self._page, + ) + ) + return tabs + + async def open_tab(self, url: str) -> TabInfo: + self._prune_tabs() + context = self._page.context + page = await context.new_page() + await page.goto(url) + self._page = page + tab_id = self._ensure_tab_id(page) + title = None + try: + title = await page.title() + except Exception: # pylint: disable=broad-exception-caught + title = None + return TabInfo(tab_id=tab_id, url=getattr(page, "url", None), title=title, is_active=True) + + async def switch_tab(self, tab_id: str) -> TabInfo: + self._prune_tabs() + page = self._tab_registry.get(tab_id) + if page is None: + raise ValueError(f"unknown tab_id: {tab_id}") + self._page = page + try: + await page.bring_to_front() + except Exception: # pylint: disable=broad-exception-caught + pass + title = None + try: + title = await page.title() + except Exception: # pylint: disable=broad-exception-caught + title = None + return TabInfo(tab_id=tab_id, url=getattr(page, "url", None), title=title, is_active=True) + + async def close_tab(self, tab_id: str) -> TabInfo: + self._prune_tabs() + page = self._tab_registry.get(tab_id) + if page is None: + raise ValueError(f"unknown tab_id: {tab_id}") + info = TabInfo( + tab_id=tab_id, + url=getattr(page, "url", None), + title=None, + is_active=page == self._page, + ) + try: + info.title = await page.title() + except Exception: # pylint: disable=broad-exception-caught + info.title = None + await page.close() + self._tab_registry.pop(tab_id, None) + if self._page == page: + context = page.context + pages = context.pages + if pages: + self._page = pages[0] + return info + + def _ensure_tab_id(self, page: "AsyncPage") -> str: + self._prune_tabs() + for tab_id, entry in self._tab_registry.items(): + if entry == page: + return tab_id + tab_id = f"tab-{id(page)}" + self._tab_registry[tab_id] = page + return tab_id + + def _prune_tabs(self) -> None: + dead: list[str] = [] + for tab_id, page in self._tab_registry.items(): + is_closed = getattr(page, "is_closed", None) + try: + closed = is_closed() if callable(is_closed) else bool(is_closed) + except Exception: # pragma: no cover - defensive + closed = False + if closed: + dead.append(tab_id) + for tab_id in dead: + self._tab_registry.pop(tab_id, None) + async def refresh_page_info(self) -> ViewportInfo: """Cache viewport + scroll offsets; cheap & safe to call often.""" result = await self._page.evaluate( diff --git a/sentience/browser.py b/sentience/browser.py index 7e40cb1..2191963 100644 --- a/sentience/browser.py +++ b/sentience/browser.py @@ -34,6 +34,51 @@ STEALTH_AVAILABLE = False +def _normalize_domain(domain: str) -> str: + raw = domain.strip() + if "://" in raw: + host = urlparse(raw).hostname or "" + else: + host = raw.split("/", 1)[0] + host = host.split(":", 1)[0] + return host.strip().lower().lstrip(".") + + +def _domain_matches(host: str, pattern: str) -> bool: + host_norm = _normalize_domain(host) + pat = _normalize_domain(pattern) + if pat.startswith("*."): + pat = pat[2:] + return host_norm == pat or host_norm.endswith(f".{pat}") + + +def _extract_host(url: str) -> str | None: + raw = url.strip() + if "://" not in raw: + raw = f"https://{raw}" + parsed = urlparse(raw) + return parsed.hostname + + +def _is_domain_allowed( + host: str | None, allowed: list[str] | None, prohibited: list[str] | None +) -> bool: + """ + Return True if host is allowed based on allow/deny lists. + + Deny list takes precedence. Empty allow list means allow all. + """ + if not host: + return False + if prohibited: + for pattern in prohibited: + if _domain_matches(host, pattern): + return False + if allowed: + return any(_domain_matches(host, pattern) for pattern in allowed) + return True + + class SentienceBrowser: """Main browser session with Sentience extension loaded""" @@ -49,6 +94,9 @@ def __init__( record_video_size: dict[str, int] | None = None, viewport: Viewport | dict[str, int] | None = None, device_scale_factor: float | None = None, + allowed_domains: list[str] | None = None, + prohibited_domains: list[str] | None = None, + keep_alive: bool = False, ): """ Initialize Sentience browser @@ -113,6 +161,11 @@ def __init__( self.record_video_dir = record_video_dir self.record_video_size = record_video_size or {"width": 1280, "height": 800} + # Domain policies + keep-alive + self.allowed_domains = allowed_domains or [] + self.prohibited_domains = prohibited_domains or [] + self.keep_alive = keep_alive + # Viewport configuration - convert dict to Viewport if needed if viewport is None: self.viewport = Viewport(width=1280, height=800) @@ -299,9 +352,16 @@ def start(self) -> None: time.sleep(0.5) def goto(self, url: str) -> None: - """Navigate to a URL and ensure extension is ready""" + """Navigate to a URL and ensure extension is ready. + + This enforces domain allow/deny policies. Direct page.goto() calls + bypass policy checks. + """ if not self.page: raise RuntimeError("Browser not started. Call start() first.") + host = _extract_host(url) + if not _is_domain_allowed(host, self.allowed_domains, self.prohibited_domains): + raise ValueError(f"domain not allowed: {host}") self.page.goto(url, wait_until="domcontentloaded") @@ -474,11 +534,16 @@ def close(self, output_path: str | Path | None = None) -> str | None: Path to video file if recording was enabled, None otherwise Note: Video files are saved automatically by Playwright when context closes. If multiple pages exist, returns the path to the first page's video. + If keep_alive is True, returns None and skips shutdown. """ # CRITICAL: Don't access page.video.path() BEFORE closing context # This can poke the video subsystem at an awkward time and cause crashes on macOS # Instead, we'll locate the video file after context closes + if self.keep_alive: + logger.info("Keep-alive enabled; skipping browser shutdown.") + return None + # Close context (this triggers video file finalization) if self.context: self.context.close() @@ -644,6 +709,9 @@ def __init__( viewport: Viewport | dict[str, int] | None = None, device_scale_factor: float | None = None, executable_path: str | None = None, + allowed_domains: list[str] | None = None, + prohibited_domains: list[str] | None = None, + keep_alive: bool = False, ): """ Initialize Async Sentience browser @@ -698,6 +766,11 @@ def __init__( self.record_video_dir = record_video_dir self.record_video_size = record_video_size or {"width": 1280, "height": 800} + # Domain policies + keep-alive + self.allowed_domains = allowed_domains or [] + self.prohibited_domains = prohibited_domains or [] + self.keep_alive = keep_alive + # Viewport configuration - convert dict to Viewport if needed if viewport is None: self.viewport = Viewport(width=1280, height=800) @@ -880,9 +953,16 @@ async def start(self) -> None: await asyncio.sleep(0.5) async def goto(self, url: str) -> None: - """Navigate to a URL and ensure extension is ready (async)""" + """Navigate to a URL and ensure extension is ready (async). + + This enforces domain allow/deny policies. Direct page.goto() calls + bypass policy checks. + """ if not self.page: raise RuntimeError("Browser not started. Call await start() first.") + host = _extract_host(url) + if not _is_domain_allowed(host, self.allowed_domains, self.prohibited_domains): + raise ValueError(f"domain not allowed: {host}") await self.page.goto(url, wait_until="domcontentloaded") @@ -1031,7 +1111,12 @@ async def close(self, output_path: str | Path | None = None) -> tuple[str | None Note: Video path is resolved AFTER context close to avoid touching video subsystem during teardown, which can cause crashes on macOS. + If keep_alive is True, returns (None, True) and skips shutdown. """ + if self.keep_alive: + logger.info("Keep-alive enabled; skipping browser shutdown.") + return None, True + # CRITICAL: Don't access page.video.path() BEFORE closing context # This can poke the video subsystem at an awkward time and cause crashes # Instead, we'll locate the video file after context closes diff --git a/sentience/cli.py b/sentience/cli.py index c5f669c..823c5be 100644 --- a/sentience/cli.py +++ b/sentience/cli.py @@ -3,16 +3,25 @@ """ import argparse +import base64 +import shlex import sys +import time +from pathlib import Path +from .actions import click, press, type_text from .browser import SentienceBrowser from .generator import ScriptGenerator from .inspector import inspect +from .models import SnapshotOptions from .recorder import Trace, record +from .screenshot import screenshot +from .snapshot import snapshot def cmd_inspect(args): """Start inspector mode""" + _ = args browser = SentienceBrowser(headless=False) try: browser.start() @@ -21,8 +30,6 @@ def cmd_inspect(args): with inspect(browser): # Keep running until interrupted - import time - try: while True: time.sleep(1) @@ -52,8 +59,6 @@ def cmd_record(args): rec.add_mask_pattern(pattern) # Keep running until interrupted - import time - try: while True: time.sleep(1) @@ -87,6 +92,144 @@ def cmd_gen(args): print(f"āœ… Generated {args.lang.upper()} script: {output}") +def _print_driver_help(): + print( + "\nCommands:\n" + " open Navigate to URL\n" + " state [limit] List clickable elements (optional limit)\n" + " click Click element by id\n" + " type Type text into element\n" + " press Press a key (e.g., Enter)\n" + " screenshot [path] Save screenshot (png/jpg)\n" + " close Close browser and exit\n" + " help Show this help\n" + ) + + +def cmd_driver(args): + """Manual driver CLI for open/state/click/type/screenshot/close.""" + browser = SentienceBrowser(headless=args.headless) + try: + browser.start() + if args.url: + browser.page.goto(args.url) + browser.page.wait_for_load_state("networkidle") + + print("āœ… Manual driver started. Type 'help' for commands.") + + while True: + try: + raw = input("sentience> ").strip() + except (EOFError, KeyboardInterrupt): + print("\nšŸ‘‹ Exiting manual driver.") + break + + if not raw: + continue + + try: + parts = shlex.split(raw) + except ValueError as exc: + print(f"āŒ Parse error: {exc}") + continue + + cmd = parts[0].lower() + cmd_args = parts[1:] + + if cmd in {"help", "?"}: + _print_driver_help() + continue + + if cmd == "open": + if not cmd_args: + print("āŒ Usage: open ") + continue + url = cmd_args[0] + browser.page.goto(url) + browser.page.wait_for_load_state("networkidle") + print(f"āœ… Opened {url}") + continue + + if cmd == "state": + limit = args.limit + if cmd_args: + try: + limit = int(cmd_args[0]) + except ValueError: + print("āŒ Usage: state [limit]") + continue + snap = snapshot(browser, SnapshotOptions(limit=limit)) + clickables = [ + el + for el in snap.elements + if getattr(getattr(el, "visual_cues", None), "is_clickable", False) + ] + print(f"URL: {snap.url}") + print(f"Clickable elements: {len(clickables)}") + for el in clickables: + text = (el.text or "").replace("\n", " ").strip() + if len(text) > 60: + text = text[:57] + "..." + print(f"- id={el.id} role={el.role} text='{text}'") + continue + + if cmd == "click": + if len(cmd_args) != 1: + print("āŒ Usage: click ") + continue + try: + element_id = int(cmd_args[0]) + except ValueError: + print("āŒ element_id must be an integer") + continue + click(browser, element_id) + print(f"āœ… Clicked element {element_id}") + continue + + if cmd == "type": + if len(cmd_args) < 2: + print("āŒ Usage: type ") + continue + try: + element_id = int(cmd_args[0]) + except ValueError: + print("āŒ element_id must be an integer") + continue + text = " ".join(cmd_args[1:]) + type_text(browser, element_id, text) + print(f"āœ… Typed into element {element_id}") + continue + + if cmd == "press": + if len(cmd_args) != 1: + print('āŒ Usage: press (e.g., "Enter")') + continue + press(browser, cmd_args[0]) + print(f"āœ… Pressed {cmd_args[0]}") + continue + + if cmd == "screenshot": + path = cmd_args[0] if cmd_args else None + if path is None: + path = f"screenshot-{int(time.time())}.png" + out_path = Path(path) + ext = out_path.suffix.lower() + fmt = "jpeg" if ext in {".jpg", ".jpeg"} else "png" + data_url = screenshot(browser, format=fmt) + _, b64 = data_url.split(",", 1) + out_path.write_bytes(base64.b64decode(b64)) + print(f"āœ… Saved screenshot to {out_path}") + continue + + if cmd in {"close", "exit", "quit"}: + print("šŸ‘‹ Closing browser.") + break + + print(f"āŒ Unknown command: {cmd}. Type 'help' for options.") + finally: + browser.close() + + def main(): """Main CLI entry point""" parser = argparse.ArgumentParser(description="Sentience SDK CLI") @@ -117,6 +260,15 @@ def main(): gen_parser.add_argument("--output", "-o", help="Output script file") gen_parser.set_defaults(func=cmd_gen) + # Manual driver command + driver_parser = subparsers.add_parser("driver", help="Manual driver CLI") + driver_parser.add_argument("--url", help="Start URL") + driver_parser.add_argument("--limit", type=int, default=50, help="Snapshot limit for state") + driver_parser.add_argument( + "--headless", action="store_true", help="Run browser in headless mode" + ) + driver_parser.set_defaults(func=cmd_driver) + args = parser.parse_args() if not args.command: diff --git a/sentience/conversational_agent.py b/sentience/conversational_agent.py index f9f2fc8..1ac39f1 100644 --- a/sentience/conversational_agent.py +++ b/sentience/conversational_agent.py @@ -311,7 +311,7 @@ def _execute_step(self, step: dict[str, Any]) -> StepExecutionResult: except Exception as e: if self.verbose: print(f"āŒ Step failed: {e}") - return StepExecutionResult(success=False, action=action, error=str(e)) + return StepExecutionResult(success=False, action=action, data={}, error=str(e)) def _extract_information(self, snap: Snapshot, info_type: str) -> ExtractionResult: """ diff --git a/sentience/extension/background.js b/sentience/extension/background.js index b5192d9..02c0408 100644 --- a/sentience/extension/background.js +++ b/sentience/extension/background.js @@ -28,14 +28,14 @@ async function handleSnapshotProcessing(rawData, options = {}) { const startTime = performance.now(); try { if (!Array.isArray(rawData)) throw new Error("rawData must be an array"); - if (rawData.length > 1e4 && (rawData = rawData.slice(0, 1e4)), await initWASM(), + if (rawData.length > 1e4 && (rawData = rawData.slice(0, 1e4)), await initWASM(), !wasmReady) throw new Error("WASM module not initialized"); let analyzedElements, prunedRawData; try { const wasmPromise = new Promise((resolve, reject) => { try { let result; - result = options.limit || options.filter ? analyze_page_with_options(rawData, options) : analyze_page(rawData), + result = options.limit || options.filter ? analyze_page_with_options(rawData, options) : analyze_page(rawData), resolve(result); } catch (e) { reject(e); @@ -101,4 +101,4 @@ initWASM().catch(err => {}), chrome.runtime.onMessage.addListener((request, send event.preventDefault(); }), self.addEventListener("unhandledrejection", event => { event.preventDefault(); -}); \ No newline at end of file +}); diff --git a/sentience/extension/content.js b/sentience/extension/content.js index b65cfb5..97923a2 100644 --- a/sentience/extension/content.js +++ b/sentience/extension/content.js @@ -82,7 +82,7 @@ if (!elements || !Array.isArray(elements)) return; removeOverlay(); const host = document.createElement("div"); - host.id = OVERLAY_HOST_ID, host.style.cssText = "\n position: fixed !important;\n top: 0 !important;\n left: 0 !important;\n width: 100vw !important;\n height: 100vh !important;\n pointer-events: none !important;\n z-index: 2147483647 !important;\n margin: 0 !important;\n padding: 0 !important;\n ", + host.id = OVERLAY_HOST_ID, host.style.cssText = "\n position: fixed !important;\n top: 0 !important;\n left: 0 !important;\n width: 100vw !important;\n height: 100vh !important;\n pointer-events: none !important;\n z-index: 2147483647 !important;\n margin: 0 !important;\n padding: 0 !important;\n ", document.body.appendChild(host); const shadow = host.attachShadow({ mode: "closed" @@ -94,15 +94,15 @@ let color; color = isTarget ? "#FF0000" : isPrimary ? "#0066FF" : "#00FF00"; const importanceRatio = maxImportance > 0 ? importance / maxImportance : .5, borderOpacity = isTarget ? 1 : isPrimary ? .9 : Math.max(.4, .5 + .5 * importanceRatio), fillOpacity = .2 * borderOpacity, borderWidth = isTarget ? 2 : isPrimary ? 1.5 : Math.max(.5, Math.round(2 * importanceRatio)), hexOpacity = Math.round(255 * fillOpacity).toString(16).padStart(2, "0"), box = document.createElement("div"); - if (box.style.cssText = `\n position: absolute;\n left: ${bbox.x}px;\n top: ${bbox.y}px;\n width: ${bbox.width}px;\n height: ${bbox.height}px;\n border: ${borderWidth}px solid ${color};\n background-color: ${color}${hexOpacity};\n box-sizing: border-box;\n opacity: ${borderOpacity};\n pointer-events: none;\n `, + if (box.style.cssText = `\n position: absolute;\n left: ${bbox.x}px;\n top: ${bbox.y}px;\n width: ${bbox.width}px;\n height: ${bbox.height}px;\n border: ${borderWidth}px solid ${color};\n background-color: ${color}${hexOpacity};\n box-sizing: border-box;\n opacity: ${borderOpacity};\n pointer-events: none;\n `, importance > 0 || isPrimary) { const badge = document.createElement("span"); - badge.textContent = isPrimary ? `⭐${importance}` : `${importance}`, badge.style.cssText = `\n position: absolute;\n top: -18px;\n left: 0;\n background: ${color};\n color: white;\n font-size: 11px;\n font-weight: bold;\n padding: 2px 6px;\n font-family: Arial, sans-serif;\n border-radius: 3px;\n opacity: 0.95;\n white-space: nowrap;\n pointer-events: none;\n `, + badge.textContent = isPrimary ? `⭐${importance}` : `${importance}`, badge.style.cssText = `\n position: absolute;\n top: -18px;\n left: 0;\n background: ${color};\n color: white;\n font-size: 11px;\n font-weight: bold;\n padding: 2px 6px;\n font-family: Arial, sans-serif;\n border-radius: 3px;\n opacity: 0.95;\n white-space: nowrap;\n pointer-events: none;\n `, box.appendChild(badge); } if (isTarget) { const targetIndicator = document.createElement("span"); - targetIndicator.textContent = "šŸŽÆ", targetIndicator.style.cssText = "\n position: absolute;\n top: -18px;\n right: 0;\n font-size: 16px;\n pointer-events: none;\n ", + targetIndicator.textContent = "šŸŽÆ", targetIndicator.style.cssText = "\n position: absolute;\n top: -18px;\n right: 0;\n font-size: 16px;\n pointer-events: none;\n ", box.appendChild(targetIndicator); } shadow.appendChild(box); @@ -122,7 +122,7 @@ if (!grids || !Array.isArray(grids)) return; removeOverlay(); const host = document.createElement("div"); - host.id = OVERLAY_HOST_ID, host.style.cssText = "\n position: fixed !important;\n top: 0 !important;\n left: 0 !important;\n width: 100vw !important;\n height: 100vh !important;\n pointer-events: none !important;\n z-index: 2147483647 !important;\n margin: 0 !important;\n padding: 0 !important;\n ", + host.id = OVERLAY_HOST_ID, host.style.cssText = "\n position: fixed !important;\n top: 0 !important;\n left: 0 !important;\n width: 100vw !important;\n height: 100vh !important;\n pointer-events: none !important;\n z-index: 2147483647 !important;\n margin: 0 !important;\n padding: 0 !important;\n ", document.body.appendChild(host); const shadow = host.attachShadow({ mode: "closed" @@ -138,10 +138,10 @@ let labelText = grid.label ? `Grid ${grid.grid_id}: ${grid.label}` : `Grid ${grid.grid_id}`; grid.is_dominant && (labelText = `⭐ ${labelText} (dominant)`); const badge = document.createElement("span"); - if (badge.textContent = labelText, badge.style.cssText = `\n position: absolute;\n top: -18px;\n left: 0;\n background: ${color};\n color: white;\n font-size: 11px;\n font-weight: bold;\n padding: 2px 6px;\n font-family: Arial, sans-serif;\n border-radius: 3px;\n opacity: 0.95;\n white-space: nowrap;\n pointer-events: none;\n `, + if (badge.textContent = labelText, badge.style.cssText = `\n position: absolute;\n top: -18px;\n left: 0;\n background: ${color};\n color: white;\n font-size: 11px;\n font-weight: bold;\n padding: 2px 6px;\n font-family: Arial, sans-serif;\n border-radius: 3px;\n opacity: 0.95;\n white-space: nowrap;\n pointer-events: none;\n `, box.appendChild(badge), isTarget) { const targetIndicator = document.createElement("span"); - targetIndicator.textContent = "šŸŽÆ", targetIndicator.style.cssText = "\n position: absolute;\n top: -18px;\n right: 0;\n font-size: 16px;\n pointer-events: none;\n ", + targetIndicator.textContent = "šŸŽÆ", targetIndicator.style.cssText = "\n position: absolute;\n top: -18px;\n right: 0;\n font-size: 16px;\n pointer-events: none;\n ", box.appendChild(targetIndicator); } shadow.appendChild(box); @@ -155,7 +155,7 @@ let overlayTimeout = null; function removeOverlay() { const existing = document.getElementById(OVERLAY_HOST_ID); - existing && existing.remove(), overlayTimeout && (clearTimeout(overlayTimeout), + existing && existing.remove(), overlayTimeout && (clearTimeout(overlayTimeout), overlayTimeout = null); } -}(); \ No newline at end of file +}(); diff --git a/sentience/extension/injected_api.js b/sentience/extension/injected_api.js index 12ad84b..73dda41 100644 --- a/sentience/extension/injected_api.js +++ b/sentience/extension/injected_api.js @@ -103,9 +103,9 @@ const iframes = document.querySelectorAll("iframe"); for (const iframe of iframes) { const src = iframe.getAttribute("src") || "", title = iframe.getAttribute("title") || ""; - if (src) for (const [provider, hints] of Object.entries(CAPTCHA_IFRAME_HINTS)) matchHints(src, hints) && (hasIframeHit = !0, + if (src) for (const [provider, hints] of Object.entries(CAPTCHA_IFRAME_HINTS)) matchHints(src, hints) && (hasIframeHit = !0, providerSignals[provider] += 1, addEvidence(evidence.iframe_src_hits, truncateText(src, 120))); - if (title && matchHints(title, [ "captcha", "recaptcha" ]) && (hasContainerHit = !0, + if (title && matchHints(title, [ "captcha", "recaptcha" ]) && (hasContainerHit = !0, addEvidence(evidence.selector_hits, 'iframe[title*="captcha"]')), evidence.iframe_src_hits.length >= 5) break; } } catch (e) {} @@ -114,14 +114,14 @@ for (const script of scripts) { const src = script.getAttribute("src") || ""; if (src) { - for (const [provider, hints] of Object.entries(CAPTCHA_SCRIPT_HINTS)) matchHints(src, hints) && (hasScriptHit = !0, + for (const [provider, hints] of Object.entries(CAPTCHA_SCRIPT_HINTS)) matchHints(src, hints) && (hasScriptHit = !0, providerSignals[provider] += 1, addEvidence(evidence.selector_hits, `script[src*="${hints[0]}"]`)); if (evidence.selector_hits.length >= 5) break; } } } catch (e) {} for (const {selector: selector, provider: provider} of CAPTCHA_CONTAINER_SELECTORS) try { - document.querySelector(selector) && (hasContainerHit = !0, addEvidence(evidence.selector_hits, selector), + document.querySelector(selector) && (hasContainerHit = !0, addEvidence(evidence.selector_hits, selector), "unknown" !== provider && (providerSignals[provider] += 1)); } catch (e) {} const textSnippet = function() { @@ -139,7 +139,7 @@ } catch (e) {} try { let bodyText = document.body?.innerText || ""; - return !bodyText && document.body?.textContent && (bodyText = document.body.textContent), + return !bodyText && document.body?.textContent && (bodyText = document.body.textContent), truncateText(bodyText.replace(/\s+/g, " ").trim(), 2e3); } catch (e) { return ""; @@ -147,21 +147,21 @@ }(); if (textSnippet) { const lowerText = textSnippet.toLowerCase(); - for (const keyword of CAPTCHA_TEXT_KEYWORDS) lowerText.includes(keyword) && (hasKeywordHit = !0, + for (const keyword of CAPTCHA_TEXT_KEYWORDS) lowerText.includes(keyword) && (hasKeywordHit = !0, addEvidence(evidence.text_hits, keyword)); } try { const lowerUrl = (window.location?.href || "").toLowerCase(); - for (const hint of CAPTCHA_URL_HINTS) lowerUrl.includes(hint) && (hasUrlHit = !0, + for (const hint of CAPTCHA_URL_HINTS) lowerUrl.includes(hint) && (hasUrlHit = !0, addEvidence(evidence.url_hits, hint)); } catch (e) {} let confidence = 0; - hasIframeHit && (confidence += .7), hasContainerHit && (confidence += .5), hasScriptHit && (confidence += .5), - hasKeywordHit && (confidence += .3), hasUrlHit && (confidence += .2), confidence = Math.min(1, confidence), + hasIframeHit && (confidence += .7), hasContainerHit && (confidence += .5), hasScriptHit && (confidence += .5), + hasKeywordHit && (confidence += .3), hasUrlHit && (confidence += .2), confidence = Math.min(1, confidence), hasIframeHit && (confidence = Math.max(confidence, .8)), !hasKeywordHit || hasIframeHit || hasContainerHit || hasScriptHit || hasUrlHit || (confidence = Math.min(confidence, .4)); const detected = confidence >= .7; let providerHint = null; - return providerSignals.recaptcha > 0 ? providerHint = "recaptcha" : providerSignals.hcaptcha > 0 ? providerHint = "hcaptcha" : providerSignals.turnstile > 0 ? providerHint = "turnstile" : providerSignals.arkose > 0 ? providerHint = "arkose" : providerSignals.awswaf > 0 ? providerHint = "awswaf" : detected && (providerHint = "unknown"), + return providerSignals.recaptcha > 0 ? providerHint = "recaptcha" : providerSignals.hcaptcha > 0 ? providerHint = "hcaptcha" : providerSignals.turnstile > 0 ? providerHint = "turnstile" : providerSignals.arkose > 0 ? providerHint = "arkose" : providerSignals.awswaf > 0 ? providerHint = "awswaf" : detected && (providerHint = "unknown"), { detected: detected, provider_hint: providerHint, @@ -271,7 +271,7 @@ if (labelEl) { let text = ""; try { - if (text = (labelEl.innerText || "").trim(), !text && labelEl.textContent && (text = labelEl.textContent.trim()), + if (text = (labelEl.innerText || "").trim(), !text && labelEl.textContent && (text = labelEl.textContent.trim()), !text && labelEl.getAttribute) { const ariaLabel = labelEl.getAttribute("aria-label"); ariaLabel && (text = ariaLabel.trim()); @@ -466,7 +466,7 @@ }); const checkStable = () => { const timeSinceLastChange = Date.now() - lastChange, totalWait = Date.now() - startTime; - timeSinceLastChange >= quietPeriod || totalWait >= maxWait ? (observer.disconnect(), + timeSinceLastChange >= quietPeriod || totalWait >= maxWait ? (observer.disconnect(), resolve()) : setTimeout(checkStable, 50); }; checkStable(); @@ -492,7 +492,7 @@ }); const checkQuiet = () => { const timeSinceLastChange = Date.now() - lastChange, totalWait = Date.now() - startTime; - timeSinceLastChange >= quietPeriod || totalWait >= maxWait ? (quietObserver.disconnect(), + timeSinceLastChange >= quietPeriod || totalWait >= maxWait ? (quietObserver.disconnect(), resolve()) : setTimeout(checkQuiet, 50); }; checkQuiet(); @@ -607,7 +607,7 @@ }(el); let safeValue = null, valueRedacted = null; try { - if (void 0 !== el.value || el.getAttribute && null !== el.getAttribute("value")) if (isPasswordInput) safeValue = null, + if (void 0 !== el.value || el.getAttribute && null !== el.getAttribute("value")) if (isPasswordInput) safeValue = null, valueRedacted = "true"; else { const rawValue = void 0 !== el.value ? String(el.value) : String(el.getAttribute("value")); safeValue = rawValue.length > 200 ? rawValue.substring(0, 200) : rawValue, valueRedacted = "false"; @@ -734,8 +734,8 @@ const requestId = `iframe-${idx}-${Date.now()}`, timeout = setTimeout(() => { resolve(null); }, 5e3), listener = event => { - "SENTIENCE_IFRAME_SNAPSHOT_RESPONSE" === event.data?.type && event.data, "SENTIENCE_IFRAME_SNAPSHOT_RESPONSE" === event.data?.type && event.data?.requestId === requestId && (clearTimeout(timeout), - window.removeEventListener("message", listener), event.data.error ? resolve(null) : (event.data.snapshot, + "SENTIENCE_IFRAME_SNAPSHOT_RESPONSE" === event.data?.type && event.data, "SENTIENCE_IFRAME_SNAPSHOT_RESPONSE" === event.data?.type && event.data?.requestId === requestId && (clearTimeout(timeout), + window.removeEventListener("message", listener), event.data.error ? resolve(null) : (event.data.snapshot, resolve({ iframe: iframe, data: event.data.snapshot, @@ -751,7 +751,7 @@ ...options, collectIframes: !0 } - }, "*") : (clearTimeout(timeout), window.removeEventListener("message", listener), + }, "*") : (clearTimeout(timeout), window.removeEventListener("message", listener), resolve(null)); } catch (error) { clearTimeout(timeout), window.removeEventListener("message", listener), resolve(null); @@ -836,7 +836,7 @@ }, 25e3), listener = e => { if ("SENTIENCE_SNAPSHOT_RESULT" === e.data.type && e.data.requestId === requestId) { if (resolved) return; - resolved = !0, clearTimeout(timeout), window.removeEventListener("message", listener), + resolved = !0, clearTimeout(timeout), window.removeEventListener("message", listener), e.data.error ? reject(new Error(e.data.error)) : resolve({ elements: e.data.elements, raw_elements: e.data.raw_elements, @@ -853,7 +853,7 @@ options: options }, "*"); } catch (error) { - resolved || (resolved = !0, clearTimeout(timeout), window.removeEventListener("message", listener), + resolved || (resolved = !0, clearTimeout(timeout), window.removeEventListener("message", listener), reject(new Error(`Failed to send snapshot request: ${error.message}`))); } }); @@ -874,7 +874,7 @@ options.screenshot && (screenshot = await function(options) { return new Promise(resolve => { const requestId = Math.random().toString(36).substring(7), listener = e => { - "SENTIENCE_SCREENSHOT_RESULT" === e.data.type && e.data.requestId === requestId && (window.removeEventListener("message", listener), + "SENTIENCE_SCREENSHOT_RESULT" === e.data.type && e.data.requestId === requestId && (window.removeEventListener("message", listener), resolve(e.data.screenshot)); }; window.addEventListener("message", listener), window.postMessage({ @@ -893,7 +893,7 @@ const lastMutationTs = window.__sentience_lastMutationTs, now = performance.now(), quietMs = "number" == typeof lastMutationTs && Number.isFinite(lastMutationTs) ? Math.max(0, now - lastMutationTs) : null, nodeCount = document.querySelectorAll("*").length; let requiresVision = !1, requiresVisionReason = null; const canvasCount = document.getElementsByTagName("canvas").length; - canvasCount > 0 && (requiresVision = !0, requiresVisionReason = `canvas:${canvasCount}`), + canvasCount > 0 && (requiresVision = !0, requiresVisionReason = `canvas:${canvasCount}`), diagnostics = { metrics: { ready_state: document.readyState || null, @@ -939,15 +939,15 @@ } if (node.nodeType !== Node.ELEMENT_NODE) return; const tag = node.tagName.toLowerCase(); - if ("h1" === tag && (markdown += "\n# "), "h2" === tag && (markdown += "\n## "), - "h3" === tag && (markdown += "\n### "), "li" === tag && (markdown += "\n- "), insideLink || "p" !== tag && "div" !== tag && "br" !== tag || (markdown += "\n"), - "strong" !== tag && "b" !== tag || (markdown += "**"), "em" !== tag && "i" !== tag || (markdown += "_"), - "a" === tag && (markdown += "[", insideLink = !0), node.shadowRoot ? Array.from(node.shadowRoot.childNodes).forEach(walk) : node.childNodes.forEach(walk), + if ("h1" === tag && (markdown += "\n# "), "h2" === tag && (markdown += "\n## "), + "h3" === tag && (markdown += "\n### "), "li" === tag && (markdown += "\n- "), insideLink || "p" !== tag && "div" !== tag && "br" !== tag || (markdown += "\n"), + "strong" !== tag && "b" !== tag || (markdown += "**"), "em" !== tag && "i" !== tag || (markdown += "_"), + "a" === tag && (markdown += "[", insideLink = !0), node.shadowRoot ? Array.from(node.shadowRoot.childNodes).forEach(walk) : node.childNodes.forEach(walk), "a" === tag) { const href = node.getAttribute("href"); markdown += href ? `](${href})` : "]", insideLink = !1; } - "strong" !== tag && "b" !== tag || (markdown += "**"), "em" !== tag && "i" !== tag || (markdown += "_"), + "strong" !== tag && "b" !== tag || (markdown += "**"), "em" !== tag && "i" !== tag || (markdown += "_"), insideLink || "h1" !== tag && "h2" !== tag && "h3" !== tag && "p" !== tag && "div" !== tag || (markdown += "\n"); }(tempDiv), markdown.replace(/\n{3,}/g, "\n\n").trim(); }(document.body) : function(root) { @@ -960,7 +960,7 @@ const style = window.getComputedStyle(node); if ("none" === style.display || "hidden" === style.visibility) return; const isBlock = "block" === style.display || "flex" === style.display || "P" === node.tagName || "DIV" === node.tagName; - isBlock && (text += " "), node.shadowRoot ? Array.from(node.shadowRoot.childNodes).forEach(walk) : node.childNodes.forEach(walk), + isBlock && (text += " "), node.shadowRoot ? Array.from(node.shadowRoot.childNodes).forEach(walk) : node.childNodes.forEach(walk), isBlock && (text += "\n"); } } else text += node.textContent; @@ -1059,25 +1059,25 @@ } function startRecording(options = {}) { const {highlightColor: highlightColor = "#ff0000", successColor: successColor = "#00ff00", autoDisableTimeout: autoDisableTimeout = 18e5, keyboardShortcut: keyboardShortcut = "Ctrl+Shift+I"} = options; - if (!window.sentience_registry || 0 === window.sentience_registry.length) return alert("Registry empty. Run `await window.sentience.snapshot()` first!"), + if (!window.sentience_registry || 0 === window.sentience_registry.length) return alert("Registry empty. Run `await window.sentience.snapshot()` first!"), () => {}; window.sentience_registry_map = new Map, window.sentience_registry.forEach((el, idx) => { el && window.sentience_registry_map.set(el, idx); }); let highlightBox = document.getElementById("sentience-highlight-box"); - highlightBox || (highlightBox = document.createElement("div"), highlightBox.id = "sentience-highlight-box", - highlightBox.style.cssText = `\n position: fixed;\n pointer-events: none;\n z-index: 2147483647;\n border: 2px solid ${highlightColor};\n background: rgba(255, 0, 0, 0.1);\n display: none;\n transition: all 0.1s ease;\n box-sizing: border-box;\n `, + highlightBox || (highlightBox = document.createElement("div"), highlightBox.id = "sentience-highlight-box", + highlightBox.style.cssText = `\n position: fixed;\n pointer-events: none;\n z-index: 2147483647;\n border: 2px solid ${highlightColor};\n background: rgba(255, 0, 0, 0.1);\n display: none;\n transition: all 0.1s ease;\n box-sizing: border-box;\n `, document.body.appendChild(highlightBox)); let recordingIndicator = document.getElementById("sentience-recording-indicator"); - recordingIndicator || (recordingIndicator = document.createElement("div"), recordingIndicator.id = "sentience-recording-indicator", - recordingIndicator.style.cssText = `\n position: fixed;\n top: 0;\n left: 0;\n right: 0;\n height: 3px;\n background: ${highlightColor};\n z-index: 2147483646;\n pointer-events: none;\n `, + recordingIndicator || (recordingIndicator = document.createElement("div"), recordingIndicator.id = "sentience-recording-indicator", + recordingIndicator.style.cssText = `\n position: fixed;\n top: 0;\n left: 0;\n right: 0;\n height: 3px;\n background: ${highlightColor};\n z-index: 2147483646;\n pointer-events: none;\n `, document.body.appendChild(recordingIndicator)), recordingIndicator.style.display = "block"; const mouseOverHandler = e => { const el = e.target; if (!el || el === highlightBox || el === recordingIndicator) return; const rect = el.getBoundingClientRect(); - highlightBox.style.display = "block", highlightBox.style.top = rect.top + window.scrollY + "px", - highlightBox.style.left = rect.left + window.scrollX + "px", highlightBox.style.width = rect.width + "px", + highlightBox.style.display = "block", highlightBox.style.top = rect.top + window.scrollY + "px", + highlightBox.style.left = rect.left + window.scrollX + "px", highlightBox.style.width = rect.width + "px", highlightBox.style.height = rect.height + "px"; }, clickHandler = e => { e.preventDefault(), e.stopPropagation(); @@ -1154,7 +1154,7 @@ debug_snapshot: rawData }, jsonString = JSON.stringify(snippet, null, 2); navigator.clipboard.writeText(jsonString).then(() => { - highlightBox.style.border = `2px solid ${successColor}`, highlightBox.style.background = "rgba(0, 255, 0, 0.2)", + highlightBox.style.border = `2px solid ${successColor}`, highlightBox.style.background = "rgba(0, 255, 0, 0.2)", setTimeout(() => { highlightBox.style.border = `2px solid ${highlightColor}`, highlightBox.style.background = "rgba(255, 0, 0, 0.1)"; }, 500); @@ -1164,15 +1164,15 @@ }; let timeoutId = null; const stopRecording = () => { - document.removeEventListener("mouseover", mouseOverHandler, !0), document.removeEventListener("click", clickHandler, !0), - document.removeEventListener("keydown", keyboardHandler, !0), timeoutId && (clearTimeout(timeoutId), - timeoutId = null), highlightBox && (highlightBox.style.display = "none"), recordingIndicator && (recordingIndicator.style.display = "none"), + document.removeEventListener("mouseover", mouseOverHandler, !0), document.removeEventListener("click", clickHandler, !0), + document.removeEventListener("keydown", keyboardHandler, !0), timeoutId && (clearTimeout(timeoutId), + timeoutId = null), highlightBox && (highlightBox.style.display = "none"), recordingIndicator && (recordingIndicator.style.display = "none"), window.sentience_registry_map && window.sentience_registry_map.clear(), window.sentience_stopRecording === stopRecording && delete window.sentience_stopRecording; }, keyboardHandler = e => { - (e.ctrlKey || e.metaKey) && e.shiftKey && "I" === e.key && (e.preventDefault(), + (e.ctrlKey || e.metaKey) && e.shiftKey && "I" === e.key && (e.preventDefault(), stopRecording()); }; - return document.addEventListener("mouseover", mouseOverHandler, !0), document.addEventListener("click", clickHandler, !0), + return document.addEventListener("mouseover", mouseOverHandler, !0), document.addEventListener("click", clickHandler, !0), document.addEventListener("keydown", keyboardHandler, !0), autoDisableTimeout > 0 && (timeoutId = setTimeout(() => { stopRecording(); }, autoDisableTimeout)), window.sentience_stopRecording = stopRecording, stopRecording; @@ -1241,4 +1241,4 @@ } }), window.sentience_iframe_handler_setup = !0)); })(); -}(); \ No newline at end of file +}(); diff --git a/sentience/extension/pkg/sentience_core.js b/sentience/extension/pkg/sentience_core.js index bb9cae0..c50ad61 100644 --- a/sentience/extension/pkg/sentience_core.js +++ b/sentience/extension/pkg/sentience_core.js @@ -25,7 +25,7 @@ function __wbg_get_imports() { }, __wbg___wbindgen_bigint_get_as_i64_8fcf4ce7f1ca72a2: function(arg0, arg1) { const v = getObject(arg1), ret = "bigint" == typeof v ? v : void 0; - getDataViewMemory0().setBigInt64(arg0 + 8, isLikeNone(ret) ? BigInt(0) : ret, !0), + getDataViewMemory0().setBigInt64(arg0 + 8, isLikeNone(ret) ? BigInt(0) : ret, !0), getDataViewMemory0().setInt32(arg0 + 0, !isLikeNone(ret), !0); }, __wbg___wbindgen_boolean_get_bbbb1c18aa2f5e25: function(arg0) { @@ -224,7 +224,7 @@ function getArrayU8FromWasm0(ptr, len) { let cachedDataViewMemory0 = null; function getDataViewMemory0() { - return (null === cachedDataViewMemory0 || !0 === cachedDataViewMemory0.buffer.detached || void 0 === cachedDataViewMemory0.buffer.detached && cachedDataViewMemory0.buffer !== wasm.memory.buffer) && (cachedDataViewMemory0 = new DataView(wasm.memory.buffer)), + return (null === cachedDataViewMemory0 || !0 === cachedDataViewMemory0.buffer.detached || void 0 === cachedDataViewMemory0.buffer.detached && cachedDataViewMemory0.buffer !== wasm.memory.buffer) && (cachedDataViewMemory0 = new DataView(wasm.memory.buffer)), cachedDataViewMemory0; } @@ -235,7 +235,7 @@ function getStringFromWasm0(ptr, len) { let cachedUint8ArrayMemory0 = null; function getUint8ArrayMemory0() { - return null !== cachedUint8ArrayMemory0 && 0 !== cachedUint8ArrayMemory0.byteLength || (cachedUint8ArrayMemory0 = new Uint8Array(wasm.memory.buffer)), + return null !== cachedUint8ArrayMemory0 && 0 !== cachedUint8ArrayMemory0.byteLength || (cachedUint8ArrayMemory0 = new Uint8Array(wasm.memory.buffer)), cachedUint8ArrayMemory0; } @@ -264,7 +264,7 @@ function isLikeNone(x) { function passStringToWasm0(arg, malloc, realloc) { if (void 0 === realloc) { const buf = cachedTextEncoder.encode(arg), ptr = malloc(buf.length, 1) >>> 0; - return getUint8ArrayMemory0().subarray(ptr, ptr + buf.length).set(buf), WASM_VECTOR_LEN = buf.length, + return getUint8ArrayMemory0().subarray(ptr, ptr + buf.length).set(buf), WASM_VECTOR_LEN = buf.length, ptr; } let len = arg.length, ptr = malloc(len, 1) >>> 0; @@ -319,7 +319,7 @@ const cachedTextEncoder = new TextEncoder; let wasmModule, wasm, WASM_VECTOR_LEN = 0; function __wbg_finalize_init(instance, module) { - return wasm = instance.exports, wasmModule = module, cachedDataViewMemory0 = null, + return wasm = instance.exports, wasmModule = module, cachedDataViewMemory0 = null, cachedUint8ArrayMemory0 = null, wasm; } @@ -360,7 +360,7 @@ function initSync(module) { async function __wbg_init(module_or_path) { if (void 0 !== wasm) return wasm; - void 0 !== module_or_path && Object.getPrototypeOf(module_or_path) === Object.prototype && ({module_or_path: module_or_path} = module_or_path), + void 0 !== module_or_path && Object.getPrototypeOf(module_or_path) === Object.prototype && ({module_or_path: module_or_path} = module_or_path), void 0 === module_or_path && (module_or_path = new URL("sentience_core_bg.wasm", import.meta.url)); const imports = __wbg_get_imports(); ("string" == typeof module_or_path || "function" == typeof Request && module_or_path instanceof Request || "function" == typeof URL && module_or_path instanceof URL) && (module_or_path = fetch(module_or_path)); @@ -368,4 +368,4 @@ async function __wbg_init(module_or_path) { return __wbg_finalize_init(instance, module); } -export { initSync, __wbg_init as default }; \ No newline at end of file +export { initSync, __wbg_init as default }; diff --git a/sentience/llm_provider.py b/sentience/llm_provider.py index 4874e47..e220e75 100644 --- a/sentience/llm_provider.py +++ b/sentience/llm_provider.py @@ -3,6 +3,7 @@ Enables "Bring Your Own Brain" (BYOB) pattern - plug in any LLM provider """ +import asyncio from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any @@ -59,6 +60,12 @@ def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> LLMRespons """ pass + async def generate_async(self, system_prompt: str, user_prompt: str, **kwargs) -> LLMResponse: + """ + Async wrapper around generate() for providers without native async support. + """ + return await asyncio.to_thread(self.generate, system_prompt, user_prompt, **kwargs) + @abstractmethod def supports_json_mode(self) -> bool: """ @@ -335,9 +342,7 @@ def __init__( model: str = "meta-llama/Meta-Llama-3-8B-Instruct", base_url: str = "https://api.deepinfra.com/v1/openai", ): - api_key = get_api_key_from_env( - ["DEEPINFRA_TOKEN", "DEEPINFRA_API_KEY"], api_key - ) + api_key = get_api_key_from_env(["DEEPINFRA_TOKEN", "DEEPINFRA_API_KEY"], api_key) super().__init__(api_key=api_key, model=model, base_url=base_url) diff --git a/sentience/models.py b/sentience/models.py index 2daef70..ebc18f8 100644 --- a/sentience/models.py +++ b/sentience/models.py @@ -630,6 +630,75 @@ class ActionResult(BaseModel): cursor: dict[str, Any] | None = None +class TabInfo(BaseModel): + """Metadata about an open browser tab/page.""" + + tab_id: str + url: str | None = None + title: str | None = None + is_active: bool = False + + +class TabListResult(BaseModel): + """Result of listing tabs.""" + + ok: bool + tabs: list[TabInfo] = Field(default_factory=list) + error: str | None = None + + +class TabOperationResult(BaseModel): + """Result of tab operations (open/switch/close).""" + + ok: bool + tab: TabInfo | None = None + error: str | None = None + + +class StepHookContext(BaseModel): + """Context passed to lifecycle hooks.""" + + step_id: str + step_index: int + goal: str + attempt: int = 0 + url: str | None = None + success: bool | None = None + outcome: str | None = None + error: str | None = None + + +class EvaluateJsRequest(BaseModel): + """Request for evaluate_js helper.""" + + code: str = Field( + ..., + min_length=1, + max_length=8000, + description="JavaScript source code to evaluate in the page context.", + ) + max_output_chars: int = Field( + 4000, + ge=1, + le=20000, + description="Maximum number of characters to return in the text field.", + ) + truncate: bool = Field( + True, + description="Whether to truncate text output when it exceeds max_output_chars.", + ) + + +class EvaluateJsResult(BaseModel): + """Result of evaluate_js helper.""" + + ok: bool = Field(..., description="Whether evaluation succeeded.") + value: Any | None = Field(None, description="Raw value returned by the page evaluation.") + text: str | None = Field(None, description="Best-effort string representation of the value.") + truncated: bool = Field(False, description="True if text output was truncated.") + error: str | None = Field(None, description="Error string when ok=False.") + + class WaitResult(BaseModel): """Result of wait_for operation""" @@ -988,6 +1057,15 @@ class ReadResult(BaseModel): error: str | None = None +class ExtractResult(BaseModel): + """Result of extract() or extract_async() operation""" + + ok: bool + data: Any | None = None + raw: str | None = None + error: str | None = None + + class TraceStats(BaseModel): """Execution statistics for trace completion""" diff --git a/sentience/read.py b/sentience/read.py index 6d95534..8245f03 100644 --- a/sentience/read.py +++ b/sentience/read.py @@ -2,10 +2,15 @@ Read page content - supports raw HTML, text, and markdown formats """ -from typing import Literal +import json +import re +from typing import Any, Literal + +from pydantic import BaseModel, ValidationError from .browser import AsyncSentienceBrowser, SentienceBrowser -from .models import ReadResult +from .llm_provider import LLMProvider +from .models import ExtractResult, ReadResult def read( @@ -66,13 +71,13 @@ def read( from markdownify import MarkdownifyError, markdownify markdown_content = markdownify(html_content, heading_style="ATX", wrap=True) - return { - "status": "success", - "url": raw_html_result["url"], - "format": "markdown", - "content": markdown_content, - "length": len(markdown_content), - } + return ReadResult( + status="success", + url=raw_html_result["url"], + format="markdown", + content=markdown_content, + length=len(markdown_content), + ) except ImportError: print( "Warning: 'markdownify' not installed. Install with 'pip install markdownify' for enhanced markdown. Falling back to extension's markdown." @@ -156,13 +161,13 @@ async def read_async( from markdownify import MarkdownifyError, markdownify markdown_content = markdownify(html_content, heading_style="ATX", wrap=True) - return { - "status": "success", - "url": raw_html_result["url"], - "format": "markdown", - "content": markdown_content, - "length": len(markdown_content), - } + return ReadResult( + status="success", + url=raw_html_result["url"], + format="markdown", + content=markdown_content, + length=len(markdown_content), + ) except ImportError: print( "Warning: 'markdownify' not installed. Install with 'pip install markdownify' for enhanced markdown. Falling back to extension's markdown." @@ -186,3 +191,81 @@ async def read_async( # Convert dict result to ReadResult model return ReadResult(**result) + + +def _extract_json_payload(text: str) -> dict[str, Any]: + fenced = re.search(r"```json\s*(\{.*?\})\s*```", text, re.DOTALL | re.IGNORECASE) + if fenced: + return json.loads(fenced.group(1)) + inline = re.search(r"(\{.*\})", text, re.DOTALL) + if inline: + return json.loads(inline.group(1)) + return json.loads(text) + + +def extract( + browser: SentienceBrowser, + llm: LLMProvider, + query: str, + schema: type[BaseModel] | None = None, + max_chars: int = 12000, +) -> ExtractResult: + """ + Extract structured data from the current page using read() markdown + LLM. + """ + result = read(browser, output_format="markdown", enhance_markdown=True) + if result.status != "success": + return ExtractResult(ok=False, error=result.error) + + content = result.content[:max_chars] + schema_desc = "" + if schema is not None: + schema_desc = json.dumps(schema.model_json_schema(), ensure_ascii=False) + system = "You extract structured data from markdown content. " "Return only JSON. No prose." + user = f"QUERY:\n{query}\n\nSCHEMA:\n{schema_desc}\n\nCONTENT:\n{content}" + response = llm.generate(system, user) + raw = response.content.strip() + + if schema is None: + return ExtractResult(ok=True, data={"text": raw}, raw=raw) + + try: + payload = _extract_json_payload(raw) + validated = schema.model_validate(payload) + return ExtractResult(ok=True, data=validated, raw=raw) + except (json.JSONDecodeError, ValidationError) as exc: + return ExtractResult(ok=False, error=str(exc), raw=raw) + + +async def extract_async( + browser: AsyncSentienceBrowser, + llm: LLMProvider, + query: str, + schema: type[BaseModel] | None = None, + max_chars: int = 12000, +) -> ExtractResult: + """ + Async version of extract(). + """ + result = await read_async(browser, output_format="markdown", enhance_markdown=True) + if result.status != "success": + return ExtractResult(ok=False, error=result.error) + + content = result.content[:max_chars] + schema_desc = "" + if schema is not None: + schema_desc = json.dumps(schema.model_json_schema(), ensure_ascii=False) + system = "You extract structured data from markdown content. " "Return only JSON. No prose." + user = f"QUERY:\n{query}\n\nSCHEMA:\n{schema_desc}\n\nCONTENT:\n{content}" + response = await llm.generate_async(system, user) + raw = response.content.strip() + + if schema is None: + return ExtractResult(ok=True, data={"text": raw}, raw=raw) + + try: + payload = _extract_json_payload(raw) + validated = schema.model_validate(payload) + return ExtractResult(ok=True, data=validated, raw=raw) + except (json.JSONDecodeError, ValidationError) as exc: + return ExtractResult(ok=False, error=str(exc), raw=raw) diff --git a/sentience/runtime_agent.py b/sentience/runtime_agent.py index 3c107e2..b6c4193 100644 --- a/sentience/runtime_agent.py +++ b/sentience/runtime_agent.py @@ -10,7 +10,9 @@ from __future__ import annotations import base64 +import inspect import re +from collections.abc import Callable from dataclasses import dataclass, field from typing import Any, Literal @@ -18,7 +20,7 @@ from .backends import actions as backend_actions from .llm_interaction_handler import LLMInteractionHandler from .llm_provider import LLMProvider -from .models import BBox, Snapshot +from .models import BBox, Snapshot, StepHookContext from .verification import AssertContext, AssertOutcome, Predicate @@ -84,15 +86,29 @@ async def run_step( *, task_goal: str, step: RuntimeStep, + on_step_start: Callable[[StepHookContext], Any] | None = None, + on_step_end: Callable[[StepHookContext], Any] | None = None, ) -> bool: - self.runtime.begin_step(step.goal) + step_id = self.runtime.begin_step(step.goal) + await self._run_hook( + on_step_start, + StepHookContext( + step_id=step_id, + step_index=self.runtime.step_index, + goal=step.goal, + url=getattr(self.runtime.last_snapshot, "url", None), + ), + ) emitted = False ok = False + error_msg: str | None = None + outcome: str | None = None try: snap = await self._snapshot_with_ramp(step=step) if await self._should_short_circuit_to_vision(step=step, snap=snap): ok = await self._vision_executor_attempt(task_goal=task_goal, step=step, snap=snap) + outcome = "ok" if ok else "verification_failed" return ok # 1) Structured executor attempt. @@ -100,15 +116,20 @@ async def run_step( await self._execute_action(action=action, snap=snap) ok = await self._apply_verifications(step=step) if ok: + outcome = "ok" return True # 2) Optional vision executor fallback (bounded). if step.vision_executor_enabled and step.max_vision_executor_attempts > 0: ok = await self._vision_executor_attempt(task_goal=task_goal, step=step, snap=snap) + outcome = "ok" if ok else "verification_failed" return ok + outcome = "verification_failed" return False except Exception as exc: + error_msg = str(exc) + outcome = "exception" try: await self.runtime.emit_step_end( success=False, @@ -130,6 +151,29 @@ async def run_step( ) except Exception: pass + await self._run_hook( + on_step_end, + StepHookContext( + step_id=step_id, + step_index=self.runtime.step_index, + goal=step.goal, + url=getattr(self.runtime.last_snapshot, "url", None), + success=ok, + outcome=outcome, + error=error_msg, + ), + ) + + async def _run_hook( + self, + hook: Callable[[StepHookContext], Any] | None, + ctx: StepHookContext, + ) -> None: + if hook is None: + return + result = hook(ctx) + if inspect.isawaitable(result): + await result async def _snapshot_with_ramp(self, *, step: RuntimeStep) -> Snapshot: limit = step.snapshot_limit_base diff --git a/sentience/tools/__init__.py b/sentience/tools/__init__.py new file mode 100644 index 0000000..cf4de8f --- /dev/null +++ b/sentience/tools/__init__.py @@ -0,0 +1,19 @@ +""" +Tool registry for LLM-callable tools. +""" + +from .context import BackendCapabilities, ToolContext, UnsupportedCapabilityError +from .defaults import register_default_tools +from .filesystem import FileSandbox, register_filesystem_tools +from .registry import ToolRegistry, ToolSpec + +__all__ = [ + "BackendCapabilities", + "FileSandbox", + "ToolContext", + "ToolRegistry", + "ToolSpec", + "UnsupportedCapabilityError", + "register_default_tools", + "register_filesystem_tools", +] diff --git a/sentience/tools/context.py b/sentience/tools/context.py new file mode 100644 index 0000000..dfcd562 --- /dev/null +++ b/sentience/tools/context.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +if TYPE_CHECKING: + from ..agent_runtime import AgentRuntime + from .filesystem import FileSandbox + + +class BackendCapabilities(BaseModel): + """Best-effort backend capability flags.""" + + tabs: bool = False + evaluate_js: bool = False + downloads: bool = False + filesystem_tools: bool = False + keyboard: bool = False + + +class UnsupportedCapabilityError(RuntimeError): + """Structured error for unsupported capabilities.""" + + def __init__(self, capability: str, detail: str | None = None) -> None: + msg = detail or f"{capability} not supported by backend" + super().__init__(msg) + self.error = "unsupported_capability" + self.detail = msg + self.capability = capability + + +class ToolContext: + """Context passed to tool handlers.""" + + def __init__( + self, + runtime: AgentRuntime, + files: "FileSandbox" | None = None, + base_dir: Path | None = None, + ) -> None: + self.runtime = runtime + if files is None: + root = base_dir or (Path.cwd() / ".sentience" / "files") + from .filesystem import FileSandbox + + files = FileSandbox(root) + self.files = files + + def capabilities(self) -> BackendCapabilities: + return self.runtime.capabilities() + + def can(self, name: str) -> bool: + return self.runtime.can(name) + + def require(self, name: str) -> None: + if not self.can(name): + raise UnsupportedCapabilityError(name) diff --git a/sentience/tools/defaults.py b/sentience/tools/defaults.py new file mode 100644 index 0000000..5b319dd --- /dev/null +++ b/sentience/tools/defaults.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +from pydantic import BaseModel, Field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from ..agent_runtime import AgentRuntime +from ..backends import actions as backend_actions +from ..models import ActionResult, BBox, EvaluateJsRequest, Snapshot +from .context import ToolContext +from .registry import ToolRegistry + + +class SnapshotToolInput(BaseModel): + limit: int = Field(50, ge=1, le=500, description="Max elements to return.") + + +class ClickToolInput(BaseModel): + element_id: int = Field(..., ge=1, description="Sentience element id from snapshot.") + + +class TypeToolInput(BaseModel): + element_id: int = Field(..., ge=1, description="Sentience element id from snapshot.") + text: str = Field(..., min_length=1, description="Text to type into the element.") + clear_first: bool = Field(False, description="Clear existing content before typing.") + + +class ScrollToolInput(BaseModel): + delta_y: float = Field(..., description="Scroll amount (positive = down, negative = up).") + x: float | None = Field(None, description="Optional scroll x coordinate.") + y: float | None = Field(None, description="Optional scroll y coordinate.") + + +class ScrollToElementToolInput(BaseModel): + element_id: int = Field(..., ge=1, description="Sentience element id from snapshot.") + behavior: str = Field("instant", description="Scroll behavior.") + block: str = Field("center", description="Vertical alignment.") + + +class ClickRectToolInput(BaseModel): + x: float = Field(..., description="Rect x coordinate.") + y: float = Field(..., description="Rect y coordinate.") + width: float = Field(..., ge=0, description="Rect width.") + height: float = Field(..., ge=0, description="Rect height.") + + +class PressToolInput(BaseModel): + key: str = Field(..., min_length=1, description="Key to press (e.g., Enter).") + + +class EvaluateJsToolInput(BaseModel): + code: str = Field(..., min_length=1, max_length=8000, description="JavaScript to execute.") + max_output_chars: int = Field(4000, ge=1, le=20000, description="Output cap.") + truncate: bool = Field(True, description="Truncate output when too long.") + + +def register_default_tools( + registry: ToolRegistry, runtime: ToolContext | "AgentRuntime" | None = None +) -> ToolRegistry: + """Register default browser tools on a registry.""" + + def _get_runtime(ctx: ToolContext | None): + if ctx is not None: + return ctx.runtime + if runtime is not None: + if isinstance(runtime, ToolContext): + return runtime.runtime + return runtime + raise RuntimeError("ToolContext with runtime is required") + + @registry.tool( + name="snapshot_state", + input_model=SnapshotToolInput, + output_model=Snapshot, + description="Capture a snapshot of the current page state.", + ) + async def snapshot_state(ctx, params: SnapshotToolInput) -> Snapshot: + runtime_ref = _get_runtime(ctx) + snap = await runtime_ref.snapshot(limit=params.limit, goal="tool_snapshot_state") + if snap is None: + raise RuntimeError("snapshot() returned None") + return snap + + @registry.tool( + name="click", + input_model=ClickToolInput, + output_model=ActionResult, + description="Click an element by id from the latest snapshot.", + ) + async def click_tool(ctx, params: ClickToolInput) -> ActionResult: + runtime_ref = _get_runtime(ctx) + snap = runtime_ref.last_snapshot or await runtime_ref.snapshot(goal="tool_click") + if snap is None: + raise RuntimeError("snapshot() returned None") + el = next((e for e in snap.elements if e.id == params.element_id), None) + if el is None: + raise ValueError(f"element_id not found: {params.element_id}") + return await backend_actions.click(runtime_ref.backend, el.bbox) + + @registry.tool( + name="type", + input_model=TypeToolInput, + output_model=ActionResult, + description="Type text into an element by id from the latest snapshot.", + ) + async def type_tool(ctx, params: TypeToolInput) -> ActionResult: + runtime_ref = _get_runtime(ctx) + snap = runtime_ref.last_snapshot or await runtime_ref.snapshot(goal="tool_type") + if snap is None: + raise RuntimeError("snapshot() returned None") + el = next((e for e in snap.elements if e.id == params.element_id), None) + if el is None: + raise ValueError(f"element_id not found: {params.element_id}") + return await backend_actions.type_text( + runtime_ref.backend, params.text, target=el.bbox, clear_first=params.clear_first + ) + + @registry.tool( + name="scroll", + input_model=ScrollToolInput, + output_model=ActionResult, + description="Scroll the page by a delta amount.", + ) + async def scroll_tool(ctx, params: ScrollToolInput) -> ActionResult: + runtime_ref = _get_runtime(ctx) + target = None + if params.x is not None and params.y is not None: + target = (params.x, params.y) + return await backend_actions.scroll(runtime_ref.backend, params.delta_y, target=target) + + @registry.tool( + name="scroll_to_element", + input_model=ScrollToElementToolInput, + output_model=ActionResult, + description="Scroll a specific element into view by element id.", + ) + async def scroll_to_element_tool(ctx, params: ScrollToElementToolInput) -> ActionResult: + if ctx is not None: + ctx.require("evaluate_js") + runtime_ref = _get_runtime(ctx) + return await backend_actions.scroll_to_element( + runtime_ref.backend, + params.element_id, + behavior=params.behavior, + block=params.block, + ) + + @registry.tool( + name="click_rect", + input_model=ClickRectToolInput, + output_model=ActionResult, + description="Click the center of a rectangle.", + ) + async def click_rect_tool(ctx, params: ClickRectToolInput) -> ActionResult: + runtime_ref = _get_runtime(ctx) + bbox = BBox( + x=params.x, + y=params.y, + width=params.width, + height=params.height, + ) + return await backend_actions.click(runtime_ref.backend, bbox) + + @registry.tool( + name="press", + input_model=PressToolInput, + output_model=ActionResult, + description="Press a keyboard key on the active element.", + ) + async def press_tool(ctx, params: PressToolInput) -> ActionResult: + if ctx is not None: + ctx.require("keyboard") + runtime_ref = _get_runtime(ctx) + page = getattr(runtime_ref.backend, "_page", None) or getattr( + runtime_ref.backend, "page", None + ) + if page is not None and getattr(page, "keyboard", None) is not None: + await page.keyboard.press(params.key) + return ActionResult(success=True, duration_ms=0, outcome="dom_updated") + try: + await runtime_ref.backend.eval( + f""" + (() => {{ + const el = document.activeElement; + if (!el) return false; + const key = {params.key!r}; + el.dispatchEvent(new KeyboardEvent('keydown', {{ key }})); + el.dispatchEvent(new KeyboardEvent('keyup', {{ key }})); + return true; + }})() + """ + ) + return ActionResult(success=True, duration_ms=0, outcome="dom_updated") + except Exception as exc: + return ActionResult( + success=False, + duration_ms=0, + outcome="error", + error={"code": "press_failed", "reason": str(exc)}, + ) + + @registry.tool( + name="evaluate_js", + input_model=EvaluateJsToolInput, + output_model=ActionResult, + description="Evaluate JavaScript in the page context (returns text output).", + ) + async def evaluate_js_tool(ctx, params: EvaluateJsToolInput) -> ActionResult: + if ctx is not None: + ctx.require("evaluate_js") + runtime_ref = _get_runtime(ctx) + result = await runtime_ref.evaluate_js( + EvaluateJsRequest( + code=params.code, + max_output_chars=params.max_output_chars, + truncate=params.truncate, + ) + ) + if not result.ok: + return ActionResult( + success=False, + duration_ms=0, + outcome="error", + error={"code": "evaluate_js_failed", "reason": result.error or "error"}, + ) + return ActionResult( + success=True, + duration_ms=0, + outcome="dom_updated", + ) + + return registry diff --git a/sentience/tools/filesystem.py b/sentience/tools/filesystem.py new file mode 100644 index 0000000..3d468d4 --- /dev/null +++ b/sentience/tools/filesystem.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from pathlib import Path + +from pydantic import BaseModel, Field +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .context import ToolContext +from .registry import ToolRegistry + + +class FileSandbox: + """Sandboxed file access rooted at a base directory.""" + + def __init__(self, base_dir: Path) -> None: + self.base_dir = base_dir.resolve() + self.base_dir.mkdir(parents=True, exist_ok=True) + + def _resolve(self, path: str) -> Path: + candidate = (self.base_dir / path).resolve() + if not candidate.is_relative_to(self.base_dir): + raise ValueError("path escapes sandbox root") + return candidate + + def read_text(self, path: str) -> str: + return self._resolve(path).read_text(encoding="utf-8") + + def write_text(self, path: str, content: str, *, overwrite: bool = True) -> int: + target = self._resolve(path) + if target.exists() and not overwrite: + raise ValueError("file exists and overwrite is False") + target.parent.mkdir(parents=True, exist_ok=True) + target.write_text(content, encoding="utf-8") + return len(content.encode("utf-8")) + + def append_text(self, path: str, content: str) -> int: + target = self._resolve(path) + target.parent.mkdir(parents=True, exist_ok=True) + with target.open("a", encoding="utf-8") as handle: + handle.write(content) + return len(content.encode("utf-8")) + + def replace_text(self, path: str, old: str, new: str) -> int: + target = self._resolve(path) + data = target.read_text(encoding="utf-8") + replaced = data.count(old) + target.write_text(data.replace(old, new), encoding="utf-8") + return replaced + + +class ReadFileInput(BaseModel): + path: str = Field(..., min_length=1, description="Path relative to sandbox root.") + + +class ReadFileOutput(BaseModel): + content: str = Field(..., description="File contents.") + + +class WriteFileInput(BaseModel): + path: str = Field(..., min_length=1, description="Path relative to sandbox root.") + content: str = Field(..., description="Content to write.") + overwrite: bool = Field(True, description="Whether to overwrite if file exists.") + + +class WriteFileOutput(BaseModel): + path: str + bytes_written: int + + +class AppendFileInput(BaseModel): + path: str = Field(..., min_length=1, description="Path relative to sandbox root.") + content: str = Field(..., description="Content to append.") + + +class AppendFileOutput(BaseModel): + path: str + bytes_written: int + + +class ReplaceFileInput(BaseModel): + path: str = Field(..., min_length=1, description="Path relative to sandbox root.") + old: str = Field(..., description="Text to replace.") + new: str = Field(..., description="Replacement text.") + + +class ReplaceFileOutput(BaseModel): + path: str + replaced: int + + +def register_filesystem_tools( + registry: ToolRegistry, sandbox: FileSandbox | None = None +) -> ToolRegistry: + """Register sandboxed filesystem tools.""" + + def _get_files(ctx: "ToolContext" | None) -> FileSandbox: + if ctx is not None: + return ctx.files + if sandbox is not None: + return sandbox + raise RuntimeError("FileSandbox is required for filesystem tools") + + @registry.tool( + name="read_file", + input_model=ReadFileInput, + output_model=ReadFileOutput, + description="Read a file from the sandbox.", + ) + async def read_file(ctx: "ToolContext" | None, params: ReadFileInput) -> ReadFileOutput: + files = _get_files(ctx) + return ReadFileOutput(content=files.read_text(params.path)) + + @registry.tool( + name="write_file", + input_model=WriteFileInput, + output_model=WriteFileOutput, + description="Write a file to the sandbox.", + ) + async def write_file(ctx: "ToolContext" | None, params: WriteFileInput) -> WriteFileOutput: + files = _get_files(ctx) + written = files.write_text(params.path, params.content, overwrite=params.overwrite) + return WriteFileOutput(path=params.path, bytes_written=written) + + @registry.tool( + name="append_file", + input_model=AppendFileInput, + output_model=AppendFileOutput, + description="Append text to a file in the sandbox.", + ) + async def append_file(ctx: "ToolContext" | None, params: AppendFileInput) -> AppendFileOutput: + files = _get_files(ctx) + written = files.append_text(params.path, params.content) + return AppendFileOutput(path=params.path, bytes_written=written) + + @registry.tool( + name="replace_file", + input_model=ReplaceFileInput, + output_model=ReplaceFileOutput, + description="Replace text in a file in the sandbox.", + ) + async def replace_file(ctx: "ToolContext" | None, params: ReplaceFileInput) -> ReplaceFileOutput: + files = _get_files(ctx) + replaced = files.replace_text(params.path, params.old, params.new) + return ReplaceFileOutput(path=params.path, replaced=replaced) + + return registry diff --git a/sentience/tools/registry.py b/sentience/tools/registry.py new file mode 100644 index 0000000..2d162b2 --- /dev/null +++ b/sentience/tools/registry.py @@ -0,0 +1,143 @@ +from __future__ import annotations + +import inspect +import time +from collections.abc import Callable +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class ToolSpec(BaseModel): + """Definition of a tool with typed input/output schemas.""" + + name: str = Field(..., min_length=1, description="Unique tool name.") + description: str | None = Field(None, description="Human-readable tool description.") + input_model: type[BaseModel] + output_model: type[BaseModel] + handler: Callable[..., Any] | None = None + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def llm_spec(self) -> dict[str, Any]: + """Return a normalized tool spec for LLM prompts.""" + return { + "name": self.name, + "description": self.description or "", + "parameters": self.input_model.model_json_schema(), + } + + def validate_input(self, payload: Any) -> BaseModel: + """Validate tool input payload.""" + return self.input_model.model_validate(payload) + + def validate_output(self, payload: Any) -> BaseModel: + """Validate tool output payload.""" + return self.output_model.model_validate(payload) + + +class ToolRegistry: + """Registry for tool specs and validation.""" + + def __init__(self) -> None: + self._tools: dict[str, ToolSpec] = {} + + def register(self, spec: ToolSpec) -> ToolRegistry: + if spec.name in self._tools: + raise ValueError(f"tool already registered: {spec.name}") + self._tools[spec.name] = spec + return self + + def tool( + self, + *, + name: str, + input_model: type[BaseModel], + output_model: type[BaseModel], + description: str | None = None, + ): + """Decorator to register a tool handler.""" + + def decorator(func: Callable[..., Any]): + spec = ToolSpec( + name=name, + description=description, + input_model=input_model, + output_model=output_model, + handler=func, + ) + self.register(spec) + return func + + return decorator + + def get(self, name: str) -> ToolSpec | None: + return self._tools.get(name) + + def list(self) -> list[ToolSpec]: + return list(self._tools.values()) + + def llm_tools(self) -> list[dict[str, Any]]: + return [spec.llm_spec() for spec in self.list()] + + def validate_input(self, name: str, payload: Any) -> BaseModel: + spec = self._tools.get(name) + if spec is None: + raise KeyError(f"tool not found: {name}") + return spec.validate_input(payload) + + def validate_output(self, name: str, payload: Any) -> BaseModel: + spec = self._tools.get(name) + if spec is None: + raise KeyError(f"tool not found: {name}") + return spec.validate_output(payload) + + def validate_call(self, name: str, payload: Any) -> tuple[BaseModel, ToolSpec]: + """Validate input and return (validated, spec).""" + validated = self.validate_input(name, payload) + return validated, self._tools[name] + + async def execute(self, name: str, payload: Any, ctx: Any | None = None) -> BaseModel: + """Validate inputs, execute handler, validate output.""" + start = time.time() + validated, spec = self.validate_call(name, payload) + if spec.handler is None: + raise ValueError(f"tool has no handler: {name}") + tracer = None + step_id = None + if ctx is not None: + runtime = getattr(ctx, "runtime", None) + tracer = getattr(runtime, "tracer", None) + step_id = getattr(runtime, "step_id", None) + try: + result = spec.handler(ctx, validated) + if inspect.isawaitable(result): + result = await result + validated_output = spec.validate_output(result) + if tracer: + tracer.emit( + "tool_call", + data={ + "tool_name": name, + "inputs": validated.model_dump(), + "outputs": validated_output.model_dump(), + "success": True, + "duration_ms": int((time.time() - start) * 1000), + }, + step_id=step_id, + ) + return validated_output + except Exception as exc: + if tracer: + tracer.emit( + "tool_call", + data={ + "tool_name": name, + "inputs": validated.model_dump(), + "success": False, + "error": str(exc), + "duration_ms": int((time.time() - start) * 1000), + }, + step_id=step_id, + ) + raise diff --git a/tests/test_actions.py b/tests/test_actions.py index d18f02a..847d8c2 100644 --- a/tests/test_actions.py +++ b/tests/test_actions.py @@ -2,7 +2,10 @@ Tests for actions (click, type, press, click_rect) """ +import pytest + from sentience import ( + AsyncSentienceBrowser, SentienceBrowser, back, check, @@ -12,7 +15,11 @@ find, press, scroll_to, + search, + search_async, select_option, + send_keys, + send_keys_async, snapshot, submit, type_text, @@ -65,6 +72,106 @@ def test_press(): assert result.duration_ms > 0 +def test_send_keys(): + """Test send_keys helper""" + with SentienceBrowser() as browser: + browser.page.goto("https://example.com") + browser.page.wait_for_load_state("networkidle") + + result = send_keys(browser, "CTRL+L") + assert result.success is True + assert result.duration_ms > 0 + + +def test_send_keys_empty_sequence() -> None: + with SentienceBrowser() as browser: + browser.page.goto("https://example.com") + browser.page.wait_for_load_state("networkidle") + + with pytest.raises(ValueError, match="empty"): + send_keys(browser, "") + + +def test_send_keys_braced_sequence() -> None: + with SentienceBrowser() as browser: + browser.page.goto("https://example.com") + browser.page.wait_for_load_state("networkidle") + + result = send_keys(browser, "{CTRL+L}") + assert result.success is True + + +def test_send_keys_multi_sequence_and_alias() -> None: + with SentienceBrowser() as browser: + browser.page.goto("https://example.com") + browser.page.wait_for_load_state("networkidle") + + result = send_keys(browser, "Tab Tab Enter") + assert result.success is True + result = send_keys(browser, "CMD+C") + assert result.success is True + + +def test_search_builds_url() -> None: + with SentienceBrowser() as browser: + browser.page.goto("https://example.com") + browser.page.wait_for_load_state("networkidle") + + result = search(browser, "sentience sdk", engine="duckduckgo") + assert result.success is True + assert result.duration_ms > 0 + + result = search(browser, "sentience sdk", engine="google") + assert result.success is True + + result = search(browser, "sentience sdk", engine="bing") + assert result.success is True + + result = search(browser, "sentience sdk", engine="google.com") + assert result.success is True + + +def test_search_empty_query() -> None: + with SentienceBrowser() as browser: + browser.page.goto("https://example.com") + browser.page.wait_for_load_state("networkidle") + + with pytest.raises(ValueError, match="empty"): + search(browser, "") + + +def test_search_disallowed_domain() -> None: + with SentienceBrowser(allowed_domains=["example.com"]) as browser: + browser.page.goto("https://example.com") + browser.page.wait_for_load_state("networkidle") + + with pytest.raises(ValueError, match="domain not allowed"): + search(browser, "sentience sdk", engine="duckduckgo") + + +@pytest.mark.asyncio +async def test_search_async() -> None: + async with AsyncSentienceBrowser() as browser: + await browser.page.goto("https://example.com") + await browser.page.wait_for_load_state("networkidle") + + result = await search_async( + browser, "sentience sdk", engine="duckduckgo", take_snapshot=True + ) + assert result.success is True + assert result.snapshot_after is not None + + +@pytest.mark.asyncio +async def test_send_keys_async() -> None: + async with AsyncSentienceBrowser() as browser: + await browser.page.goto("https://example.com") + await browser.page.wait_for_load_state("networkidle") + + result = await send_keys_async(browser, "CTRL+L") + assert result.success is True + + def test_click_rect(): """Test click_rect with rect dict""" with SentienceBrowser() as browser: diff --git a/tests/test_agent_runtime.py b/tests/test_agent_runtime.py index c7f1b06..130ce5c 100644 --- a/tests/test_agent_runtime.py +++ b/tests/test_agent_runtime.py @@ -10,7 +10,7 @@ import pytest from sentience.agent_runtime import AgentRuntime -from sentience.models import SnapshotOptions +from sentience.models import EvaluateJsRequest, SnapshotOptions, TabInfo from sentience.verification import AssertContext, AssertOutcome @@ -54,6 +54,21 @@ async def type_text(self, text: str) -> None: async def wait_ready_state(self, state="interactive", timeout_ms=15000) -> None: pass + async def list_tabs(self): + return [ + TabInfo(tab_id="tab-1", url="https://example.com", is_active=True), + TabInfo(tab_id="tab-2", url="https://example.com/2", is_active=False), + ] + + async def open_tab(self, url: str): + return TabInfo(tab_id="tab-new", url=url, is_active=True) + + async def switch_tab(self, tab_id: str): + return TabInfo(tab_id=tab_id, url="https://example.com/2", is_active=True) + + async def close_tab(self, tab_id: str): + return TabInfo(tab_id=tab_id, url="https://example.com/2", is_active=False) + class MockTracer: """Mock Tracer for testing.""" @@ -130,6 +145,52 @@ def test_init_with_api_key_and_options(self) -> None: assert runtime._snapshot_options.sentience_api_key == "sk_pro_key" assert runtime._snapshot_options.use_api is True + @pytest.mark.asyncio + async def test_evaluate_js_success(self) -> None: + backend = MockBackend() + tracer = MockTracer() + backend.eval_results["1 + 1"] = 2 + runtime = AgentRuntime(backend=backend, tracer=tracer) + + result = await runtime.evaluate_js(EvaluateJsRequest(code="1 + 1")) + + assert result.ok is True + assert result.value == 2 + assert result.text == "2" + + @pytest.mark.asyncio + async def test_evaluate_js_truncate(self) -> None: + backend = MockBackend() + tracer = MockTracer() + backend.eval_results["long"] = "x" * 50 + runtime = AgentRuntime(backend=backend, tracer=tracer) + + result = await runtime.evaluate_js(EvaluateJsRequest(code="long", max_output_chars=10)) + + assert result.ok is True + assert result.truncated is True + assert result.text == "x" * 10 + "..." + + @pytest.mark.asyncio + async def test_tab_operations(self) -> None: + backend = MockBackend() + tracer = MockTracer() + runtime = AgentRuntime(backend=backend, tracer=tracer) + + tabs = await runtime.list_tabs() + assert tabs.ok is True + assert len(tabs.tabs) == 2 + + opened = await runtime.open_tab("https://example.com/new") + assert opened.ok is True + assert opened.tab is not None + + switched = await runtime.switch_tab("tab-2") + assert switched.ok is True + + closed = await runtime.close_tab("tab-2") + assert closed.ok is True + class TestAgentRuntimeGetUrl: """Tests for get_url method.""" diff --git a/tests/unit/test_domain_policies.py b/tests/unit/test_domain_policies.py new file mode 100644 index 0000000..3116230 --- /dev/null +++ b/tests/unit/test_domain_policies.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from sentience.browser import _domain_matches, _extract_host, _is_domain_allowed + + +def test_domain_matches_suffix() -> None: + assert _domain_matches("sub.example.com", "example.com") is True + assert _domain_matches("example.com", "example.com") is True + assert _domain_matches("example.com", "*.example.com") is True + assert _domain_matches("other.com", "example.com") is False + assert _domain_matches("example.com", "https://example.com") is True + assert _domain_matches("localhost", "http://localhost:3000") is True + + +def test_domain_allowlist_denylist() -> None: + assert _is_domain_allowed("a.example.com", ["example.com"], []) is True + assert _is_domain_allowed("a.example.com", ["example.com"], ["bad.com"]) is True + assert _is_domain_allowed("bad.example.com", [], ["example.com"]) is False + assert _is_domain_allowed("x.com", ["example.com"], []) is False + assert _is_domain_allowed("example.com", ["https://example.com"], []) is True + + +def test_extract_host_handles_ports() -> None: + assert _extract_host("http://localhost:3000") == "localhost" + assert _extract_host("localhost:3000") == "localhost" + + +def test_keep_alive_skips_close() -> None: + from sentience.browser import SentienceBrowser + + class Dummy: + def __init__(self) -> None: + self.closed = False + + def close(self): + self.closed = True + + def stop(self): + self.closed = True + + browser = SentienceBrowser() + browser.keep_alive = True + dummy_context = Dummy() + dummy_playwright = Dummy() + browser.context = dummy_context + browser.playwright = dummy_playwright + browser._extension_path = None + + browser.close() + assert dummy_context.closed is False + assert dummy_playwright.closed is False diff --git a/tests/unit/test_extract.py b/tests/unit/test_extract.py new file mode 100644 index 0000000..42ad1ed --- /dev/null +++ b/tests/unit/test_extract.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import pytest +from pydantic import BaseModel + +from sentience.llm_provider import LLMProvider, LLMResponse +from sentience.read import extract + + +class LLMStub(LLMProvider): + def __init__(self, response: str): + super().__init__("stub") + self._response = response + + def generate(self, system_prompt: str, user_prompt: str, **kwargs) -> LLMResponse: + _ = system_prompt, user_prompt, kwargs + return LLMResponse(content=self._response, model_name="stub") + + def supports_json_mode(self) -> bool: + return True + + @property + def model_name(self) -> str: + return "stub" + + +class PageStub: + def __init__(self, content: str): + self._content = content + + def evaluate(self, _script: str, _opts: dict): + return { + "status": "success", + "url": "https://example.com", + "format": "markdown", + "content": self._content, + "length": len(self._content), + } + + +class AsyncPageStub: + def __init__(self, content: str): + self._content = content + + async def evaluate(self, _script: str, _opts: dict): + return { + "status": "success", + "url": "https://example.com", + "format": "markdown", + "content": self._content, + "length": len(self._content), + } + + +class BrowserStub: + def __init__(self, content: str): + self.page = PageStub(content) + + +class AsyncBrowserStub: + def __init__(self, content: str): + self.page = AsyncPageStub(content) + + +class ItemSchema(BaseModel): + name: str + price: str + + +def test_extract_schema_success() -> None: + browser = BrowserStub("Product: Widget") + llm = LLMStub('{"name":"Widget","price":"$10"}') + result = extract(browser, llm, query="Extract item", schema=ItemSchema) + assert result.ok is True + assert result.data is not None + assert result.data.name == "Widget" + + +def test_extract_schema_invalid_json() -> None: + browser = BrowserStub("Product: Widget") + llm = LLMStub("not json") + result = extract(browser, llm, query="Extract item", schema=ItemSchema) + assert result.ok is False + assert result.error is not None + + +@pytest.mark.asyncio +async def test_extract_async_schema_success() -> None: + browser = AsyncBrowserStub("Product: Widget") + llm = LLMStub('{"name":"Widget","price":"$10"}') + from sentience.read import extract_async + + result = await extract_async(browser, llm, query="Extract item", schema=ItemSchema) + assert result.ok is True + assert result.data is not None + assert result.data.name == "Widget" diff --git a/tests/unit/test_filesystem_tools.py b/tests/unit/test_filesystem_tools.py new file mode 100644 index 0000000..452a9a8 --- /dev/null +++ b/tests/unit/test_filesystem_tools.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from sentience.tools import FileSandbox, ToolContext, ToolRegistry, register_filesystem_tools + + +class RuntimeStub: + def __init__(self) -> None: + self.tracer = None + self.step_id = None + + def capabilities(self): + return None + + def can(self, _name: str) -> bool: + return True + + +def test_file_sandbox_prevents_traversal(tmp_path: Path) -> None: + sandbox = FileSandbox(tmp_path) + with pytest.raises(ValueError): + sandbox.read_text("../secret.txt") + + +def test_file_sandbox_prefix_edge_case(tmp_path: Path) -> None: + base = tmp_path / "sandbox" + sibling = tmp_path / "sandbox2" + base.mkdir() + sibling.mkdir() + sandbox = FileSandbox(base) + with pytest.raises(ValueError): + sandbox.read_text("../sandbox2/file.txt") + + +@pytest.mark.asyncio +async def test_filesystem_tools_read_write_append_replace(tmp_path: Path) -> None: + registry = ToolRegistry() + sandbox = FileSandbox(tmp_path) + ctx = ToolContext(RuntimeStub(), files=sandbox) + register_filesystem_tools(registry, sandbox) + + await registry.execute("write_file", {"path": "note.txt", "content": "hello"}, ctx=ctx) + result = await registry.execute("read_file", {"path": "note.txt"}, ctx=ctx) + assert result.content == "hello" + + await registry.execute("append_file", {"path": "note.txt", "content": " world"}, ctx=ctx) + result = await registry.execute("read_file", {"path": "note.txt"}, ctx=ctx) + assert result.content == "hello world" + + replaced = await registry.execute( + "replace_file", {"path": "note.txt", "old": "world", "new": "there"}, ctx=ctx + ) + assert replaced.replaced == 1 + result = await registry.execute("read_file", {"path": "note.txt"}, ctx=ctx) + assert result.content == "hello there" diff --git a/tests/unit/test_runtime_agent.py b/tests/unit/test_runtime_agent.py index f97ab03..2565319 100644 --- a/tests/unit/test_runtime_agent.py +++ b/tests/unit/test_runtime_agent.py @@ -6,7 +6,15 @@ from sentience.agent_runtime import AgentRuntime from sentience.llm_provider import LLMProvider, LLMResponse -from sentience.models import BBox, Element, Snapshot, SnapshotDiagnostics, Viewport, VisualCues +from sentience.models import ( + BBox, + Element, + Snapshot, + SnapshotDiagnostics, + StepHookContext, + Viewport, + VisualCues, +) from sentience.runtime_agent import RuntimeAgent, RuntimeStep, StepVerification from sentience.verification import AssertContext, AssertOutcome @@ -233,6 +241,43 @@ def pred(ctx: AssertContext) -> AssertOutcome: assert len(vision.calls) == 1 +@pytest.mark.asyncio +async def test_runtime_agent_hooks_called() -> None: + backend = MockBackend() + tracer = MockTracer() + runtime = AgentRuntime(backend=backend, tracer=tracer) + executor = ProviderStub(responses=["CLICK(1)"]) + + agent = RuntimeAgent(runtime=runtime, executor=executor) + step = RuntimeStep(goal="click first", verifications=[], max_snapshot_attempts=1) + + started: list[StepHookContext] = [] + ended: list[StepHookContext] = [] + + async def on_start(ctx: StepHookContext): + started.append(ctx) + + async def on_end(ctx: StepHookContext): + ended.append(ctx) + + snapshot = make_snapshot(url="https://example.com/start", elements=[make_clickable_element(1)]) + + async def fake_snapshot(**_kwargs): + runtime.last_snapshot = snapshot + return snapshot + + runtime.snapshot = AsyncMock(side_effect=fake_snapshot) # type: ignore[method-assign] + + await agent.run_step(task_goal="task", step=step, on_step_start=on_start, on_step_end=on_end) + + assert len(started) == 1 + assert len(ended) == 1 + assert started[0].goal == "click first" + assert ended[0].success is True + assert ended[0].outcome == "ok" + assert ended[0].error is None + + @pytest.mark.asyncio async def test_snapshot_limit_ramp_increases_limit_on_low_confidence() -> None: backend = MockBackend() diff --git a/tests/unit/test_tool_registry.py b/tests/unit/test_tool_registry.py new file mode 100644 index 0000000..8210a1c --- /dev/null +++ b/tests/unit/test_tool_registry.py @@ -0,0 +1,233 @@ +from __future__ import annotations + +import pytest +from pydantic import BaseModel, ValidationError + +from sentience.tools import ( + ToolContext, + ToolRegistry, + ToolSpec, + UnsupportedCapabilityError, + register_default_tools, +) + + +class EchoInput(BaseModel): + message: str + + +class EchoOutput(BaseModel): + echoed: str + + +def test_register_and_list_tools() -> None: + registry = ToolRegistry() + spec = ToolSpec( + name="echo", + description="Echo a message", + input_model=EchoInput, + output_model=EchoOutput, + ) + registry.register(spec) + assert registry.get("echo") is spec + assert len(registry.list()) == 1 + + +def test_validate_input_and_output() -> None: + registry = ToolRegistry() + registry.register( + ToolSpec( + name="echo", + description="Echo a message", + input_model=EchoInput, + output_model=EchoOutput, + ) + ) + + validated = registry.validate_input("echo", {"message": "hi"}) + assert isinstance(validated, EchoInput) + assert validated.message == "hi" + + out = registry.validate_output("echo", {"echoed": "hi"}) + assert isinstance(out, EchoOutput) + assert out.echoed == "hi" + + with pytest.raises(ValidationError): + registry.validate_input("echo", {"message": 123}) + + +def test_llm_spec_generation() -> None: + registry = ToolRegistry() + + @registry.tool( + name="echo", + input_model=EchoInput, + output_model=EchoOutput, + description="Echo a message", + ) + def _echo(_ctx, params: EchoInput) -> EchoOutput: + return EchoOutput(echoed=params.message) + + tools = registry.llm_tools() + assert tools[0]["name"] == "echo" + assert tools[0]["description"] == "Echo a message" + assert "parameters" in tools[0] + + +def test_register_default_tools_adds_core_tools() -> None: + registry = ToolRegistry() + + class RuntimeStub: + async def snapshot(self, **_kwargs): + return None + + last_snapshot = None + backend = None + + ctx = ToolContext(RuntimeStub()) + register_default_tools(registry, ctx) + names = {spec.name for spec in registry.list()} + assert { + "snapshot_state", + "click", + "type", + "scroll", + "scroll_to_element", + "click_rect", + "press", + "evaluate_js", + } <= names + + +@pytest.mark.asyncio +async def test_registry_execute_validates_and_runs() -> None: + registry = ToolRegistry() + + class TracerStub: + def __init__(self) -> None: + self.events: list[dict] = [] + + def emit(self, event_type: str, data: dict, step_id: str | None = None) -> None: + self.events.append({"type": event_type, "data": data, "step_id": step_id}) + + class RuntimeStub: + def __init__(self) -> None: + self.tracer = TracerStub() + self.step_id = "step-1" + + def capabilities(self): + return None + + def can(self, _name: str) -> bool: + return True + + runtime = RuntimeStub() + ctx = ToolContext(runtime) + + @registry.tool( + name="echo", + input_model=EchoInput, + output_model=EchoOutput, + description="Echo", + ) + async def _echo(_ctx, params: EchoInput) -> EchoOutput: + return EchoOutput(echoed=params.message) + + result = await registry.execute("echo", {"message": "hi"}, ctx=ctx) + assert isinstance(result, EchoOutput) + assert result.echoed == "hi" + assert runtime.tracer.events[0]["type"] == "tool_call" + assert runtime.tracer.events[0]["data"]["success"] is True + + +def test_tool_context_require_raises_on_missing_capability() -> None: + class RuntimeStub: + def capabilities(self): + return None + + def can(self, _name: str) -> bool: + return False + + ctx = ToolContext(RuntimeStub()) + with pytest.raises(UnsupportedCapabilityError) as excinfo: + ctx.require("tabs") + assert excinfo.value.error == "unsupported_capability" + + +@pytest.mark.asyncio +async def test_tool_call_emits_error_on_failure() -> None: + registry = ToolRegistry() + + class TracerStub: + def __init__(self) -> None: + self.events: list[dict] = [] + + def emit(self, event_type: str, data: dict, step_id: str | None = None) -> None: + self.events.append({"type": event_type, "data": data, "step_id": step_id}) + + class RuntimeStub: + def __init__(self) -> None: + self.tracer = TracerStub() + self.step_id = "step-1" + + def capabilities(self): + return None + + def can(self, _name: str) -> bool: + return True + + runtime = RuntimeStub() + ctx = ToolContext(runtime) + + @registry.tool( + name="boom", + input_model=EchoInput, + output_model=EchoOutput, + description="Boom", + ) + async def _boom(_ctx, _params: EchoInput) -> EchoOutput: + raise RuntimeError("bad") + + with pytest.raises(RuntimeError, match="bad"): + await registry.execute("boom", {"message": "x"}, ctx=ctx) + + assert runtime.tracer.events[0]["type"] == "tool_call" + assert runtime.tracer.events[0]["data"]["success"] is False + assert "error" in runtime.tracer.events[0]["data"] + + +@pytest.mark.asyncio +async def test_default_tools_capability_checks() -> None: + registry = ToolRegistry() + + class RuntimeStub: + def __init__(self) -> None: + self.tracer = None + self.step_id = None + + def capabilities(self): + return None + + def can(self, name: str) -> bool: + return name != "keyboard" and name != "evaluate_js" + + async def snapshot(self, **_kwargs): + return None + + last_snapshot = None + backend = None + + ctx = ToolContext(RuntimeStub()) + register_default_tools(registry, ctx) + + with pytest.raises(UnsupportedCapabilityError) as excinfo: + await registry.execute("press", {"key": "Enter"}, ctx=ctx) + assert excinfo.value.error == "unsupported_capability" + + with pytest.raises(UnsupportedCapabilityError) as excinfo: + await registry.execute( + "scroll_to_element", + {"element_id": 1, "behavior": "instant", "block": "center"}, + ctx=ctx, + ) + assert excinfo.value.error == "unsupported_capability"