Skip to content
This repository was archived by the owner on Mar 31, 2026. It is now read-only.

Commit 133ead1

Browse files
committed
feat: Propagate client_context in Session and update tests
- Update Session.transaction to accept client_context. - Update unit tests to support client_context propagation. - Update mock objects in tests to match the expected attribute hierarchy. - Clean up nested imports in test files.
1 parent 55b213b commit 133ead1

File tree

11 files changed

+56
-30
lines changed

11 files changed

+56
-30
lines changed

google/cloud/spanner_v1/session.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -472,9 +472,14 @@ def batch(self):
472472

473473
return Batch(self)
474474

475-
def transaction(self) -> Transaction:
475+
def transaction(self, client_context=None) -> Transaction:
476476
"""Create a transaction to perform a set of reads with shared staleness.
477477
478+
:type client_context: :class:`~google.cloud.spanner_v1.types.ClientContext`
479+
or :class:`dict`
480+
:param client_context: (Optional) Client context to use for all requests made
481+
by this transaction.
482+
478483
:rtype: :class:`~google.cloud.spanner_v1.transaction.Transaction`
479484
:returns: a transaction bound to this session
480485
@@ -483,7 +488,7 @@ def transaction(self) -> Transaction:
483488
if self._session_id is None:
484489
raise ValueError("Session has not been created.")
485490

486-
return Transaction(self)
491+
return Transaction(self, client_context=client_context)
487492

488493
def run_in_transaction(self, func, *args, **kw):
489494
"""Perform a unit of work in a transaction, retrying on abort.
@@ -512,6 +517,8 @@ def run_in_transaction(self, func, *args, **kw):
512517
the DDL option `allow_txn_exclusion` being false or unset.
513518
"isolation_level" sets the isolation level for the transaction.
514519
"read_lock_mode" sets the read lock mode for the transaction.
520+
"client_context" (Optional) Client context to use for all requests made
521+
by this transaction.
515522
516523
:rtype: Any
517524
:returns: The return value of ``func``.
@@ -529,6 +536,7 @@ def run_in_transaction(self, func, *args, **kw):
529536
)
530537
isolation_level = kw.pop("isolation_level", None)
531538
read_lock_mode = kw.pop("read_lock_mode", None)
539+
client_context = kw.pop("client_context", None)
532540

533541
database = self._database
534542
log_commit_stats = database.log_commit_stats
@@ -554,7 +562,7 @@ def run_in_transaction(self, func, *args, **kw):
554562
previous_transaction_id: Optional[bytes] = None
555563

556564
while True:
557-
txn = self.transaction()
565+
txn = self.transaction(client_context=client_context)
558566
txn.transaction_tag = transaction_tag
559567
txn.exclude_txn_from_change_streams = exclude_txn_from_change_streams
560568
txn.isolation_level = isolation_level

tests/unit/spanner_dbapi/test_connection.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,7 @@ class _Client(object):
872872
def __init__(self, project="project_id"):
873873
self.project = project
874874
self.project_name = "projects/" + self.project
875+
self._client_context = None
875876

876877
def instance(self, instance_id="instance_id"):
877878
return _Instance(name=instance_id, client=self)

tests/unit/test_backup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ class _Client(object):
679679
def __init__(self, project=TestBackup.PROJECT_ID):
680680
self.project = project
681681
self.project_name = "projects/" + self.project
682+
self._client_context = None
682683

683684

684685
class _Instance(object):

tests/unit/test_batch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -806,6 +806,9 @@ class _Database(object):
806806

807807
def __init__(self, enable_end_to_end_tracing=False):
808808
self.name = "testing"
809+
self._instance = mock.Mock()
810+
self._instance._client = mock.Mock()
811+
self._instance._client._client_context = None
809812
self._route_to_leader_enabled = True
810813
if enable_end_to_end_tracing:
811814
self.observability_options = dict(enable_end_to_end_tracing=True)

tests/unit/test_database.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
RequestOptions,
3131
DirectedReadOptions,
3232
DefaultTransactionOptions,
33+
ExecuteSqlRequest,
3334
)
3435
from google.cloud.spanner_v1._helpers import (
3536
AtomicCounter,
@@ -2599,6 +2600,7 @@ def test__get_snapshot_new_wo_staleness(self):
25992600
exact_staleness=None,
26002601
multi_use=True,
26012602
transaction_id=None,
2603+
client_context=None,
26022604
)
26032605
snapshot.begin.assert_called_once_with()
26042606

@@ -2614,6 +2616,7 @@ def test__get_snapshot_w_read_timestamp(self):
26142616
exact_staleness=None,
26152617
multi_use=True,
26162618
transaction_id=None,
2619+
client_context=None,
26172620
)
26182621
snapshot.begin.assert_called_once_with()
26192622

@@ -2629,6 +2632,7 @@ def test__get_snapshot_w_exact_staleness(self):
26292632
exact_staleness=duration,
26302633
multi_use=True,
26312634
transaction_id=None,
2635+
client_context=None,
26322636
)
26332637
snapshot.begin.assert_called_once_with()
26342638

@@ -3540,6 +3544,7 @@ def __init__(
35403544
self.directed_read_options = directed_read_options
35413545
self.default_transaction_options = default_transaction_options
35423546
self.observability_options = observability_options
3547+
self._client_context = None
35433548
self._nth_client_id = _Client.NTH_CLIENT.increment()
35443549
self._nth_request = AtomicCounter()
35453550

@@ -3589,6 +3594,13 @@ class _Database(object):
35893594
def __init__(self, name, instance=None):
35903595
self.name = name
35913596
self.database_id = name.rsplit("/", 1)[1]
3597+
if instance is None:
3598+
instance = mock.Mock()
3599+
instance._client = mock.Mock()
3600+
instance._client._client_context = None
3601+
instance._client._query_options = ExecuteSqlRequest.QueryOptions(
3602+
optimizer_version="1"
3603+
)
35923604
self._instance = instance
35933605
from logging import Logger
35943606

tests/unit/test_instance.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,7 @@ def __init__(self, project, timeout_seconds=None):
10231023
self.route_to_leader_enabled = True
10241024
self.directed_read_options = None
10251025
self.default_transaction_options = DefaultTransactionOptions()
1026+
self._client_context = None
10261027

10271028
def copy(self):
10281029
from copy import deepcopy

tests/unit/test_pool.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,21 @@
1919
from datetime import datetime, timedelta
2020

2121
import mock
22+
from google.cloud.spanner_v1 import pool as MUT
2223
from google.cloud.spanner_v1 import _opentelemetry_tracing
24+
from google.cloud.spanner_v1 import ExecuteSqlRequest
25+
from google.cloud.spanner_v1 import BatchCreateSessionsResponse
26+
from google.cloud.spanner_v1 import Session
27+
from google.cloud.spanner_v1 import SpannerClient
28+
from google.cloud.spanner_v1.database import Database
29+
from google.cloud.spanner_v1.pool import AbstractSessionPool
30+
from google.cloud.spanner_v1.pool import SessionCheckout
31+
from google.cloud.spanner_v1.pool import FixedSizePool
32+
from google.cloud.spanner_v1.pool import BurstyPool
33+
from google.cloud.spanner_v1.pool import PingingPool
34+
from google.cloud.spanner_v1.transaction import Transaction
35+
from google.cloud.exceptions import NotFound
36+
from google.cloud._testing import _Monkey
2337
from google.cloud.spanner_v1._helpers import (
2438
_metadata_with_request_id,
2539
_metadata_with_request_id_and_req_id,
@@ -40,20 +54,17 @@
4054

4155

4256
def _make_database(name="name"):
43-
from google.cloud.spanner_v1.database import Database
4457

4558
return mock.create_autospec(Database, instance=True)
4659

4760

4861
def _make_session():
49-
from google.cloud.spanner_v1.database import Session
5062

5163
return mock.create_autospec(Session, instance=True)
5264

5365

5466
class TestAbstractSessionPool(unittest.TestCase):
5567
def _getTargetClass(self):
56-
from google.cloud.spanner_v1.pool import AbstractSessionPool
5768

5869
return AbstractSessionPool
5970

@@ -129,7 +140,6 @@ def test__new_session_w_database_role(self):
129140
self.assertEqual(new_session.database_role, database_role)
130141

131142
def test_session_wo_kwargs(self):
132-
from google.cloud.spanner_v1.pool import SessionCheckout
133143

134144
pool = self._make_one()
135145
checkout = pool.session()
@@ -139,7 +149,6 @@ def test_session_wo_kwargs(self):
139149
self.assertEqual(checkout._kwargs, {})
140150

141151
def test_session_w_kwargs(self):
142-
from google.cloud.spanner_v1.pool import SessionCheckout
143152

144153
pool = self._make_one()
145154
checkout = pool.session(foo="bar")
@@ -164,7 +173,6 @@ class TestFixedSizePool(OpenTelemetryBase):
164173
enrich_with_otel_scope(BASE_ATTRIBUTES)
165174

166175
def _getTargetClass(self):
167-
from google.cloud.spanner_v1.pool import FixedSizePool
168176

169177
return FixedSizePool
170178

@@ -559,7 +567,6 @@ class TestBurstyPool(OpenTelemetryBase):
559567
enrich_with_otel_scope(BASE_ATTRIBUTES)
560568

561569
def _getTargetClass(self):
562-
from google.cloud.spanner_v1.pool import BurstyPool
563570

564571
return BurstyPool
565572

@@ -850,7 +857,6 @@ class TestPingingPool(OpenTelemetryBase):
850857
enrich_with_otel_scope(BASE_ATTRIBUTES)
851858

852859
def _getTargetClass(self):
853-
from google.cloud.spanner_v1.pool import PingingPool
854860

855861
return PingingPool
856862

@@ -946,8 +952,6 @@ def test_get_hit_no_ping(self, mock_region):
946952
)
947953
def test_get_hit_w_ping(self, mock_region):
948954
import datetime
949-
from google.cloud._testing import _Monkey
950-
from google.cloud.spanner_v1 import pool as MUT
951955

952956
pool = self._make_one(size=4)
953957
database = _Database("name")
@@ -974,8 +978,6 @@ def test_get_hit_w_ping(self, mock_region):
974978
)
975979
def test_get_hit_w_ping_expired(self, mock_region):
976980
import datetime
977-
from google.cloud._testing import _Monkey
978-
from google.cloud.spanner_v1 import pool as MUT
979981

980982
pool = self._make_one(size=4)
981983
database = _Database("name")
@@ -1097,8 +1099,6 @@ def test_spans_put_full(self, mock_region):
10971099
)
10981100
def test_put_non_full(self, mock_region):
10991101
import datetime
1100-
from google.cloud._testing import _Monkey
1101-
from google.cloud.spanner_v1 import pool as MUT
11021102

11031103
pool = self._make_one(size=1)
11041104
session_queue = pool._sessions = _Queue()
@@ -1172,8 +1172,6 @@ def test_ping_oldest_fresh(self, mock_region):
11721172
)
11731173
def test_ping_oldest_stale_but_exists(self, mock_region):
11741174
import datetime
1175-
from google.cloud._testing import _Monkey
1176-
from google.cloud.spanner_v1 import pool as MUT
11771175

11781176
pool = self._make_one(size=1)
11791177
database = _Database("name")
@@ -1193,8 +1191,6 @@ def test_ping_oldest_stale_but_exists(self, mock_region):
11931191
)
11941192
def test_ping_oldest_stale_and_not_exists(self, mock_region):
11951193
import datetime
1196-
from google.cloud._testing import _Monkey
1197-
from google.cloud.spanner_v1 import pool as MUT
11981194

11991195
pool = self._make_one(size=1)
12001196
database = _Database("name")
@@ -1257,7 +1253,6 @@ def test_spans_get_and_leave_empty_pool(self, mock_region):
12571253

12581254
class TestSessionCheckout(unittest.TestCase):
12591255
def _getTargetClass(self):
1260-
from google.cloud.spanner_v1.pool import SessionCheckout
12611256

12621257
return SessionCheckout
12631258

@@ -1314,7 +1309,6 @@ def test_context_manager_w_kwargs(self):
13141309

13151310

13161311
def _make_transaction(*args, **kw):
1317-
from google.cloud.spanner_v1.transaction import Transaction
13181312

13191313
txn = mock.create_autospec(Transaction)(*args, **kw)
13201314
txn.committed = None
@@ -1352,14 +1346,12 @@ def exists(self):
13521346
return self._exists
13531347

13541348
def ping(self):
1355-
from google.cloud.exceptions import NotFound
13561349

13571350
self._pinged = True
13581351
if not self._exists:
13591352
raise NotFound("expired session")
13601353

13611354
def delete(self):
1362-
from google.cloud.exceptions import NotFound
13631355

13641356
self._deleted = True
13651357
if not self._exists:
@@ -1391,9 +1383,6 @@ def mock_batch_create_sessions(
13911383
metadata=[],
13921384
labels={},
13931385
):
1394-
from google.cloud.spanner_v1 import BatchCreateSessionsResponse
1395-
from google.cloud.spanner_v1 import Session
1396-
13971386
database_role = request.session_template.creator_role if request else None
13981387
if request.session_count < 2:
13991388
response = BatchCreateSessionsResponse(
@@ -1408,10 +1397,15 @@ def mock_batch_create_sessions(
14081397
)
14091398
return response
14101399

1411-
from google.cloud.spanner_v1 import SpannerClient
1412-
14131400
self.spanner_api = mock.create_autospec(SpannerClient, instance=True)
14141401
self.spanner_api.batch_create_sessions.side_effect = mock_batch_create_sessions
1402+
self._instance = mock.Mock()
1403+
self._instance._client = mock.Mock()
1404+
self._instance._client._client_context = None
1405+
self._instance._client.spanner_api = self.spanner_api
1406+
self._instance._client._query_options = ExecuteSqlRequest.QueryOptions(
1407+
optimizer_version="1"
1408+
)
14151409

14161410
@property
14171411
def database_role(self):

tests/unit/test_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,9 @@ def _make_database(
194194
database.database_role = database_role
195195
database._route_to_leader_enabled = True
196196
database.default_transaction_options = default_transaction_options
197+
database._instance = mock.Mock()
198+
database._instance._client = mock.Mock()
199+
database._instance._client._client_context = None
197200
inject_into_mock_database(database)
198201

199202
return database

tests/unit/test_snapshot.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2182,6 +2182,7 @@ def __init__(self):
21822182
from google.cloud.spanner_v1 import ExecuteSqlRequest
21832183

21842184
self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1")
2185+
self._client_context = None
21852186
self._nth_client_id = _Client.NTH_CLIENT.increment()
21862187
self._nth_request = AtomicCounter()
21872188

tests/unit/test_spanner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,7 @@ def __init__(self):
12801280
self._query_options = ExecuteSqlRequest.QueryOptions(optimizer_version="1")
12811281
self.directed_read_options = None
12821282
self.default_transaction_options = DefaultTransactionOptions()
1283+
self._client_context = None
12831284
self._nth_client_id = _Client.NTH_CLIENT.increment()
12841285
self._nth_request = AtomicCounter()
12851286

0 commit comments

Comments
 (0)