diff --git a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py index a8ca59bd..e82e6d7b 100644 --- a/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py +++ b/google/cloud/sqlalchemy_spanner/sqlalchemy_spanner.py @@ -820,6 +820,40 @@ def visit_JSON(self, type_, **kw): return "JSON" +def _make_json_serializer(json_serializer): + """Build a ``_json_serializer`` callable from a user-supplied function. + + SQLAlchemy's ``create_engine(json_serializer=fn)`` convention expects a + callable that replaces ``json.dumps`` entirely — it takes a Python object + and returns a JSON string. The Spanner pipeline is different: it wraps + values in a :class:`JsonObject` first, and serialization happens later in + ``_helpers._make_param_value_pb`` via ``obj.serialize()``. + + To bridge this gap we use a **serialize-then-wrap** strategy: + + 1. Call the user's ``json_serializer(value)`` to produce a JSON string + with all custom types (``datetime``, etc.) already handled. + 2. Feed that string into ``JsonObject.from_str()`` which parses it back + into a ``JsonObject`` containing only native Python types. + 3. When ``_helpers.py`` later calls ``obj.serialize()``, the standard + ``json.dumps`` works because no custom types remain. + + This avoids subclassing or monkey-patching ``JsonObject`` and requires + no changes to the core ``google-cloud-spanner`` library. + + If *json_serializer* is already a ``JsonObject`` subclass (e.g. the + default class-level value), it is returned directly. + """ + if isinstance(json_serializer, type) and issubclass(json_serializer, JsonObject): + return json_serializer + + def _factory(value): + json_str = json_serializer(value) + return JsonObject.from_str(json_str) + + return _factory + + class SpannerDialect(DefaultDialect): """Cloud Spanner dialect. @@ -869,6 +903,13 @@ class SpannerDialect(DefaultDialect): _json_serializer = JsonObject _json_deserializer = JsonObject + def __init__(self, json_serializer=None, json_deserializer=None, **kwargs): + super().__init__(**kwargs) + if json_serializer is not None: + self._json_serializer = _make_json_serializer(json_serializer) + if json_deserializer is not None: + self._json_deserializer = json_deserializer + @classmethod def dbapi(cls): """A pointer to the Cloud Spanner DB API package. diff --git a/test/unit/test_json_serializer.py b/test/unit/test_json_serializer.py new file mode 100644 index 00000000..c9e47196 --- /dev/null +++ b/test/unit/test_json_serializer.py @@ -0,0 +1,238 @@ +# Copyright 2026 Google LLC All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import unittest + +from google.cloud.sqlalchemy_spanner.sqlalchemy_spanner import ( + SpannerDialect, + _make_json_serializer, +) +from google.cloud.spanner_v1.data_types import JsonObject + + +def _custom_serializer(obj): + """Sample json_serializer that handles datetime objects.""" + return json.dumps(obj, default=_datetime_default) + + +def _datetime_default(obj): + """Sample default handler for json.dumps.""" + if hasattr(obj, "isoformat"): + return obj.isoformat() + raise TypeError(f"Object of type {type(obj).__name__} is not JSON serializable") + + +class TestMakeJsonSerializer(unittest.TestCase): + """Tests for _make_json_serializer factory.""" + + def test_json_object_subclass_returned_directly(self): + result = _make_json_serializer(JsonObject) + assert result is JsonObject + + def test_custom_subclass_returned_directly(self): + class MyJsonObject(JsonObject): + pass + + result = _make_json_serializer(MyJsonObject) + assert result is MyJsonObject + + def test_callable_produces_json_object(self): + factory = _make_json_serializer(_custom_serializer) + obj = factory({"key": "value", "num": 42}) + assert isinstance(obj, JsonObject) + parsed = json.loads(obj.serialize()) + assert parsed == {"key": "value", "num": 42} + + def test_callable_handles_datetime(self): + factory = _make_json_serializer(_custom_serializer) + dt = datetime.datetime(2023, 6, 15) + obj = factory({"ts": dt}) + assert isinstance(obj, JsonObject) + parsed = json.loads(obj.serialize()) + assert parsed["ts"] == "2023-06-15T00:00:00" + + def test_callable_handles_nested_datetimes(self): + factory = _make_json_serializer(_custom_serializer) + obj = factory({ + "events": [ + {"ts": datetime.datetime(2023, 1, 1), "action": "created"}, + {"ts": datetime.datetime(2023, 6, 15), "action": "updated"}, + ] + }) + parsed = json.loads(obj.serialize()) + assert parsed["events"][0]["ts"] == "2023-01-01T00:00:00" + assert parsed["events"][1]["ts"] == "2023-06-15T00:00:00" + + def test_callable_handles_arrays(self): + factory = _make_json_serializer(_custom_serializer) + obj = factory([1, 2, 3]) + assert isinstance(obj, JsonObject) + assert json.loads(obj.serialize()) == [1, 2, 3] + + def test_callable_handles_null(self): + factory = _make_json_serializer(lambda v: json.dumps(v)) + obj = factory(None) + assert isinstance(obj, JsonObject) + assert obj.serialize() is None + + def test_no_custom_types_remain_in_json_object(self): + """After serialize-then-wrap, the JsonObject contains only native types.""" + factory = _make_json_serializer(_custom_serializer) + dt = datetime.datetime(2023, 6, 15, 9, 30, 0) + obj = factory({"ts": dt, "name": "test"}) + assert isinstance(obj["ts"], str) + assert obj["ts"] == "2023-06-15T09:30:00" + + +class TestSpannerDialectJsonSerializer(unittest.TestCase): + """Tests for json_serializer/json_deserializer support in SpannerDialect.""" + + def test_default_json_serializer_is_json_object(self): + dialect = SpannerDialect() + assert dialect._json_serializer is JsonObject + + def test_default_json_deserializer_is_json_object(self): + dialect = SpannerDialect() + assert dialect._json_deserializer is JsonObject + + def test_custom_json_serializer_produces_factory(self): + dialect = SpannerDialect(json_serializer=_custom_serializer) + assert dialect._json_serializer is not JsonObject + obj = dialect._json_serializer({"ts": datetime.datetime(2023, 1, 1)}) + assert isinstance(obj, JsonObject) + parsed = json.loads(obj.serialize()) + assert parsed["ts"] == "2023-01-01T00:00:00" + + def test_json_object_subclass_used_directly(self): + dialect = SpannerDialect(json_serializer=JsonObject) + assert dialect._json_serializer is JsonObject + + def test_custom_json_deserializer(self): + custom = lambda x: json.loads(x) + dialect = SpannerDialect(json_deserializer=custom) + assert dialect._json_deserializer is custom + + def test_class_attribute_unchanged_after_instance_override(self): + _ = SpannerDialect(json_serializer=_custom_serializer) + assert SpannerDialect._json_serializer is JsonObject + + def test_json_serializer_accepted_by_get_cls_kwargs(self): + from sqlalchemy.util import get_cls_kwargs + + kwargs = get_cls_kwargs(SpannerDialect) + assert "json_serializer" in kwargs + assert "json_deserializer" in kwargs + + +class TestEndToEndJsonSerialization(unittest.TestCase): + """End-to-end: SQLAlchemy JSON bind_processor -> serialize-then-wrap -> JsonObject. + + Simulates the full pipeline that occurs during a DML INSERT/UPDATE + with a JSON column containing non-natively-serializable types. + """ + + def test_bind_processor_with_custom_serializer(self): + """Simulate SQLAlchemy's JSON.bind_processor using the dialect.""" + from sqlalchemy import types as sa_types + + dialect = SpannerDialect(json_serializer=_custom_serializer) + processor = sa_types.JSON().bind_processor(dialect) + + dt = datetime.datetime(2023, 6, 15, 9, 30, 0) + value = {"event": "deploy", "timestamp": dt, "count": 42} + + result = processor(value) + + assert isinstance(result, JsonObject) + serialized = result.serialize() + parsed = json.loads(serialized) + assert parsed["event"] == "deploy" + assert parsed["timestamp"] == "2023-06-15T09:30:00" + assert parsed["count"] == 42 + + def test_bind_processor_with_nested_datetimes(self): + from sqlalchemy import types as sa_types + + dialect = SpannerDialect(json_serializer=_custom_serializer) + processor = sa_types.JSON().bind_processor(dialect) + + value = { + "history": [ + {"ts": datetime.datetime(2023, 1, 1), "action": "created"}, + {"ts": datetime.datetime(2023, 6, 15), "action": "updated"}, + ] + } + result = processor(value) + parsed = json.loads(result.serialize()) + assert parsed["history"][0]["ts"] == "2023-01-01T00:00:00" + assert parsed["history"][1]["ts"] == "2023-06-15T00:00:00" + + def test_bind_processor_with_null_default(self): + """With none_as_null=False (default), None becomes a null JsonObject.""" + from sqlalchemy import types as sa_types + + dialect = SpannerDialect(json_serializer=_custom_serializer) + processor = sa_types.JSON().bind_processor(dialect) + + result = processor(None) + assert isinstance(result, JsonObject) + assert result.serialize() is None + + def test_bind_processor_with_null_as_sql_null(self): + """With none_as_null=True, None becomes Python None (SQL NULL).""" + from sqlalchemy import types as sa_types + + dialect = SpannerDialect(json_serializer=_custom_serializer) + processor = sa_types.JSON(none_as_null=True).bind_processor(dialect) + + result = processor(None) + assert result is None + + def test_spanner_helpers_pipeline(self): + """Simulate _helpers._make_param_value_pb: isinstance check + bare serialize(). + + _helpers.py checks isinstance(value, JsonObject) then calls + value.serialize() with no arguments. Verify this works after + the serialize-then-wrap round-trip. + """ + dialect = SpannerDialect(json_serializer=_custom_serializer) + factory = dialect._json_serializer + + dt = datetime.datetime(2023, 12, 25, 0, 0, 0) + obj = factory({"holiday": "christmas", "date": dt}) + + assert isinstance(obj, JsonObject) + serialized = obj.serialize() + assert serialized is not None + parsed = json.loads(serialized) + assert parsed["date"] == "2023-12-25T00:00:00" + + def test_default_dialect_unchanged(self): + """Without json_serializer, the dialect uses plain JsonObject (no round-trip).""" + from sqlalchemy import types as sa_types + + dialect = SpannerDialect() + processor = sa_types.JSON().bind_processor(dialect) + + value = {"name": "test", "count": 42} + result = processor(value) + assert type(result) is JsonObject + parsed = json.loads(result.serialize()) + assert parsed == {"count": 42, "name": "test"} + + +if __name__ == "__main__": + unittest.main()