From 3c7020536b6b68c5c37b76ed90d5d737204b96eb Mon Sep 17 00:00:00 2001 From: Hung Nguyen Date: Sun, 26 Oct 2025 14:26:33 +0700 Subject: [PATCH 1/2] fix: Update DynamicPickleType to support MySQL dialect The `process_bind_param` and `process_result_value` methods in the `DynamicPickleType` class have been modified to handle MySQL dialect in addition to Spanner. This change ensures that pickled values are correctly processed for both database types. --- .../adk/sessions/database_session_service.py | 4 +- .../sessions/test_dynamic_pickle_type.py | 213 ++++++++++++++++++ 2 files changed, 215 insertions(+), 2 deletions(-) create mode 100644 tests/unittests/sessions/test_dynamic_pickle_type.py diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 215a5d8278..948fcbf8cf 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -125,14 +125,14 @@ def load_dialect_impl(self, dialect): def process_bind_param(self, value, dialect): """Ensures the pickled value is a bytes object before passing it to the database dialect.""" if value is not None: - if dialect.name == "spanner+spanner": + if dialect.name in ("spanner+spanner", "mysql"): return pickle.dumps(value) return value def process_result_value(self, value, dialect): """Ensures the raw bytes from the database are unpickled back into a Python object.""" if value is not None: - if dialect.name == "spanner+spanner": + if dialect.name in ("spanner+spanner", "mysql"): return pickle.loads(value) return value diff --git a/tests/unittests/sessions/test_dynamic_pickle_type.py b/tests/unittests/sessions/test_dynamic_pickle_type.py new file mode 100644 index 0000000000..64c28ede7d --- /dev/null +++ b/tests/unittests/sessions/test_dynamic_pickle_type.py @@ -0,0 +1,213 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import pickle +from unittest import mock + +from google.adk.sessions.database_session_service import DynamicPickleType +import pytest +from sqlalchemy import create_engine +from sqlalchemy.dialects import mysql + + +@pytest.fixture +def pickle_type(): + """Fixture for DynamicPickleType instance.""" + return DynamicPickleType() + + +def test_load_dialect_impl_mysql(pickle_type): + """Test that MySQL dialect uses LONGBLOB.""" + # Mock the MySQL dialect + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + # Mock the return value of type_descriptor + mock_longblob_type = mock.Mock() + mock_dialect.type_descriptor.return_value = mock_longblob_type + + impl = pickle_type.load_dialect_impl(mock_dialect) + + # Verify type_descriptor was called once + assert mock_dialect.type_descriptor.call_count == 1 + # Verify it was called with mysql.LONGBLOB type + call_args = mock_dialect.type_descriptor.call_args[0][0] + assert call_args.__name__ == "LONGBLOB" + assert call_args == mysql.LONGBLOB + # Verify the return value is what we expect + assert impl == mock_longblob_type + + +def test_load_dialect_impl_spanner(pickle_type): + """Test that Spanner dialect uses SpannerPickleType.""" + # Mock the spanner dialect + mock_dialect = mock.Mock() + mock_dialect.name = "spanner+spanner" + + with mock.patch( + "google.cloud.sqlalchemy_spanner.sqlalchemy_spanner.SpannerPickleType" + ) as mock_spanner_type: + pickle_type.load_dialect_impl(mock_dialect) + mock_dialect.type_descriptor.assert_called_once_with(mock_spanner_type) + + +def test_load_dialect_impl_default(pickle_type): + """Test that other dialects use default PickleType.""" + engine = create_engine("sqlite:///:memory:") + dialect = engine.dialect + impl = pickle_type.load_dialect_impl(dialect) + # Should return the default impl (PickleType) + assert impl == pickle_type.impl + + +def test_process_bind_param_mysql(pickle_type): + """Test that MySQL dialect pickles the value.""" + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + test_data = {"key": "value", "nested": [1, 2, 3]} + result = pickle_type.process_bind_param(test_data, mock_dialect) + + # Should be pickled bytes + assert isinstance(result, bytes) + # Should be able to unpickle back to original + assert pickle.loads(result) == test_data + + +def test_process_bind_param_spanner(pickle_type): + """Test that Spanner dialect pickles the value.""" + mock_dialect = mock.Mock() + mock_dialect.name = "spanner+spanner" + + test_data = {"key": "value", "nested": [1, 2, 3]} + result = pickle_type.process_bind_param(test_data, mock_dialect) + + # Should be pickled bytes + assert isinstance(result, bytes) + # Should be able to unpickle back to original + assert pickle.loads(result) == test_data + + +def test_process_bind_param_default(pickle_type): + """Test that other dialects return value as-is.""" + mock_dialect = mock.Mock() + mock_dialect.name = "sqlite" + + test_data = {"key": "value"} + result = pickle_type.process_bind_param(test_data, mock_dialect) + + # Should return value unchanged (SQLAlchemy's PickleType handles it) + assert result == test_data + + +def test_process_bind_param_none(pickle_type): + """Test that None values are handled correctly.""" + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + result = pickle_type.process_bind_param(None, mock_dialect) + assert result is None + + +def test_process_result_value_mysql(pickle_type): + """Test that MySQL dialect unpickles the value.""" + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + test_data = {"key": "value", "nested": [1, 2, 3]} + pickled_data = pickle.dumps(test_data) + + result = pickle_type.process_result_value(pickled_data, mock_dialect) + + # Should be unpickled back to original + assert result == test_data + + +def test_process_result_value_spanner(pickle_type): + """Test that Spanner dialect unpickles the value.""" + mock_dialect = mock.Mock() + mock_dialect.name = "spanner+spanner" + + test_data = {"key": "value", "nested": [1, 2, 3]} + pickled_data = pickle.dumps(test_data) + + result = pickle_type.process_result_value(pickled_data, mock_dialect) + + # Should be unpickled back to original + assert result == test_data + + +def test_process_result_value_default(pickle_type): + """Test that other dialects return value as-is.""" + mock_dialect = mock.Mock() + mock_dialect.name = "sqlite" + + test_data = {"key": "value"} + result = pickle_type.process_result_value(test_data, mock_dialect) + + # Should return value unchanged (SQLAlchemy's PickleType handles it) + assert result == test_data + + +def test_process_result_value_none(pickle_type): + """Test that None values are handled correctly.""" + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + result = pickle_type.process_result_value(None, mock_dialect) + assert result is None + + +def test_roundtrip_mysql(pickle_type): + """Test full roundtrip for MySQL: bind -> result.""" + mock_dialect = mock.Mock() + mock_dialect.name = "mysql" + + original_data = { + "string": "test", + "number": 42, + "list": [1, 2, 3], + "nested": {"a": 1, "b": 2}, + } + + # Simulate bind (Python -> DB) + bound_value = pickle_type.process_bind_param(original_data, mock_dialect) + assert isinstance(bound_value, bytes) + + # Simulate result (DB -> Python) + result_value = pickle_type.process_result_value(bound_value, mock_dialect) + assert result_value == original_data + + +def test_roundtrip_spanner(pickle_type): + """Test full roundtrip for Spanner: bind -> result.""" + mock_dialect = mock.Mock() + mock_dialect.name = "spanner+spanner" + + original_data = { + "string": "test", + "number": 42, + "list": [1, 2, 3], + "nested": {"a": 1, "b": 2}, + } + + # Simulate bind (Python -> DB) + bound_value = pickle_type.process_bind_param(original_data, mock_dialect) + assert isinstance(bound_value, bytes) + + # Simulate result (DB -> Python) + result_value = pickle_type.process_result_value(bound_value, mock_dialect) + assert result_value == original_data From 5ed3749220a7f041dd69d2450dfb29091de9a345 Mon Sep 17 00:00:00 2001 From: Hung Nguyen Date: Sun, 26 Oct 2025 14:52:05 +0700 Subject: [PATCH 2/2] refactor: Consolidate tests for DynamicPickleType to support multiple dialects Updated unit tests for `DynamicPickleType` to use parameterized testing for MySQL and Spanner dialects. This change simplifies the test structure and ensures consistent behavior across both dialects for binding and unbinding operations. --- .../sessions/test_dynamic_pickle_type.py | 96 +++++++------------ 1 file changed, 32 insertions(+), 64 deletions(-) diff --git a/tests/unittests/sessions/test_dynamic_pickle_type.py b/tests/unittests/sessions/test_dynamic_pickle_type.py index 64c28ede7d..e4eb084f88 100644 --- a/tests/unittests/sessions/test_dynamic_pickle_type.py +++ b/tests/unittests/sessions/test_dynamic_pickle_type.py @@ -41,12 +41,8 @@ def test_load_dialect_impl_mysql(pickle_type): impl = pickle_type.load_dialect_impl(mock_dialect) - # Verify type_descriptor was called once - assert mock_dialect.type_descriptor.call_count == 1 - # Verify it was called with mysql.LONGBLOB type - call_args = mock_dialect.type_descriptor.call_args[0][0] - assert call_args.__name__ == "LONGBLOB" - assert call_args == mysql.LONGBLOB + # Verify type_descriptor was called once with mysql.LONGBLOB + mock_dialect.type_descriptor.assert_called_once_with(mysql.LONGBLOB) # Verify the return value is what we expect assert impl == mock_longblob_type @@ -73,24 +69,17 @@ def test_load_dialect_impl_default(pickle_type): assert impl == pickle_type.impl -def test_process_bind_param_mysql(pickle_type): - """Test that MySQL dialect pickles the value.""" +@pytest.mark.parametrize( + "dialect_name", + [ + pytest.param("mysql", id="mysql"), + pytest.param("spanner+spanner", id="spanner"), + ], +) +def test_process_bind_param_pickle_dialects(pickle_type, dialect_name): + """Test that MySQL and Spanner dialects pickle the value.""" mock_dialect = mock.Mock() - mock_dialect.name = "mysql" - - test_data = {"key": "value", "nested": [1, 2, 3]} - result = pickle_type.process_bind_param(test_data, mock_dialect) - - # Should be pickled bytes - assert isinstance(result, bytes) - # Should be able to unpickle back to original - assert pickle.loads(result) == test_data - - -def test_process_bind_param_spanner(pickle_type): - """Test that Spanner dialect pickles the value.""" - mock_dialect = mock.Mock() - mock_dialect.name = "spanner+spanner" + mock_dialect.name = dialect_name test_data = {"key": "value", "nested": [1, 2, 3]} result = pickle_type.process_bind_param(test_data, mock_dialect) @@ -122,24 +111,17 @@ def test_process_bind_param_none(pickle_type): assert result is None -def test_process_result_value_mysql(pickle_type): - """Test that MySQL dialect unpickles the value.""" - mock_dialect = mock.Mock() - mock_dialect.name = "mysql" - - test_data = {"key": "value", "nested": [1, 2, 3]} - pickled_data = pickle.dumps(test_data) - - result = pickle_type.process_result_value(pickled_data, mock_dialect) - - # Should be unpickled back to original - assert result == test_data - - -def test_process_result_value_spanner(pickle_type): - """Test that Spanner dialect unpickles the value.""" +@pytest.mark.parametrize( + "dialect_name", + [ + pytest.param("mysql", id="mysql"), + pytest.param("spanner+spanner", id="spanner"), + ], +) +def test_process_result_value_pickle_dialects(pickle_type, dialect_name): + """Test that MySQL and Spanner dialects unpickle the value.""" mock_dialect = mock.Mock() - mock_dialect.name = "spanner+spanner" + mock_dialect.name = dialect_name test_data = {"key": "value", "nested": [1, 2, 3]} pickled_data = pickle.dumps(test_data) @@ -171,31 +153,17 @@ def test_process_result_value_none(pickle_type): assert result is None -def test_roundtrip_mysql(pickle_type): - """Test full roundtrip for MySQL: bind -> result.""" +@pytest.mark.parametrize( + "dialect_name", + [ + pytest.param("mysql", id="mysql"), + pytest.param("spanner+spanner", id="spanner"), + ], +) +def test_roundtrip_pickle_dialects(pickle_type, dialect_name): + """Test full roundtrip for MySQL and Spanner: bind -> result.""" mock_dialect = mock.Mock() - mock_dialect.name = "mysql" - - original_data = { - "string": "test", - "number": 42, - "list": [1, 2, 3], - "nested": {"a": 1, "b": 2}, - } - - # Simulate bind (Python -> DB) - bound_value = pickle_type.process_bind_param(original_data, mock_dialect) - assert isinstance(bound_value, bytes) - - # Simulate result (DB -> Python) - result_value = pickle_type.process_result_value(bound_value, mock_dialect) - assert result_value == original_data - - -def test_roundtrip_spanner(pickle_type): - """Test full roundtrip for Spanner: bind -> result.""" - mock_dialect = mock.Mock() - mock_dialect.name = "spanner+spanner" + mock_dialect.name = dialect_name original_data = { "string": "test",