diff --git a/CHANGELOG.md b/CHANGELOG.md index 7b06366..495e0a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on Keep a Changelog and this project adheres to Semantic Versioning. +## [0.1.82] - 2026-04-21 + +- Refine custom ADK instrumentation to produce a cleaner trace hierarchy, include sufficient metadata, and eliminate duplicate spans. + + ## [0.1.81] - 2026-04-16 - Fix root span attachment issue in tracer provider @@ -239,4 +244,4 @@ The format is based on Keep a Changelog and this project adheres to Semantic Ver - Added utility to set input and output data for any active span in a trace -[0.1.81]: https://github.com/KeyValueSoftwareSystems/netra-sdk-py/tree/main +[0.1.82]: https://github.com/KeyValueSoftwareSystems/netra-sdk-py/tree/main diff --git a/examples/07_custom_metrics/custom_metrics.py b/examples/07_custom_metrics/custom_metrics.py index 2ab4194..60a7900 100644 --- a/examples/07_custom_metrics/custom_metrics.py +++ b/examples/07_custom_metrics/custom_metrics.py @@ -24,7 +24,6 @@ from typing import Any, Dict, List from dotenv import load_dotenv - from opentelemetry.metrics import Observation from netra import Netra @@ -44,6 +43,7 @@ # 1. Initialise the SDK with metrics enabled # --------------------------------------------------------------------------- + def init_sdk() -> None: Netra.init( app_name="custom-metrics-example", @@ -61,6 +61,7 @@ def init_sdk() -> None: # 2. Create instruments from a named Meter # --------------------------------------------------------------------------- + def create_instruments(): """Return a dict of OTel instruments scoped to an 'ai_service' meter.""" meter = Netra.get_meter("ai_service") @@ -195,9 +196,7 @@ def run_inference( results.append({"prompt": prompt[:40], "error": str(exc)}) successes = sum(1 for r in results if "error" not in r) - logger.info( - "Inference workflow complete: %d/%d succeeded", successes, len(results) - ) + logger.info("Inference workflow complete: %d/%d succeeded", successes, len(results)) return results @@ -244,6 +243,7 @@ def _gpu_utilization_callback(_): # 5. Async example — concurrent requests with shared instruments # --------------------------------------------------------------------------- + @task(name="async_llm_call") # type: ignore[arg-type] async def async_llm_call( prompt: str, @@ -285,10 +285,7 @@ async def run_async_inference( logger.info("Starting async inference for %d prompts", len(prompts)) instruments["queue_depth"].add(len(prompts), {"source": "async_batch"}) - tasks = [ - async_llm_call(prompt, select_model(prompt), instruments) # type: ignore[misc] - for prompt in prompts - ] + tasks = [async_llm_call(prompt, select_model(prompt), instruments) for prompt in prompts] # type: ignore[misc] results = await asyncio.gather(*tasks, return_exceptions=True) ok = [r for r in results if isinstance(r, dict)] @@ -300,6 +297,7 @@ async def run_async_inference( # 6. Dedicated-meter example — separate meter per domain # --------------------------------------------------------------------------- + def demonstrate_multiple_meters(instruments: Dict[str, Any]) -> None: """Show that different parts of an app can own independent meters.""" billing_meter = Netra.get_meter("billing_service") @@ -333,6 +331,7 @@ def demonstrate_multiple_meters(instruments: Dict[str, Any]) -> None: # Main # --------------------------------------------------------------------------- + def main() -> None: print("=" * 65) print(" Netra SDK — Custom Metrics Example") @@ -367,18 +366,12 @@ def main() -> None: # --- Async fan-out --- print("\n--- Async concurrent inference ---") async_prompts = [f"Async prompt #{i}: tell me something interesting" for i in range(8)] - async_results = asyncio.run( - run_async_inference(async_prompts, instruments) # type: ignore[misc] - ) + async_results = asyncio.run(run_async_inference(async_prompts, instruments)) # type: ignore[misc] for r in async_results: if "error" in r: print(f" FAIL: {r}") else: - print( - f" {r['model']:>15} " - f"tokens={r['tokens']:<5} " - f"latency={r['latency_ms']}ms" - ) + print(f" {r['model']:>15} " f"tokens={r['tokens']:<5} " f"latency={r['latency_ms']}ms") # --- Multiple meters --- print("\n--- Multiple independent meters ---") diff --git a/netra/__init__.py b/netra/__init__.py index d272f6e..0ecdfec 100644 --- a/netra/__init__.py +++ b/netra/__init__.py @@ -12,7 +12,7 @@ from netra.dashboard import Dashboard from netra.evaluation import Evaluation from netra.instrumentation import init_instrumentations -from netra.instrumentation.instruments import NetraInstruments +from netra.instrumentation.instruments import DEFAULT_INSTRUMENTS_FOR_ROOT, NetraInstruments from netra.logging_utils import configure_package_logging from netra.meter import MetricsSetup from netra.meter import get_meter as _get_meter @@ -23,13 +23,6 @@ from netra.tracer import Tracer from netra.usage import Usage -__all__ = [ - "Netra", - "UsageModel", - "ActionModel", - "Prompts", -] - logger = logging.getLogger(__name__) @@ -46,6 +39,12 @@ class Netra: _root_ctx_token = None _metrics_enabled = False + evaluation: Optional["Evaluation"] = None + usage: Optional["Usage"] = None + dashboard: Optional["Dashboard"] = None + prompts: Optional["Prompts"] = None + simulation: Optional["Simulation"] = None + @classmethod def is_initialized(cls) -> bool: """ @@ -61,7 +60,7 @@ def is_initialized(cls) -> bool: def init( cls, app_name: Optional[str] = None, - headers: Optional[str] = None, + headers: Optional[str | Dict[str, str]] = None, disable_batch: Optional[bool] = None, trace_content: Optional[bool] = None, debug_mode: Optional[bool] = None, @@ -75,6 +74,7 @@ def init( enable_metrics: Optional[bool] = None, metrics_export_interval_ms: Optional[int] = None, export_auto_metrics: Optional[bool] = None, + root_instruments: Optional[Set[NetraInstruments]] = None, ) -> None: """ Thread-safe initialization of Netra. @@ -95,6 +95,12 @@ def init( enable_metrics: Whether to enable OTLP custom metrics export (default: False) metrics_export_interval_ms: Metrics push interval in milliseconds (default: 60000) export_auto_metrics: Whether to export OTel auto-instrumented metrics (default: False) + root_instruments: Set of instruments allowed to produce root-level + spans. When a root span is blocked, its entire subtree is + discarded. Resolution priority: + 1. Explicit ``root_instruments`` value if provided. + 2. The ``instruments`` value if provided (but ``root_instruments`` is not). + 3. ``DEFAULT_INSTRUMENTS_FOR_ROOT`` if neither is provided. Returns: None @@ -124,8 +130,17 @@ def init( # Configure logging based on debug mode configure_package_logging(debug_mode=cfg.debug_mode) + # Resolve root_instruments → set of instrumentation-name strings. + resolved_root: Optional[Set[str]] = None + if root_instruments is not None: + resolved_root = {m.value for m in root_instruments} + elif instruments is not None: + resolved_root = {m.value for m in instruments} + else: + resolved_root = {m.value for m in DEFAULT_INSTRUMENTS_FOR_ROOT} + # Initialize tracer (OTLP exporter, span processor, resource) - Tracer(cfg) + Tracer(cfg, root_instrument_names=resolved_root) # Initialize metrics pipeline when explicitly enabled if cfg.enable_metrics: @@ -137,38 +152,38 @@ def init( # Initialize evaluation client and expose as class attribute try: - cls.evaluation = Evaluation(cfg) # type:ignore[attr-defined] + cls.evaluation = Evaluation(cfg) except Exception as e: logger.warning("Failed to initialize evaluation client: %s", e, exc_info=True) - cls.evaluation = None # type:ignore[attr-defined] + cls.evaluation = None # Initialize usage client and expose as class attribute try: - cls.usage = Usage(cfg) # type:ignore[attr-defined] + cls.usage = Usage(cfg) except Exception as e: logger.warning("Failed to initialize usage client: %s", e, exc_info=True) - cls.usage = None # type:ignore[attr-defined] + cls.usage = None # Initialize dashboard client and expose as class attribute try: - cls.dashboard = Dashboard(cfg) # type:ignore[attr-defined] + cls.dashboard = Dashboard(cfg) except Exception as e: logger.warning("Failed to initialize dashboard client: %s", e, exc_info=True) - cls.dashboard = None # type:ignore[attr-defined] + cls.dashboard = None # Initialize prompts client and expose as class attribute try: - cls.prompts = Prompts(cfg) # type:ignore[attr-defined] + cls.prompts = Prompts(cfg) except Exception as e: logger.warning("Failed to initialize prompts client: %s", e, exc_info=True) - cls.prompts = None # type:ignore[attr-defined] + cls.prompts = None # Initialize simulation client and expose as class attribute try: - cls.simulation = Simulation(cfg) # type:ignore[attr-defined] + cls.simulation = Simulation(cfg) except Exception as e: logger.warning("Failed to initialize simulation client: %s", e, exc_info=True) - cls.simulation = None # type:ignore[attr-defined] + cls.simulation = None # Instrument all supported modules init_instrumentations( @@ -241,6 +256,33 @@ def shutdown(cls) -> None: except Exception: pass + # Close HTTP clients to release connection-pool resources + if cls.evaluation is not None: + try: + cls.evaluation._client.close() + except Exception: + pass + if cls.usage is not None: + try: + cls.usage._client.close() + except Exception: + pass + if cls.dashboard is not None: + try: + cls.dashboard._client.close() + except Exception: + pass + if cls.prompts is not None: + try: + cls.prompts._client.close() + except Exception: + pass + if cls.simulation is not None: + try: + cls.simulation._client.close() + except Exception: + pass + @classmethod def get_meter(cls, name: str = "netra", version: Optional[str] = None) -> otel_metrics.Meter: """ @@ -370,12 +412,52 @@ def add_conversation(cls, conversation_type: ConversationType, role: str, conten """ SessionManager.add_conversation(conversation_type=conversation_type, role=role, content=content) + @classmethod + def set_input(cls, value: Any) -> None: + """ + Set the input attribute on the current active span. + + Args: + value: The input value to record + """ + SessionManager.set_input(value) + + @classmethod + def set_output(cls, value: Any) -> None: + """ + Set the output attribute on the current active span. + + Args: + value: The output value to record + """ + SessionManager.set_output(value) + + @classmethod + def set_root_input(cls, value: Any) -> None: + """ + Set the input attribute on the root span of the current trace. + + Args: + value: The input value to record + """ + SessionManager.set_root_input(value) + + @classmethod + def set_root_output(cls, value: Any) -> None: + """ + Set the output attribute on the root span of the current trace. + + Args: + value: The output value to record + """ + SessionManager.set_root_output(value) + @classmethod def start_span( cls, name: str, attributes: Optional[Dict[str, str]] = None, - module_name: str = "combat_sdk", + module_name: str = Config.SDK_NAME, as_type: Optional[SpanType] = SpanType.SPAN, ) -> SpanWrapper: """ @@ -392,5 +474,15 @@ def start_span( """ return SpanWrapper(name, attributes, module_name, as_type=as_type) + @classmethod + def get_trace_id(cls) -> Optional[str]: + """ + Return the trace ID of the currently active span. + + Returns: + str: 32-character lowercase hex trace ID, or None if no active span exists. + """ + return SessionManager.get_trace_id() + -__all__ = ["Netra", "UsageModel", "ActionModel", "SpanType", "EvaluationScore", "Prompts", "ConversationType"] +__all__ = ["Netra", "UsageModel", "ActionModel", "SpanType", "Prompts", "ConversationType"] diff --git a/netra/client.py b/netra/client.py new file mode 100644 index 0000000..a201a99 --- /dev/null +++ b/netra/client.py @@ -0,0 +1,168 @@ +"""Base HTTP client shared by all Netra API sub-clients.""" + +import logging +import os +from typing import Any, Dict, Optional + +import httpx + +from netra.config import Config + +logger = logging.getLogger(__name__) + +_TELEMETRY_SUFFIX = "/telemetry" +_API_KEY_HEADER = "x-api-key" + + +class BaseNetraClient: + """Shared foundation for every Netra HTTP sub-client. + + Provides endpoint resolution, header construction, timeout parsing, + httpx client creation, and safe error-message extraction so that + sub-clients only need to define domain-specific endpoints. + + Args: + config: Netra SDK configuration. + log_prefix: Short prefix used in all log messages (e.g. ``"netra.dashboard"``). + timeout_env_var: Name of the environment variable that overrides the + default timeout (e.g. ``"NETRA_DASHBOARD_TIMEOUT"``). + default_timeout: Fallback timeout in seconds when the env var is unset. + extra_headers: Additional headers merged on top of the standard set. + """ + + def __init__( + self, + config: Config, + *, + log_prefix: str, + timeout_env_var: str, + default_timeout: float = 10.0, + extra_headers: Optional[Dict[str, str]] = None, + ) -> None: + self._log_prefix = log_prefix + self._timeout_env_var = timeout_env_var + self._default_timeout = default_timeout + self._extra_headers = extra_headers or {} + self._client: Optional[httpx.Client] = self._create_client(config) + + def _create_client(self, config: Config) -> Optional[httpx.Client]: + """Build an ``httpx.Client`` from the shared configuration. + + Args: + config: Netra SDK configuration. + + Returns: + A configured client, or ``None`` if the endpoint is missing or + client creation fails. + """ + endpoint = (config.otlp_endpoint or "").strip() + if not endpoint: + logger.error("%s: NETRA_OTLP_ENDPOINT is required", self._log_prefix) + return None + + base_url = self._resolve_base_url(endpoint) + headers = self._build_headers(config) + timeout = self._get_timeout() + + try: + return httpx.Client(base_url=base_url, headers=headers, timeout=timeout) + except Exception as exc: + logger.error("%s: Failed to create HTTP client: %s", self._log_prefix, exc) + return None + + def _resolve_base_url(self, endpoint: str) -> str: + """Strip trailing slash and ``/telemetry`` suffix from an endpoint URL. + + Args: + endpoint: The raw endpoint URL. + + Returns: + The cleaned base URL. + """ + base_url = endpoint.rstrip("/") + if base_url.endswith(_TELEMETRY_SUFFIX): + base_url = base_url[: -len(_TELEMETRY_SUFFIX)] + return base_url + + def _build_headers(self, config: Config) -> Dict[str, str]: + """Construct request headers from configuration. + + Args: + config: Netra SDK configuration. + + Returns: + A dictionary of HTTP headers. + """ + headers: Dict[str, str] = dict(config.headers or {}) + if config.api_key: + headers[_API_KEY_HEADER] = config.api_key + headers.update(self._extra_headers) + return headers + + def _get_timeout(self) -> float: + """Read timeout from the environment or fall back to the default. + + Returns: + Timeout value in seconds. + """ + raw = os.getenv(self._timeout_env_var) + if not raw: + return self._default_timeout + try: + return float(raw) + except ValueError: + logger.warning( + "%s: Invalid %s value '%s', using default %.1f", + self._log_prefix, + self._timeout_env_var, + raw, + self._default_timeout, + ) + return self._default_timeout + + def close(self) -> None: + """Close the underlying HTTP client and release connection-pool resources. + + Safe to call multiple times or when the client was never created. + """ + if self._client is not None: + try: + self._client.close() + except Exception: + logger.debug("%s: Error closing HTTP client", self._log_prefix, exc_info=True) + finally: + self._client = None + + def __enter__(self) -> "BaseNetraClient": + """Support ``with`` blocks for short-lived client usage.""" + return self + + def __exit__(self, *args: Any) -> None: + """Close the client when exiting a ``with`` block.""" + self.close() + + def _extract_error_message(self, exc: Exception) -> str: + """Derive a human-readable error string from an exception. + + For HTTP status errors whose ``response`` attribute carries a body, + this tries to extract the backend JSON error payload. Falls back to + ``str(exc)`` in all other cases. + + Args: + exc: The exception that was raised. + + Returns: + A descriptive error message. + """ + response: Any = getattr(exc, "response", None) + if response is not None: + try: + body = response.json() + error_data = body.get("error", {}) + if isinstance(error_data, dict): + msg = error_data.get("message") + if msg: + return str(msg) + except Exception: + pass + return str(exc) diff --git a/netra/config.py b/netra/config.py index e72c6dc..fe233a9 100644 --- a/netra/config.py +++ b/netra/config.py @@ -23,7 +23,7 @@ class Config: def __init__( self, app_name: Optional[str] = None, - headers: Optional[str] = None, + headers: Optional[str | Dict[str, str]] = None, disable_batch: Optional[bool] = None, trace_content: Optional[bool] = None, debug_mode: Optional[bool] = None, @@ -88,11 +88,13 @@ def _get_otlp_endpoint(self) -> str | None: """Get OTLP endpoint from environment variables.""" return os.getenv("NETRA_OTLP_ENDPOINT") or os.getenv("OTEL_EXPORTER_OTLP_ENDPOINT") - def _parse_headers(self, headers: Optional[str]) -> Dict[str, str] | Any: + def _parse_headers(self, headers: Optional[str] | Dict[str, str]) -> Dict[str, str] | Any: """Parse headers from parameter or environment variable.""" headers = headers or os.getenv("NETRA_HEADERS") if isinstance(headers, str): return parse_env_headers(headers) + elif isinstance(headers, dict): + return headers return {} def _validate_api_key(self) -> None: diff --git a/netra/dashboard/client.py b/netra/dashboard/client.py index 29d74ba..d9ab63a 100644 --- a/netra/dashboard/client.py +++ b/netra/dashboard/client.py @@ -1,9 +1,7 @@ import logging -import os from typing import Any, Dict, List, Optional -import httpx - +from netra.client import BaseNetraClient from netra.config import Config from netra.dashboard.models import ( ChartType, @@ -19,8 +17,10 @@ logger = logging.getLogger(__name__) +_LOG_PREFIX = "netra.dashboard" + -class DashboardHttpClient: +class DashboardHttpClient(BaseNetraClient): """Internal HTTP client for Dashboard APIs.""" def __init__(self, config: Config) -> None: @@ -28,85 +28,15 @@ def __init__(self, config: Config) -> None: Initialize the dashboard HTTP client. Args: - config: Configuration object with dashboard settings - """ - self._client: Optional[httpx.Client] = self._create_client(config) - - def _create_client(self, config: Config) -> Optional[httpx.Client]: - """ - Create an HTTP client for dashboard endpoints. - - Args: - config: The configuration object. - - Returns: - An HTTP client for dashboard endpoints, or None if creation fails. - """ - endpoint = (config.otlp_endpoint or "").strip() - if not endpoint: - logger.error("netra.dashboard: NETRA_OTLP_ENDPOINT is required for dashboard APIs") - return None - - base_url = self._resolve_base_url(endpoint) - headers = self._build_headers(config) - timeout = self._get_timeout() - - try: - return httpx.Client(base_url=base_url, headers=headers, timeout=timeout) - except Exception as exc: - logger.error("netra.dashboard: Failed to initialize dashboard HTTP client: %s", exc) - return None - - def _resolve_base_url(self, endpoint: str) -> str: + config: Configuration object with dashboard settings. """ - Resolve base URL from endpoint. - - Args: - endpoint: The endpoint to resolve. - - Returns: - The resolved base URL. - """ - base_url = endpoint.rstrip("/") - if base_url.endswith("/telemetry"): - base_url = base_url[: -len("/telemetry")] - return base_url - - def _build_headers(self, config: Config) -> Dict[str, str]: - """ - Build Headers for Dashboard Client. - - Args: - config: The configuration object. - - Returns: - The headers for dashboard client. - """ - headers: Dict[str, str] = dict(config.headers or {}) - api_key = config.api_key - if api_key: - headers["x-api-key"] = api_key - headers["Content-Type"] = "application/json" - return headers - - def _get_timeout(self) -> float: - """ - Get timeout for dashboard client. - - Returns: - The timeout for dashboard client. - """ - timeout_env = os.getenv("NETRA_DASHBOARD_TIMEOUT") - if not timeout_env: - return 30.0 - try: - return float(timeout_env) - except ValueError: - logger.warning( - "netra.dashboard: Invalid NETRA_DASHBOARD_TIMEOUT value '%s', using default 30.0", - timeout_env, - ) - return 30.0 + super().__init__( + config, + log_prefix=_LOG_PREFIX, + timeout_env_var="NETRA_DASHBOARD_TIMEOUT", + default_timeout=30.0, + extra_headers={"Content-Type": "application/json"}, + ) def query_data( self, @@ -130,7 +60,7 @@ def query_data( The query response data or None on error. """ if not self._client: - logger.error("netra.dashboard: Dashboard client is not initialized; cannot execute query") + logger.error("%s: Client is not initialized; cannot execute query", _LOG_PREFIX) return None try: @@ -174,13 +104,12 @@ def query_data( response = self._client.post(url, json=payload) response.raise_for_status() - data = response.json() - return data - except Exception: - response_json = response.json() + return response.json() + except Exception as exc: logger.error( - "netra.dashboard: Failed to execute dashboard query: %s", - response_json.get("error").get("message", ""), + "%s: Failed to execute dashboard query: %s", + _LOG_PREFIX, + self._extract_error_message(exc), ) return None @@ -202,7 +131,7 @@ def get_session_stats( end_time: End time in ISO 8601 UTC format. filters: Optional list of session filters. limit: Maximum number of results per page. - page: Page number for pagination. + cursor: Cursor for pagination. sort_field: Field to sort by. sort_order: Sort order (asc/desc). @@ -210,7 +139,7 @@ def get_session_stats( The session stats response data or None on error. """ if not self._client: - logger.error("netra.dashboard: Dashboard client is not initialized; cannot fetch session stats") + logger.error("%s: Client is not initialized; cannot fetch session stats", _LOG_PREFIX) return None try: @@ -244,13 +173,12 @@ def get_session_stats( response = self._client.post(url, json=payload) response.raise_for_status() - data = response.json() - return data - except Exception: - response_json = response.json() + return response.json() + except Exception as exc: logger.error( - "netra.dashboard: Failed to fetch session stats: %s", - response_json.get("error").get("message", ""), + "%s: Failed to fetch session stats: %s", + _LOG_PREFIX, + self._extract_error_message(exc), ) return None @@ -267,7 +195,7 @@ def get_session_summary(self, start_time: str, end_time: str, filters: Optional[ The session summary response data or None on error. """ if not self._client: - logger.error("netra.dashboard: Dashboard client is not initialized; cannot execute query") + logger.error("%s: Client is not initialized; cannot fetch session summary", _LOG_PREFIX) return None try: @@ -292,12 +220,11 @@ def get_session_summary(self, start_time: str, end_time: str, filters: Optional[ response = self._client.post(url, json=payload) response.raise_for_status() - data = response.json() - return data - except Exception: - response_json = response.json() + return response.json() + except Exception as exc: logger.error( - "netra.dashboard: Failed to fetch session summary: %s", - response_json.get("error").get("message", ""), + "%s: Failed to fetch session summary: %s", + _LOG_PREFIX, + self._extract_error_message(exc), ) return None diff --git a/netra/decorators.py b/netra/decorators.py index f875d8e..d75a7ce 100644 --- a/netra/decorators.py +++ b/netra/decorators.py @@ -89,25 +89,54 @@ def _add_span_attributes( input_data[key] = _serialize_value(value) if input_data: - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.input", json.dumps(input_data)) + span.set_attribute("input", json.dumps(input_data)) except Exception as e: - span.set_attribute(f"{Config.LIBRARY_NAME}.input_error", str(e)) + span.set_attribute("input_error", str(e)) + + +def _span_has_output(span: trace.Span) -> bool: + """Return True if the span already has a non-empty ``output`` attribute set. + + Checks both the public ``attributes`` property and the internal + ``_attributes`` dict for compatibility with different span implementations. + Returns False on any error to ensure output is never silently dropped. + + Args: + span: The span to inspect. + + Returns: + True if a non-empty ``output`` attribute is already present on the span. + """ + try: + for attr_name in ("attributes", "_attributes"): + attrs = getattr(span, attr_name, None) + if attrs is not None and "output" in attrs and attrs["output"]: + return True + except Exception: + logger.debug("_span_has_output: error inspecting span attributes", exc_info=True) + return False def _add_output_attributes(span: trace.Span, result: Any) -> None: """ Helper function to add output attributes to span. + Skips setting ``output`` if the user already set it manually via + ``Netra.set_output()`` inside the decorated function — the explicit + value takes priority over the auto-captured return value. + Args: span: The OpenTelemetry span to add attributes to. result: The result to serialize and add as an attribute. """ try: + if _span_has_output(span): + return serialized_output = _serialize_value(result) - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.output", serialized_output) + span.set_attribute("output", serialized_output) except Exception as e: - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.output_error", str(e)) + span.set_attribute("output_error", str(e)) def _is_streaming_response(obj: Any) -> bool: @@ -154,6 +183,40 @@ def _is_sync_generator(obj: Any) -> bool: return inspect.isgenerator(obj) +def _finalize_span( + span: trace.Span, + span_name: str, + entity_type: str, + error: Optional[Exception] = None, +) -> None: + """End a span, unregister it from SessionManager, and pop its entity. + + This is the single teardown path shared by every generator wrapper and + the non-streaming finalization in ``_create_function_wrapper``. + + When *error* is provided the exception is recorded on the span before + it is ended. + + Args: + span: The span to finalize. + span_name: Name the span was registered under. + entity_type: Entity type used for the SessionManager stack. + error: Optional exception to record before ending the span. + """ + if error is not None: + try: + span.set_attribute(f"{Config.LIBRARY_NAME}.entity.error", str(error)) + span.record_exception(error) + except Exception: + logger.debug("Failed to record error on span '%s'", span_name, exc_info=True) + try: + span.end() + SessionManager.unregister_span(span_name, span) + SessionManager.pop_entity(entity_type) + except Exception as e: + logger.exception("Failed to unregister span '%s' and pop entity '%s': %s", span_name, entity_type, e) + + def _wrap_async_generator_with_span( span: trace.Span, agen: AsyncGenerator[Any, None], @@ -180,26 +243,10 @@ async def _wrapped() -> AsyncGenerator[Any, None]: async for item in agen: yield item except Exception as e: - try: - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.error", str(e)) - span.record_exception(e) - finally: - span.end() - # De-register and pop entity at the very end for streaming lifecycle - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type, error=e) raise else: - # Normal completion - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type) return _wrapped() @@ -229,24 +276,10 @@ def _wrapped() -> Generator[Any, None, None]: for item in gen: yield item except Exception as e: - try: - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.error", str(e)) - span.record_exception(e) - finally: - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type, error=e) raise else: - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type) return _wrapped() @@ -282,24 +315,10 @@ async def _aiter_wrapper(): # type: ignore[no-untyped-def] async for chunk in body_iter: yield chunk except Exception as e: - try: - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.error", str(e)) - span.record_exception(e) - finally: - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type, error=e) raise else: - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type) resp.body_iterator = _aiter_wrapper() # type: ignore[no-untyped-call] return resp @@ -313,24 +332,10 @@ def _iter_wrapper(): # type: ignore[no-untyped-def] for chunk in body_iter: yield chunk except Exception as e: - try: - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.error", str(e)) - span.record_exception(e) - finally: - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type, error=e) raise else: - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type) resp.body_iterator = _iter_wrapper() # type: ignore[no-untyped-call] return resp @@ -365,16 +370,14 @@ def _create_function_wrapper( @functools.wraps(func) async def async_wrapper(*args: Any, **kwargs: Any) -> Any: - # Push entity before span starts so processors can capture it + if not isinstance(as_type, SpanType): + logger.error("Invalid span type: %s", as_type) + return await cast(Awaitable[Any], func(*args, **kwargs)) + SessionManager.push_entity(entity_type, span_name) tracer = trace.get_tracer(module_name) span = tracer.start_span(span_name) - # Set span type if provided - - if not isinstance(as_type, SpanType): - logger.error("Invalid span type: %s", as_type) - return try: span.set_attribute("netra.span.type", as_type.value) except Exception: @@ -391,17 +394,9 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: try: result = await cast(Awaitable[Any], func(*args, **kwargs)) except Exception as e: - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.error", str(e)) - span.record_exception(e) - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type, error=e) raise - # If result is streaming, defer span end to when stream completes if _is_streaming_response(result): return _wrap_streaming_response_with_span(span, result, span_name, entity_type) if _is_async_generator(result): @@ -409,14 +404,8 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: if _is_sync_generator(result): return _wrap_sync_generator_with_span(span, result, span_name, entity_type) - # Non-streaming: finalize now _add_output_attributes(span, result) - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type) return result return cast(Callable[P, R], async_wrapper) @@ -425,21 +414,19 @@ async def async_wrapper(*args: Any, **kwargs: Any) -> Any: @functools.wraps(func) def sync_wrapper(*args: Any, **kwargs: Any) -> Any: - # Push entity before span starts so processors can capture it + if not isinstance(as_type, SpanType): + logger.error("Invalid span type: %s", as_type) + return func(*args, **kwargs) + SessionManager.push_entity(entity_type, span_name) tracer = trace.get_tracer(module_name) span = tracer.start_span(span_name) - # Set span type if provided - if as_type is not None: - if not isinstance(as_type, SpanType): - logger.error("Invalid span type: %s", as_type) - return - try: - span.set_attribute("netra.span.type", as_type.value) - except Exception: - pass - # Register and activate span + try: + span.set_attribute("netra.span.type", as_type.value) + except Exception: + pass + try: SessionManager.register_span(span_name, span) SessionManager.set_current_span(span) @@ -451,14 +438,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: try: result = func(*args, **kwargs) except Exception as e: - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.error", str(e)) - span.record_exception(e) - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type, error=e) raise # If result is streaming, defer span end to when stream completes @@ -469,14 +449,8 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: if _is_sync_generator(result): return _wrap_sync_generator_with_span(span, result, span_name, entity_type) # type: ignore[arg-type] - # Non-streaming: finalize now _add_output_attributes(span, result) - span.end() - try: - SessionManager.unregister_span(span_name, span) - except Exception: - logger.exception("Failed to unregister span '%s' from SessionManager", span_name) - SessionManager.pop_entity(entity_type) + _finalize_span(span, span_name, entity_type) return result return cast(Callable[P, R], sync_wrapper) diff --git a/netra/evaluation/api.py b/netra/evaluation/api.py index 5b5441c..da5430f 100644 --- a/netra/evaluation/api.py +++ b/netra/evaluation/api.py @@ -184,6 +184,8 @@ def create_run( if evaluators_config: evaluators_config_dicts = [e.model_dump(by_alias=True) for e in evaluators_config] response = self._client.create_run(name=name, dataset_id=dataset_id, evaluators_config=evaluators_config_dicts) + if not response: + return None run_id = response.get("id", None) return run_id diff --git a/netra/evaluation/client.py b/netra/evaluation/client.py index 9756fa4..d9f6cee 100644 --- a/netra/evaluation/client.py +++ b/netra/evaluation/client.py @@ -1,21 +1,19 @@ import asyncio import logging -import os import time from typing import Any, Dict, List, Optional -import httpx - +from netra.client import BaseNetraClient from netra.config import Config from netra.evaluation.models import DatasetItem, TurnType logger = logging.getLogger(__name__) +_LOG_PREFIX = "netra.evaluation" + -class EvaluationHttpClient: - """ - Internal HTTP client for Evaluation APIs. - """ +class EvaluationHttpClient(BaseNetraClient): + """Internal HTTP client for Evaluation APIs.""" def __init__(self, config: Config) -> None: """ @@ -24,88 +22,17 @@ def __init__(self, config: Config) -> None: Args: config: The configuration object. """ - self._client: Optional[httpx.Client] = self._create_client(config) - - def _create_client(self, config: Config) -> Optional[httpx.Client]: - """ - Create an HTTP client for evaluation endpoints. - - Args: - config: The configuration object. - - Returns: - An HTTP client for evaluation endpoints, or None if creation fails. - """ - endpoint = (config.otlp_endpoint or "").strip() - if not endpoint: - logger.error("netra.evaluation: NETRA_OTLP_ENDPOINT is required for evaluation APIs") - return None - - base_url = self._resolve_base_url(endpoint) - headers = self._build_headers(config) - timeout = self._get_timeout() - - try: - return httpx.Client(base_url=base_url, headers=headers, timeout=timeout) - except Exception as exc: - logger.error("netra.evaluation: Failed to initialize evaluation HTTP client: %s", exc) - return None - - def _resolve_base_url(self, endpoint: str) -> str: - """ - Resolve base URL from endpoint. - - Args: - endpoint: The endpoint to resolve. - - Returns: - The resolved base URL. - """ - base_url = endpoint.rstrip("/") - if base_url.endswith("/telemetry"): - base_url = base_url[: -len("/telemetry")] - return base_url - - def _build_headers(self, config: Config) -> Dict[str, str]: - """ - Build Headers for Evaluation Client - - Args: - config: The configuration object. - - Returns: - The headers for evaluation client. - """ - headers: Dict[str, str] = dict(config.headers or {}) - api_key = config.api_key - if api_key: - headers["x-api-key"] = api_key - return headers - - def _get_timeout(self) -> float: - """ - Get timeout for evaluation client. - - Returns: - The timeout for evaluation client. - """ - timeout_env = os.getenv("NETRA_EVALUATION_TIMEOUT") - if not timeout_env: - return 10.0 - try: - return float(timeout_env) - except ValueError: - logger.warning( - "netra.evaluation: Invalid NETRA_EVALUATION_TIMEOUT value '%s', using default 10.0", - timeout_env, - ) - return 10.0 + super().__init__( + config, + log_prefix=_LOG_PREFIX, + timeout_env_var="NETRA_EVALUATION_TIMEOUT", + ) def create_dataset( self, name: Optional[str], tags: Optional[List[str]] = None, turn_type: TurnType = TurnType.SINGLE ) -> Any: """ - Create an empty dataset + Create an empty dataset. Args: name: The name of the dataset. @@ -113,11 +40,10 @@ def create_dataset( turn_type: The turn type of the dataset, either "single" or "multi". Defaults to "single". Returns: - A backend JSON response containing dataset info (id, name, tags, etc.) on success, - or None if creation fails. + A backend JSON response containing dataset info on success, or None if creation fails. """ if not self._client: - logger.error("netra.evaluation: Evaluation client is not initialized; cannot create dataset") + logger.error("%s: Client is not initialized; cannot create dataset", _LOG_PREFIX) return None try: url = "/evaluations/dataset" @@ -126,29 +52,26 @@ def create_dataset( response.raise_for_status() data = response.json() if isinstance(data, dict) and "data" in data: - logger.info("netra.evaluation: Dataset created successfully") + logger.info("%s: Dataset created successfully", _LOG_PREFIX) return data.get("data", {}) - except Exception: - response_json = response.json() - logger.error( - "netra.evaluation: Failed to create dataset: %s", response_json.get("error").get("message", "") - ) + except Exception as exc: + logger.error("%s: Failed to create dataset: %s", _LOG_PREFIX, self._extract_error_message(exc)) return None def add_dataset_item(self, dataset_id: str, item: DatasetItem) -> Any: """ - Add a single item to an existing dataset and return backend data (e.g., new item id). + Add a single item to an existing dataset. Args: dataset_id: The id of the dataset to which the item will be added. - item_payload: The dataset item to add. + item: The dataset item to add. Returns: - A backend JSON response on success or {"success": False} on error. + A backend JSON response on success or None on error. """ if not self._client: - logger.error("netra.evaluation: Evaluation client is not initialized; cannot add item to dataset") - return {"success": False} + logger.error("%s: Client is not initialized; cannot add item to dataset", _LOG_PREFIX) + return None try: url = f"/evaluations/dataset/{dataset_id}/items" item_payload: Dict[str, Any] = { @@ -161,14 +84,14 @@ def add_dataset_item(self, dataset_id: str, item: DatasetItem) -> Any: response.raise_for_status() data = response.json() if isinstance(data, dict) and "data" in data: - logger.info("netra.evaluation: Dataset item added successfully") + logger.info("%s: Dataset item added successfully", _LOG_PREFIX) return data.get("data", {}) - except Exception: - response_json = response.json() + except Exception as exc: logger.error( - "netra.evaluation: Failed to add item to dataset '%s': %s", + "%s: Failed to add item to dataset '%s': %s", + _LOG_PREFIX, dataset_id, - response_json.get("error").get("message", ""), + self._extract_error_message(exc), ) return None @@ -180,25 +103,25 @@ def get_dataset(self, dataset_id: str) -> Any: dataset_id: The id of the dataset to fetch. Returns: - A list of dataset items. + A list of dataset items, or None on error. """ if not self._client: - logger.error("netra.evaluation: Evaluation client is not initialized; cannot fetch dataset") - return {"success": False} + logger.error("%s: Client is not initialized; cannot fetch dataset", _LOG_PREFIX) + return None try: url = f"/evaluations/dataset/{dataset_id}" response = self._client.get(url) response.raise_for_status() data = response.json() if isinstance(data, dict) and "data" in data: - logger.info("netra.evaluation: Dataset fetched successfully") + logger.info("%s: Dataset fetched successfully", _LOG_PREFIX) return data.get("data", []) - except Exception: - response_json = response.json() + except Exception as exc: logger.error( - "netra.evaluation: Failed to fetch dataset '%s': %s", + "%s: Failed to fetch dataset '%s': %s", + _LOG_PREFIX, dataset_id, - response_json.get("error").get("message", ""), + self._extract_error_message(exc), ) return None @@ -209,7 +132,7 @@ def create_run( evaluators_config: Optional[List[Dict[str, Any]]] = None, ) -> Any: """ - Create a new run based on the provided name, dataset_id, and evaluators_config. + Create a new evaluation run. Args: name: The name of the run. @@ -217,13 +140,13 @@ def create_run( evaluators_config: Optional list of evaluators to be used for the run. Returns: - A backend JSON response containing run_id + A backend JSON response containing run_id, or None on failure. """ if not self._client: - logger.error("netra.evaluation: Evaluation client is not initialized; cannot create run") - return {"success": False} + logger.error("%s: Client is not initialized; cannot create run", _LOG_PREFIX) + return None try: - url = f"/evaluations/test_run" + url = "/evaluations/test_run" payload: Dict[str, Any] = { "name": name, "datasetId": dataset_id if dataset_id else None, @@ -234,12 +157,10 @@ def create_run( data = response.json() if isinstance(data, dict) and "data" in data: return data.get("data", {}) - except Exception: - response_json = response.json() - logger.error( - "netra.evaluation: Failed to create run '%s': %s", name, response_json.get("error").get("message", "") - ) - return {"success": False} + return data + except Exception as exc: + logger.error("%s: Failed to create run '%s': %s", _LOG_PREFIX, name, self._extract_error_message(exc)) + return None def post_run_item(self, run_id: str, payload: Dict[str, Any]) -> Any: """ @@ -250,11 +171,11 @@ def post_run_item(self, run_id: str, payload: Dict[str, Any]) -> Any: payload: The run item to add. Returns: - A backend JSON response on success or {"success": False} on error. + The run item id on success, or None on failure. """ if not self._client: - logger.error("netra.evaluation: Evaluation client is not initialized; cannot post run item") - return {"success": False} + logger.error("%s: Client is not initialized; cannot post run item", _LOG_PREFIX) + return None try: url = f"/evaluations/run/{run_id}/item" response = self._client.post(url, json=payload) @@ -265,23 +186,23 @@ def post_run_item(self, run_id: str, payload: Dict[str, Any]) -> Any: run_item_id = run_item.get("id") return run_item_id return data - except Exception: - response_json = response.json() + except Exception as exc: logger.error( - "netra.evaluation: Failed to post run item for run '%s': %s", + "%s: Failed to post run item for run '%s': %s", + _LOG_PREFIX, run_id, - response_json.get("error").get("message", ""), + self._extract_error_message(exc), ) - return {"success": False} + return None def submit_local_evaluations( self, run_id: str, test_run_item_id: str, evaluator_results: List[Dict[str, Any]] ) -> Any: """ - Submit local evaluations result + Submit local evaluations result. Args: - run_id: The id of the run to which the item will be added. + run_id: The id of the run. test_run_item_id: The id of the test run item. evaluator_results: The evaluator results to submit. @@ -289,8 +210,8 @@ def submit_local_evaluations( A backend JSON response containing confirmation of the submission. """ if not self._client: - logger.error("netra.evaluation: Evaluation client is not initialized; cannot submit local evaluations") - return {"success": False} + logger.error("%s: Client is not initialized; cannot submit local evaluations", _LOG_PREFIX) + return None try: url = f"/evaluations/run/{run_id}/item/{test_run_item_id}/local-evaluations" payload: Dict[str, Any] = {"evaluatorResults": evaluator_results} @@ -300,30 +221,30 @@ def submit_local_evaluations( if isinstance(data, dict) and "data" in data: return data.get("data", {}) return data - except Exception: - response_json = response.json() + except Exception as exc: logger.error( - "netra.evaluation: Failed to submit local evaluations for run '%s', item '%s': %s", + "%s: Failed to submit local evaluations for run '%s', item '%s': %s", + _LOG_PREFIX, run_id, test_run_item_id, - response_json.get("error").get("message", ""), + self._extract_error_message(exc), ) - return {"success": False} + return None def post_run_status(self, run_id: str, status: str) -> Any: """ - Submit the run status + Submit the run status. Args: - run_id: The id of the run to which the item will be added. - status: The status of the run. + run_id: The id of the run. + status: The status of the run. - Returns: - A backend JSON response containing confirmation of the submission. + Returns: + A backend JSON response containing confirmation of the submission. """ if not self._client: - logger.error("netra.evaluation: Evaluation client is not initialized; cannot post run status") - return {"success": False} + logger.error("%s: Client is not initialized; cannot post run status", _LOG_PREFIX) + return None try: url = f"/evaluations/run/{run_id}/status" payload: Dict[str, Any] = {"status": status} @@ -331,17 +252,17 @@ def post_run_status(self, run_id: str, status: str) -> Any: response.raise_for_status() data = response.json() if isinstance(data, dict) and "data" in data: - logger.info("netra.evaluation: Completed test run successfully") + logger.info("%s: Completed test run successfully", _LOG_PREFIX) return data.get("data", {}) return data - except Exception: - response_json = response.json() + except Exception as exc: logger.error( - "netra.evaluation: Failed to post run status for run '%s': %s", + "%s: Failed to post run status for run '%s': %s", + _LOG_PREFIX, run_id, - response_json.get("error").get("message", ""), + self._extract_error_message(exc), ) - return {"success": False} + return None def get_span_by_id(self, span_id: str) -> Any: """ @@ -354,17 +275,18 @@ def get_span_by_id(self, span_id: str) -> Any: The span data if found, None otherwise. """ if not self._client: - logger.error("netra.evaluation: Evaluation client is not initialized; cannot get span") + logger.error("%s: Client is not initialized; cannot get span", _LOG_PREFIX) return None try: - url = f"sdk/traces/spans/{span_id}" + url = f"/sdk/traces/spans/{span_id}" response = self._client.get(url) response.raise_for_status() data = response.json() if isinstance(data, dict): return data.get("data", data) return data - except Exception: + except Exception as exc: + logger.error("%s: Failed to get span '%s': %s", _LOG_PREFIX, span_id, self._extract_error_message(exc)) return None async def wait_for_span_ingestion( @@ -377,9 +299,6 @@ async def wait_for_span_ingestion( """ Wait until a span is available in the backend. - Polls the GET /spans/:id endpoint to verify span availability - before running evaluators. - Args: span_id: The span ID to poll for. timeout_seconds: Maximum time to wait for span ingestion. diff --git a/netra/exporters/filtering_span_exporter.py b/netra/exporters/filtering_span_exporter.py index cea41e1..8df9c1e 100644 --- a/netra/exporters/filtering_span_exporter.py +++ b/netra/exporters/filtering_span_exporter.py @@ -9,7 +9,7 @@ from netra.exporters.utils import add_blocked_trace_id, get_trace_id, is_trace_id_blocked, is_trial_blocked from netra.processors.local_filtering_span_processor import ( - BLOCKED_LOCAL_PARENT_MAP, + blocked_local_parent_map_snapshot, ) logger = logging.getLogger(__name__) @@ -113,12 +113,7 @@ def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: # Merge with registry of locally blocked spans captured by processor to handle # cases where children export before their blocked parent (SimpleSpanProcessor) - merged_map: Dict[Any, Any] = {} - try: - if BLOCKED_LOCAL_PARENT_MAP: - merged_map.update(BLOCKED_LOCAL_PARENT_MAP) - except Exception: - pass + merged_map: Dict[Any, Any] = blocked_local_parent_map_snapshot() merged_map.update(blocked_parent_map) if merged_map: diff --git a/netra/instrumentation/__init__.py b/netra/instrumentation/__init__.py index d6c0de9..c5e70cc 100644 --- a/netra/instrumentation/__init__.py +++ b/netra/instrumentation/__init__.py @@ -7,7 +7,7 @@ from traceloop.sdk import Instruments, Telemetry from traceloop.sdk.utils.package_check import is_package_installed -from netra.instrumentation.instruments import CustomInstruments, NetraInstruments +from netra.instrumentation.instruments import DEFAULT_INSTRUMENTS, CustomInstruments, NetraInstruments def init_instrumentations( @@ -18,12 +18,15 @@ def init_instrumentations( ) -> None: from traceloop.sdk.tracing.tracing import init_instrumentations + # When the user does not pass instruments, fall back to the curated default set. + resolved_instruments = instruments if instruments is not None else DEFAULT_INSTRUMENTS + traceloop_instruments = set() traceloop_block_instruments = set() netra_custom_instruments = set() netra_custom_block_instruments = set() - if instruments: - for instrument in instruments: + if resolved_instruments: + for instrument in resolved_instruments: if instrument.origin == CustomInstruments: # type: ignore[attr-defined] netra_custom_instruments.add(getattr(CustomInstruments, instrument.name)) else: @@ -36,18 +39,13 @@ def init_instrumentations( traceloop_block_instruments.add(getattr(Instruments, instrument.name)) # If no instruments in traceloop are provided for instrumentation - if instruments and not traceloop_instruments and not traceloop_block_instruments: + if resolved_instruments and not traceloop_instruments and not traceloop_block_instruments: traceloop_block_instruments = set(Instruments) # If no custom instruments in netra are provided for instrumentation - if instruments and not netra_custom_instruments and not netra_custom_block_instruments: + if resolved_instruments and not netra_custom_instruments and not netra_custom_block_instruments: netra_custom_block_instruments = set(CustomInstruments) - # If no instruments are provided for instrumentation, instrument all instruments - if not instruments and not block_instruments: - traceloop_instruments = set(Instruments) - netra_custom_instruments = set(CustomInstruments) - netra_custom_instruments = netra_custom_instruments - netra_custom_block_instruments traceloop_instruments = traceloop_instruments - traceloop_block_instruments if not traceloop_instruments: @@ -1369,4 +1367,4 @@ def init_claude_agent_sdk_instrumentation() -> bool: except Exception as e: logging.error(f"Error initializing Claude Agent SDK instrumentor: {e}") Telemetry().log_exception(e) - return False \ No newline at end of file + return False diff --git a/netra/instrumentation/claude_agent_sdk/__init__.py b/netra/instrumentation/claude_agent_sdk/__init__.py index 894137b..47a52b8 100644 --- a/netra/instrumentation/claude_agent_sdk/__init__.py +++ b/netra/instrumentation/claude_agent_sdk/__init__.py @@ -1,22 +1,21 @@ -import wrapt import logging -from opentelemetry.trace import Tracer, get_tracer -from opentelemetry.instrumentation.utils import unwrap +from typing import Any + +import wrapt from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.instrumentation.utils import unwrap +from opentelemetry.trace import Tracer, get_tracer from netra.instrumentation.claude_agent_sdk.version import __version__ -from netra.instrumentation.claude_agent_sdk.wrappers import ( - client_query_wrapper, - client_response_wrapper, - query_wrapper -) +from netra.instrumentation.claude_agent_sdk.wrappers import client_query_wrapper, client_response_wrapper, query_wrapper logger = logging.getLogger(__name__) -_instruments = ("claude_agent_sdk >= 0.1.0", ) +_instruments = ("claude_agent_sdk >= 0.1.0",) + -class NetraClaudeAgentSDKInstrumentor(BaseInstrumentor): - def instrumentation_dependencies(self): +class NetraClaudeAgentSDKInstrumentor(BaseInstrumentor): # type: ignore[misc] + def instrumentation_dependencies(self) -> tuple[str, ...]: """ Return the list of packages required for this instrumentation to function. @@ -27,8 +26,8 @@ def instrumentation_dependencies(self): tuple: A tuple of pip requirement strings for the instrumented library. """ return _instruments - - def _instrument(self, **kwargs): + + def _instrument(self, **kwargs: Any) -> None: """ Set up OpenTelemetry instrumentation for the Claude Agent SDK. @@ -52,7 +51,7 @@ def _instrument(self, **kwargs): self._instrument_client_query(tracer) self._instrument_client_response(tracer) - def _uninstrument(self, **kwargs): + def _uninstrument(self, **kwargs: Any) -> None: """ Remove all custom instrumentation wrappers from the Claude Agent SDK. @@ -66,7 +65,7 @@ def _uninstrument(self, **kwargs): self._uninstrument_client_query() self._uninstrument_client_response() - def _instrument_query(self, tracer: Tracer): + def _instrument_query(self, tracer: Tracer) -> None: """ Wrap InternalClient.process_query with a tracing wrapper. @@ -78,14 +77,12 @@ def _instrument_query(self, tracer: Tracer): """ try: wrapt.wrap_function_wrapper( - "claude_agent_sdk._internal.client", - "InternalClient.process_query", - query_wrapper(tracer) + "claude_agent_sdk._internal.client", "InternalClient.process_query", query_wrapper(tracer) ) except Exception as e: logger.error(f"Failed to instrument claude-agent-sdk query: {e}") - def _instrument_client_query(self, tracer: Tracer): + def _instrument_client_query(self, tracer: Tracer) -> None: """ Wrap ClaudeSDKClient.query to capture the prompt for downstream tracing. @@ -96,15 +93,11 @@ def _instrument_client_query(self, tracer: Tracer): None """ try: - wrapt.wrap_function_wrapper( - "claude_agent_sdk.client", - "ClaudeSDKClient.query", - client_query_wrapper() - ) + wrapt.wrap_function_wrapper("claude_agent_sdk.client", "ClaudeSDKClient.query", client_query_wrapper()) except Exception as e: logger.error(f"Failed to instrument claude-sdk-client query: {e}") - def _instrument_client_response(self, tracer: Tracer): + def _instrument_client_response(self, tracer: Tracer) -> None: """ Wrap ClaudeSDKClient.receive_messages with a tracing wrapper. @@ -116,14 +109,12 @@ def _instrument_client_response(self, tracer: Tracer): """ try: wrapt.wrap_function_wrapper( - "claude_agent_sdk.client", - "ClaudeSDKClient.receive_messages", - client_response_wrapper(tracer) + "claude_agent_sdk.client", "ClaudeSDKClient.receive_messages", client_response_wrapper(tracer) ) except Exception as e: logger.error(f"Failed to instrument claude-sdk-client response: {e}") - def _uninstrument_query(self): + def _uninstrument_query(self) -> None: """ Remove the tracing wrapper from InternalClient.process_query. @@ -138,7 +129,7 @@ def _uninstrument_query(self): except (AttributeError, ModuleNotFoundError): logger.error(f"Failed to uninstrument claude-agent-sdk query") - def _uninstrument_client_query(self): + def _uninstrument_client_query(self) -> None: """ Remove the tracing wrapper from ClaudeSDKClient.query. @@ -153,7 +144,7 @@ def _uninstrument_client_query(self): except (AttributeError, ModuleNotFoundError): logger.error(f"Failed to uninstrument claude-sdk-client query") - def _uninstrument_client_response(self): + def _uninstrument_client_response(self) -> None: """ Remove the tracing wrapper from ClaudeSDKClient.receive_messages. @@ -166,4 +157,4 @@ def _uninstrument_client_response(self): try: unwrap("claude_agent_sdk.client", "ClaudeSDKClient.receive_messages") except (AttributeError, ModuleNotFoundError): - logger.error(f"Failed to uninstrument claude-sdk-client response") \ No newline at end of file + logger.error(f"Failed to uninstrument claude-sdk-client response") diff --git a/netra/instrumentation/claude_agent_sdk/utils.py b/netra/instrumentation/claude_agent_sdk/utils.py index f63af73..02735c7 100644 --- a/netra/instrumentation/claude_agent_sdk/utils.py +++ b/netra/instrumentation/claude_agent_sdk/utils.py @@ -2,20 +2,22 @@ import logging import threading from typing import Any -from opentelemetry.context import Context -from opentelemetry.trace import Span, Tracer -from opentelemetry.semconv_ai import SpanAttributes + from claude_agent_sdk import ( - ClaudeAgentOptions, - SystemMessage, AssistantMessage, - UserMessage, + ClaudeAgentOptions, ResultMessage, + SystemMessage, TextBlock, ThinkingBlock, - ToolUseBlock, ToolResultBlock, + ToolUseBlock, + UserMessage, ) +from opentelemetry.context import Context +from opentelemetry.semconv_ai import SpanAttributes +from opentelemetry.trace import Span, Tracer + from netra.config import Config logger = logging.getLogger(__name__) @@ -67,7 +69,7 @@ def _set_conversation(span: Span, role: str, content: str, prompt_index: int = 0 return prompt_index -def _set_usage(span: Span, usage: dict) -> None: +def _set_usage(span: Span, usage: dict[str, Any]) -> None: """ Write token usage attributes to the span. @@ -117,10 +119,10 @@ def _set_tool_result(tracer: Tracer, parent_ctx: Context, block: ToolResultBlock with tracer.start_as_current_span(tool_name, parent_ctx) as span: try: if tool_call: - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.input", _serialize(tool_call.get("input", {}))) + span.set_attribute("input", _serialize(tool_call.get("input", {}))) span.set_attribute(f"{Config.LIBRARY_NAME}.span.type", "TOOL") - span.set_attribute(f"{Config.LIBRARY_NAME}.entity.output", _serialize(block.content)) + span.set_attribute(f"output", _serialize(block.content)) except Exception as e: logger.error(f"Cannot set tool result attributes for tool_use_id={block.tool_use_id}: {e}") diff --git a/netra/instrumentation/claude_agent_sdk/version.py b/netra/instrumentation/claude_agent_sdk/version.py index d538f87..5becc17 100644 --- a/netra/instrumentation/claude_agent_sdk/version.py +++ b/netra/instrumentation/claude_agent_sdk/version.py @@ -1 +1 @@ -__version__ = "1.0.0" \ No newline at end of file +__version__ = "1.0.0" diff --git a/netra/instrumentation/claude_agent_sdk/wrappers.py b/netra/instrumentation/claude_agent_sdk/wrappers.py index 25eb13d..cf6cb78 100644 --- a/netra/instrumentation/claude_agent_sdk/wrappers.py +++ b/netra/instrumentation/claude_agent_sdk/wrappers.py @@ -1,17 +1,18 @@ import logging from typing import Any, AsyncIterator, Callable, Tuple + +from claude_agent_sdk import AssistantMessage, ResultMessage, SystemMessage, UserMessage from opentelemetry import trace from opentelemetry.context import Context from opentelemetry.trace import Span, SpanKind, Tracer from opentelemetry.trace.status import Status, StatusCode -from claude_agent_sdk import SystemMessage, AssistantMessage, UserMessage, ResultMessage from netra.instrumentation.claude_agent_sdk.utils import ( + set_assistant_message_attributes, set_request_attributes, + set_result_message_attributes, set_system_message_attributes, - set_assistant_message_attributes, set_user_message_attributes, - set_result_message_attributes, ) logger = logging.getLogger(__name__) @@ -24,9 +25,9 @@ async def _dispatch_messages( tracer: Tracer, root_span: Span, root_ctx: Context, - aiterator: AsyncIterator, + aiterator: AsyncIterator[Any], prompt_index: int, -) -> AsyncIterator: +) -> AsyncIterator[Any]: """ Dispatch each incoming SDK message to its span attribute handler and yield it. @@ -61,7 +62,7 @@ async def _dispatch_messages( yield message -def query_wrapper(tracer: Tracer): +def query_wrapper(tracer: Tracer) -> Callable[..., Any]: """ Return a wrapper that traces a single InternalClient.process_query call with child spans per message. @@ -71,6 +72,7 @@ def query_wrapper(tracer: Tracer): Returns: Callable: An async generator wrapper function for InternalClient.process_query. """ + async def wrapper( wrapped: Callable[..., Any], instance: Any, @@ -126,7 +128,7 @@ async def wrapper( return wrapper -def client_query_wrapper(): +def client_query_wrapper() -> Callable[..., Any]: """ Return a wrapper that captures the prompt from ClaudeSDKClient.query for later tracing. @@ -136,6 +138,7 @@ def client_query_wrapper(): Returns: Callable: An async wrapper function for ClaudeSDKClient.query. """ + async def wrapper( wrapped: Callable[..., Any], instance: Any, @@ -160,7 +163,7 @@ async def wrapper( return wrapper -def client_response_wrapper(tracer: Tracer): +def client_response_wrapper(tracer: Tracer) -> Callable[..., Any]: """ Return a wrapper that traces a full ClaudeSDKClient.receive_messages call covering all messages. @@ -170,6 +173,7 @@ def client_response_wrapper(tracer: Tracer): Returns: Callable: An async generator wrapper function for ClaudeSDKClient.receive_messages. """ + async def wrapper( wrapped: Callable[..., Any], instance: Any, diff --git a/netra/instrumentation/google_adk/__init__.py b/netra/instrumentation/google_adk/__init__.py index 21a1c63..255d5cc 100644 --- a/netra/instrumentation/google_adk/__init__.py +++ b/netra/instrumentation/google_adk/__init__.py @@ -10,28 +10,44 @@ from netra.instrumentation.google_adk.version import __version__ from netra.instrumentation.google_adk.wrappers import ( NoOpTracer, - adk_trace_call_llm_wrapper, - adk_trace_send_data_wrapper, - adk_trace_tool_call_wrapper, - adk_trace_tool_response_wrapper, base_agent_run_async_wrapper, base_llm_flow_call_llm_async_wrapper, call_tool_async_wrapper, - finalize_model_response_event_wrapper, ) logger = logging.getLogger(__name__) _instruments = ("google-adk >= 0.1.0",) +_ADK_TRACER_MODULES = ( + "google.adk.agents.base_agent", + "google.adk.flows.llm_flows.base_llm_flow", + "google.adk.flows.llm_flows.functions", + "google.adk.models.gemini_context_cache_manager", + "google.adk.models.google_llm", + "google.adk.plugins.bigquery_agent_analytics_plugin", + "google.adk.runners", + "google.adk.telemetry", + "google.adk.telemetry.tracing", + "google.cloud.bigquery.opentelemetry_tracing", + "google.cloud.pubsub_v1.open_telemetry.publish_message_wrapper", + "google.cloud.pubsub_v1.open_telemetry.subscribe_opentelemetry", + "google.cloud.pubsub_v1.publisher._batch.thread", + "google.cloud.spanner_v1._opentelemetry_tracing", + "google.cloud.sqlalchemy_spanner._opentelemetry_tracing", + "google.cloud.storage._opentelemetry_tracing", +) + class NetraGoogleADKInstrumentor(BaseInstrumentor): # type: ignore[misc] """Custom Google ADK instrumentor for Netra SDK.""" def instrumentation_dependencies(self) -> Collection[str]: + """Return the package requirements for this instrumentor.""" return _instruments def _instrument(self, **kwargs) -> Any: # type: ignore[no-untyped-def] + """Patch ADK with Netra spans and replace ADK's own tracers with NoOps to avoid duplicates.""" try: tracer_provider = kwargs.get("tracer_provider") tracer = get_tracer(__name__, __version__, tracer_provider) @@ -39,20 +55,8 @@ def _instrument(self, **kwargs) -> Any: # type: ignore[no-untyped-def] logger.error(f"Failed to initialize tracer: {e}") return - # Set ADK tracer to NoOpTracer to prevent ADK from creating its own spans - try: - import google.adk.telemetry as adk_telemetry - - adk_telemetry.tracer = NoOpTracer() - except Exception as e: - logger.debug(f"Unable to replace ADK tracer: {e}") - - for module_name in ( - "google.adk.runners", - "google.adk.agents.base_agent", - "google.adk.flows.llm_flows.base_llm_flow", - "google.adk.flows.llm_flows.functions", - ): + # Replace ADK's own tracers with NoOp to prevent duplicate spans + for module_name in _ADK_TRACER_MODULES: try: if module_name in sys.modules: module = sys.modules[module_name] @@ -70,29 +74,14 @@ def _instrument(self, **kwargs) -> Any: # type: ignore[no-untyped-def] except Exception as e: logger.error(f"Failed to instrument BaseAgent.run_async: {e}") - try: - wrap_function_wrapper("google.adk.telemetry", "trace_tool_call", adk_trace_tool_call_wrapper(tracer)) - wrap_function_wrapper( - "google.adk.telemetry", "trace_tool_response", adk_trace_tool_response_wrapper(tracer) - ) - wrap_function_wrapper("google.adk.telemetry", "trace_call_llm", adk_trace_call_llm_wrapper(tracer)) - wrap_function_wrapper("google.adk.telemetry", "trace_send_data", adk_trace_send_data_wrapper(tracer)) - except Exception as e: - logger.error(f"Failed to instrument ADK telemetry functions: {e}") - try: wrap_function_wrapper( "google.adk.flows.llm_flows.base_llm_flow", "BaseLlmFlow._call_llm_async", base_llm_flow_call_llm_async_wrapper(tracer), ) - wrap_function_wrapper( - "google.adk.flows.llm_flows.base_llm_flow", - "BaseLlmFlow._finalize_model_response_event", - finalize_model_response_event_wrapper(tracer), - ) except Exception as e: - logger.error(f"Failed to instrument BaseLlmFlow methods: {e}") + logger.error(f"Failed to instrument BaseLlmFlow._call_llm_async: {e}") try: wrap_function_wrapper( @@ -104,44 +93,21 @@ def _instrument(self, **kwargs) -> Any: # type: ignore[no-untyped-def] logger.error(f"Failed to instrument __call_tool_async: {e}") def _uninstrument(self, **kwargs) -> None: # type: ignore[no-untyped-def] - # Unwrap in reverse order + """Remove Netra wrappers. Note: replaced ADK tracers are intentionally left as NoOps.""" try: unwrap("google.adk.flows.llm_flows.functions", "__call_tool_async") except (AttributeError, ModuleNotFoundError): logger.debug("Failed to uninstrument __call_tool_async") try: - unwrap( - "google.adk.flows.llm_flows.base_llm_flow", - "BaseLlmFlow._finalize_model_response_event", - ) - unwrap( - "google.adk.flows.llm_flows.base_llm_flow", - "BaseLlmFlow._call_llm_async", - ) - except (AttributeError, ModuleNotFoundError): - logger.debug("Failed to uninstrument BaseLlmFlow methods") - - try: - unwrap("google.adk.telemetry", "trace_send_data") - unwrap("google.adk.telemetry", "trace_call_llm") - unwrap("google.adk.telemetry", "trace_tool_response") - unwrap("google.adk.telemetry", "trace_tool_call") + unwrap("google.adk.flows.llm_flows.base_llm_flow", "BaseLlmFlow._call_llm_async") except (AttributeError, ModuleNotFoundError): - logger.debug("Failed to uninstrument ADK telemetry functions") + logger.debug("Failed to uninstrument BaseLlmFlow._call_llm_async") try: unwrap("google.adk.agents.base_agent", "BaseAgent.run_async") except (AttributeError, ModuleNotFoundError): logger.debug("Failed to uninstrument BaseAgent.run_async") - try: - import google.adk.telemetry as adk_telemetry - from opentelemetry import trace as otel_trace - - adk_telemetry.tracer = otel_trace.get_tracer("gcp.vertex.agent") - except Exception: - pass - __all__ = ["NetraGoogleADKInstrumentor"] diff --git a/netra/instrumentation/google_adk/utils.py b/netra/instrumentation/google_adk/utils.py index 104cf46..fcb5a52 100644 --- a/netra/instrumentation/google_adk/utils.py +++ b/netra/instrumentation/google_adk/utils.py @@ -1,25 +1,54 @@ import json -from typing import Any, Dict +import logging +from typing import Any, Dict, List from opentelemetry.semconv_ai import SpanAttributes +logger = logging.getLogger(__name__) -def _build_llm_request_for_trace(llm_request: Any) -> Dict[str, Any]: +NETRA_SPAN_TYPE = "netra.span.type" + + +def build_llm_request_for_trace(llm_request: Any) -> Dict[str, Any]: + """Serialize an ADK LlmRequest into a plain dict suitable for tracing, stripping binary/schema fields. + + Args: + llm_request: The ADK LlmRequest object to serialize. + + Returns: + A dictionary with model, config, and contents fields suitable for tracing. + """ from google.genai import types + model = getattr(llm_request, "model", None) + request_config = getattr(llm_request, "config", None) + try: + if request_config: + request_config = request_config.model_dump(exclude_none=True, exclude={"response_schema"}) + except Exception as e: + logger.warning("Failed to model dump LLM request config: %s", e) + result: Dict[str, Any] = { - "model": llm_request.model, - "config": llm_request.config.model_dump(exclude_none=True, exclude="response_schema"), + "model": model, + "config": request_config, "contents": [], } - for content in llm_request.contents: + for content in llm_request.contents or []: parts = [part for part in content.parts if not hasattr(part, "inline_data") or not part.inline_data] result["contents"].append(types.Content(role=content.role, parts=parts).model_dump(exclude_none=True)) return result -def _extract_llm_attributes(llm_request_dict: Dict[str, Any], llm_response: Any) -> Dict[str, Any]: +def extract_llm_request_attributes(llm_request_dict: Dict[str, Any]) -> Dict[str, Any]: + """Convert a serialized LLM request dict into OpenTelemetry span attributes. + + Args: + llm_request_dict: Serialized LLM request dict from build_llm_request_for_trace. + + Returns: + A dictionary of OpenTelemetry span attributes for the LLM request. + """ attributes: Dict[str, Any] = {} if "model" in llm_request_dict: @@ -50,17 +79,22 @@ def _extract_llm_attributes(llm_request_dict: Dict[str, Any], llm_response: Any) attributes["gen_ai.request.response_mime_type"] = config["response_mime_type"] if "tools" in config: - for i, tool in enumerate(config["tools"]): + func_index = 0 + for tool in config["tools"]: if "function_declarations" in tool: - for j, func in enumerate(tool["function_declarations"]): - attributes[f"gen_ai.request.tools.{j}.name"] = func.get("name", "") - attributes[f"gen_ai.request.tools.{j}.description"] = func.get("description", "") + for func in tool["function_declarations"]: + attributes[f"gen_ai.request.tools.{func_index}.name"] = func.get("name", "") + attributes[f"gen_ai.request.tools.{func_index}.description"] = func.get("description", "") + func_index += 1 message_index = 0 + all_inputs: List[Dict[str, Any]] = [] + if "config" in llm_request_dict and "system_instruction" in llm_request_dict["config"]: system_instruction = llm_request_dict["config"]["system_instruction"] attributes[f"{SpanAttributes.LLM_PROMPTS}.{message_index}.role"] = "system" attributes[f"{SpanAttributes.LLM_PROMPTS}.{message_index}.content"] = system_instruction + all_inputs.append({"role": "system", "content": system_instruction}) message_index += 1 if "contents" in llm_request_dict: @@ -71,7 +105,10 @@ def _extract_llm_attributes(llm_request_dict: Dict[str, Any], llm_response: Any) attributes[f"{SpanAttributes.LLM_PROMPTS}.{message_index}.role"] = role - text_parts = [] + text_parts: List[str] = [] + func_calls: List[Dict[str, Any]] = [] + func_responses: List[Dict[str, Any]] = [] + for part in parts: if "text" in part and part.get("text") is not None: text_parts.append(str(part["text"])) @@ -83,6 +120,10 @@ def _extract_llm_attributes(llm_request_dict: Dict[str, Any], llm_response: Any) ) if "id" in func_call: attributes[f"gen_ai.prompt.{message_index}.function_call.id"] = func_call["id"] + entry: Dict[str, Any] = {"name": func_call.get("name", ""), "args": func_call.get("args", {})} + if "id" in func_call: + entry["id"] = func_call["id"] + func_calls.append(entry) elif "function_response" in part: func_resp = part["function_response"] attributes[f"gen_ai.prompt.{message_index}.function_response.name"] = func_resp.get("name", "") @@ -91,63 +132,175 @@ def _extract_llm_attributes(llm_request_dict: Dict[str, Any], llm_response: Any) ) if "id" in func_resp: attributes[f"gen_ai.prompt.{message_index}.function_response.id"] = func_resp["id"] + resp_entry: Dict[str, Any] = { + "name": func_resp.get("name", ""), + "result": func_resp.get("response", {}), + } + if "id" in func_resp: + resp_entry["id"] = func_resp["id"] + func_responses.append(resp_entry) + msg: Dict[str, Any] = {"role": role} if text_parts: - attributes[f"{SpanAttributes.LLM_PROMPTS}.{message_index}.content"] = "\n".join(text_parts) + content_str = "\n".join(text_parts) + attributes[f"{SpanAttributes.LLM_PROMPTS}.{message_index}.content"] = content_str + msg["content"] = content_str + if func_calls: + msg["function_calls"] = func_calls + if func_responses: + msg["function_responses"] = func_responses + all_inputs.append(msg) message_index += 1 - if llm_response: - try: - response_dict = json.loads(llm_response) if isinstance(llm_response, str) else llm_response - - if "model" in response_dict: - attributes[SpanAttributes.LLM_RESPONSE_MODEL] = response_dict["model"] - - if "content" in response_dict and "parts" in response_dict["content"]: - parts = response_dict["content"]["parts"] - attributes[f"{SpanAttributes.LLM_COMPLETIONS}.0.role"] = "assistant" - - text_parts = [] - tool_call_index = 0 - for part in parts: - if "text" in part and part.get("text") is not None: - text_parts.append(str(part["text"])) - elif "function_call" in part: - func_call = part["function_call"] - attributes[f"gen_ai.completions.0.tool_calls.{tool_call_index}.name"] = func_call.get( - "name", "" - ) - attributes[f"gen_ai.completions.0.tool_calls.{tool_call_index}.arguments"] = json.dumps( - func_call.get("args", {}) - ) - if "id" in func_call: - attributes[f"gen_ai.completions.0.tool_calls.{tool_call_index}.id"] = func_call["id"] - tool_call_index += 1 - - if text_parts: - attributes[f"{SpanAttributes.LLM_COMPLETIONS}.0.content"] = "\n".join(text_parts) - - if "finish_reason" in response_dict: - attributes[SpanAttributes.LLM_RESPONSE_FINISH_REASON] = response_dict["finish_reason"] - - if "id" in response_dict: - attributes[SpanAttributes.LLM_RESPONSE_ID] = response_dict["id"] + try: + attributes["input"] = json.dumps(all_inputs) + except Exception: + logger.exception("Failed to serialize LLM request inputs to JSON") + attributes["input"] = str(all_inputs) + + return attributes + + +def _extract_event_attributes(event: Any) -> Dict[str, Any]: + """Extract non-content attributes from an ADK event. + + Args: + event: The event object to extract attributes from. + + Returns: + A dictionary of attributes. + """ + attributes: Dict[str, Any] = {} + + if model_version := getattr(event, "model_version", None): + attributes[SpanAttributes.LLM_REQUEST_MODEL] = model_version + if author := getattr(event, "author", None): + attributes["gen_ai.event.author"] = author + + if invocation_id := getattr(event, "invocation_id", None): + attributes["gen_ai.invocation.id"] = invocation_id + + if (timestamp := getattr(event, "timestamp", None)) is not None: + attributes["gen_ai.event.timestamp"] = timestamp + + if (finish_reason := getattr(event, "finish_reason", None)) is not None: + attributes[SpanAttributes.LLM_RESPONSE_FINISH_REASON] = ( + finish_reason.value if hasattr(finish_reason, "value") else str(finish_reason) + ) + + if (error_code := getattr(event, "error_code", None)) is not None: + attributes["gen_ai.error.code"] = error_code + + if error_message := getattr(event, "error_message", None): + attributes["gen_ai.error.message"] = error_message + + if usage := getattr(event, "usage_metadata", None): + if (v := getattr(usage, "prompt_token_count", None)) is not None: + attributes[SpanAttributes.LLM_USAGE_PROMPT_TOKENS] = v + + # Calculate completion tokens + output = 0 + if isinstance(ct := getattr(usage, "candidates_token_count", None), int): + output += ct + if isinstance(tt := getattr(usage, "thoughts_token_count", None), int): + output += tt + if output > 0: + attributes[SpanAttributes.LLM_USAGE_COMPLETION_TOKENS] = output + + if (v := getattr(usage, "total_token_count", None)) is not None: + attributes[SpanAttributes.LLM_USAGE_TOTAL_TOKENS] = v + if (v := getattr(usage, "cached_content_token_count", None)) is not None: + attributes[SpanAttributes.LLM_USAGE_CACHE_READ_INPUT_TOKENS] = v + + if (avg_logprobs := getattr(event, "avg_logprobs", None)) is not None: + attributes["gen_ai.response.avg_logprobs"] = avg_logprobs + + if actions := getattr(event, "actions", None): + if v := getattr(actions, "transfer_to_agent", None): + attributes["gen_ai.actions.transfer_to_agent"] = v + if v := getattr(actions, "state_delta", None): + attributes["gen_ai.actions.state_delta"] = json.dumps(v) + if (v := getattr(actions, "escalate", None)) is not None: + attributes["gen_ai.actions.escalate"] = v + + return attributes + + +def extract_llm_response_attributes(last_response: Any, accumulated_text: List[str]) -> Dict[str, Any]: + """Build span attributes from the final LLM response event and accumulated streamed text. + + Args: + last_response: The last ADK event from the LLM response stream. + accumulated_text: List of text strings accumulated during streaming. + + Returns: + A dictionary of OpenTelemetry span attributes for the LLM response. + """ + attributes: Dict[str, Any] = {} + attributes.update(_extract_event_attributes(last_response)) + + content = getattr(last_response, "content", None) + parts = (content.parts or []) if content else [] + + text_parts: List[str] = accumulated_text if accumulated_text else [] + tool_calls: List[Dict[str, Any]] = [] + tool_call_index = 0 + + for part in parts: + if func_call := getattr(part, "function_call", None): + entry: Dict[str, Any] = {} + if call_id := getattr(func_call, "id", None): + entry["id"] = call_id + attributes[f"gen_ai.completions.0.tool_calls.{tool_call_index}.id"] = call_id + if call_name := getattr(func_call, "name", None): + entry["name"] = call_name + attributes[f"gen_ai.completions.0.tool_calls.{tool_call_index}.name"] = call_name + args = getattr(func_call, "args", {}) + entry["arguments"] = args + attributes[f"gen_ai.completions.0.tool_calls.{tool_call_index}.arguments"] = json.dumps(args) + tool_calls.append(entry) + tool_call_index += 1 + elif not accumulated_text and (text := getattr(part, "text", None)) is not None: + text_parts.append(str(text)) + + attributes[f"{SpanAttributes.LLM_COMPLETIONS}.0.role"] = "assistant" + + output: Dict[str, Any] = {"role": "assistant"} + if text_parts: + full_text = "\n".join(text_parts) + attributes[f"{SpanAttributes.LLM_COMPLETIONS}.0.content"] = full_text + output["content"] = full_text + if tool_calls: + output["tool_calls"] = tool_calls + + if len(output) > 1: + try: + attributes["output"] = json.dumps(output) except Exception: - pass + logger.exception("Failed to serialize LLM response output to JSON") + attributes["output"] = str(output) return attributes def extract_agent_attributes(instance: Any) -> Dict[str, Any]: + """Extract agent metadata attributes from a BaseAgent instance, including nested sub-agents. + + Args: + instance: The BaseAgent instance to extract attributes from. + + Returns: + A dictionary of agent metadata attributes for tracing. + """ attributes: Dict[str, Any] = {} attributes["gen_ai.agent.name"] = getattr(instance, "name", "unknown") - if hasattr(instance, "description"): + if hasattr(instance, "description") and instance.description: attributes["gen_ai.agent.description"] = instance.description - if hasattr(instance, "model"): + if hasattr(instance, "model") and instance.model: attributes["gen_ai.agent.model"] = instance.model - if hasattr(instance, "instruction"): + if hasattr(instance, "instruction") and instance.instruction: attributes["gen_ai.agent.instruction"] = instance.instruction if hasattr(instance, "tools"): for idx, tool in enumerate(instance.tools): @@ -155,11 +308,12 @@ def extract_agent_attributes(instance: Any) -> Dict[str, Any]: attributes[f"gen_ai.agent.tools.{idx}.name"] = tool.name if hasattr(tool, "description"): attributes[f"gen_ai.agent.tools.{idx}.description"] = tool.description - if hasattr(instance, "output_key"): + if hasattr(instance, "output_key") and instance.output_key: attributes["gen_ai.agent.output_key"] = instance.output_key if hasattr(instance, "sub_agents"): for i, sub_agent in enumerate(instance.sub_agents): sub_attrs = extract_agent_attributes(sub_agent) for key, value in sub_attrs.items(): - attributes[f"gen_ai.agent.sub_agents.{i}.{key}"] = value + if value: + attributes[f"gen_ai.agent.sub_agents.{i}.{key}"] = value return attributes diff --git a/netra/instrumentation/google_adk/wrappers.py b/netra/instrumentation/google_adk/wrappers.py index 299a475..4cecef5 100644 --- a/netra/instrumentation/google_adk/wrappers.py +++ b/netra/instrumentation/google_adk/wrappers.py @@ -1,298 +1,417 @@ import json import logging -from typing import Any, AsyncIterator, Callable, Dict, Tuple, cast +import time +from typing import Any, AsyncIterator, Callable, Dict, List, Tuple, cast -import wrapt +from opentelemetry import context as opentelemetry_context from opentelemetry import trace as opentelemetry_api_trace from opentelemetry.semconv_ai import SpanAttributes -from opentelemetry.trace import SpanKind, Tracer +from opentelemetry.trace import SpanKind, StatusCode, Tracer +from netra.config import Config from netra.instrumentation.google_adk.utils import ( - _build_llm_request_for_trace, - _extract_llm_attributes, + NETRA_SPAN_TYPE, + build_llm_request_for_trace, extract_agent_attributes, + extract_llm_request_attributes, + extract_llm_response_attributes, ) +from netra.instrumentation.utils import record_span_timing +from netra.span_wrapper import SpanType + +TIME_TO_FIRST_TOKEN = "gen_ai.performance.time_to_first_token" +RELATIVE_TIME_TO_FIRST_TOKEN = "gen_ai.performance.relative_time_to_first_token" logger = logging.getLogger(__name__) class NoOpSpan: - def __init__(self, *args: Any, **kwargs: Any) -> None: - pass + """Span implementation that silently discards all operations, used to suppress ADK's built-in tracing.""" + + def __init__(self, *_args: Any, **_kwargs: Any) -> None: + """Initialize the NoOpSpan, discarding all arguments.""" def __enter__(self) -> "NoOpSpan": + """Enter the context manager. + + Returns: The NoOpSpan instance. + """ return self - def __exit__(self, *args: Any) -> None: - pass + def __exit__(self, *_args: Any) -> None: + """Exit the context manager, discarding all arguments.""" - def set_attribute(self, *args: Any, **kwargs: Any) -> None: - pass + def set_attribute(self, *_args: Any, **_kwargs: Any) -> None: + """Set a span attribute (no-op), discarding all arguments.""" - def set_attributes(self, *args: Any, **kwargs: Any) -> None: - pass + def set_attributes(self, *_args: Any, **_kwargs: Any) -> None: + """Set multiple span attributes (no-op), discarding all arguments.""" - def add_event(self, *args: Any, **kwargs: Any) -> None: - pass + def add_event(self, *_args: Any, **_kwargs: Any) -> None: + """Add a span event (no-op), discarding all arguments.""" - def set_status(self, *args: Any, **kwargs: Any) -> None: - pass + def set_status(self, *_args: Any, **_kwargs: Any) -> None: + """Set the span status (no-op), discarding all arguments.""" - def update_name(self, *args: Any, **kwargs: Any) -> None: - pass + def update_name(self, *_args: Any, **_kwargs: Any) -> None: + """Update the span name (no-op), discarding all arguments.""" def is_recording(self) -> bool: + """Check whether the span is recording. + + Returns: Always False for a NoOpSpan. + """ return False - def end(self, *args: Any, **kwargs: Any) -> None: - pass + def end(self, *_args: Any, **_kwargs: Any) -> None: + """End the span (no-op), discarding all arguments.""" - def record_exception(self, *args: Any, **kwargs: Any) -> None: - pass + def record_exception(self, *_args: Any, **_kwargs: Any) -> None: + """Record an exception on the span (no-op), discarding all arguments.""" class NoOpTracer: - def start_as_current_span(self, *args: Any, **kwargs: Any) -> NoOpSpan: - return NoOpSpan() + """Tracer that returns NoOpSpans, injected into ADK modules to prevent duplicate span emission.""" - def start_span(self, *args: Any, **kwargs: Any) -> NoOpSpan: - return NoOpSpan() + def start_as_current_span(self, *_args: Any, **_kwargs: Any) -> NoOpSpan: + """Start a span as the current span (no-op). - def use_span(self, *args: Any, **kwargs: Any) -> NoOpSpan: + Returns: A NoOpSpan instance. + """ return NoOpSpan() + def start_span(self, *_args: Any, **_kwargs: Any) -> NoOpSpan: + """Start a new span (no-op). -def base_agent_run_async_wrapper(tracer: Tracer) -> Callable[..., Any]: - def wrapper(wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: - async def new_function() -> AsyncIterator[Any]: - agent_name = instance.name if hasattr(instance, "name") else "unknown" - span_name = f"ADK.Agent.{agent_name}" - - with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span: - span.set_attribute(SpanAttributes.LLM_SYSTEM, "gcp.vertex.agent") - span.set_attribute("gen_ai.entity", "agent") - span.set_attribute("netra.span.type", "span") - - span.set_attributes(extract_agent_attributes(instance)) - if len(args) > 0 and hasattr(args[0], "invocation_id"): - span.set_attribute("adk.invocation_id", args[0].invocation_id) + Returns: A NoOpSpan instance. + """ + return NoOpSpan() - async_gen = wrapped(*args, **kwargs) - async for item in async_gen: - yield item + def use_span(self, *_args: Any, **_kwargs: Any) -> NoOpSpan: + """Use an existing span as context (no-op). - return new_function() + Returns: A NoOpSpan instance. + """ + return NoOpSpan() - return cast(Callable[..., Any], wrapper) +class _SpanScope: + """Manages span lifecycle for async generators (start, attach context, record errors, detach, end).""" + + def __init__(self, tracer: Tracer, name: str, kind: SpanKind = SpanKind.CLIENT) -> None: + """Start a span and attach it as the current context. + + Args: + tracer: The OpenTelemetry tracer used to start the span. + name: The span name. + kind: The span kind; defaults to SpanKind.CLIENT. + """ + self.span = tracer.start_span(name, kind=kind) + ctx = opentelemetry_api_trace.set_span_in_context(self.span) + self._token = opentelemetry_context.attach(ctx) + + def record_error(self, exc: Exception) -> None: + """Record an exception on the span and set its status to ERROR. + + Args: + exc: The exception to record. + """ + try: + self.span.set_attribute(f"{Config.LIBRARY_NAME}.entity.error", str(exc)) + self.span.record_exception(exc) + self.span.set_status(StatusCode.ERROR, str(exc)) + except Exception as e: + logger.warning("Failed to record error on span: %s", e) + + def end(self) -> None: + """Detach the span from the active context and end it.""" + try: + opentelemetry_context.detach(self._token) + except Exception as e: + logger.warning("Failed to detach span context: %s", e) + try: + self.span.end() + except Exception as e: + logger.warning("Failed to end span: %s", e) + + +def base_agent_run_async_wrapper(tracer: Tracer) -> Callable[..., AsyncIterator[Any]]: + """Return a wrapt wrapper that creates an agent span around BaseAgent.run_async. + + Args: + tracer: The OpenTelemetry tracer used to create agent spans. + + Returns: + A wrapt-compatible wrapper function for BaseAgent.run_async. + """ + + def wrapper( + wrapped: Callable[..., AsyncIterator[Any]], instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: + """Wrap a single BaseAgent.run_async call, creating and closing an agent span. + + Args: + wrapped: The original BaseAgent.run_async coroutine. + instance: The BaseAgent instance being called. + args: Positional arguments passed to run_async. + kwargs: Keyword arguments passed to run_async. + + Returns: + An async generator that yields ADK events with agent span instrumentation. + """ -def base_llm_flow_call_llm_async_wrapper(tracer: Tracer) -> Callable[..., Any]: - def wrapper(wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: async def new_function() -> AsyncIterator[Any]: - model_name = "unknown" - llm_request = None - - if len(args) > 1: - llm_request = args[1] - if hasattr(llm_request, "model"): - model_name = llm_request.model - - span_name = f"ADK.LLM.{model_name}" - - with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span: - span.set_attribute(SpanAttributes.LLM_SYSTEM, "gcp.vertex.agent") - span.set_attribute("gen_ai.entity", "request") - span.set_attribute("netra.span.type", "generation") - - if llm_request: - llm_request_dict = _build_llm_request_for_trace(llm_request) - llm_attrs = _extract_llm_attributes(llm_request_dict, None) - for key, value in llm_attrs.items(): - span.set_attribute(key, value) + agent_name = getattr(instance, "name", "unknown") + try: + scope = _SpanScope(tracer, f"ADK.Agent.{agent_name}") + except Exception as e: + logger.warning("Failed to start agent span: %s", e) + async for event in wrapped(*args, **kwargs): + yield event + return + + try: + span = scope.span + span.set_attribute(SpanAttributes.LLM_SYSTEM, "adk") + span.set_attribute("gen_ai.entity", "agent") + span.set_attribute(NETRA_SPAN_TYPE, SpanType.AGENT) + span.set_attributes(extract_agent_attributes(instance)) - async_gen = wrapped(*args, **kwargs) - async for item in async_gen: - yield item + invocation_context = args[0] if args else None + if invocation_context: + if hasattr(invocation_context, "invocation_id"): + span.set_attribute("adk.invocation_id", invocation_context.invocation_id) + user_content = getattr(invocation_context, "user_content", None) + if user_content: + parts = getattr(user_content, "parts", []) or [] + user_texts = [ + str(getattr(p, "text", "")) for p in parts if getattr(p, "text", None) is not None + ] + if user_texts: + span.set_attribute("input", "\n".join(user_texts)) + except Exception as e: + logger.warning("Failed to set agent span attributes: %s", e) + + last_text_output: List[str] = [] + try: + async for event in wrapped(*args, **kwargs): + try: + parts = event.content.parts if event.content and event.content.parts else [] + texts = [str(getattr(p, "text", "")) for p in parts if getattr(p, "text", None) is not None] + if texts: + last_text_output = texts + except Exception as e: + logger.warning("Failed to extract agent event text: %s", e) + yield event + except Exception as e: + scope.record_error(e) + raise + finally: + if last_text_output: + try: + span.set_attribute("output", "\n".join(last_text_output)) + except Exception as e: + logger.warning("Failed to set agent output attribute: %s", e) + scope.end() return new_function() return cast(Callable[..., Any], wrapper) -def finalize_model_response_event_wrapper(tracer: Tracer) -> Callable[..., Any]: - def wrapper(wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: - result = wrapped(*args, **kwargs) +def base_llm_flow_call_llm_async_wrapper(tracer: Tracer) -> Callable[..., AsyncIterator[Any]]: + """Return a wrapt wrapper that creates an LLM generation span around BaseLlmFlow._call_llm_async. - llm_request = args[0] if len(args) > 0 else kwargs.get("llm_request") - llm_response = args[1] if len(args) > 1 else kwargs.get("llm_response") + Args: + tracer: The OpenTelemetry tracer used to create LLM generation spans. - current_span = opentelemetry_api_trace.get_current_span() - if current_span.is_recording() and llm_request and llm_response: - span_name = getattr(current_span, "name", "") - if "ADK.LLM" in span_name: - llm_request_dict = _build_llm_request_for_trace(llm_request) - llm_response_json = llm_response.model_dump_json(exclude_none=True) - llm_attrs = _extract_llm_attributes(llm_request_dict, llm_response_json) + Returns: + A wrapt-compatible wrapper function for BaseLlmFlow._call_llm_async. + """ - for key, value in llm_attrs.items(): - if "usage" in key or "completion" in key or "response" in key: - current_span.set_attribute(key, value) + def wrapper( + wrapped: Callable[..., AsyncIterator[Any]], _instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any] + ) -> Any: + """Wrap a single _call_llm_async call, creating and closing an LLM span. - return result - - return cast(Callable[..., Any], wrapper) + Args: + wrapped: The original _call_llm_async coroutine. + _instance: The BaseLlmFlow instance. + args: Positional arguments passed to _call_llm_async. + kwargs: Keyword arguments passed to _call_llm_async. + Returns: + An async generator that yields ADK events with LLM span instrumentation. + """ -def adk_trace_tool_call_wrapper(tracer: Tracer) -> Callable[..., Any]: - @wrapt.decorator # type: ignore[misc] - def wrapper(wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: - result = wrapped(*args, **kwargs) - - tool_args = args[0] if args else kwargs.get("args") - current_span = opentelemetry_api_trace.get_current_span() - if current_span.is_recording() and tool_args is not None: - current_span.set_attribute(SpanAttributes.LLM_SYSTEM, "gcp.vertex.agent") - current_span.set_attribute("gcp.vertex.agent.tool_call_args", str(tool_args)) - return result + async def new_function() -> AsyncIterator[Any]: + llm_request = args[1] if len(args) > 1 else None + model_name = getattr(llm_request, "model", "unknown") if llm_request else "unknown" + + try: + scope = _SpanScope(tracer, model_name) + except Exception as e: + logger.warning("Failed to start LLM span: %s", e) + async for item in wrapped(*args, **kwargs): + yield item + return - return cast(Callable[..., Any], wrapper) + try: + span = scope.span + span.set_attribute(SpanAttributes.LLM_SYSTEM, "gcp.vertex.agent") + span.set_attribute("gen_ai.entity", "request") + span.set_attribute(NETRA_SPAN_TYPE, SpanType.GENERATION) + if llm_request: + llm_request_dict = build_llm_request_for_trace(llm_request) + span.set_attributes(extract_llm_request_attributes(llm_request_dict)) + + except Exception as e: + logger.warning("Failed to set LLM span attributes: %s", e) + + accumulated_text: List[str] = [] + last_response = None + + # Peek-ahead buffer: hold back one item so the inner generator is + # fully exhausted — and the span closed — before the last item + # reaches the caller. Without this, the caller processes the last + # item (potentially launching sub-agent / tool spans) while the LLM + # span is still open, inflating its duration. + prev_item = None + span_ended = False + first_token_recorded = False + + try: + # Capture timestamp right before the LLM call to measure TIME_TO_FIRST_TOKEN accurately. + # ADK uses iterator start time as the closest approximation (span start introduces larger variance). + # GenAI uses span start time since delay to actual call is negligible. + # Keeps latency metrics reasonably aligned across both. + llm_call_start = time.time() + async for item in wrapped(*args, **kwargs): + if not first_token_recorded: + first_token_time = time.time() + record_span_timing(span, TIME_TO_FIRST_TOKEN, first_token_time, reference_time=llm_call_start) + record_span_timing(span, RELATIVE_TIME_TO_FIRST_TOKEN, first_token_time, use_root_span=True) + first_token_recorded = True + last_response = item + try: + parts = item.content.parts if item.content and item.content.parts else [] + for part in parts: + if (text := getattr(part, "text", None)) is not None: + accumulated_text.append(str(text)) + except Exception as e: + logger.warning("Failed to extract LLM response text: %s", e) + if prev_item is not None: + yield prev_item + prev_item = item + + # Inner generator is exhausted. Close the span now, before + # handing the last item to the caller. + if last_response is not None: + try: + response_attrs = extract_llm_response_attributes(last_response, accumulated_text) + span.set_attributes(response_attrs) + except Exception as e: + logger.warning("Failed to set LLM response attributes: %s", e) + + scope.end() + span_ended = True + + # Yield the last item after ending the span + if prev_item is not None: + yield prev_item + + except Exception as e: + if not span_ended: + scope.record_error(e) + scope.end() + raise -def adk_trace_tool_response_wrapper(tracer: Tracer) -> Callable[..., Any]: - @wrapt.decorator # type: ignore[misc] - def wrapper(wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: - result = wrapped(*args, **kwargs) - - invocation_context = args[0] if len(args) > 0 else kwargs.get("invocation_context") - event_id = args[1] if len(args) > 1 else kwargs.get("event_id") - function_response_event = args[2] if len(args) > 2 else kwargs.get("function_response_event") - - current_span = opentelemetry_api_trace.get_current_span() - if current_span.is_recording(): - current_span.set_attribute(SpanAttributes.LLM_SYSTEM, "gcp.vertex.agent") - if invocation_context: - current_span.set_attribute("gcp.vertex.agent.invocation_id", invocation_context.invocation_id) - if event_id: - current_span.set_attribute("gcp.vertex.agent.event_id", event_id) - if function_response_event: - current_span.set_attribute( - "gcp.vertex.agent.tool_response", function_response_event.model_dump_json(exclude_none=True) - ) - current_span.set_attribute("gcp.vertex.agent.llm_request", "{}") - current_span.set_attribute("gcp.vertex.agent.llm_response", "{}") - return result + return new_function() return cast(Callable[..., Any], wrapper) -def adk_trace_call_llm_wrapper(tracer: Tracer) -> Callable[..., Any]: - @wrapt.decorator # type: ignore[misc] - def wrapper(wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: - result = wrapped(*args, **kwargs) - - invocation_context = args[0] if len(args) > 0 else kwargs.get("invocation_context") - event_id = args[1] if len(args) > 1 else kwargs.get("event_id") - llm_request = args[2] if len(args) > 2 else kwargs.get("llm_request") - llm_response = args[3] if len(args) > 3 else kwargs.get("llm_response") - - current_span = opentelemetry_api_trace.get_current_span() - if current_span.is_recording(): - current_span.set_attribute(SpanAttributes.LLM_SYSTEM, "gcp.vertex.agent") - if llm_request: - current_span.set_attribute(SpanAttributes.LLM_REQUEST_MODEL, llm_request.model) - if invocation_context: - current_span.set_attribute("gcp.vertex.agent.invocation_id", invocation_context.invocation_id) - current_span.set_attribute("gcp.vertex.agent.session_id", invocation_context.session.id) - if event_id: - current_span.set_attribute("gcp.vertex.agent.event_id", event_id) - - if llm_request: - llm_request_dict = _build_llm_request_for_trace(llm_request) - current_span.set_attribute("gcp.vertex.agent.llm_request", json.dumps(llm_request_dict)) - - llm_response_json = None - if llm_response: - llm_response_json = llm_response.model_dump_json(exclude_none=True) - current_span.set_attribute("gcp.vertex.agent.llm_response", llm_response_json) +def call_tool_async_wrapper(tracer: Tracer) -> Callable[..., Any]: + """Return a wrapt wrapper that creates a tool span around __call_tool_async. - llm_attrs = _extract_llm_attributes(llm_request_dict, llm_response_json) - for key, value in llm_attrs.items(): - current_span.set_attribute(key, value) + Args: + tracer: The OpenTelemetry tracer used to create tool spans. - return result + Returns: + A wrapt-compatible wrapper function for __call_tool_async. + """ - return cast(Callable[..., Any], wrapper) + def wrapper(wrapped: Callable[..., Any], _instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + """Wrap a single __call_tool_async call, creating and closing a tool span. + Args: + wrapped: The original __call_tool_async coroutine. + _instance: The BaseLlmFlow instance. + args: Positional arguments passed to __call_tool_async. + kwargs: Keyword arguments passed to __call_tool_async. -def adk_trace_send_data_wrapper(tracer: Tracer) -> Callable[..., Any]: - @wrapt.decorator # type: ignore[misc] - def wrapper(wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: - result = wrapped(*args, **kwargs) - - invocation_context = args[0] if len(args) > 0 else kwargs.get("invocation_context") - event_id = args[1] if len(args) > 1 else kwargs.get("event_id") - data = args[2] if len(args) > 2 else kwargs.get("data") - - current_span = opentelemetry_api_trace.get_current_span() - if current_span.is_recording(): - if invocation_context: - current_span.set_attribute("gcp.vertex.agent.invocation_id", invocation_context.invocation_id) - if event_id: - current_span.set_attribute("gcp.vertex.agent.event_id", event_id) - if data: - from google.genai import types - - current_span.set_attribute( - "gcp.vertex.agent.data", - json.dumps( - [ - types.Content(role=content.role, parts=content.parts).model_dump(exclude_none=True) - for content in data - ] - ), - ) - return result + Returns: + A coroutine that awaits the tool call with tool span instrumentation. + """ - return cast(Callable[..., Any], wrapper) - - -def call_tool_async_wrapper(tracer: Tracer) -> Callable[..., Any]: - def wrapper(wrapped: Callable[..., Any], instance: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: async def new_function() -> Any: tool = args[0] if args else kwargs.get("tool") tool_args = args[1] if len(args) > 1 else kwargs.get("args", {}) tool_context = args[2] if len(args) > 2 else kwargs.get("tool_context") tool_name = getattr(tool, "name", "unknown_tool") - span_name = f"ADK.Tool.{tool_name}" - with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span: + try: + scope = _SpanScope(tracer, tool_name) + except Exception as e: + logger.warning("Failed to start tool span: %s", e) + return await wrapped(*args, **kwargs) + + try: + span = scope.span span.set_attribute(SpanAttributes.LLM_SYSTEM, "gcp.vertex.agent") span.set_attribute("gen_ai.entity", "tool") - span.set_attribute("netra.span.type", "tool") - + span.set_attribute(NETRA_SPAN_TYPE, SpanType.TOOL) span.set_attribute("gen_ai.tool.name", tool_name) + if tool is not None and hasattr(tool, "description"): - span.set_attribute("gen_ai.tool.description", getattr(tool, "description")) + span.set_attribute("gen_ai.tool.description", tool.description) if tool is not None and hasattr(tool, "is_long_running"): - span.set_attribute("gen_ai.tool.is_long_running", getattr(tool, "is_long_running")) - span.set_attribute("gen_ai.tool.parameters", str(tool_args)) - - if tool_context and hasattr(tool_context, "function_call_id"): - span.set_attribute("tool.call_id", tool_context.function_call_id) - if tool_context and hasattr(tool_context, "invocation_context"): - span.set_attribute("adk.invocation_id", tool_context.invocation_context.invocation_id) - + span.set_attribute("gen_ai.tool.is_long_running", tool.is_long_running) + + if tool_args is not None: + try: + span.set_attribute( + "input", json.dumps(tool_args) if isinstance(tool_args, dict) else str(tool_args) + ) + except Exception: + span.set_attribute("input", str(tool_args)) + + if tool_context: + if hasattr(tool_context, "function_call_id"): + span.set_attribute("tool.call_id", tool_context.function_call_id) + if hasattr(tool_context, "invocation_context"): + span.set_attribute("adk.invocation_id", tool_context.invocation_context.invocation_id) + except Exception as e: + logger.warning("Failed to set tool span attributes: %s", e) + + try: result = await wrapped(*args, **kwargs) - if result: - if isinstance(result, dict): - span.set_attribute("gen_ai.tool.result", json.dumps(result)) - else: - span.set_attribute("gen_ai.tool.result", str(result)) + if result is not None: + try: + span.set_attribute("output", json.dumps(result) if isinstance(result, dict) else str(result)) + except Exception: + span.set_attribute("output", str(result)) return result + except Exception as e: + scope.record_error(e) + raise + finally: + scope.end() return new_function() - return wrapper + return cast(Callable[..., Any], wrapper) diff --git a/netra/instrumentation/instruments.py b/netra/instrumentation/instruments.py index 9a694f7..cf99b26 100644 --- a/netra/instrumentation/instruments.py +++ b/netra/instrumentation/instruments.py @@ -17,7 +17,7 @@ class CustomInstruments(Enum): WEAVIATEDB = "weaviate_db" GOOGLE_GENERATIVEAI = "google_genai" FASTAPI = "fastapi" - ADK = "adk" + ADK = "google_adk" AIO_PIKA = "aio_pika" AIOHTTP_SERVER = "aiohttp_server" AIOKAFKA = "aiokafka" @@ -74,7 +74,7 @@ class CustomInstruments(Enum): CLAUDE_AGENT_SDK = "claude_agent_sdk" -class NetraInstruments(Enum): +class InstrumentSet(Enum): """Custom enum that stores the original enum class in an 'origin' attribute.""" def __new__(cls: Any, value: Any, origin: Any = None) -> Any: @@ -83,16 +83,147 @@ def __new__(cls: Any, value: Any, origin: Any = None) -> Any: member.origin = origin return member + ADK = ("google_adk", CustomInstruments) + AIOHTTP = ("aiohttp", CustomInstruments) + AIOHTTP_SERVER = ("aiohttp_server", CustomInstruments) + AIO_PIKA = ("aio_pika", CustomInstruments) + AIOKAFKA = ("aiokafka", CustomInstruments) + AIOPG = ("aiopg", CustomInstruments) + ALEPHALPHA = ("alephalpha", Instruments) + ANTHROPIC = ("anthropic", Instruments) + ASGI = ("asgi", CustomInstruments) + ASYNCCLICK = ("asyncclick", CustomInstruments) + ASYNCIO = ("asyncio", CustomInstruments) + ASYNCPG = ("asyncpg", CustomInstruments) + AWS_LAMBDA = ("aws_lambda", CustomInstruments) + BEDROCK = ("bedrock", Instruments) + BOTO = ("boto", CustomInstruments) + BOTO3SQS = ("boto3sqs", CustomInstruments) + BOTOCORE = ("botocore", CustomInstruments) + CARTESIA = ("cartesia", CustomInstruments) + CASSANDRA = ("cassandra", CustomInstruments) + CEREBRAS = ("cerebras", CustomInstruments) + CELERY = ("celery", CustomInstruments) + CHROMA = ("chroma", Instruments) + CLICK = ("click", CustomInstruments) + COHEREAI = ("cohere_ai", CustomInstruments) + CONFLUENT_KAFKA = ("confluent_kafka", CustomInstruments) + CREW = ("crew", Instruments) + DEEPGRAM = ("deepgram", CustomInstruments) + DBAPI = ("dbapi", CustomInstruments) + DJANGO = ("django", CustomInstruments) + DSPY = ("dspy", CustomInstruments) + ELASTICSEARCH = ("elasticsearch", CustomInstruments) + ELEVENLABS = ("elevenlabs", CustomInstruments) + FALCON = ("falcon", CustomInstruments) + FASTAPI = ("fastapi", CustomInstruments) + FLASK = ("flask", CustomInstruments) + GOOGLE_GENERATIVEAI = ("google_genai", CustomInstruments) + GROQ = ("groq", CustomInstruments) + GRPC = ("grpc", CustomInstruments) + HAYSTACK = ("haystack", Instruments) + HTTPX = ("httpx", CustomInstruments) + JINJA2 = ("jinja2", CustomInstruments) + KAFKA_PYTHON = ("kafka_python", CustomInstruments) + LANCEDB = ("lancedb", Instruments) + LANGCHAIN = ("langchain", Instruments) + LITELLM = ("litellm", CustomInstruments) + LLAMA_INDEX = ("llama_index", Instruments) + LOGGING = ("logging", CustomInstruments) + MARQO = ("marqo", Instruments) + MCP = ("mcp", Instruments) + MILVUS = ("milvus", Instruments) + MISTRALAI = ("mistral_ai", CustomInstruments) + MYSQL = ("mysql", CustomInstruments) + MYSQLCLIENT = ("mysqlclient", CustomInstruments) + OLLAMA = ("ollama", Instruments) + OPENAI = ("openai", CustomInstruments) + OPENAI_AGENTS = ("openai_agents", Instruments) + PIKA = ("pika", CustomInstruments) + PINECONE = ("pinecone", Instruments) + PSYCOPG = ("psycopg", CustomInstruments) + PSYCOPG2 = ("psycopg2", CustomInstruments) + PYDANTIC_AI = ("pydantic_ai", CustomInstruments) + PYMEMCACHE = ("pymemcache", CustomInstruments) + PYMONGO = ("pymongo", CustomInstruments) + PYMSSQL = ("pymssql", CustomInstruments) + PYMYSQL = ("pymysql", CustomInstruments) + PYRAMID = ("pyramid", CustomInstruments) + QDRANTDB = ("qdrant_db", CustomInstruments) + REDIS = ("redis", CustomInstruments) + REMOULADE = ("remoulade", CustomInstruments) + REPLICATE = ("replicate", Instruments) + REQUESTS = ("requests", CustomInstruments) + SAGEMAKER = ("sagemaker", Instruments) + SQLALCHEMY = ("sqlalchemy", CustomInstruments) + SQLITE3 = ("sqlite3", CustomInstruments) + STARLETTE = ("starlette", CustomInstruments) + SYSTEM_METRICS = ("system_metrics", CustomInstruments) + THREADING = ("threading", CustomInstruments) + TOGETHER = ("together", Instruments) + TORNADO = ("tornado", CustomInstruments) + TORTOISEORM = ("tortoiseorm", CustomInstruments) + TRANSFORMERS = ("transformers", Instruments) + URLLIB = ("urllib", CustomInstruments) + URLLIB3 = ("urllib3", CustomInstruments) + VERTEXAI = ("vertexai", Instruments) + WATSONX = ("watsonx", Instruments) + WEAVIATEDB = ("weaviate_db", CustomInstruments) + WRITER = ("writer", Instruments) + WSGI = ("wsgi", CustomInstruments) + -merged_members = {} +NetraInstruments = InstrumentSet -for member in Instruments: - merged_members[member.name] = (member.value, Instruments) -for member in CustomInstruments: - merged_members[member.name] = (member.value, CustomInstruments) +# Curated default instrument set used for root_instruments when the user does +# not pass an explicit value. Covers core LLM/AI providers and frameworks. +DEFAULT_INSTRUMENTS_FOR_ROOT = { + InstrumentSet.ANTHROPIC, + InstrumentSet.CARTESIA, + InstrumentSet.COHEREAI, + InstrumentSet.CREW, + InstrumentSet.DEEPGRAM, + InstrumentSet.ELEVENLABS, + InstrumentSet.GOOGLE_GENERATIVEAI, + InstrumentSet.ADK, + InstrumentSet.GROQ, + InstrumentSet.LANGCHAIN, + InstrumentSet.LITELLM, + InstrumentSet.CEREBRAS, + InstrumentSet.MISTRALAI, + InstrumentSet.OPENAI, + InstrumentSet.OLLAMA, + InstrumentSet.VERTEXAI, + InstrumentSet.LLAMA_INDEX, + InstrumentSet.PYDANTIC_AI, + InstrumentSet.DSPY, + InstrumentSet.HAYSTACK, + InstrumentSet.BEDROCK, + InstrumentSet.TOGETHER, + InstrumentSet.REPLICATE, + InstrumentSet.ALEPHALPHA, + InstrumentSet.WATSONX, +} -InstrumentSet = NetraInstruments("InstrumentSet", merged_members) +# Broader default instrument set used for the ``instruments`` parameter when +# the user does not pass an explicit value. Includes the root defaults plus +# common vector DBs, HTTP client/server, and database ORM/client libraries. +DEFAULT_INSTRUMENTS = DEFAULT_INSTRUMENTS_FOR_ROOT.union( + { + InstrumentSet.PINECONE, + InstrumentSet.CHROMA, + InstrumentSet.WEAVIATEDB, + InstrumentSet.QDRANTDB, + InstrumentSet.MILVUS, + InstrumentSet.LANCEDB, + InstrumentSet.MARQO, + InstrumentSet.PYMYSQL, + InstrumentSet.REQUESTS, + InstrumentSet.SQLALCHEMY, + InstrumentSet.HTTPX, + } +) ##################################################################################### @@ -100,7 +231,7 @@ def __new__(cls: Any, value: Any, origin: Any = None) -> Any: NetraInstruments follows the given structure. Refer this for usage within Netra SDK: class InstrumentSet(Enum): - ADK = "adk" + ADK = "google_adk" AIOHTTP = "aiohttp" AIO_PIKA = "aio_pika" AIOKAFKA = "aiokafka" diff --git a/netra/instrumentation/utils.py b/netra/instrumentation/utils.py index c96b895..1eeec68 100644 --- a/netra/instrumentation/utils.py +++ b/netra/instrumentation/utils.py @@ -46,10 +46,12 @@ def record_span_timing( attribute: str, event_time: Optional[float] = None, use_root_span: bool = False, + reference_time: Optional[float] = None, ) -> bool: """Compute elapsed time for an event and set it as a span attribute. Elapsed time is measured from: + - ``reference_time`` (seconds since epoch) if provided explicitly. - ``use_root_span=False`` (default): the start time of the given span. - ``use_root_span=True``: the start time of the root span of the given span. @@ -59,13 +61,21 @@ def record_span_timing( event_time: The event timestamp in seconds since epoch. Defaults to ``time.time()`` if not provided. use_root_span: If True, elapsed time is measured from the root span's - start time instead of the given span's start time. + start time instead of the given span's start time. Ignored when + ``reference_time`` is provided. + reference_time: Optional explicit reference timestamp in seconds since + epoch. When provided, elapsed is computed as + ``event_time - reference_time``, bypassing span start-time lookup. Returns: True if the timing attribute was successfully set, False if the elapsed time could not be computed (e.g. missing start time or root span). """ t = event_time if event_time is not None else time.time() + + if reference_time is not None: + return _safe_set_attribute(span, attribute, t - reference_time) + start_time = None if not use_root_span: diff --git a/netra/meter.py b/netra/meter.py index 7744cfe..50bf1e5 100644 --- a/netra/meter.py +++ b/netra/meter.py @@ -1,7 +1,7 @@ import json import logging import threading -from typing import Any, Optional +from typing import Any, List, Optional from google.protobuf.json_format import MessageToDict from opentelemetry import metrics @@ -21,6 +21,7 @@ MetricsData, PeriodicExportingMetricReader, ) +from opentelemetry.sdk.metrics.view import DropAggregation, View from opentelemetry.sdk.resources import Resource from netra.config import Config @@ -32,12 +33,20 @@ RESOURCE_ATTR_SERVICE_NAME = "service" RESOURCE_ATTR_DEPLOYMENT_ENVIRONMENT = "environment" +# Glob patterns for OTel SDK internal metrics emitted by TracerProvider / +# BatchSpanProcessor (e.g. otel.sdk.span.live, otel.sdk.span.started). +# These are suppressed via DropAggregation views when export_auto_metrics +# is disabled so only user-defined metrics reach the backend. +_OTEL_SDK_AUTO_METRIC_PATTERNS: List[str] = [ + "otel.sdk.*", +] + # Map every OTel instrument type to DELTA so the backend receives # incremental values on each export cycle, matching standard # observability platform behavior (Datadog, Prometheus pull model, etc.) # NOTE: Keys must be the SDK instrument classes (opentelemetry.sdk.metrics), # not the public API classes (opentelemetry.metrics). -_DELTA_TEMPORALITY: dict = { +_DELTA_TEMPORALITY: dict[type, AggregationTemporality] = { Counter: AggregationTemporality.DELTA, UpDownCounter: AggregationTemporality.DELTA, Histogram: AggregationTemporality.DELTA, @@ -47,7 +56,7 @@ } -class _JsonOTLPMetricExporter(OTLPMetricExporter): +class _JsonOTLPMetricExporter(OTLPMetricExporter): # type: ignore[misc] """Thin wrapper that sends OTLP metrics as JSON instead of protobuf. The upstream ``OTLPMetricExporter`` serialises to protobuf and sets @@ -127,6 +136,26 @@ def __init__(self, cfg: Config) -> None: self.cfg = cfg self._setup_meter() + def _build_views(self) -> List[View]: + """Build the list of metric Views for the MeterProvider. + + When ``export_auto_metrics`` is disabled, OTel SDK internal metrics + (e.g. ``otel.sdk.span.live``, ``otel.sdk.span.started``) are dropped + via ``DropAggregation`` so only user-defined metrics are exported. + + Returns: + A list of ``View`` instances to pass to the ``MeterProvider``. + """ + views: List[View] = [] + if not self.cfg.export_auto_metrics: + for pattern in _OTEL_SDK_AUTO_METRIC_PATTERNS: + views.append(View(instrument_name=pattern, aggregation=DropAggregation())) + logger.debug( + "Auto-metrics export disabled; dropping patterns: %s", + _OTEL_SDK_AUTO_METRIC_PATTERNS, + ) + return views + def _setup_meter(self) -> None: """Install a global MeterProvider with an OTLP exporter.""" if not self.cfg.otlp_endpoint: @@ -163,7 +192,13 @@ def _setup_meter(self) -> None: export_interval_millis=self.cfg.metrics_export_interval_ms, ) - provider = MeterProvider(resource=resource, metric_readers=[reader]) + views = self._build_views() + + provider = MeterProvider( + resource=resource, + metric_readers=[reader], + views=views, + ) metrics.set_meter_provider(provider) logger.info( diff --git a/netra/processors/__init__.py b/netra/processors/__init__.py index 2c8b8f3..8891e11 100644 --- a/netra/processors/__init__.py +++ b/netra/processors/__init__.py @@ -1,15 +1,19 @@ from netra.processors.instrumentation_span_processor import InstrumentationSpanProcessor from netra.processors.llm_trace_identifier_span_processor import LlmTraceIdentifierSpanProcessor from netra.processors.local_filtering_span_processor import LocalFilteringSpanProcessor +from netra.processors.root_instrument_filter_processor import RootInstrumentFilterProcessor from netra.processors.root_span_processor import RootSpanProcessor from netra.processors.scrubbing_span_processor import ScrubbingSpanProcessor from netra.processors.session_span_processor import SessionSpanProcessor +from netra.processors.span_io_processor import SpanIOProcessor __all__ = [ "SessionSpanProcessor", "InstrumentationSpanProcessor", + "SpanIOProcessor", "LlmTraceIdentifierSpanProcessor", "ScrubbingSpanProcessor", "LocalFilteringSpanProcessor", + "RootInstrumentFilterProcessor", "RootSpanProcessor", ] diff --git a/netra/processors/instrumentation_span_processor.py b/netra/processors/instrumentation_span_processor.py index 35e9c81..dd249fe 100644 --- a/netra/processors/instrumentation_span_processor.py +++ b/netra/processors/instrumentation_span_processor.py @@ -62,7 +62,7 @@ def _get_blocked_url_patterns() -> frozenset[str]: # Pre-computed allowed instrumentation names -_ALLOWED_INSTRUMENTATION_NAMES: Set[str] = {member.value for member in InstrumentSet} # type: ignore[attr-defined] +_ALLOWED_INSTRUMENTATION_NAMES: Set[str] = {member.value for member in InstrumentSet} class InstrumentationSpanProcessor(SpanProcessor): # type: ignore[misc] @@ -115,10 +115,7 @@ def _wrap_set_attribute(self, span: Span) -> None: span: The span whose set_attribute method will be wrapped. """ original_set_attribute: SetAttributeFunc = span.set_attribute - instrumentation_name = self._extract_instrumentation_name(span) - # is_httpx = self._is_httpx_instrumentation(instrumentation_name) - - # if is_httpx: + self._extract_instrumentation_name(span) self._check_and_mark_blocked_url(span, original_set_attribute) def wrapped_set_attribute(key: str, value: Any) -> None: diff --git a/netra/processors/local_filtering_span_processor.py b/netra/processors/local_filtering_span_processor.py index 89f3d4c..dd832e3 100644 --- a/netra/processors/local_filtering_span_processor.py +++ b/netra/processors/local_filtering_span_processor.py @@ -1,7 +1,8 @@ import json import logging +import threading from contextlib import contextmanager -from typing import List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence from opentelemetry import baggage from opentelemetry import context as otel_context @@ -15,9 +16,43 @@ # Attribute key to copy resolved local blocked patterns onto each span _LOCAL_BLOCKED_SPANS_ATTR_KEY = "netra.local_blocked_spans" -# Registry of locally blocked spans: span_id -> parent_context -# This lets exporters reparent children reliably even when children export before parents -BLOCKED_LOCAL_PARENT_MAP: dict[object, object] = {} +# Registry of locally blocked spans: span_id -> parent_context. +# This lets exporters reparent children reliably even when children export +# before parents. All access must go through the accessor functions below +# to ensure thread-safety. +_blocked_local_parent_map: Dict[Any, Any] = {} +_blocked_local_parent_lock = threading.Lock() + + +def blocked_local_parent_map_put(span_id: Any, parent_context: Any) -> None: + """Register a locally-blocked span's parent context. + + Args: + span_id: The span ID of the blocked span. + parent_context: The parent ``SpanContext`` to reparent children to. + """ + with _blocked_local_parent_lock: + _blocked_local_parent_map[span_id] = parent_context + + +def blocked_local_parent_map_pop(span_id: Any) -> None: + """Remove a span entry from the blocked-parent registry. + + Args: + span_id: The span ID to remove. + """ + with _blocked_local_parent_lock: + _blocked_local_parent_map.pop(span_id, None) + + +def blocked_local_parent_map_snapshot() -> Dict[Any, Any]: + """Return a shallow copy of the blocked-parent registry. + + Returns: + A dict copy safe to iterate without holding the lock. + """ + with _blocked_local_parent_lock: + return dict(_blocked_local_parent_map) class LocalFilteringSpanProcessor(SpanProcessor): # type: ignore[misc] @@ -62,7 +97,7 @@ def on_start(self, span: trace.Span, parent_context: Optional[otel_context.Conte parent_span.get_span_context() if hasattr(parent_span, "get_span_context") else None ) if span_id is not None and parent_span_context is not None: - BLOCKED_LOCAL_PARENT_MAP[span_id] = parent_span_context + blocked_local_parent_map_put(span_id, parent_span_context) # Mark on the span for visibility/debugging try: span.set_attribute("netra.local_blocked", True) @@ -87,7 +122,7 @@ def on_end(self, span: trace.Span) -> None: # noqa: D401 ctx = getattr(span, "context", None) span_id = getattr(ctx, "span_id", None) if ctx else None if span_id is not None: - BLOCKED_LOCAL_PARENT_MAP.pop(span_id, None) + blocked_local_parent_map_pop(span_id) except Exception: pass return diff --git a/netra/processors/root_instrument_filter_processor.py b/netra/processors/root_instrument_filter_processor.py new file mode 100644 index 0000000..67cad5f --- /dev/null +++ b/netra/processors/root_instrument_filter_processor.py @@ -0,0 +1,236 @@ +import logging +import threading +from typing import Optional, Set, cast + +from opentelemetry import context as otel_context +from opentelemetry import trace +from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor +from opentelemetry.trace import INVALID_SPAN_ID + +logger = logging.getLogger(__name__) + +# Attribute written on blocked spans so that the FilteringSpanExporter drops them. +_LOCAL_BLOCKED_ATTR = "netra.local_blocked" + +# Scope-name prefixes that identify auto-instrumentation libraries. +_INSTRUMENTATION_PREFIXES = ("opentelemetry.instrumentation.", "netra.instrumentation.") + + +class RootInstrumentFilterProcessor(SpanProcessor): # type: ignore[misc] + """Blocks root spans (and their entire subtree) from instrumentations not in + the allowed *root_instruments* set. + + The set stores the **instrumentation name values** (e.g. ``"openai"``, + ``"adk"``, ``"google_genai"``) that are permitted to create root-level spans. + Any root span whose instrumentation name is *not* in this set is marked with + ``netra.local_blocked = True`` and its ``span_id`` is recorded. Child spans + whose parent ``span_id`` appears in the blocked registry inherit the block. + + Args: + allowed_root_instrument_names: Set of instrumentation-name strings + (matching ``InstrumentSet`` member values) that are allowed to + produce root spans. + """ + + def __init__(self, allowed_root_instrument_names: Set[str]) -> None: + """ + Initialize the processor with a set of allowed root instrument names. + + Args: + allowed_root_instrument_names: Set of instrumentation-name strings + (matching ``InstrumentSet`` member values) that are allowed to + produce root spans. + """ + self._allowed: frozenset[str] = frozenset(allowed_root_instrument_names) + # span_id -> True for every span that belongs to a blocked root tree. + self._blocked_span_ids: dict[int, bool] = {} + self._lock = threading.Lock() + + def on_start( + self, + span: Span, + parent_context: Optional[otel_context.Context] = None, + ) -> None: + """ + Called when a span is started. + + Args: + span: The span that is being started. + parent_context: The parent context of the span. + """ + try: + self._process_span_start(span, parent_context) + except Exception: + logger.debug("RootInstrumentFilterProcessor.on_start failed", exc_info=True) + + def on_end(self, span: ReadableSpan) -> None: + """ + Called when a span is ended. + + Args: + span: The span that is being ended. + """ + try: + span_id = self._get_span_id(span) + if span_id is not None: + with self._lock: + self._blocked_span_ids.pop(span_id, None) + except Exception: + pass + + def shutdown(self) -> None: + """ + Called when the processor is shut down. + """ + with self._lock: + self._blocked_span_ids.clear() + + def force_flush(self, timeout_millis: int = 30000) -> bool: + """ + Called when the processor is forced to flush. + + Args: + timeout_millis: The timeout in milliseconds. + + Returns: + True if the flush was successful, False otherwise. + """ + return True + + def _process_span_start( + self, + span: Span, + parent_context: Optional[otel_context.Context], + ) -> None: + """ + Processes the start of a span. + + Args: + span: The span that is being started. + parent_context: The parent context of the span. + """ + parent_span_id = self._resolve_parent_span_id(parent_context) + + if parent_span_id is not None and parent_span_id != INVALID_SPAN_ID: + # This is a child span – inherit blocked status from parent. + with self._lock: + if parent_span_id in self._blocked_span_ids: + own_id = self._get_span_id(span) + if own_id is not None: + self._blocked_span_ids[own_id] = True + self._mark_blocked(span) + return + + # Root span – only apply the allow-list to auto-instrumentation spans. + # Spans created directly through netra (decorators / Netra.start_span) + # use arbitrary tracer names and must never be blocked. + if not self._is_from_instrumentation_library(span): + return + + instr_name = self._extract_instrumentation_name(span) + if instr_name is not None and instr_name not in self._allowed: + own_id = self._get_span_id(span) + if own_id is not None: + with self._lock: + self._blocked_span_ids[own_id] = True + self._mark_blocked(span) + + @staticmethod + def _resolve_parent_span_id( + parent_context: Optional[otel_context.Context], + ) -> Optional[int]: + """ + Return the parent span's ``span_id`` from the supplied context, or ``None``. + + Args: + parent_context: The parent context of the span. + + Returns: + The parent span's ``span_id`` or ``None``. + """ + if parent_context is None: + return None + parent_span = trace.get_current_span(parent_context) + if parent_span is None: + return None + sc = parent_span.get_span_context() + if sc is None: + return None + return cast(Optional[int], sc.span_id) + + @staticmethod + def _get_span_id(span: object) -> Optional[int]: + """ + Get the span ID from the span. + + Args: + span: The span to get the ID from. + + Returns: + The span ID or None. + """ + ctx = getattr(span, "context", None) or getattr(span, "get_span_context", lambda: None)() + if ctx is None: + return None + return cast(Optional[int], getattr(ctx, "span_id", None)) + + @staticmethod + def _mark_blocked(span: Span) -> None: + """ + Mark the span as blocked. + + Args: + span: The span to mark as blocked. + """ + try: + span.set_attribute(_LOCAL_BLOCKED_ATTR, True) + except Exception: + pass + + @staticmethod + def _is_from_instrumentation_library(span: Span) -> bool: + """Return ``True`` if the span originates from a known auto-instrumentation library. + + Spans created by netra decorators or ``Netra.start_span`` use arbitrary + tracer names that do not match the instrumentation naming convention and + will return ``False``. + + Args: + span: The span to check. + + Returns: + ``True`` when the span's instrumentation scope starts with a known + instrumentation prefix, ``False`` otherwise. + """ + scope = getattr(span, "instrumentation_scope", None) + if scope is None: + return False + name = getattr(scope, "name", None) + if not isinstance(name, str) or not name: + return False + return name.startswith(_INSTRUMENTATION_PREFIXES) + + @staticmethod + def _extract_instrumentation_name(span: Span) -> Optional[str]: + """ + Extract the short instrumentation name from the span's scope. + + Mirrors the logic in ``InstrumentationSpanProcessor._extract_instrumentation_name``. + + Args: + span: The span to extract the instrumentation name from. + + Returns: + The instrumentation name or None. + """ + scope = getattr(span, "instrumentation_scope", None) + if scope is None: + return None + name = getattr(scope, "name", None) + if not isinstance(name, str) or not name: + return None + for prefix in _INSTRUMENTATION_PREFIXES: + if name.startswith(prefix): + base = name.rsplit(".", 1)[-1].strip() + return base if base else name + return name diff --git a/netra/processors/span_io_processor.py b/netra/processors/span_io_processor.py new file mode 100644 index 0000000..c192825 --- /dev/null +++ b/netra/processors/span_io_processor.py @@ -0,0 +1,195 @@ +import json +import logging +import re +from typing import Any, Callable, Dict, Optional + +from opentelemetry import context as otel_context +from opentelemetry.sdk.trace import ReadableSpan, Span, SpanProcessor + +logger = logging.getLogger(__name__) + +# Patterns for gen_ai indexed attributes +_PROMPT_RE = re.compile(r"^gen_ai\.prompts?\.(\d+)\.(role|content)$") +_COMPLETION_RE = re.compile(r"^gen_ai\.completions?\.(\d+)\.(role|content)$") + +_TRACELOOP_PREFIX = "traceloop." +_NETRA_PREFIX = "netra." + +SetAttributeFunc = Callable[[str, Any], None] + + +def _build_messages(index_map: Dict[int, Dict[str, str]]) -> str: + """Serialize an index→message dict to a JSON array ordered by index. + + Args: + index_map: Mapping of integer index to partial message dict. + + Returns: + JSON string of the ordered message list. + """ + return json.dumps([index_map[i] for i in sorted(index_map)]) + + +def _extract_traceloop_input(raw: Any) -> str: + """Extract the ``inputs`` payload from a traceloop entity input value. + + Traceloop serialises entity inputs as: + '{"inputs": {...}, "tags": [...], "metadata": {...}, "kwargs": {...}}' + + We surface only the ``inputs`` dict as the canonical ``input`` attribute. + If parsing fails the raw value is returned as-is. + + Args: + raw: The raw attribute value (expected to be a JSON string). + + Returns: + Serialized string of the inputs payload. + """ + try: + parsed = json.loads(raw) if isinstance(raw, str) else raw + payload = parsed.get("inputs", parsed) + return json.dumps(payload) if not isinstance(payload, str) else payload + except Exception: + return str(raw) + + +def _extract_traceloop_output(raw: Any) -> str: + """Extract the ``outputs`` payload from a traceloop entity output value. + + Traceloop serialises entity outputs as: + '{"outputs": {...}, "kwargs": {...}}' + + We surface only the ``outputs`` value as the canonical ``output`` attribute. + If parsing fails the raw value is returned as-is. + + Args: + raw: The raw attribute value (expected to be a JSON string). + + Returns: + Serialized string of the outputs payload. + """ + try: + parsed = json.loads(raw) if isinstance(raw, str) else raw + payload = parsed.get("outputs", parsed) + + return json.dumps(payload) if not isinstance(payload, str) else payload + except Exception: + return str(raw) + + +class SpanIOProcessor(SpanProcessor): # type: ignore[misc] + """Normalises ``input`` / ``output`` attributes and remaps ``traceloop.*`` + keys to ``netra.*`` on all spans. + + All interception is done in ``on_start`` via a per-span closure that wraps + ``span.set_attribute``, following the same pattern as + ``InstrumentationSpanProcessor``. + """ + + def on_start( + self, + span: Span, + parent_context: Optional[otel_context.Context] = None, + ) -> None: + """Wrap the span's ``set_attribute`` to intercept and normalise writes. + + Args: + span: The span that was started. + parent_context: The parent context (unused). + """ + try: + attrs = span.attributes or {} + if "input" not in attrs: + span.set_attribute("input", "") + if "output" not in attrs: + span.set_attribute("output", "") + self._wrap_set_attribute(span) + except Exception: + logger.exception("SpanIOProcessor.on_start failed") + + def on_end(self, span: ReadableSpan) -> None: + """No-op. All attribute normalisation is applied eagerly via the set_attribute wrapper installed in on_start. + + Args: + span: The span that has ended. + """ + + def force_flush(self, timeout_millis: int = 30000) -> bool: + """No-op flush. + + Args: + timeout_millis: Maximum time to wait (unused). + + Returns: + Always True. + """ + return True + + def shutdown(self) -> None: + """No-op shutdown.""" + + @staticmethod + def _wrap_set_attribute(span: Span) -> None: + """Replace ``span.set_attribute`` with a normalising closure. + + Per-span accumulators for gen_ai prompts/completions are closure-scoped + so each span owns its own independent state. + + Args: + span: The span whose ``set_attribute`` will be replaced. + """ + original: SetAttributeFunc = span.set_attribute + + # Per-span accumulators for gen_ai indexed attributes + prompts: Dict[int, Dict[str, str]] = {} + completions: Dict[int, Dict[str, str]] = {} + + def patched_set_attribute(key: str, value: Any) -> None: # noqa: C901 + try: + # 1. gen_ai.prompts.* / gen_ai.prompt.* → keep original + update input + prompt_match = _PROMPT_RE.match(key) + if prompt_match: + original(key, value) + idx = int(prompt_match.group(1)) + field = prompt_match.group(2) + prompts.setdefault(idx, {})[field] = str(value) + original("input", _build_messages(prompts)) + return + + # 2. gen_ai.completions.* / gen_ai.completion.* → keep original + update output + completion_match = _COMPLETION_RE.match(key) + if completion_match: + original(key, value) + idx = int(completion_match.group(1)) + field = completion_match.group(2) + completions.setdefault(idx, {})[field] = str(value) + original("output", _build_messages(completions)) + return + + # 3. traceloop.entity.input → input (no traceloop key written) + if key == "traceloop.entity.input": + original("input", _extract_traceloop_input(value)) + return + + # 4. traceloop.entity.output → output (no traceloop key written) + if key == "traceloop.entity.output": + original("output", _extract_traceloop_output(value)) + return + + # 5. Other traceloop.* → netra.* (no traceloop key written) + if key.startswith(_TRACELOOP_PREFIX): + new_key = _NETRA_PREFIX + key[len(_TRACELOOP_PREFIX) :] + original(new_key, value) + return + + # 6. Everything else — pass through unchanged + original(key, value) + + except Exception: + logger.debug("SpanIOProcessor: error processing key=%s", key, exc_info=True) + try: + original(key, value) + except Exception: + logger.debug("SpanIOProcessor: error calling set_attribute key=%s", key, exc_info=True) + + setattr(span, "set_attribute", patched_set_attribute) diff --git a/netra/prompts/client.py b/netra/prompts/client.py index f23de36..ff3c8be 100644 --- a/netra/prompts/client.py +++ b/netra/prompts/client.py @@ -1,120 +1,48 @@ import logging -import os -from typing import Any, Dict, Optional - -import httpx +from typing import Any, Dict +from netra.client import BaseNetraClient from netra.config import Config logger = logging.getLogger(__name__) +_LOG_PREFIX = "netra.prompts" + -class PromptsHttpClient: - """ - Internal HTTP client for prompts APIs. - """ +class PromptsHttpClient(BaseNetraClient): + """Internal HTTP client for prompts APIs.""" def __init__(self, config: Config) -> None: """ Initialize the prompts HTTP client. Args: - config: Configuration object containing API key and base URL - """ - self._client: Optional[httpx.Client] = self._create_client(config) - - def _create_client(self, config: Config) -> Optional[httpx.Client]: - """ - Create and configure the HTTP client. - - Args: - config: Configuration object containing API key and base URL - - Returns: - Configured HTTP client or None if initialization fails - """ - endpoint = (config.otlp_endpoint or "").strip() - if not endpoint: - logger.error("netra.prompts: NETRA_OTLP_ENDPOINT is required for prompts APIs") - return None - - base_url = self._resolve_base_url(endpoint) - headers = self._build_headers(config) - timeout = self._get_timeout() - - try: - return httpx.Client(base_url=base_url, headers=headers, timeout=timeout) - except Exception as exc: - logger.error("netra.prompts: Failed to initialize prompts HTTP client: %s", exc) - return None - - def _resolve_base_url(self, endpoint: str) -> str: - """ - Resolve the base URL by removing /telemetry suffix if present. - - Args: - endpoint: The endpoint URL - - Returns: - Resolved base URL - """ - base_url = endpoint.rstrip("/") - if base_url.endswith("/telemetry"): - base_url = base_url[: -len("/telemetry")] - return base_url - - def _build_headers(self, config: Config) -> Dict[str, str]: - """ - Build HTTP headers for API requests. - - Args: - config: Configuration object containing API key and base URL - - Returns: - Dictionary of HTTP headers + config: Configuration object containing API key and base URL. """ - headers: Dict[str, str] = dict(config.headers or {}) - api_key = config.api_key - if api_key: - headers["x-api-key"] = api_key - return headers - - def _get_timeout(self) -> float: - """ - Get the timeout value from environment variable or use default. - - Returns: - Timeout value in seconds - """ - timeout_env = os.getenv("NETRA_PROMPTS_TIMEOUT") - if not timeout_env: - return 10.0 - try: - return float(timeout_env) - except ValueError: - logger.warning( - "netra.prompts: Invalid NETRA_PROMPTS_TIMEOUT value '%s', using default 10.0", - timeout_env, - ) - return 10.0 + super().__init__( + config, + log_prefix=_LOG_PREFIX, + timeout_env_var="NETRA_PROMPTS_TIMEOUT", + ) def get_prompt_version(self, prompt_name: str, label: str) -> Any: """ Fetch a prompt version by name and label. Args: - prompt_name: Name of the prompt - label: Label of the prompt version + prompt_name: Name of the prompt. + label: Label of the prompt version. Returns: - Prompt version data or empty dict if not found + Prompt version data or None if not found. """ if not self._client: logger.error( - "netra.prompts: Prompts client is not initialized; cannot fetch prompt version for '%s'", + "%s: Client is not initialized; cannot fetch prompt version for '%s'", + _LOG_PREFIX, prompt_name, ) - return {} + return None try: url = "/sdk/prompts/version" @@ -127,9 +55,10 @@ def get_prompt_version(self, prompt_name: str, label: str) -> Any: return data except Exception as exc: logger.error( - "netra.prompts: Failed to fetch prompt version for '%s' (label=%s): %s", + "%s: Failed to fetch prompt version for '%s' (label=%s): %s", + _LOG_PREFIX, prompt_name, label, - exc, + self._extract_error_message(exc), ) - return {} + return None diff --git a/netra/session_manager.py b/netra/session_manager.py index 9d19440..d1b8766 100644 --- a/netra/session_manager.py +++ b/netra/session_manager.py @@ -1,4 +1,6 @@ +import json import logging +import threading from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional, Union @@ -19,7 +21,14 @@ class ConversationType(str, Enum): class SessionManager: - """Manages session and user context for applications.""" + """Manages session and user context for applications. + + All mutable class-level state is protected by ``_lock`` so that + concurrent threads (or ``asyncio.to_thread`` calls) cannot corrupt + the internal stacks and registries. + """ + + _lock = threading.Lock() # Class variable to track the current span _current_span: Optional[trace.Span] = None @@ -45,7 +54,8 @@ def set_current_span(cls, span: Optional[trace.Span]) -> None: Args: span: The current span to store """ - cls._current_span = span + with cls._lock: + cls._current_span = span @classmethod def get_current_span(cls) -> Optional[trace.Span]: @@ -55,7 +65,8 @@ def get_current_span(cls) -> Optional[trace.Span]: Returns: The stored current span or None if not set """ - return cls._current_span + with cls._lock: + return cls._current_span @classmethod def register_span(cls, name: str, span: trace.Span) -> None: @@ -67,13 +78,13 @@ def register_span(cls, name: str, span: trace.Span) -> None: span: The span to register """ try: - stack = cls._spans_by_name.get(name) - if stack is None: - cls._spans_by_name[name] = [span] - else: - stack.append(span) - # Track globally as active - cls._active_spans.append(span) + with cls._lock: + stack = cls._spans_by_name.get(name) + if stack is None: + cls._spans_by_name[name] = [span] + else: + stack.append(span) + cls._active_spans.append(span) except Exception: logger.exception("Failed to register span '%s'", name) @@ -87,24 +98,37 @@ def unregister_span(cls, name: str, span: trace.Span) -> None: span: The span to unregister """ try: - stack = cls._spans_by_name.get(name) - if not stack: - return - # Remove the last matching instance (normal case) - for i in range(len(stack) - 1, -1, -1): - if stack[i] is span: - stack.pop(i) - break - if not stack: - cls._spans_by_name.pop(name, None) - # Also remove from global active list (remove last matching instance) - for i in range(len(cls._active_spans) - 1, -1, -1): - if cls._active_spans[i] is span: - cls._active_spans.pop(i) - break + with cls._lock: + stack = cls._spans_by_name.get(name) + if not stack: + return + for i in range(len(stack) - 1, -1, -1): + if stack[i] is span: + stack.pop(i) + break + if not stack: + cls._spans_by_name.pop(name, None) + for i in range(len(cls._active_spans) - 1, -1, -1): + if cls._active_spans[i] is span: + cls._active_spans.pop(i) + break except Exception: logger.exception("Failed to unregister span '%s'", name) + @classmethod + def get_trace_id(cls) -> Optional[str]: + """ + Return the trace ID of the currently active span. + + Returns: + str: 32-character lowercase hex trace ID, or None if no active span exists. + """ + span = trace.get_current_span() + ctx = span.get_span_context() + if ctx.is_valid: + return format(ctx.trace_id, "032x") + return None + @classmethod def get_span_by_name(cls, name: str) -> Optional[trace.Span]: """ @@ -116,10 +140,11 @@ def get_span_by_name(cls, name: str) -> Optional[trace.Span]: Returns: The most recently registered span with the given name, or None if not found """ - stack = cls._spans_by_name.get(name) - if stack: - return stack[-1] - return None + with cls._lock: + stack = cls._spans_by_name.get(name) + if stack: + return stack[-1] + return None @classmethod def push_entity(cls, entity_type: str, entity_name: str) -> None: @@ -130,14 +155,15 @@ def push_entity(cls, entity_type: str, entity_name: str) -> None: entity_type: Type of entity (workflow, task, agent, span) entity_name: Name of the entity """ - if entity_type == "workflow": - cls._workflow_stack.append(entity_name) - elif entity_type == "task": - cls._task_stack.append(entity_name) - elif entity_type == "agent": - cls._agent_stack.append(entity_name) - elif entity_type == "span": - cls._span_stack.append(entity_name) + with cls._lock: + if entity_type == "workflow": + cls._workflow_stack.append(entity_name) + elif entity_type == "task": + cls._task_stack.append(entity_name) + elif entity_type == "agent": + cls._agent_stack.append(entity_name) + elif entity_type == "span": + cls._span_stack.append(entity_name) @classmethod def pop_entity(cls, entity_type: str) -> Optional[str]: @@ -150,15 +176,16 @@ def pop_entity(cls, entity_type: str) -> Optional[str]: Returns: Entity name or None if stack is empty """ - if entity_type == "workflow" and cls._workflow_stack: - return cls._workflow_stack.pop() - elif entity_type == "task" and cls._task_stack: - return cls._task_stack.pop() - elif entity_type == "agent" and cls._agent_stack: - return cls._agent_stack.pop() - elif entity_type == "span" and cls._span_stack: - return cls._span_stack.pop() - return None + with cls._lock: + if entity_type == "workflow" and cls._workflow_stack: + return cls._workflow_stack.pop() + elif entity_type == "task" and cls._task_stack: + return cls._task_stack.pop() + elif entity_type == "agent" and cls._agent_stack: + return cls._agent_stack.pop() + elif entity_type == "span" and cls._span_stack: + return cls._span_stack.pop() + return None @classmethod def get_current_entity_attributes(cls) -> Dict[str, str]: @@ -168,33 +195,31 @@ def get_current_entity_attributes(cls) -> Dict[str, str]: Returns: Dictionary of entity attributes to add to spans """ - attributes = {} + with cls._lock: + attributes = {} - # Add current workflow if exists - if cls._workflow_stack: - attributes[f"{Config.LIBRARY_NAME}.workflow.name"] = cls._workflow_stack[-1] + if cls._workflow_stack: + attributes[f"{Config.LIBRARY_NAME}.workflow.name"] = cls._workflow_stack[-1] - # Add current task if exists - if cls._task_stack: - attributes[f"{Config.LIBRARY_NAME}.task.name"] = cls._task_stack[-1] + if cls._task_stack: + attributes[f"{Config.LIBRARY_NAME}.task.name"] = cls._task_stack[-1] - # Add current agent if exists - if cls._agent_stack: - attributes[f"{Config.LIBRARY_NAME}.agent.name"] = cls._agent_stack[-1] + if cls._agent_stack: + attributes[f"{Config.LIBRARY_NAME}.agent.name"] = cls._agent_stack[-1] - # Add current span if exists - if cls._span_stack: - attributes[f"{Config.LIBRARY_NAME}.span.name"] = cls._span_stack[-1] + if cls._span_stack: + attributes[f"{Config.LIBRARY_NAME}.span.name"] = cls._span_stack[-1] - return attributes + return attributes @classmethod def clear_entity_stacks(cls) -> None: """Clear all entity stacks.""" - cls._workflow_stack.clear() - cls._task_stack.clear() - cls._agent_stack.clear() - cls._span_stack.clear() + with cls._lock: + cls._workflow_stack.clear() + cls._task_stack.clear() + cls._agent_stack.clear() + cls._span_stack.clear() @classmethod def get_stack_info(cls) -> Dict[str, List[str]]: @@ -204,12 +229,13 @@ def get_stack_info(cls) -> Dict[str, List[str]]: Returns: Dictionary containing all stack contents """ - return { - "workflows": cls._workflow_stack.copy(), - "tasks": cls._task_stack.copy(), - "agents": cls._agent_stack.copy(), - "spans": cls._span_stack.copy(), - } + with cls._lock: + return { + "workflows": cls._workflow_stack.copy(), + "tasks": cls._task_stack.copy(), + "agents": cls._agent_stack.copy(), + "spans": cls._span_stack.copy(), + } @staticmethod def set_session_context( @@ -303,13 +329,15 @@ def add_conversation(cls, conversation_type: ConversationType, role: str, conten span = trace.get_current_span() if not (span and getattr(span, "is_recording", lambda: False)()): # Fallback: use the most recent active span from SessionManager - if not cls._active_spans: + with cls._lock: + active_snapshot = list(cls._active_spans) + + if not active_snapshot: logger.warning("No active span to add conversation attribute.") return - # Find the most recent *recording* span (the last item can be a finished span) recording_span: Optional[trace.Span] = None - for span in reversed(cls._active_spans): + for span in reversed(active_snapshot): try: if span and getattr(span, "is_recording", lambda: False)(): recording_span = span @@ -372,6 +400,111 @@ def add_conversation(cls, conversation_type: ConversationType, role: str, conten except Exception as e: logger.exception("Failed to add conversation attribute: %s", e) + @classmethod + def set_input(cls, value: Any) -> None: + """Set the ``input`` attribute on the current active span. + + Accepts any value. Dicts and lists are JSON-serialised; primitives are + converted with ``str()``. The result is truncated to + ``Config.ATTRIBUTE_MAX_LEN`` characters. + + Args: + value: The input value to record. + """ + try: + if isinstance(value, (dict, list)): + serialized = json.dumps(value, default=str)[: Config.ATTRIBUTE_MAX_LEN] + else: + serialized = str(value)[: Config.ATTRIBUTE_MAX_LEN] + cls.set_attribute_on_active_span("input", serialized) + except Exception: + logger.exception("SessionManager.set_input: failed to set input attribute") + + @classmethod + def set_output(cls, value: Any) -> None: + """Set the ``output`` attribute on the current active span. + + Accepts any value. Dicts and lists are JSON-serialised; primitives are + converted with ``str()``. The result is truncated to + ``Config.ATTRIBUTE_MAX_LEN`` characters. + + Args: + value: The output value to record. + """ + try: + if isinstance(value, (dict, list)): + serialized = json.dumps(value, default=str)[: Config.ATTRIBUTE_MAX_LEN] + else: + serialized = str(value)[: Config.ATTRIBUTE_MAX_LEN] + cls.set_attribute_on_active_span("output", serialized) + except Exception: + logger.exception("SessionManager.set_output: failed to set output attribute") + + @classmethod + def set_root_input(cls, value: Any) -> None: + """Set the ``input`` attribute on the root span of the current trace. + + The root span is the oldest span registered via :meth:`register_span`. + If no such span exists, falls back to the current active OTel span. + + Args: + value: The input value to record. + """ + try: + if isinstance(value, (dict, list)): + serialized = json.dumps(value, default=str)[: Config.ATTRIBUTE_MAX_LEN] + else: + serialized = str(value)[: Config.ATTRIBUTE_MAX_LEN] + cls.set_attribute_on_root_span("input", serialized) + except Exception: + logger.exception("SessionManager.set_root_input: failed to set input attribute") + + @classmethod + def set_root_output(cls, value: Any) -> None: + """Set the ``output`` attribute on the root span of the current trace. + + The root span is the oldest span registered via :meth:`register_span`. + If no such span exists, falls back to the current active OTel span. + + Args: + value: The output value to record. + """ + try: + if isinstance(value, (dict, list)): + serialized = json.dumps(value, default=str)[: Config.ATTRIBUTE_MAX_LEN] + else: + serialized = str(value)[: Config.ATTRIBUTE_MAX_LEN] + cls.set_attribute_on_root_span("output", serialized) + except Exception: + logger.exception("SessionManager.set_root_output: failed to set output attribute") + + @classmethod + def set_attribute_on_root_span(cls, attr_key: str, attr_value: Any) -> None: + """Set an attribute on the root span of the current trace. + + + Args: + attr_key: Key for the attribute to set + attr_value: Value for the attribute to set + """ + try: + from netra.processors.root_span_processor import RootSpanProcessor + + span_ctx = trace.get_current_span().get_span_context() + if not span_ctx.is_valid: + logger.warning("set_attribute_on_root_span called outside any active span context") + return + + trace_id = span_ctx.trace_id + root_span = RootSpanProcessor.get_root_span_by_trace_id(trace_id) + if not root_span: + # Format as 32-character zero-padded lowercase hex + logger.warning(f"Cannot find root span for trace_id: {trace_id:032x}") + return + root_span.set_attribute(attr_key, attr_value) + except Exception: + logger.exception("Failed to set attribute '%s' on root span", attr_key) + @staticmethod def set_attribute_on_active_span(attr_key: str, attr_value: Any) -> None: """ diff --git a/netra/simulation/client.py b/netra/simulation/client.py index d495185..6ed56ef 100644 --- a/netra/simulation/client.py +++ b/netra/simulation/client.py @@ -1,28 +1,19 @@ """HTTP client for simulation API endpoints.""" import logging -import os -from typing import Any, Optional - -import httpx +from typing import Any, Dict, Optional +from netra.client import BaseNetraClient from netra.config import Config from netra.simulation.models import ConversationResponse, SimulationItem logger = logging.getLogger(__name__) -_DEFAULT_TIMEOUT = 10.0 _LOG_PREFIX = "netra.simulation" -class SimulationHttpClient: - """Internal HTTP client for simulation API endpoints. - - Attributes: - _client: The underlying httpx client instance. - """ - - __slots__ = ("_client",) +class SimulationHttpClient(BaseNetraClient): + """Internal HTTP client for simulation API endpoints.""" def __init__(self, config: Config) -> None: """Initialize the simulation HTTP client. @@ -30,86 +21,19 @@ def __init__(self, config: Config) -> None: Args: config: The Netra configuration object. """ - self._client: Optional[httpx.Client] = self._create_client(config) - - def _create_client(self, config: Config) -> Optional[httpx.Client]: - """Create and configure the HTTP client. - - Args: - config: The Netra configuration object. - - Returns: - Configured httpx client or None if creation fails. - """ - endpoint = (config.otlp_endpoint or "").strip() - if not endpoint: - logger.error("%s: NETRA_OTLP_ENDPOINT is required", _LOG_PREFIX) - return None - - base_url = self._resolve_base_url(endpoint) - headers = self._build_headers(config) - timeout = self._get_timeout() - - try: - return httpx.Client(base_url=base_url, headers=headers, timeout=timeout) - except Exception as exc: - logger.error("%s: Failed to create HTTP client: %s", _LOG_PREFIX, exc) - return None - - def _resolve_base_url(self, endpoint: str) -> str: - """Extract base URL, removing telemetry suffix if present. - - Args: - endpoint: The raw endpoint URL. - - Returns: - The cleaned base URL. - """ - base_url = endpoint.rstrip("/") - if base_url.endswith("/telemetry"): - base_url = base_url[: -len("/telemetry")] - return base_url - - def _build_headers(self, config: Config) -> dict[str, str]: - """Build request headers from configuration. - - Args: - config: The Netra configuration object. - - Returns: - Dictionary of HTTP headers. - """ - headers: dict[str, str] = dict(config.headers or {}) - if config.api_key: - headers["x-api-key"] = config.api_key - return headers - - def _get_timeout(self) -> float: - """Get timeout from environment or use default. - - Returns: - The timeout value in seconds. - """ - timeout_str = os.getenv("NETRA_SIMULATION_TIMEOUT") - if not timeout_str: - return _DEFAULT_TIMEOUT - try: - return float(timeout_str) - except ValueError: - logger.warning( - "%s: Invalid timeout '%s', using default %.1f", - _LOG_PREFIX, - timeout_str, - _DEFAULT_TIMEOUT, - ) - return _DEFAULT_TIMEOUT + super().__init__( + config, + log_prefix=_LOG_PREFIX, + timeout_env_var="NETRA_SIMULATION_TIMEOUT", + default_timeout=500.0, + ) def create_run( self, name: str, dataset_id: str, - context: Optional[dict[str, Any]] = None, - ) -> Optional[dict[str, Any]]: + context: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: """Create a new simulation run for the specified dataset. Args: @@ -124,15 +48,14 @@ def create_run( logger.error("%s: Client not initialized", _LOG_PREFIX) return None - response: Optional[httpx.Response] = None try: url = "/evaluations/test_run/multi-turn" - payload: dict[str, Any] = { + payload: Dict[str, Any] = { "name": name, "datasetId": dataset_id, "context": context or {}, } - response = self._client.post(url, json=payload, timeout=500) + response = self._client.post(url, json=payload) response.raise_for_status() data = response.json() @@ -157,7 +80,7 @@ def create_run( } except Exception as exc: - error_msg = self._extract_error_message(response, exc) + error_msg = self._extract_error_message(exc) logger.error("%s: Failed to create simulation run: %s", _LOG_PREFIX, error_msg) return None @@ -183,17 +106,16 @@ def trigger_conversation( logger.error("%s: Client not initialized", _LOG_PREFIX) return None - response: Optional[httpx.Response] = None try: url = "/evaluations/turn/agent-response" - payload: dict[str, Any] = { + payload: Dict[str, Any] = { "turnId": turn_id, "agentResponse": {"message": message}, "sessionId": session_id, "traceId": trace_id, } - response = self._client.post(url, json=payload, timeout=500) + response = self._client.post(url, json=payload) response.raise_for_status() data = response.json() @@ -220,7 +142,7 @@ def trigger_conversation( ) except Exception as exc: - error_msg = self._extract_error_message(response, exc) + error_msg = self._extract_error_message(exc) logger.error("%s: Failed to trigger conversation: %s", _LOG_PREFIX, error_msg) raise @@ -236,16 +158,13 @@ def report_failure(self, run_id: str, run_item_id: str, error: str) -> None: logger.error("%s: Client not initialized", _LOG_PREFIX) return - response: Optional[httpx.Response] = None try: url = f"/evaluations/run/{run_id}/item/{run_item_id}/status" - payload: dict[str, Any] = {"status": "failed", "failureReason": error} - response = self._client.patch(url, json=payload) - response.raise_for_status() + payload: Dict[str, Any] = {"status": "failed", "failureReason": error} + self._client.patch(url, json=payload).raise_for_status() logger.info("%s: Reported failure - %s", _LOG_PREFIX, error) except Exception as exc: - error_msg = self._extract_error_message(response, exc) - logger.error("%s: Failed to report failure: %s", _LOG_PREFIX, error_msg) + logger.error("%s: Failed to report failure: %s", _LOG_PREFIX, self._extract_error_message(exc)) def post_run_status(self, run_id: str, status: str) -> Any: """Submit the run status. @@ -255,16 +174,15 @@ def post_run_status(self, run_id: str, status: str) -> Any: status: The status of the run. Returns: - Backend JSON response containing confirmation, or error dict. + Backend JSON response containing confirmation, or None on failure. """ if not self._client: logger.error("%s: Client not initialized; cannot post run status", _LOG_PREFIX) - return {"success": False} + return None - response: Optional[httpx.Response] = None try: url = f"/evaluations/run/{run_id}/status" - payload: dict[str, Any] = {"status": status} + payload: Dict[str, Any] = {"status": status} response = self._client.post(url, json=payload) response.raise_for_status() data = response.json() @@ -273,30 +191,10 @@ def post_run_status(self, run_id: str, status: str) -> Any: return data.get("data", {}) return data except Exception as exc: - error_msg = self._extract_error_message(response, exc) - logger.error("%s: Failed to post run status for run '%s': %s", _LOG_PREFIX, run_id, error_msg) - return {"success": False} - - def _extract_error_message( - self, - response: Optional[httpx.Response], - exc: Exception, - ) -> Any: - """Extract error message from response or exception. - - Args: - response: The HTTP response object, if available. - exc: The exception that was raised. - - Returns: - A descriptive error message string. - """ - if response is not None: - try: - response_json = response.json() - error_data = response_json.get("error", {}) - if isinstance(error_data, dict): - return error_data.get("message", str(exc)) - except Exception: - pass - return str(exc) + logger.error( + "%s: Failed to post run status for run '%s': %s", + _LOG_PREFIX, + run_id, + self._extract_error_message(exc), + ) + return None diff --git a/netra/span_wrapper.py b/netra/span_wrapper.py index 1a6d0f4..984c2f3 100644 --- a/netra/span_wrapper.py +++ b/netra/span_wrapper.py @@ -9,7 +9,7 @@ from opentelemetry import context as otel_context from opentelemetry import trace from opentelemetry.trace import SpanKind, Status, StatusCode -from pydantic import BaseModel +from pydantic import BaseModel, Field from netra.config import Config from netra.session_manager import SessionManager @@ -21,7 +21,7 @@ class ActionModel(BaseModel): # type: ignore[misc] - start_time: str = str((datetime.now().timestamp() * 1_000_000_000)) + start_time: str = Field(default_factory=lambda: str(int(datetime.now().timestamp() * 1_000_000_000))) action: str action_type: str success: bool @@ -56,6 +56,12 @@ class SpanType(str, Enum): AGENT = "AGENT" +SPAN_TYPE_TO_ENTITY_TYPE: Dict[SpanType, str] = { + SpanType.AGENT: "agent", + SpanType.TOOL: "task", +} + + class SpanWrapper: """ Context manager for tracking observability data for external API calls. @@ -65,7 +71,7 @@ def __init__( self, name: str, attributes: Optional[Dict[str, str]] = None, - module_name: str = "combat_sdk", + module_name: str = Config.SDK_NAME, as_type: Optional[SpanType] = SpanType.SPAN, ): """ @@ -96,14 +102,20 @@ def __init__( if isinstance(as_type, SpanType): self.attributes["netra.span.type"] = as_type.value + self._entity_type: Optional[str] = SPAN_TYPE_TO_ENTITY_TYPE.get(as_type) else: logger.error("Invalid span type: %s", as_type) + self._entity_type = None return def __enter__(self) -> "SpanWrapper": """Start the span wrapper, begin time tracking, and create OpenTelemetry span.""" self.start_time = time.time() + # Push entity before span starts so SessionSpanProcessor can capture the name + if self._entity_type: + SessionManager.push_entity(self._entity_type, self.name) + # If user provided local blocked patterns in attributes, attach them as baggage try: patterns = None @@ -191,6 +203,10 @@ def __exit__(self, exc_type: Optional[type], exc_val: Optional[Exception], exc_t finally: self._local_block_token = None + # Pop entity from session stack so nested spans get correct parentage + if self._entity_type: + SessionManager.pop_entity(self._entity_type) + # Don't suppress exceptions return False @@ -311,3 +327,12 @@ def get_current_span(self) -> Optional[trace.Span]: The current OpenTelemetry span """ return self.span + + def get_trace_id(self) -> Optional[str]: + """ + Return the trace ID of this span. + + Returns: + str: 32-character lowercase hex trace ID, or None if the span is invalid. + """ + return SessionManager.get_trace_id() diff --git a/netra/tracer.py b/netra/tracer.py index e4af2e9..f7d59e0 100644 --- a/netra/tracer.py +++ b/netra/tracer.py @@ -1,6 +1,6 @@ import logging import threading -from typing import Any, Dict +from typing import Any, Dict, Optional, Set from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter @@ -27,13 +27,18 @@ class Tracer: and appropriate span processors. """ - def __init__(self, cfg: Config) -> None: + def __init__(self, cfg: Config, root_instrument_names: Optional[Set[str]] = None) -> None: """Initialize the Netra tracer with the provided configuration. Args: cfg: Configuration object with tracer settings + root_instrument_names: Optional set of instrumentation-name strings + that are allowed to produce root-level spans. When provided, a + ``RootInstrumentFilterProcessor`` is installed that discards root + spans (and their entire subtree) from all other instrumentations. """ self.cfg = cfg + self._root_instrument_names = root_instrument_names self._setup_tracer() def _setup_tracer(self) -> None: @@ -93,14 +98,20 @@ def _setup_tracer(self) -> None: InstrumentationSpanProcessor, LlmTraceIdentifierSpanProcessor, LocalFilteringSpanProcessor, + RootInstrumentFilterProcessor, RootSpanProcessor, ScrubbingSpanProcessor, SessionSpanProcessor, + SpanIOProcessor, ) + if self._root_instrument_names is not None: + provider.add_span_processor(RootInstrumentFilterProcessor(self._root_instrument_names)) + provider.add_span_processor(LocalFilteringSpanProcessor()) provider.add_span_processor(InstrumentationSpanProcessor()) provider.add_span_processor(SessionSpanProcessor()) + provider.add_span_processor(SpanIOProcessor()) provider.add_span_processor(LlmTraceIdentifierSpanProcessor()) # Adding RootSpanProcessor after LlmTraceIdentifierSpanProcessor diff --git a/netra/usage/api.py b/netra/usage/api.py index 0b2cf80..3e00604 100644 --- a/netra/usage/api.py +++ b/netra/usage/api.py @@ -45,6 +45,8 @@ def get_session_usage( logger.error("netra.usage: start_time and end_time are required to fetch session usage") return None result = self._client.get_session_usage(session_id, start_time=start_time, end_time=end_time) + if not result: + return None session_id = result.get("session_id", "") if not session_id: return None @@ -83,6 +85,8 @@ def get_tenant_usage( logger.error("netra.usage: start_time and end_time are required to fetch tenant usage") return None result = self._client.get_tenant_usage(tenant_id, start_time=start_time, end_time=end_time) + if not result: + return None tenant_id = result.get("tenant_id", "") if not tenant_id: return None diff --git a/netra/usage/client.py b/netra/usage/client.py index dc67361..0aee975 100644 --- a/netra/usage/client.py +++ b/netra/usage/client.py @@ -1,15 +1,15 @@ import logging -import os from typing import Any, Dict, Optional -import httpx - +from netra.client import BaseNetraClient from netra.config import Config logger = logging.getLogger(__name__) +_LOG_PREFIX = "netra.usage" + -class UsageHttpClient: +class UsageHttpClient(BaseNetraClient): """Internal HTTP client for usage APIs.""" def __init__(self, config: Config) -> None: @@ -17,68 +17,31 @@ def __init__(self, config: Config) -> None: Initialize the usage HTTP client. Args: - config: Configuration object with usage settings + config: Configuration object with usage settings. """ - self._client: Optional[httpx.Client] = self._create_client(config) - - def _create_client(self, config: Config) -> Optional[httpx.Client]: - endpoint = (config.otlp_endpoint or "").strip() - if not endpoint: - logger.error("netra.usage: NETRA_OTLP_ENDPOINT is required for usage APIs") - return None - - base_url = self._resolve_base_url(endpoint) - headers = self._build_headers(config) - timeout = self._get_timeout() - - try: - return httpx.Client(base_url=base_url, headers=headers, timeout=timeout) - except Exception as exc: - logger.error("netra.usage: Failed to initialize usage HTTP client: %s", exc) - return None - - def _resolve_base_url(self, endpoint: str) -> str: - base_url = endpoint.rstrip("/") - if base_url.endswith("/telemetry"): - base_url = base_url[: -len("/telemetry")] - return base_url - - def _build_headers(self, config: Config) -> Dict[str, str]: - headers: Dict[str, str] = dict(config.headers or {}) - api_key = config.api_key - if api_key: - headers["x-api-key"] = api_key - return headers - - def _get_timeout(self) -> float: - timeout_env = os.getenv("NETRA_USAGE_TIMEOUT") - if not timeout_env: - return 10.0 - try: - return float(timeout_env) - except ValueError: - logger.warning( - "netra.usage: Invalid NETRA_USAGE_TIMEOUT value '%s', using default 10.0", - timeout_env, - ) - return 10.0 - - def get_session_usage(self, session_id: str, start_time: str | None = None, end_time: str | None = None) -> Any: + super().__init__( + config, + log_prefix=_LOG_PREFIX, + timeout_env_var="NETRA_USAGE_TIMEOUT", + ) + + def get_session_usage( + self, session_id: str, start_time: Optional[str] = None, end_time: Optional[str] = None + ) -> Any: """ Get session usage data. Args: - session_id: Session identifier + session_id: Session identifier. + start_time: Optional start time filter. + end_time: Optional end time filter. Returns: - Any: Session usage data + Session usage data. """ if not self._client: - logger.error( - "netra.usage: Usage client is not initialized; cannot fetch session usage '%s'", - session_id, - ) - return {} + logger.error("%s: Client is not initialized; cannot fetch session usage '%s'", _LOG_PREFIX, session_id) + return None try: url = f"/usage/sessions/{session_id}" @@ -94,25 +57,26 @@ def get_session_usage(self, session_id: str, start_time: str | None = None, end_ return data.get("data", {}) return data except Exception as exc: - logger.error("netra.usage: Failed to fetch session usage '%s': %s", session_id, exc) - return {} + logger.error( + "%s: Failed to fetch session usage '%s': %s", _LOG_PREFIX, session_id, self._extract_error_message(exc) + ) + return None - def get_tenant_usage(self, tenant_id: str, start_time: str | None = None, end_time: str | None = None) -> Any: + def get_tenant_usage(self, tenant_id: str, start_time: Optional[str] = None, end_time: Optional[str] = None) -> Any: """ Get tenant usage data. Args: - tenant_id: Tenant identifier + tenant_id: Tenant identifier. + start_time: Optional start time filter. + end_time: Optional end time filter. Returns: - Any: Tenant usage data + Tenant usage data. """ if not self._client: - logger.error( - "netra.usage: Usage client is not initialized; cannot fetch tenant usage '%s'", - tenant_id, - ) - return {} + logger.error("%s: Client is not initialized; cannot fetch tenant usage '%s'", _LOG_PREFIX, tenant_id) + return None try: url = f"/usage/tenants/{tenant_id}" @@ -128,45 +92,47 @@ def get_tenant_usage(self, tenant_id: str, start_time: str | None = None, end_ti return data.get("data", {}) return data except Exception as exc: - logger.error("netra.usage: Failed to fetch tenant usage '%s': %s", tenant_id, exc) - return {} + logger.error( + "%s: Failed to fetch tenant usage '%s': %s", _LOG_PREFIX, tenant_id, self._extract_error_message(exc) + ) + return None def list_traces( self, - start_time: str | None = None, - end_time: str | None = None, - trace_id: str | None = None, - session_id: str | None = None, - user_id: str | None = None, - tenant_id: str | None = None, - limit: int | None = None, - cursor: str | None = None, - direction: str | None = None, - sort_field: str | None = None, - sort_order: str | None = None, + start_time: Optional[str] = None, + end_time: Optional[str] = None, + trace_id: Optional[str] = None, + session_id: Optional[str] = None, + user_id: Optional[str] = None, + tenant_id: Optional[str] = None, + limit: Optional[int] = None, + cursor: Optional[str] = None, + direction: Optional[str] = None, + sort_field: Optional[str] = None, + sort_order: Optional[str] = None, ) -> Any: """ List all traces. Args: - start_time: Start time for the traces (in ISO 8601 UTC format) - end_time: End time for the traces (in ISO 8601 UTC format) - trace_id: Search based on trace_id, if provided - session_id: Search based on session_id, if provided - user_id: Search based on user_id, if provided - tenant_id: Search based on tenant_id, if provided - limit: Maximum number of traces to return - cursor: Cursor for pagination - direction: Direction of pagination - sort_field: Field to sort by - sort_order: Order to sort by + start_time: Start time for the traces (in ISO 8601 UTC format). + end_time: End time for the traces (in ISO 8601 UTC format). + trace_id: Search based on trace_id, if provided. + session_id: Search based on session_id, if provided. + user_id: Search based on user_id, if provided. + tenant_id: Search based on tenant_id, if provided. + limit: Maximum number of traces to return. + cursor: Cursor for pagination. + direction: Direction of pagination. + sort_field: Field to sort by. + sort_order: Order to sort by. Returns: - Any: Traces data + Traces data. """ if not self._client: - logger.error("netra.usage: Usage client is not initialized; cannot list traces") - return {} + logger.error("%s: Client is not initialized; cannot list traces", _LOG_PREFIX) + return None try: url = "/sdk/traces" @@ -207,36 +173,35 @@ def list_traces( response = self._client.post(url, json=payload or None) response.raise_for_status() - data = response.json() - return data + return response.json() except Exception as exc: - logger.error("netra.usage: Failed to list traces: %s", exc) - return {} + logger.error("%s: Failed to list traces: %s", _LOG_PREFIX, self._extract_error_message(exc)) + return None def list_spans_by_trace_id( self, trace_id: str, - cursor: str | None = None, - direction: str | None = None, - limit: int | None = None, - span_name: str | None = None, + cursor: Optional[str] = None, + direction: Optional[str] = None, + limit: Optional[int] = None, + span_name: Optional[str] = None, ) -> Any: """ List all spans for a given trace. Args: - trace_id: Trace identifier - cursor: Cursor for pagination - direction: Direction of pagination - limit: Maximum number of spans to return - span_name: Search query for the spans + trace_id: Trace identifier. + cursor: Cursor for pagination. + direction: Direction of pagination. + limit: Maximum number of spans to return. + span_name: Search query for the spans. Returns: - Any: Spans data + Spans data. """ if not self._client: - logger.error("netra.usage: Usage client is not initialized; cannot list spans for trace '%s'", trace_id) - return {} + logger.error("%s: Client is not initialized; cannot list spans for trace '%s'", _LOG_PREFIX, trace_id) + return None try: url = f"/sdk/traces/{trace_id}/spans" @@ -252,8 +217,9 @@ def list_spans_by_trace_id( response = self._client.get(url, params=params or None) response.raise_for_status() - data = response.json() - return data + return response.json() except Exception as exc: - logger.error("netra.usage: Failed to list spans for trace '%s': %s", trace_id, exc) - return {} + logger.error( + "%s: Failed to list spans for trace '%s': %s", _LOG_PREFIX, trace_id, self._extract_error_message(exc) + ) + return None diff --git a/netra/version.py b/netra/version.py index e5e0b9d..f192e9f 100644 --- a/netra/version.py +++ b/netra/version.py @@ -1 +1 @@ -__version__ = "0.1.81" +__version__ = "0.1.82dev0" diff --git a/pyproject.toml b/pyproject.toml index 8722445..7f6d056 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [project] name = "netra-sdk" -version = "0.1.81" +version = "0.1.82dev0" description = "A Python SDK for AI application observability that provides OpenTelemetry-based monitoring, tracing, and PII protection for LLM and vector database applications. Enables easy instrumentation, session tracking, and privacy-focused data collection for AI systems in production environments." authors = [ {name = "Sooraj Thomas",email = "sooraj@keyvalue.systems"} diff --git a/tests/test_decorators.py b/tests/test_decorators.py index 836d7b9..3f687ae 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -95,7 +95,7 @@ def test_func(arg1: str, arg2: int): # Check input attributes expected_input = json.dumps({"arg1": "hello", "arg2": "42"}) - mock_span.set_attribute.assert_any_call(f"{Config.LIBRARY_NAME}.entity.input", expected_input) + mock_span.set_attribute.assert_any_call("input", expected_input) def test_add_span_attributes_with_kwargs(self): """Test adding span attributes with keyword arguments.""" @@ -111,7 +111,7 @@ def test_func(arg1: str, arg2: int = 10): # Check input attributes include both args and kwargs expected_input = json.dumps({"arg1": "hello", "arg2": "42"}) - mock_span.set_attribute.assert_any_call(f"{Config.LIBRARY_NAME}.entity.input", expected_input) + mock_span.set_attribute.assert_any_call("input", expected_input) def test_add_span_attributes_with_self_parameter(self): """Test adding span attributes ignoring self parameter.""" @@ -124,7 +124,7 @@ def test_method(self, arg1: str): # Check that self parameter is ignored expected_input = json.dumps({"arg1": "hello"}) - mock_span.set_attribute.assert_any_call(f"{Config.LIBRARY_NAME}.entity.input", expected_input) + mock_span.set_attribute.assert_any_call("input", expected_input) def test_add_span_attributes_with_cls_parameter(self): """Test adding span attributes ignoring cls parameter.""" @@ -137,7 +137,7 @@ def test_classmethod(cls, arg1: str): # Check that cls parameter is ignored expected_input = json.dumps({"arg1": "hello"}) - mock_span.set_attribute.assert_any_call(f"{Config.LIBRARY_NAME}.entity.input", expected_input) + mock_span.set_attribute.assert_any_call("input", expected_input) def test_add_span_attributes_exception_handling(self): """Test span attribute addition with exception handling.""" @@ -151,7 +151,7 @@ def problematic_func(): _add_span_attributes(mock_span, problematic_func, (), {}, "workflow") # Check that error is recorded - mock_span.set_attribute.assert_any_call(f"{Config.LIBRARY_NAME}.input_error", "Signature error") + mock_span.set_attribute.assert_any_call("input_error", "Signature error") class TestAddOutputAttributes: @@ -164,7 +164,7 @@ def test_add_output_attributes_simple_result(self): _add_output_attributes(mock_span, result) - mock_span.set_attribute.assert_called_once_with(f"{Config.LIBRARY_NAME}.entity.output", "test_result") + mock_span.set_attribute.assert_called_once_with("output", "test_result") def test_add_output_attributes_complex_result(self): """Test adding output attributes for complex result.""" @@ -174,7 +174,7 @@ def test_add_output_attributes_complex_result(self): _add_output_attributes(mock_span, result) expected_output = '{"status": "success", "data": [1, 2, 3]}' - mock_span.set_attribute.assert_called_once_with(f"{Config.LIBRARY_NAME}.entity.output", expected_output) + mock_span.set_attribute.assert_called_once_with("output", expected_output) def test_add_output_attributes_exception_handling(self): """Test output attribute addition with exception handling.""" @@ -190,9 +190,7 @@ def __str__(self): with patch("netra.decorators._serialize_value", side_effect=ValueError("Cannot serialize")): _add_output_attributes(mock_span, result) - mock_span.set_attribute.assert_called_once_with( - f"{Config.LIBRARY_NAME}.entity.output_error", "Cannot serialize" - ) + mock_span.set_attribute.assert_called_once_with("output_error", "Cannot serialize") class TestCreateFunctionWrapper: