Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sentience/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Pydantic models for Sentience SDK - matches spec/snapshot.schema.json
"""

from typing import List, Literal, Optional, Union
from typing import Literal

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -114,6 +114,8 @@ class SnapshotOptions(BaseModel):
limit: int = Field(50, ge=1, le=500)
filter: SnapshotFilter | None = None
use_api: bool | None = None # Force API vs extension
save_trace: bool = False # Save raw_elements to JSON for benchmarking/training
trace_path: str | None = None # Path to save trace (default: "trace_{timestamp}.json")

class Config:
arbitrary_types_allowed = True
Expand Down
94 changes: 71 additions & 23 deletions sentience/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,38 @@
Snapshot functionality - calls window.sentience.snapshot() or server-side API
"""

import json
import os
import time
from typing import Any

import requests

from .browser import SentienceBrowser
from .models import Snapshot
from .models import Snapshot, SnapshotOptions


def _save_trace_to_file(raw_elements: list[dict[str, Any]], trace_path: str | None = None) -> None:
"""
Save raw_elements to a JSON file for benchmarking/training

Args:
raw_elements: Raw elements data from snapshot
trace_path: Path to save trace file. If None, uses "trace_{timestamp}.json"
"""
# Default filename if none provided
filename = trace_path or f"trace_{int(time.time())}.json"

# Ensure directory exists
directory = os.path.dirname(filename)
if directory:
os.makedirs(directory, exist_ok=True)

# Save the raw elements to JSON
with open(filename, "w") as f:
json.dump(raw_elements, f, indent=2)

print(f"[SDK] Trace saved to: {filename}")


def snapshot(
Expand All @@ -16,6 +42,8 @@ def snapshot(
limit: int | None = None,
filter: dict[str, Any] | None = None,
use_api: bool | None = None,
save_trace: bool = False,
trace_path: str | None = None,
) -> Snapshot:
"""
Take a snapshot of the current page
Expand All @@ -27,26 +55,38 @@ def snapshot(
filter: Filter options (min_area, allowed_roles, min_z_index)
use_api: Force use of server-side API if True, local extension if False.
If None, uses API if api_key is set, otherwise uses local extension.
save_trace: Whether to save raw_elements to JSON for benchmarking/training
trace_path: Path to save trace file. If None, uses "trace_{timestamp}.json"

Returns:
Snapshot object
"""
# Build SnapshotOptions from individual parameters
options = SnapshotOptions(
screenshot=screenshot if screenshot is not None else False,
limit=limit if limit is not None else 50,
filter=filter,
use_api=use_api,
save_trace=save_trace,
trace_path=trace_path,
)

# Determine if we should use server-side API
should_use_api = use_api if use_api is not None else (browser.api_key is not None)
should_use_api = (
options.use_api if options.use_api is not None else (browser.api_key is not None)
)

if should_use_api and browser.api_key:
# Use server-side API (Pro/Enterprise tier)
return _snapshot_via_api(browser, screenshot, limit, filter)
return _snapshot_via_api(browser, options)
else:
# Use local extension (Free tier)
return _snapshot_via_extension(browser, screenshot, limit, filter)
return _snapshot_via_extension(browser, options)


def _snapshot_via_extension(
browser: SentienceBrowser,
screenshot: bool | None,
limit: int | None,
filter: dict[str, Any] | None,
options: SnapshotOptions,
) -> Snapshot:
"""Take snapshot using local extension (Free tier)"""
if not browser.page:
Expand Down Expand Up @@ -77,14 +117,16 @@ def _snapshot_via_extension(
f"Is the extension loaded? Diagnostics: {diag}"
) from e

# Build options
options: dict[str, Any] = {}
if screenshot is not None:
options["screenshot"] = screenshot
if limit is not None:
options["limit"] = limit
if filter is not None:
options["filter"] = filter
# Build options dict for extension API (exclude save_trace/trace_path)
ext_options: dict[str, Any] = {}
if options.screenshot is not False:
ext_options["screenshot"] = options.screenshot
if options.limit != 50:
ext_options["limit"] = options.limit
if options.filter is not None:
ext_options["filter"] = (
options.filter.model_dump() if hasattr(options.filter, "model_dump") else options.filter
)

# Call extension API
result = browser.page.evaluate(
Expand All @@ -93,19 +135,21 @@ def _snapshot_via_extension(
return window.sentience.snapshot(options);
}
""",
options,
ext_options,
)

# Save trace if requested
if options.save_trace:
_save_trace_to_file(result.get("raw_elements", []), options.trace_path)

# Validate and parse with Pydantic
snapshot_obj = Snapshot(**result)
return snapshot_obj


def _snapshot_via_api(
browser: SentienceBrowser,
screenshot: bool | None,
limit: int | None,
filter: dict[str, Any] | None,
options: SnapshotOptions,
) -> Snapshot:
"""Take snapshot using server-side API (Pro/Enterprise tier)"""
if not browser.page:
Expand All @@ -128,8 +172,8 @@ def _snapshot_via_api(

# Step 1: Get raw data from local extension (always happens locally)
raw_options: dict[str, Any] = {}
if screenshot is not None:
raw_options["screenshot"] = screenshot
if options.screenshot is not False:
raw_options["screenshot"] = options.screenshot

raw_result = browser.page.evaluate(
"""
Expand All @@ -140,6 +184,10 @@ def _snapshot_via_api(
raw_options,
)

# Save trace if requested (save raw data before API processing)
if options.save_trace:
_save_trace_to_file(raw_result.get("raw_elements", []), options.trace_path)

# Step 2: Send to server for smart ranking/filtering
# Use raw_elements (raw data) instead of elements (processed data)
# Server validates API key and applies proprietary ranking logic
Expand All @@ -148,8 +196,8 @@ def _snapshot_via_api(
"url": raw_result.get("url", ""),
"viewport": raw_result.get("viewport"),
"options": {
"limit": limit,
"filter": filter,
"limit": options.limit,
"filter": options.filter.model_dump() if options.filter else None,
},
}

Expand Down