Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
790 changes: 782 additions & 8 deletions drift/instrumentation/psycopg/instrumentation.py

Large diffs are not rendered by default.

228 changes: 228 additions & 0 deletions drift/instrumentation/psycopg/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,231 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def sync(self):
"""No-op sync for mock pipeline."""
pass


# ==================== ASYNC MOCKS ====================


class MockAsyncConnection:
"""Mock async database connection for REPLAY mode when postgres is not available.

Provides minimal async interface for FastAPI/asyncio apps to work without a real database.
All queries are mocked at the cursor.execute() level.
"""

def __init__(
self,
sdk: TuskDrift,
instrumentation: PsycopgInstrumentation,
cursor_factory,
row_factory=None,
):
self.sdk = sdk
self.instrumentation = instrumentation
self.cursor_factory = cursor_factory
self.row_factory = row_factory
self.closed = False
self.autocommit = False

# psycopg3 async connection attributes
self.isolation_level = None
self.encoding = "UTF8"
self.adapters = MockAdapters()
self.pgconn = None

class MockInfo:
vendor = "postgresql"
server_version = 150000
encoding = "UTF8"

def parameter_status(self, param):
if param == "TimeZone":
return "UTC"
elif param == "server_version":
return "15.0"
return None

self.info = MockInfo()

logger.debug("[MOCK_ASYNC_CONNECTION] Created mock async connection for REPLAY mode (psycopg3)")

def cursor(self, name=None, *, cursor_factory=None, **kwargs):
"""Create an async cursor using the instrumented cursor factory."""
cursor = MockAsyncCursor(self)

instrumentation = self.instrumentation
sdk = self.sdk

async def mock_execute(query, params=None, **kwargs):
# Use async execute handler
async def noop_execute(q, p, **kw):
return cursor

return await instrumentation._traced_async_execute(cursor, noop_execute, sdk, query, params, **kwargs)

async def mock_executemany(query, params_seq, **kwargs):
async def noop_executemany(q, ps, **kw):
return cursor

return await instrumentation._traced_async_executemany(
cursor, noop_executemany, sdk, query, params_seq, **kwargs
)

cursor.execute = mock_execute # type: ignore[method-assign]
cursor.executemany = mock_executemany # type: ignore[method-assign]

logger.debug("[MOCK_ASYNC_CONNECTION] Created async cursor (psycopg3)")
return cursor

async def commit(self):
"""Mock async commit - no-op in REPLAY mode."""
logger.debug("[MOCK_ASYNC_CONNECTION] commit() called (no-op)")
pass

async def rollback(self):
"""Mock async rollback - no-op in REPLAY mode."""
logger.debug("[MOCK_ASYNC_CONNECTION] rollback() called (no-op)")
pass

async def close(self):
"""Mock async close - no-op in REPLAY mode."""
logger.debug("[MOCK_ASYNC_CONNECTION] close() called (no-op)")
self.closed = True

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
await self.rollback()
else:
await self.commit()
return False

def transaction(self):
"""Return a mock async transaction context manager for REPLAY mode."""
return MockAsyncTransaction(self)

def pipeline(self):
"""Return a mock async pipeline context manager for REPLAY mode."""
return MockAsyncPipeline(self)


class MockAsyncCursor:
"""Mock async cursor for when we can't create a real async cursor.

This is a fallback when the async connection is completely mocked.
"""

def __init__(self, connection):
self.connection = connection
self.rowcount = -1
self._tusk_description = None
self.arraysize = 1
self._mock_rows = []
self._mock_index = 0
self._mock_result_sets = []
self._mock_result_set_index = 0
self.adapters = MockAdapters()
logger.debug("[MOCK_ASYNC_CURSOR] Created fallback mock async cursor (psycopg3)")

@property
def description(self):
return self._tusk_description

@property
def rownumber(self):
if self._mock_rows:
return self._mock_index
return None

@property
def statusmessage(self):
return getattr(self, "_mock_statusmessage", None)

async def execute(self, query, params=None, **kwargs):
"""Will be replaced by instrumentation."""
logger.debug(f"[MOCK_ASYNC_CURSOR] execute() called: {query[:100]}")
return self

async def executemany(self, query, params_seq, **kwargs):
"""Will be replaced by instrumentation."""
logger.debug(f"[MOCK_ASYNC_CURSOR] executemany() called: {query[:100]}")
return self

async def fetchone(self):
if self._mock_index < len(self._mock_rows):
row = self._mock_rows[self._mock_index]
self._mock_index += 1
return tuple(row) if isinstance(row, list) else row
return None

async def fetchmany(self, size=None):
if size is None:
size = self.arraysize
rows = []
for _ in range(size):
row = await self.fetchone()
if row is None:
break
rows.append(row)
return rows

async def fetchall(self):
rows = self._mock_rows[self._mock_index :]
self._mock_index = len(self._mock_rows)
return [tuple(row) if isinstance(row, list) else row for row in rows]

def __aiter__(self):
return self

async def __anext__(self):
if self._mock_index < len(self._mock_rows):
row = self._mock_rows[self._mock_index]
self._mock_index += 1
return tuple(row) if isinstance(row, list) else row
raise StopAsyncIteration

async def close(self):
pass

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
return False


class MockAsyncTransaction:
"""Mock async transaction context manager for REPLAY mode."""

def __init__(self, connection: MockAsyncConnection):
self._conn = connection

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
await self._conn.rollback()
else:
await self._conn.commit()
return False


class MockAsyncPipeline:
"""Mock async Pipeline for REPLAY mode."""

def __init__(self, connection: MockAsyncConnection):
self._conn = connection

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
return False

async def sync(self):
"""No-op async sync for mock pipeline."""
pass
92 changes: 84 additions & 8 deletions drift/instrumentation/psycopg2/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ class MockCursor:
"""Mock cursor for when we can't create a real cursor from base class.

This is a fallback when the connection is completely mocked.
Provides all attributes that psycopg2 cursors have to ensure compatibility
with frameworks like Django that access cursor properties.
"""

def __init__(self, connection):
Expand All @@ -159,6 +161,14 @@ def __init__(self, connection):
self.arraysize = 1
self._mock_rows = []
self._mock_index = 0
# psycopg2 cursor attributes that Django/Flask may access
self.query = None # Last executed query string
self.statusmessage = None # Status message from last command
self.lastrowid = None # OID of last inserted row (if applicable)
self.closed = False
self.name = None # Server-side cursor name (None for client-side)
self.scrollable = None
self.withhold = False
logger.debug("[MOCK_CURSOR] Created fallback mock cursor")

def execute(self, query: Any, vars: Any = None) -> None:
Expand Down Expand Up @@ -210,6 +220,58 @@ def __init__(self, connection: Any, instrumentation: Psycopg2Instrumentation, sd

def cursor(self, name: str | None = None, cursor_factory: Any = None, *args: Any, **kwargs: Any) -> Any:
"""Intercept cursor creation to wrap user-provided cursor_factory."""
# In REPLAY mode, use MockCursor to have full control over cursor state
# This is necessary because psycopg2's cursor.description is a read-only
# C-level property that cannot be set after the cursor is created
if self._sdk.mode == TuskDriftMode.REPLAY:
cursor = MockCursor(self)
instrumentation = self._instrumentation
sdk = self._sdk

# Detect if user wants dict-style cursors (RealDictCursor, DictCursor)
is_dict_cursor = False
effective_cursor_factory = cursor_factory if cursor_factory is not None else self._default_cursor_factory
if effective_cursor_factory is not None:
try:
import psycopg2.extras

if effective_cursor_factory in (
psycopg2.extras.RealDictCursor,
psycopg2.extras.DictCursor,
) or (
isinstance(effective_cursor_factory, type)
and issubclass(
effective_cursor_factory, (psycopg2.extras.RealDictCursor, psycopg2.extras.DictCursor)
)
):
is_dict_cursor = True
except (ImportError, AttributeError):
pass

# Store cursor type info on the cursor for _mock_execute_with_data
cursor._is_dict_cursor = is_dict_cursor # type: ignore[attr-defined]

def mock_execute(query, vars=None):
def noop_execute(q, v):
return None

return instrumentation._traced_execute(cursor, noop_execute, sdk, query, vars)

def mock_executemany(query, vars_list):
def noop_executemany(q, vl):
return None

return instrumentation._traced_executemany(cursor, noop_executemany, sdk, query, vars_list)

cursor.execute = mock_execute # type: ignore[method-assign]
cursor.executemany = mock_executemany # type: ignore[method-assign]

logger.debug(
f"[INSTRUMENTED_CONNECTION] Created MockCursor for REPLAY mode (is_dict_cursor={is_dict_cursor})"
)
return cursor

# In RECORD mode, use real cursor with instrumentation
# Use cursor_factory from cursor() call, or fall back to connection's default
base_factory = cursor_factory if cursor_factory is not None else self._default_cursor_factory
# Create instrumented cursor factory (wrapping the base factory)
Expand Down Expand Up @@ -493,6 +555,9 @@ def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: A
logger.warning("[PSYCOPG2_REPLAY] No mock found for pre-app-start query, returning empty result")
empty_mock = {"rowcount": 0, "rows": [], "description": None}
self._mock_execute_with_data(cursor, empty_mock)
# Set cursor.query to the executed query (psycopg2 cursor attribute)
if hasattr(cursor, "query"):
cursor.query = query_str.encode("utf-8") if isinstance(query_str, str) else query_str
span_info.span.end()
return None

Expand All @@ -503,6 +568,9 @@ def _replay_execute(self, cursor: Any, sdk: TuskDrift, query_str: str, params: A
)

self._mock_execute_with_data(cursor, mock_result)
# Set cursor.query to the executed query (psycopg2 cursor attribute)
if hasattr(cursor, "query"):
cursor.query = query_str.encode("utf-8") if isinstance(query_str, str) else query_str
span_info.span.end()
return None

Expand Down Expand Up @@ -621,6 +689,9 @@ def _replay_executemany(self, cursor: Any, sdk: TuskDrift, query_str: str, param
)
empty_mock = {"rowcount": 0, "rows": [], "description": None}
self._mock_execute_with_data(cursor, empty_mock)
# Set cursor.query to the executed query (psycopg2 cursor attribute)
if hasattr(cursor, "query"):
cursor.query = query_str.encode("utf-8") if isinstance(query_str, str) else query_str
span_info.span.end()
return None

Expand All @@ -631,6 +702,9 @@ def _replay_executemany(self, cursor: Any, sdk: TuskDrift, query_str: str, param
)

self._mock_execute_with_data(cursor, mock_result)
# Set cursor.query to the executed query (psycopg2 cursor attribute)
if hasattr(cursor, "query"):
cursor.query = query_str.encode("utf-8") if isinstance(query_str, str) else query_str
span_info.span.end()
return None

Expand Down Expand Up @@ -777,15 +851,17 @@ def _mock_execute_with_data(self, cursor: Any, mock_data: dict[str, Any]) -> Non
# Deserialize datetime strings back to datetime objects for consistent Flask/Django serialization
mock_rows = [deserialize_db_value(row) for row in mock_rows]

# Check if this is a dict-cursor (like RealDictCursor) by checking if cursor class
# inherits from a dict-returning cursor type
is_dict_cursor = False
try:
import psycopg2.extras
# Check if this is a dict-cursor (like RealDictCursor)
# First check if cursor has _is_dict_cursor attribute (set by InstrumentedConnection.cursor())
# Then fall back to isinstance check for real dict cursors
is_dict_cursor = getattr(cursor, "_is_dict_cursor", False)
if not is_dict_cursor:
try:
import psycopg2.extras

is_dict_cursor = isinstance(cursor, (psycopg2.extras.RealDictCursor, psycopg2.extras.DictCursor))
except (ImportError, AttributeError):
pass
is_dict_cursor = isinstance(cursor, (psycopg2.extras.RealDictCursor, psycopg2.extras.DictCursor))
except (ImportError, AttributeError):
pass

# If it's a dict cursor and we have description, convert rows to dicts
if is_dict_cursor and description_data:
Expand Down
Loading