diff --git a/src/langtrace_python_sdk/instrumentation/__init__.py b/src/langtrace_python_sdk/instrumentation/__init__.py index 8e31fc17..c49a7800 100644 --- a/src/langtrace_python_sdk/instrumentation/__init__.py +++ b/src/langtrace_python_sdk/instrumentation/__init__.py @@ -22,6 +22,7 @@ from .llamaindex import LlamaindexInstrumentation from .milvus import MilvusInstrumentation from .mistral import MistralInstrumentation +from .neo4j_graphrag import Neo4jGraphRAGInstrumentation from .ollama import OllamaInstrumentor from .openai import OpenAIInstrumentation from .openai_agents import OpenAIAgentsInstrumentation @@ -59,6 +60,7 @@ "AWSBedrockInstrumentation", "CerebrasInstrumentation", "MilvusInstrumentation", + "Neo4jGraphRAGInstrumentation", "GoogleGenaiInstrumentation", "CrewaiToolsInstrumentation", "GraphlitInstrumentation", diff --git a/src/langtrace_python_sdk/instrumentation/neo4j_graphrag/__init__.py b/src/langtrace_python_sdk/instrumentation/neo4j_graphrag/__init__.py new file mode 100644 index 00000000..56f74172 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/neo4j_graphrag/__init__.py @@ -0,0 +1,3 @@ +from .instrumentation import Neo4jGraphRAGInstrumentation + +__all__ = ["Neo4jGraphRAGInstrumentation"] diff --git a/src/langtrace_python_sdk/instrumentation/neo4j_graphrag/instrumentation.py b/src/langtrace_python_sdk/instrumentation/neo4j_graphrag/instrumentation.py new file mode 100644 index 00000000..56dcc71a --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/neo4j_graphrag/instrumentation.py @@ -0,0 +1,62 @@ +""" +Copyright (c) 2025 Scale3 Labs + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from typing import Collection +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry import trace +from wrapt import wrap_function_wrapper as _W +from importlib.metadata import version as v +from .patch import patch_graphrag_search, patch_kg_pipeline_run, \ +patch_kg_pipeline_run, patch_retriever_search + + +class Neo4jGraphRAGInstrumentation(BaseInstrumentor): + + def instrumentation_dependencies(self) -> Collection[str]: + return ["neo4j-graphrag>=1.6.0"] + + def _instrument(self, **kwargs): + tracer_provider = kwargs.get("tracer_provider") + tracer = trace.get_tracer(__name__, "", tracer_provider) + graphrag_version = v("neo4j-graphrag") + + try: + # instrument kg builder + _W( + "neo4j_graphrag.experimental.pipeline.kg_builder", + "SimpleKGPipeline.run_async", + patch_kg_pipeline_run("run_async", graphrag_version, tracer), + ) + + # Instrument GraphRAG + _W( + "neo4j_graphrag.generation.graphrag", + "GraphRAG.search", + patch_graphrag_search("search", graphrag_version, tracer), + ) + + # Instrument retrievers + _W( + "neo4j_graphrag.retrievers.vector", + "VectorRetriever.get_search_results", + patch_retriever_search("vector_search", graphrag_version, tracer), + ) + + except Exception as e: + print(f"Failed to instrument Neo4j GraphRAG: {e}") + + def _uninstrument(self, **kwargs): + pass \ No newline at end of file diff --git a/src/langtrace_python_sdk/instrumentation/neo4j_graphrag/patch.py b/src/langtrace_python_sdk/instrumentation/neo4j_graphrag/patch.py new file mode 100644 index 00000000..38b69434 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/neo4j_graphrag/patch.py @@ -0,0 +1,229 @@ +""" +Copyright (c) 2025 Scale3 Labs + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import json + +from importlib_metadata import version as v +from langtrace.trace_attributes import FrameworkSpanAttributes +from opentelemetry import baggage +from opentelemetry.trace import Span, SpanKind, Tracer +from opentelemetry.trace.status import Status, StatusCode + +from langtrace_python_sdk.constants import LANGTRACE_SDK_NAME +from langtrace_python_sdk.constants.instrumentation.common import ( + LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY, SERVICE_PROVIDERS) +from langtrace_python_sdk.utils.llm import set_span_attributes +from langtrace_python_sdk.utils.misc import serialize_args, serialize_kwargs + + +def patch_kg_pipeline_run(operation_name: str, version: str, tracer: Tracer): + + async def async_traced_method(wrapped, instance, args, kwargs): + service_provider = SERVICE_PROVIDERS.get("NEO4J_GRAPHRAG", "neo4j_graphrag") + extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY) + + span_attributes = { + "langtrace.sdk.name": "langtrace-python-sdk", + "langtrace.service.name": service_provider, + "langtrace.service.type": "framework", + "langtrace.service.version": version, + "langtrace.version": v(LANGTRACE_SDK_NAME), + "neo4j.pipeline.type": "SimpleKGPipeline", + **(extra_attributes if extra_attributes is not None else {}), + } + + if len(args) > 0: + span_attributes["neo4j.pipeline.inputs"] = serialize_args(*args) + if kwargs: + span_attributes["neo4j.pipeline.kwargs"] = serialize_kwargs(**kwargs) + + file_path = kwargs.get("file_path", args[0] if len(args) > 0 else None) + text = kwargs.get("text", args[1] if len(args) > 1 else None) + if file_path: + span_attributes["neo4j.pipeline.file_path"] = file_path + if text: + span_attributes["neo4j.pipeline.text_length"] = len(text) + + if hasattr(instance, "runner") and hasattr(instance.runner, "config"): + config = instance.runner.config + if config: + span_attributes["neo4j.pipeline.from_pdf"] = getattr(config, "from_pdf", None) + span_attributes["neo4j.pipeline.perform_entity_resolution"] = getattr(config, "perform_entity_resolution", None) + + attributes = FrameworkSpanAttributes(**span_attributes) + + with tracer.start_as_current_span( + name=f"neo4j.pipeline.{operation_name}", + kind=SpanKind.CLIENT, + ) as span: + try: + set_span_attributes(span, attributes) + + result = await wrapped(*args, **kwargs) + + if result: + try: + if hasattr(result, "to_dict"): + result_dict = result.to_dict() + span.set_attribute("neo4j.pipeline.result", json.dumps(result_dict)) + elif hasattr(result, "model_dump"): + result_dict = result.model_dump() + span.set_attribute("neo4j.pipeline.result", json.dumps(result_dict)) + except Exception as e: + span.set_attribute("neo4j.pipeline.result_error", str(e)) + + span.set_status(Status(StatusCode.OK)) + return result + + except Exception as err: + span.record_exception(err) + span.set_status(Status(StatusCode.ERROR, str(err))) + raise + + return async_traced_method + + +def patch_graphrag_search(operation_name: str, version: str, tracer: Tracer): + + def traced_method(wrapped, instance, args, kwargs): + service_provider = SERVICE_PROVIDERS.get("NEO4J_GRAPHRAG", "neo4j_graphrag") + extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY) + + # Basic attributes + span_attributes = { + "langtrace.sdk.name": "langtrace-python-sdk", + "langtrace.service.name": service_provider, + "langtrace.service.type": "framework", + "langtrace.service.version": version, + "langtrace.version": v(LANGTRACE_SDK_NAME), + "neo4j_graphrag.operation": operation_name, + **(extra_attributes if extra_attributes is not None else {}), + } + + query_text = kwargs.get("query_text", args[0] if len(args) > 0 else None) + if query_text: + span_attributes["neo4j_graphrag.query_text"] = query_text + + retriever_config = kwargs.get("retriever_config", None) + if retriever_config: + span_attributes["neo4j_graphrag.retriever_config"] = json.dumps(retriever_config) + + if hasattr(instance, "retriever"): + span_attributes["neo4j_graphrag.retriever_type"] = instance.retriever.__class__.__name__ + + if hasattr(instance, "llm"): + span_attributes["neo4j_graphrag.llm_type"] = instance.llm.__class__.__name__ + + attributes = FrameworkSpanAttributes(**span_attributes) + + with tracer.start_as_current_span( + name=f"neo4j_graphrag.{operation_name}", + kind=SpanKind.CLIENT, + ) as span: + try: + set_span_attributes(span, attributes) + + result = wrapped(*args, **kwargs) + + if result and hasattr(result, "answer"): + span.set_attribute("neo4j_graphrag.answer", result.answer) + + if hasattr(result, "retriever_result") and result.retriever_result: + try: + retriever_items = len(result.retriever_result.items) + span.set_attribute("neo4j_graphrag.context_items", retriever_items) + except Exception: + pass + + span.set_status(Status(StatusCode.OK)) + return result + + except Exception as err: + span.record_exception(err) + span.set_status(Status(StatusCode.ERROR, str(err))) + raise + + return traced_method + + +def patch_retriever_search(operation_name: str, version: str, tracer: Tracer): + + def traced_method(wrapped, instance, args, kwargs): + service_provider = SERVICE_PROVIDERS.get("NEO4J_GRAPHRAG", "neo4j_graphrag") + extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY) + + # Basic attributes + span_attributes = { + "langtrace.sdk.name": "langtrace-python-sdk", + "langtrace.service.name": service_provider, + "langtrace.service.type": "framework", + "langtrace.service.version": version, + "langtrace.version": v(LANGTRACE_SDK_NAME), + "neo4j.retriever.operation": operation_name, + "neo4j.retriever.type": instance.__class__.__name__, + **(extra_attributes if extra_attributes is not None else {}), + } + + query_text = kwargs.get("query_text", args[0] if len(args) > 0 else None) + if query_text: + span_attributes["neo4j.retriever.query_text"] = query_text + + if hasattr(instance, "__class__") and hasattr(instance.__class__, "__name__"): + retriever_type = instance.__class__.__name__ + + if retriever_type == "VectorRetriever" and hasattr(instance, "index_name"): + span_attributes["neo4j.vector_retriever.index_name"] = instance.index_name + + if retriever_type == "KnowledgeGraphRetriever" and hasattr(instance, "cypher_query"): + span_attributes["neo4j.kg_retriever.cypher_query"] = instance.cypher_query + + for param in ["top_k", "similarity_threshold"]: + if param in kwargs: + span_attributes[f"neo4j.retriever.{param}"] = kwargs[param] + elif hasattr(instance, param): + span_attributes[f"neo4j.retriever.{param}"] = getattr(instance, param) + + attributes = FrameworkSpanAttributes(**span_attributes) + + with tracer.start_as_current_span( + name=f"neo4j.retriever.{operation_name}", + kind=SpanKind.CLIENT, + ) as span: + try: + set_span_attributes(span, attributes) + + result = wrapped(*args, **kwargs) + + if result: + if hasattr(result, "items") and isinstance(result.items, list): + span.set_attribute("neo4j.retriever.items_count", len(result.items)) + + try: + item_ids = [item.id for item in result.items[:5] if hasattr(item, "id")] + if item_ids: + span.set_attribute("neo4j.retriever.item_ids", json.dumps(item_ids)) + except Exception: + pass + + span.set_status(Status(StatusCode.OK)) + return result + + except Exception as err: + span.record_exception(err) + span.set_status(Status(StatusCode.ERROR, str(err))) + raise + + return traced_method diff --git a/src/langtrace_python_sdk/langtrace.py b/src/langtrace_python_sdk/langtrace.py index 2535926a..6f16f998 100644 --- a/src/langtrace_python_sdk/langtrace.py +++ b/src/langtrace_python_sdk/langtrace.py @@ -48,8 +48,8 @@ GeminiInstrumentation, GoogleGenaiInstrumentation, GraphlitInstrumentation, GroqInstrumentation, LangchainCommunityInstrumentation, LangchainCoreInstrumentation, LangchainInstrumentation, - LanggraphInstrumentation, LiteLLMInstrumentation, - LlamaindexInstrumentation, MilvusInstrumentation, MistralInstrumentation, + LanggraphInstrumentation, LiteLLMInstrumentation, LlamaindexInstrumentation, + MilvusInstrumentation, MistralInstrumentation, Neo4jGraphRAGInstrumentation, OllamaInstrumentor, OpenAIAgentsInstrumentation, OpenAIInstrumentation, PhiDataInstrumentation, PineconeInstrumentation, PyMongoInstrumentation, QdrantInstrumentation, VertexAIInstrumentation, WeaviateInstrumentation) @@ -284,6 +284,7 @@ def init( "phidata": PhiDataInstrumentation(), "agno": AgnoInstrumentation(), "mistralai": MistralInstrumentation(), + "neo4j-graphrag": Neo4jGraphRAGInstrumentation(), "boto3": AWSBedrockInstrumentation(), "autogen": AutogenInstrumentation(), "pymongo": PyMongoInstrumentation(), diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index 08f7211d..a48af93e 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "3.8.7" +__version__ = "3.8.8"