Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/datajoint/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,45 @@ def close(self) -> None:
"""Close the database connection."""
self._conn.close()

def __enter__(self) -> "Connection":
"""
Enter context manager.

Returns
-------
Connection
This connection object.

Examples
--------
>>> with dj.Connection(host, user, password) as conn:
... schema = dj.schema('my_schema', connection=conn)
... # perform operations
... # connection automatically closed
"""
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
"""
Exit context manager and close connection.

Parameters
----------
exc_type : type or None
Exception type if an exception was raised.
exc_val : Exception or None
Exception instance if an exception was raised.
exc_tb : traceback or None
Traceback if an exception was raised.

Returns
-------
bool
False to propagate exceptions.
"""
self.close()
return False

def register(self, schema) -> None:
"""
Register a schema with this connection.
Expand Down
30 changes: 30 additions & 0 deletions tests/integration/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,36 @@ def test_dj_connection_class(connection_test):
assert connection_test.is_connected


def test_connection_context_manager(db_creds_test):
"""
Connection should support context manager protocol for automatic cleanup.
"""
# Test basic context manager usage
with dj.Connection(**db_creds_test) as conn:
assert conn.is_connected
# Verify we can use the connection
result = conn.query("SELECT 1").fetchone()
assert result[0] == 1

# Connection should be closed after exiting context
assert not conn.is_connected


def test_connection_context_manager_exception(db_creds_test):
"""
Connection should close even when exception is raised inside context.
"""
conn = None
with pytest.raises(ValueError):
with dj.Connection(**db_creds_test) as conn:
assert conn.is_connected
raise ValueError("Test exception")

# Connection should still be closed after exception
assert conn is not None
assert not conn.is_connected


def test_persistent_dj_conn(db_creds_root):
"""
conn() method should provide persistent connection across calls.
Expand Down
Loading