|
5 | 5 |
|
6 | 6 | import asyncio |
7 | 7 | import hashlib |
| 8 | +import inspect |
| 9 | +import logging |
8 | 10 | import time |
| 11 | +from collections.abc import Callable |
9 | 12 | from typing import TYPE_CHECKING, Any, Optional, Union |
10 | 13 |
|
11 | 14 | from .action_executor import ActionExecutor |
|
23 | 26 | ScreenshotConfig, |
24 | 27 | Snapshot, |
25 | 28 | SnapshotOptions, |
| 29 | + StepHookContext, |
26 | 30 | TokenStats, |
27 | 31 | ) |
28 | 32 | from .protocols import AsyncBrowserProtocol, BrowserProtocol |
@@ -65,6 +69,51 @@ def _safe_tracer_call( |
65 | 69 | print(f"⚠️ Tracer error (non-fatal): {tracer_error}") |
66 | 70 |
|
67 | 71 |
|
| 72 | +def _safe_hook_call_sync( |
| 73 | + hook: Callable[[StepHookContext], Any] | None, |
| 74 | + ctx: StepHookContext, |
| 75 | + verbose: bool, |
| 76 | +) -> None: |
| 77 | + if not hook: |
| 78 | + return |
| 79 | + try: |
| 80 | + result = hook(ctx) |
| 81 | + if inspect.isawaitable(result): |
| 82 | + try: |
| 83 | + loop = asyncio.get_running_loop() |
| 84 | + except RuntimeError: |
| 85 | + asyncio.run(result) |
| 86 | + else: |
| 87 | + loop.create_task(result) |
| 88 | + except Exception as hook_error: |
| 89 | + if verbose: |
| 90 | + print(f"⚠️ Hook error (non-fatal): {hook_error}") |
| 91 | + else: |
| 92 | + logging.getLogger(__name__).warning( |
| 93 | + "Hook error (non-fatal): %s", hook_error |
| 94 | + ) |
| 95 | + |
| 96 | + |
| 97 | +async def _safe_hook_call_async( |
| 98 | + hook: Callable[[StepHookContext], Any] | None, |
| 99 | + ctx: StepHookContext, |
| 100 | + verbose: bool, |
| 101 | +) -> None: |
| 102 | + if not hook: |
| 103 | + return |
| 104 | + try: |
| 105 | + result = hook(ctx) |
| 106 | + if inspect.isawaitable(result): |
| 107 | + await result |
| 108 | + except Exception as hook_error: |
| 109 | + if verbose: |
| 110 | + print(f"⚠️ Hook error (non-fatal): {hook_error}") |
| 111 | + else: |
| 112 | + logging.getLogger(__name__).warning( |
| 113 | + "Hook error (non-fatal): %s", hook_error |
| 114 | + ) |
| 115 | + |
| 116 | + |
68 | 117 | class SentienceAgent(BaseAgent): |
69 | 118 | """ |
70 | 119 | High-level agent that combines Sentience SDK with any LLM provider. |
@@ -181,6 +230,8 @@ def act( # noqa: C901 |
181 | 230 | goal: str, |
182 | 231 | max_retries: int = 2, |
183 | 232 | snapshot_options: SnapshotOptions | None = None, |
| 233 | + on_step_start: Callable[[StepHookContext], Any] | None = None, |
| 234 | + on_step_end: Callable[[StepHookContext], Any] | None = None, |
184 | 235 | ) -> AgentActionResult: |
185 | 236 | """ |
186 | 237 | Execute a high-level goal using observe → think → act loop |
@@ -224,6 +275,18 @@ def act( # noqa: C901 |
224 | 275 | pre_url=pre_url, |
225 | 276 | ) |
226 | 277 |
|
| 278 | + _safe_hook_call_sync( |
| 279 | + on_step_start, |
| 280 | + StepHookContext( |
| 281 | + step_id=step_id, |
| 282 | + step_index=self._step_count, |
| 283 | + goal=goal, |
| 284 | + attempt=0, |
| 285 | + url=pre_url, |
| 286 | + ), |
| 287 | + self.verbose, |
| 288 | + ) |
| 289 | + |
227 | 290 | # Track data collected during step execution for step_end emission on failure |
228 | 291 | _step_snap_with_diff: Snapshot | None = None |
229 | 292 | _step_pre_url: str | None = None |
@@ -396,8 +459,8 @@ def act( # noqa: C901 |
396 | 459 | _step_duration_ms = duration_ms |
397 | 460 |
|
398 | 461 | # Emit action execution trace event if tracer is enabled |
| 462 | + post_url = self.browser.page.url if self.browser.page else None |
399 | 463 | if self.tracer: |
400 | | - post_url = self.browser.page.url if self.browser.page else None |
401 | 464 |
|
402 | 465 | # Include element data for live overlay visualization |
403 | 466 | elements_data = [ |
@@ -454,7 +517,6 @@ def act( # noqa: C901 |
454 | 517 | if self.tracer: |
455 | 518 | # Get pre_url from step_start (stored in tracer or use current) |
456 | 519 | pre_url = snap.url |
457 | | - post_url = self.browser.page.url if self.browser.page else None |
458 | 520 |
|
459 | 521 | # Compute snapshot digest (simplified - use URL + timestamp) |
460 | 522 | snapshot_digest = f"sha256:{self._compute_hash(f'{pre_url}{snap.timestamp}')}" |
@@ -561,6 +623,20 @@ def act( # noqa: C901 |
561 | 623 | step_id=step_id, |
562 | 624 | ) |
563 | 625 |
|
| 626 | + _safe_hook_call_sync( |
| 627 | + on_step_end, |
| 628 | + StepHookContext( |
| 629 | + step_id=step_id, |
| 630 | + step_index=self._step_count, |
| 631 | + goal=goal, |
| 632 | + attempt=attempt, |
| 633 | + url=post_url, |
| 634 | + success=result.success, |
| 635 | + outcome=result.outcome, |
| 636 | + error=result.error, |
| 637 | + ), |
| 638 | + self.verbose, |
| 639 | + ) |
564 | 640 | return result |
565 | 641 |
|
566 | 642 | except Exception as e: |
@@ -660,6 +736,20 @@ def act( # noqa: C901 |
660 | 736 | "duration_ms": 0, |
661 | 737 | } |
662 | 738 | ) |
| 739 | + _safe_hook_call_sync( |
| 740 | + on_step_end, |
| 741 | + StepHookContext( |
| 742 | + step_id=step_id, |
| 743 | + step_index=self._step_count, |
| 744 | + goal=goal, |
| 745 | + attempt=attempt, |
| 746 | + url=_step_pre_url, |
| 747 | + success=False, |
| 748 | + outcome="exception", |
| 749 | + error=str(e), |
| 750 | + ), |
| 751 | + self.verbose, |
| 752 | + ) |
663 | 753 | raise RuntimeError(f"Failed after {max_retries} retries: {e}") |
664 | 754 |
|
665 | 755 | def _track_tokens(self, goal: str, llm_response: LLMResponse): |
@@ -833,6 +923,8 @@ async def act( # noqa: C901 |
833 | 923 | goal: str, |
834 | 924 | max_retries: int = 2, |
835 | 925 | snapshot_options: SnapshotOptions | None = None, |
| 926 | + on_step_start: Callable[[StepHookContext], Any] | None = None, |
| 927 | + on_step_end: Callable[[StepHookContext], Any] | None = None, |
836 | 928 | ) -> AgentActionResult: |
837 | 929 | """ |
838 | 930 | Execute a high-level goal using observe → think → act loop (async) |
@@ -873,6 +965,18 @@ async def act( # noqa: C901 |
873 | 965 | pre_url=pre_url, |
874 | 966 | ) |
875 | 967 |
|
| 968 | + await _safe_hook_call_async( |
| 969 | + on_step_start, |
| 970 | + StepHookContext( |
| 971 | + step_id=step_id, |
| 972 | + step_index=self._step_count, |
| 973 | + goal=goal, |
| 974 | + attempt=0, |
| 975 | + url=pre_url, |
| 976 | + ), |
| 977 | + self.verbose, |
| 978 | + ) |
| 979 | + |
876 | 980 | # Track data collected during step execution for step_end emission on failure |
877 | 981 | _step_snap_with_diff: Snapshot | None = None |
878 | 982 | _step_pre_url: str | None = None |
@@ -1209,6 +1313,21 @@ async def act( # noqa: C901 |
1209 | 1313 | step_id=step_id, |
1210 | 1314 | ) |
1211 | 1315 |
|
| 1316 | + post_url = self.browser.page.url if self.browser.page else None |
| 1317 | + await _safe_hook_call_async( |
| 1318 | + on_step_end, |
| 1319 | + StepHookContext( |
| 1320 | + step_id=step_id, |
| 1321 | + step_index=self._step_count, |
| 1322 | + goal=goal, |
| 1323 | + attempt=attempt, |
| 1324 | + url=post_url, |
| 1325 | + success=result.success, |
| 1326 | + outcome=result.outcome, |
| 1327 | + error=result.error, |
| 1328 | + ), |
| 1329 | + self.verbose, |
| 1330 | + ) |
1212 | 1331 | return result |
1213 | 1332 |
|
1214 | 1333 | except Exception as e: |
@@ -1308,6 +1427,20 @@ async def act( # noqa: C901 |
1308 | 1427 | "duration_ms": 0, |
1309 | 1428 | } |
1310 | 1429 | ) |
| 1430 | + await _safe_hook_call_async( |
| 1431 | + on_step_end, |
| 1432 | + StepHookContext( |
| 1433 | + step_id=step_id, |
| 1434 | + step_index=self._step_count, |
| 1435 | + goal=goal, |
| 1436 | + attempt=attempt, |
| 1437 | + url=_step_pre_url, |
| 1438 | + success=False, |
| 1439 | + outcome="exception", |
| 1440 | + error=str(e), |
| 1441 | + ), |
| 1442 | + self.verbose, |
| 1443 | + ) |
1311 | 1444 | raise RuntimeError(f"Failed after {max_retries} retries: {e}") |
1312 | 1445 |
|
1313 | 1446 | def _track_tokens(self, goal: str, llm_response: LLMResponse): |
|
0 commit comments