Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 0 additions & 17 deletions agentic-rag-authorization/.env.example

This file was deleted.

105 changes: 0 additions & 105 deletions agentic-rag-authorization/agentic_rag/authorization_helpers.py

This file was deleted.

104 changes: 8 additions & 96 deletions agentic-rag-authorization/agentic_rag/grpc_helpers.py
Original file line number Diff line number Diff line change
@@ -1,129 +1,41 @@
"""Helper functions for gRPC and SpiceDB authentication."""
"""Helper functions for SpiceDB client creation."""

import grpc
from threading import Lock
from typing import Optional

from authzed.api.v1 import InsecureClient

class BearerTokenInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor):
"""
gRPC interceptor that adds bearer token to all requests.

This is for local development with SpiceDB's --grpc-no-tls flag.
"""

def __init__(self, token: str):
self._token = token

def _add_authorization(self, client_call_details):
"""Add authorization metadata to the call."""
metadata = []
if client_call_details.metadata is not None:
metadata = list(client_call_details.metadata)
metadata.append(("authorization", f"Bearer {self._token}"))

return grpc._interceptor._ClientCallDetails(
client_call_details.method,
client_call_details.timeout,
metadata,
client_call_details.credentials,
client_call_details.wait_for_ready,
client_call_details.compression,
)

def intercept_unary_unary(self, continuation, client_call_details, request):
"""Intercept unary-unary calls."""
new_details = self._add_authorization(client_call_details)
return continuation(new_details, request)

def intercept_unary_stream(self, continuation, client_call_details, request):
"""Intercept unary-stream calls."""
new_details = self._add_authorization(client_call_details)
return continuation(new_details, request)


# Global singleton for SpiceDB client with thread-safe initialization
_spicedb_client: Optional["Client"] = None
_spicedb_client: Optional[InsecureClient] = None
_spicedb_lock = Lock()


def create_insecure_spicedb_client(endpoint: str, token: str):
def create_insecure_spicedb_client(endpoint: str, token: str) -> InsecureClient:
"""
Create a SpiceDB client for insecure connections (local development).

This is for SpiceDB running with --grpc-no-tls flag.

Args:
endpoint: The SpiceDB endpoint (e.g., "localhost:50051")
token: The bearer token (e.g., "devtoken")

Returns:
authzed.api.v1.Client configured for insecure connection
For SpiceDB running with --grpc-no-tls flag.
"""
from authzed.api.v1 import Client

# Create insecure channel with bearer token interceptor
channel = grpc.insecure_channel(endpoint)
interceptor = BearerTokenInterceptor(token)
intercepted_channel = grpc.intercept_channel(channel, interceptor)

# Create client bypassing __init__ and initialize with our channel
client = Client.__new__(Client)
client.init_stubs(intercepted_channel)
return InsecureClient(endpoint, token)

return client


def get_spicedb_client(endpoint: str, token: str):
def get_spicedb_client(endpoint: str, token: str) -> InsecureClient:
"""
Get or create reusable SpiceDB client (singleton, thread-safe).

This function provides connection pooling for SpiceDB by maintaining
a single client instance across requests, eliminating connection overhead.

Args:
endpoint: The SpiceDB endpoint (e.g., "localhost:50051")
token: The bearer token (e.g., "devtoken")

Returns:
authzed.api.v1.Client configured for insecure connection
"""
from authzed.api.v1 import Client

global _spicedb_client

# Fast path: client already exists
if _spicedb_client is not None:
return _spicedb_client

# Slow path: create new client with thread-safe lock
with _spicedb_lock:
# Double-check after acquiring lock
if _spicedb_client is None:
_spicedb_client = create_insecure_spicedb_client(endpoint, token)

return _spicedb_client


def reset_spicedb_client():
"""
Reset singleton (useful for testing).

This allows tests to clear the cached client and create a fresh one.
"""
"""Reset singleton (useful for testing)."""
global _spicedb_client
with _spicedb_lock:
_spicedb_client = None


# Backward compatibility - keep the old function name
def insecure_bearer_token_credentials(token: str):
"""
Deprecated: Use create_insecure_spicedb_client instead.

This function is kept for backward compatibility but doesn't work
with authzed Client for insecure connections.
"""
raise NotImplementedError(
"For insecure SpiceDB connections, use create_insecure_spicedb_client() instead"
)
91 changes: 48 additions & 43 deletions agentic-rag-authorization/agentic_rag/nodes/authorization_node.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,70 @@
"""Authorization node - deterministic permission filtering via SpiceDB."""

from langchain_core.messages import SystemMessage
from langchain_spicedb.core import SpiceDBAuthorizer

from ..state import AgenticRAGState
from ..config import get_config
from ..grpc_helpers import get_spicedb_client
from ..logging_config import get_logger
from ..authorization_helpers import batch_check_permissions
from ..node_helpers import log_node_execution

logger = get_logger("nodes.authorization")

_authorizer: SpiceDBAuthorizer | None = None

def authorization_node(state: AgenticRAGState) -> dict:

def _get_authorizer() -> SpiceDBAuthorizer:
global _authorizer
if _authorizer is None:
config = get_config()
_authorizer = SpiceDBAuthorizer(
spicedb_endpoint=config.spicedb_endpoint,
spicedb_token=config.spicedb_token,
resource_type="document",
subject_type="user",
permission="view",
resource_id_key="doc_id",
)
return _authorizer


async def authorization_node(state: AgenticRAGState) -> dict:
"""
Deterministic authorization node - ALWAYS runs, cannot be bypassed.

This node filters retrieved documents based on SpiceDB permissions.
Filters retrieved documents through SpiceDB's CheckBulkPermissions API.
This is a security boundary - the agent cannot bypass this check.
"""
config = get_config()
authorizer = _get_authorizer()

with log_node_execution(
logger,
"authorization",
{
logger.info(
"Starting authorization",
extra={
"subject_id": state["subject_id"],
"document_count": len(state["retrieved_documents"]),
}
):
# Get or create SpiceDB client (reused across requests)
client = get_spicedb_client(
config.spicedb_endpoint,
config.spicedb_token,
)

# Batch check permissions using SpiceDB's bulk API
authorized_docs, denied_doc_ids = batch_check_permissions(
client,
state["subject_id"],
state["retrieved_documents"],
)
},
)

denied_count = len(denied_doc_ids)
result = await authorizer.filter_documents(
documents=state["retrieved_documents"],
subject_id=state["subject_id"],
)

logger.info(
"Authorization results",
extra={
"authorized": len(authorized_docs),
"denied": denied_count,
"denied_doc_ids": denied_doc_ids,
},
)
logger.info(
"Authorization results",
extra={
"authorized": result.total_authorized,
"denied": len(result.denied_resource_ids),
"denied_doc_ids": result.denied_resource_ids,
},
)

return {
"authorized_documents": authorized_docs,
"denied_count": denied_count,
"authorization_passed": len(authorized_docs) > 0,
"messages": [
SystemMessage(
content=f"Authorization: {len(authorized_docs)}/{len(state['retrieved_documents'])} documents authorized"
)
],
}
return {
"authorized_documents": result.authorized_documents,
"denied_count": len(result.denied_resource_ids),
"authorization_passed": result.total_authorized > 0,
"messages": [
SystemMessage(
content=f"Authorization: {result.total_authorized}/{result.total_retrieved} documents authorized"
)
],
}
1 change: 1 addition & 0 deletions agentic-rag-authorization/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
langchain>=0.1.0
langchain-openai>=0.1.0
langgraph>=0.0.20
langchain-spicedb>=0.2.0
weaviate-client>=3.26.0,<4.0 # v3 for REST API stability (no gRPC issues)
authzed>=0.7.0
python-dotenv>=1.0.0
Expand Down
Loading
Loading