diff --git a/awscrt/cbor.py b/awscrt/cbor.py index 4c97ab616..d5df98610 100644 --- a/awscrt/cbor.py +++ b/awscrt/cbor.py @@ -5,7 +5,7 @@ from awscrt import NativeResource from enum import IntEnum -from typing import Callable, Any, Union +from typing import Callable, Any, Union, Dict class AwsCborType(IntEnum): @@ -53,6 +53,105 @@ class AwsCborType(IntEnum): IndefMap = 16 +class ShapeBase: + """ + Base class for shape objects used by CRT CBOR encoding. + + This class defines the interface that shape implementations should follow. + Libraries can extend this class directly to provide shape information + to the CBOR encoder without requiring intermediate conversions. + + Subclasses must implement the type_name property and can override other + properties as needed based on their shape type. + """ + + @property + def type_name(self) -> str: + """ + Return the shape type name. + TODO: maybe return the `AwsCborType` instead? + + Returns: + str: One of: 'structure', 'list', 'map', 'string', 'integer', 'long', + 'float', 'double', 'boolean', 'blob', 'timestamp' + + Note: + Subclasses must implement this property. + """ + raise NotImplementedError("Subclasses must implement type_name property") + + @property + def members(self) -> Dict[str, 'ShapeBase']: + """ + For structure types, return dict of member name -> ShapeBase. + + Returns: + Dict[str, ShapeBase]: Dictionary mapping member names to their shapes + + Raises: + AttributeError: If called on non-structure type + """ + raise AttributeError(f"Shape type {self.type_name} has no members") + + @property + def member(self) -> 'ShapeBase': + """ + For list types, return the ShapeBase of list elements. + + Returns: + ShapeBase: Shape of list elements + + Raises: + AttributeError: If called on non-list type + """ + raise AttributeError(f"Shape type {self.type_name} has no member") + + @property + def key(self) -> 'ShapeBase': + """ + For map types, return the ShapeBase of map keys. + + Returns: + ShapeBase: Shape of map keys + + Raises: + AttributeError: If called on non-map type + """ + raise AttributeError(f"Shape type {self.type_name} has no key") + + @property + def value(self) -> 'ShapeBase': + """ + For map types, return the ShapeBase of map values. + + Returns: + ShapeBase: Shape of map values + + Raises: + AttributeError: If called on non-map type + """ + raise AttributeError(f"Shape type {self.type_name} has no value") + + def get_serialization_name(self, member_name: str) -> str: + """ + Get the serialization name for a structure member. + + For structure types, returns the name to use in CBOR encoding. + This allows for custom field name mappings (e.g., 'user_id' -> 'UserId'). + + Args: + member_name: The member name as it appears in the structure + + Returns: + str: The name to use in CBOR encoding (may be same as member_name) + + Raises: + AttributeError: If called on non-structure type + ValueError: If member_name is not found in the structure + """ + raise AttributeError(f"Shape type {self.type_name} has no serialization_name") + + class AwsCborEncoder(NativeResource): """CBOR encoder for converting Python objects to CBOR binary format. @@ -283,6 +382,61 @@ def write_data_item(self, data_item: Any): """ return _awscrt.cbor_encoder_write_data_item(self._binding, data_item) + def write_data_item_shaped(self, + data_item: Any, + shape: 'ShapeBase', + timestamp_converter: Callable[[Any], + float] = None): + """Generic API to write any type of data_item as cbor formatted, using shape information. + + The shape parameter must be a CRTShape wrapper object - a lightweight wrapper around + botocore Shape objects that exposes only the properties needed by the CRT CBOR encoder. + + Supported shape types: + - integer/long: Integer values + - float/double: Floating point values + - boolean: Boolean values + - blob: Byte strings + - string: Text strings + - list: Lists with typed members + - map: Maps with typed keys and values + - structure: Structures with named members (None values filtered) + - timestamp: Timestamps (with optional converter callback) + + Args: + data_item (Any): The data to encode + shape (CRTShape): A CRTShape wrapper object that wraps a botocore Shape + timestamp_converter (Callable[[Any], float], optional): Optional callback to convert + timestamp values to epoch seconds (float). If not provided, assumes data_item + is already a numeric timestamp for timestamp shapes. + + Example: + ```python + from awscrt.cbor_shape import CRTShape + + encoder = AwsCborEncoder() + + # Wrap the botocore shape + crt_shape = CRTShape(botocore_shape) + + # Encode data with shape information + data = {"id": 123, "name": "Alice"} + + def timestamp_converter(dt): + return dt.timestamp() + + encoder.write_data_item_shaped(data, crt_shape, timestamp_converter) + cbor_bytes = encoder.get_encoded_data() + ``` + + For complete specification, see CRT_SHAPE_WRAPPER_APPROACH.md + + Note: + The CRTShape wrapper provides lazy initialization and caching for optimal performance. + Shape objects are typically cached by the serializer for reuse across multiple requests. + """ + return _awscrt.cbor_encoder_write_data_item_shaped(self._binding, data_item, shape, timestamp_converter) + class AwsCborDecoder(NativeResource): """CBOR decoder for converting CBOR binary format to Python objects. @@ -568,6 +722,6 @@ def pop_next_data_item(self) -> Any: - `AwsCborType.ArrayStart` or `AwsCborType.IndefArray` and all the followed data items in the array -> list - `AwsCborType.MapStart` or `AwsCborType.IndefMap` and all the followed data items in the map -> dict - `AwsCborType.Tag`: For tag with id 1, as the epoch time, it invokes the _on_epoch_time for python to convert to expected type. - For the reset tag, exception will be raised. + For the other tags, exception will be raised. """ return _awscrt.cbor_decoder_pop_next_data_item(self._binding) diff --git a/source/cbor.c b/source/cbor.c index 602ccb21d..236d62a74 100644 --- a/source/cbor.c +++ b/source/cbor.c @@ -180,33 +180,48 @@ static PyObject *s_cbor_encoder_write_pydict(struct aws_cbor_encoder *encoder, P static PyObject *s_cbor_encoder_write_pyobject(struct aws_cbor_encoder *encoder, PyObject *py_object) { - /** - * TODO: timestamp <-> datetime?? Decimal fraction <-> decimal?? - */ - if (PyLong_CheckExact(py_object)) { - return s_cbor_encoder_write_pylong(encoder, py_object); - } else if (PyFloat_CheckExact(py_object)) { - return s_cbor_encoder_write_pyobject_as_float(encoder, py_object); - } else if (PyBool_Check(py_object)) { - return s_cbor_encoder_write_pyobject_as_bool(encoder, py_object); - } else if (PyBytes_CheckExact(py_object)) { - return s_cbor_encoder_write_pyobject_as_bytes(encoder, py_object); - } else if (PyUnicode_Check(py_object)) { + /* Handle None first as it's a singleton, not a type */ + if (py_object == Py_None) { + aws_cbor_encoder_write_null(encoder); + Py_RETURN_NONE; + } + + /* Get type once for efficiency - PyObject_Type returns a new reference */ + /* https://docs.python.org/3/c-api/structures.html#c.Py_TYPE is not a stable API until 3.14, so that we cannot use + * it. */ + PyObject *type = PyObject_Type(py_object); + if (!type) { + return NULL; + } + + PyObject *result = NULL; + + /* Exact type matches first (no subclasses) - fast path */ + if (type == (PyObject *)&PyLong_Type) { + result = s_cbor_encoder_write_pylong(encoder, py_object); + } else if (type == (PyObject *)&PyFloat_Type) { + result = s_cbor_encoder_write_pyobject_as_float(encoder, py_object); + } else if (type == (PyObject *)&PyBool_Type) { + result = s_cbor_encoder_write_pyobject_as_bool(encoder, py_object); + } else if (type == (PyObject *)&PyBytes_Type) { + result = s_cbor_encoder_write_pyobject_as_bytes(encoder, py_object); + } else if (PyType_IsSubtype((PyTypeObject *)type, &PyUnicode_Type)) { /* Allow subclasses of `str` */ - return s_cbor_encoder_write_pyobject_as_text(encoder, py_object); - } else if (PyList_Check(py_object)) { + result = s_cbor_encoder_write_pyobject_as_text(encoder, py_object); + } else if (PyType_IsSubtype((PyTypeObject *)type, &PyList_Type)) { /* Write py_list, allow subclasses of `list` */ - return s_cbor_encoder_write_pylist(encoder, py_object); - } else if (PyDict_Check(py_object)) { + result = s_cbor_encoder_write_pylist(encoder, py_object); + } else if (PyType_IsSubtype((PyTypeObject *)type, &PyDict_Type)) { /* Write py_dict, allow subclasses of `dict` */ - return s_cbor_encoder_write_pydict(encoder, py_object); - } else if (py_object == Py_None) { - aws_cbor_encoder_write_null(encoder); + result = s_cbor_encoder_write_pydict(encoder, py_object); } else { - PyErr_Format(PyExc_ValueError, "Not supported type %R", (PyObject *)Py_TYPE(py_object)); + /* Unsupported type */ + PyErr_Format(PyExc_ValueError, "Not supported type %R", type); } - Py_RETURN_NONE; + /* Release the type reference */ + Py_DECREF(type); + return result; } /*********************************** BINDINGS ***********************************************/ @@ -239,6 +254,265 @@ ENCODER_WRITE(py_list, s_cbor_encoder_write_pylist) ENCODER_WRITE(py_dict, s_cbor_encoder_write_pydict) ENCODER_WRITE(data_item, s_cbor_encoder_write_pyobject) +static PyObject *s_cbor_encoder_write_pyobject_shaped( + struct aws_cbor_encoder *encoder, + PyObject *py_object, + PyObject *py_shape, + PyObject *py_timestamp_converter); + +static PyObject *s_cbor_encoder_write_shaped_structure( + struct aws_cbor_encoder *encoder, + PyObject *py_dict, + PyObject *py_shape, + PyObject *py_timestamp_converter) { + + /* Filter None values */ + PyObject *filtered_dict = PyDict_New(); + if (!filtered_dict) { + return NULL; + } + + PyObject *key = NULL; + PyObject *value = NULL; + Py_ssize_t pos = 0; + + while (PyDict_Next(py_dict, &pos, &key, &value)) { + if (value != Py_None) { + if (PyDict_SetItem(filtered_dict, key, value) == -1) { + Py_DECREF(filtered_dict); + return NULL; + } + } + } + + Py_ssize_t size = PyDict_Size(filtered_dict); + aws_cbor_encoder_write_map_start(encoder, (size_t)size); + + /* Get members from shape */ + PyObject *members = PyObject_GetAttrString(py_shape, "members"); + if (!members) { + Py_DECREF(filtered_dict); + return NULL; + } + + pos = 0; + while (PyDict_Next(filtered_dict, &pos, &key, &value)) { + /* Get member shape */ + PyObject *member_shape = PyDict_GetItem(members, key); + if (!member_shape) { + Py_DECREF(filtered_dict); + Py_DECREF(members); + PyErr_Format(PyExc_KeyError, "Member shape not found for key"); + return NULL; + } + + /* Get serialization name if present */ + PyObject *serialization = PyObject_GetAttrString(member_shape, "serialization"); + PyObject *member_key = key; + if (serialization) { + if (PyDict_Check(serialization)) { + PyObject *name = PyDict_GetItemString(serialization, "name"); + if (name) { + member_key = name; + } + } + Py_DECREF(serialization); + } else { + /* Clear the AttributeError if serialization doesn't exist */ + PyErr_Clear(); + } + + /* Write the key */ + PyObject *key_result = s_cbor_encoder_write_pyobject_as_text(encoder, member_key); + if (!key_result) { + Py_DECREF(filtered_dict); + Py_DECREF(members); + return NULL; + } + Py_DECREF(key_result); + + /* Write the value with shape */ + PyObject *value_result = + s_cbor_encoder_write_pyobject_shaped(encoder, value, member_shape, py_timestamp_converter); + if (!value_result) { + Py_DECREF(filtered_dict); + Py_DECREF(members); + return NULL; + } + Py_DECREF(value_result); + } + + Py_DECREF(filtered_dict); + Py_DECREF(members); + Py_RETURN_NONE; +} + +static PyObject *s_cbor_encoder_write_shaped_list( + struct aws_cbor_encoder *encoder, + PyObject *py_list, + PyObject *py_shape, + PyObject *py_timestamp_converter) { + + Py_ssize_t size = PyList_Size(py_list); + aws_cbor_encoder_write_array_start(encoder, (size_t)size); + + /* Get member shape */ + PyObject *member_shape = PyObject_GetAttrString(py_shape, "member"); + if (!member_shape) { + return NULL; + } + + for (Py_ssize_t i = 0; i < size; i++) { + PyObject *item = PyList_GetItem(py_list, i); + if (!item) { + Py_DECREF(member_shape); + return NULL; + } + PyObject *result = s_cbor_encoder_write_pyobject_shaped(encoder, item, member_shape, py_timestamp_converter); + if (!result) { + Py_DECREF(member_shape); + return NULL; + } + Py_DECREF(result); + } + + Py_DECREF(member_shape); + Py_RETURN_NONE; +} + +static PyObject *s_cbor_encoder_write_shaped_map( + struct aws_cbor_encoder *encoder, + PyObject *py_dict, + PyObject *py_shape, + PyObject *py_timestamp_converter) { + + Py_ssize_t size = PyDict_Size(py_dict); + aws_cbor_encoder_write_map_start(encoder, (size_t)size); + + /* Get key and value shapes */ + PyObject *key_shape = PyObject_GetAttrString(py_shape, "key"); + PyObject *value_shape = PyObject_GetAttrString(py_shape, "value"); + if (!key_shape || !value_shape) { + Py_XDECREF(key_shape); + Py_XDECREF(value_shape); + return NULL; + } + + PyObject *key = NULL; + PyObject *value = NULL; + Py_ssize_t pos = 0; + + while (PyDict_Next(py_dict, &pos, &key, &value)) { + PyObject *key_result = s_cbor_encoder_write_pyobject_shaped(encoder, key, key_shape, py_timestamp_converter); + if (!key_result) { + Py_DECREF(key_shape); + Py_DECREF(value_shape); + return NULL; + } + Py_DECREF(key_result); + + PyObject *value_result = + s_cbor_encoder_write_pyobject_shaped(encoder, value, value_shape, py_timestamp_converter); + if (!value_result) { + Py_DECREF(key_shape); + Py_DECREF(value_shape); + return NULL; + } + Py_DECREF(value_result); + } + + Py_DECREF(key_shape); + Py_DECREF(value_shape); + Py_RETURN_NONE; +} + +static PyObject *s_cbor_encoder_write_pyobject_shaped( + struct aws_cbor_encoder *encoder, + PyObject *py_object, + PyObject *py_shape, + PyObject *py_timestamp_converter) { + + /* Get type_name from shape */ + PyObject *type_name = PyObject_GetAttrString(py_shape, "type_name"); + if (!type_name) { + return NULL; + } + + struct aws_byte_cursor type_cursor = aws_byte_cursor_from_pyunicode(type_name); + Py_DECREF(type_name); + if (!type_cursor.ptr) { + return NULL; + } + + PyObject *result = NULL; + + /* Handle different shape types */ + if (aws_byte_cursor_eq_c_str(&type_cursor, "integer") || aws_byte_cursor_eq_c_str(&type_cursor, "long")) { + result = s_cbor_encoder_write_pylong(encoder, py_object); + } else if (aws_byte_cursor_eq_c_str(&type_cursor, "float") || aws_byte_cursor_eq_c_str(&type_cursor, "double")) { + result = s_cbor_encoder_write_pyobject_as_float(encoder, py_object); + } else if (aws_byte_cursor_eq_c_str(&type_cursor, "boolean")) { + result = s_cbor_encoder_write_pyobject_as_bool(encoder, py_object); + } else if (aws_byte_cursor_eq_c_str(&type_cursor, "blob")) { + if (PyUnicode_Check(py_object)) { + /* Convert string to bytes */ + PyObject *encoded = PyUnicode_AsEncodedString(py_object, "utf-8", "strict"); + if (encoded) { + result = s_cbor_encoder_write_pyobject_as_bytes(encoder, encoded); + Py_DECREF(encoded); + } + } else { + result = s_cbor_encoder_write_pyobject_as_bytes(encoder, py_object); + } + } else if (aws_byte_cursor_eq_c_str(&type_cursor, "string")) { + result = s_cbor_encoder_write_pyobject_as_text(encoder, py_object); + } else if (aws_byte_cursor_eq_c_str(&type_cursor, "list")) { + result = s_cbor_encoder_write_shaped_list(encoder, py_object, py_shape, py_timestamp_converter); + } else if (aws_byte_cursor_eq_c_str(&type_cursor, "map")) { + result = s_cbor_encoder_write_shaped_map(encoder, py_object, py_shape, py_timestamp_converter); + } else if (aws_byte_cursor_eq_c_str(&type_cursor, "structure")) { + result = s_cbor_encoder_write_shaped_structure(encoder, py_object, py_shape, py_timestamp_converter); + } else if (aws_byte_cursor_eq_c_str(&type_cursor, "timestamp")) { + /* Write CBOR tag 1 (epoch time) */ + aws_cbor_encoder_write_tag(encoder, AWS_CBOR_TAG_EPOCH_TIME); + if (!py_timestamp_converter || py_timestamp_converter == Py_None) { + /* Error out as the timestamp_converter is not provided */ + PyErr_Format(PyExc_ValueError, "Timestamp converter is not provided to enable timestamp"); + } + /* Call the converter to get epoch seconds as float */ + PyObject *converted = PyObject_CallFunctionObjArgs(py_timestamp_converter, py_object, NULL); + if (converted) { + result = s_cbor_encoder_write_pyobject_as_float(encoder, converted); + Py_DECREF(converted); + } + } else { + /* Format error message with cursor data */ + PyErr_Format( + PyExc_ValueError, "Unsupported shape type: %.*s", (int)type_cursor.len, (const char *)type_cursor.ptr); + } + + return result; +} + +PyObject *aws_py_cbor_encoder_write_data_item_shaped(PyObject *self, PyObject *args) { + (void)self; + PyObject *py_capsule; + PyObject *py_object; + PyObject *py_shape; + PyObject *py_timestamp_converter = NULL; + + if (!PyArg_ParseTuple(args, "OOOO", &py_capsule, &py_object, &py_shape, &py_timestamp_converter)) { + return NULL; + } + + struct aws_cbor_encoder *encoder = s_cbor_encoder_from_capsule(py_capsule); + if (!encoder) { + return NULL; + } + + return s_cbor_encoder_write_pyobject_shaped(encoder, py_object, py_shape, py_timestamp_converter); +} + PyObject *aws_py_cbor_encoder_write_simple_types(PyObject *self, PyObject *args) { (void)self; PyObject *py_capsule = NULL; diff --git a/source/cbor.h b/source/cbor.h index b196bc135..d813ede97 100644 --- a/source/cbor.h +++ b/source/cbor.h @@ -29,6 +29,7 @@ PyObject *aws_py_cbor_encoder_write_simple_types(PyObject *self, PyObject *args) PyObject *aws_py_cbor_encoder_write_py_list(PyObject *self, PyObject *args); PyObject *aws_py_cbor_encoder_write_py_dict(PyObject *self, PyObject *args); PyObject *aws_py_cbor_encoder_write_data_item(PyObject *self, PyObject *args); +PyObject *aws_py_cbor_encoder_write_data_item_shaped(PyObject *self, PyObject *args); /******************************************************************************* * DECODE diff --git a/source/module.c b/source/module.c index 84b5673d2..b580de0c0 100644 --- a/source/module.c +++ b/source/module.c @@ -963,6 +963,7 @@ static PyMethodDef s_module_methods[] = { AWS_PY_METHOD_DEF(cbor_encoder_write_py_list, METH_VARARGS), AWS_PY_METHOD_DEF(cbor_encoder_write_py_dict, METH_VARARGS), AWS_PY_METHOD_DEF(cbor_encoder_write_data_item, METH_VARARGS), + AWS_PY_METHOD_DEF(cbor_encoder_write_data_item_shaped, METH_VARARGS), /* CBOR Decode */ AWS_PY_METHOD_DEF(cbor_decoder_new, METH_VARARGS), diff --git a/source/module.h b/source/module.h index 2f9dd217e..52d9165d7 100644 --- a/source/module.h +++ b/source/module.h @@ -70,6 +70,15 @@ struct aws_byte_cursor aws_byte_cursor_from_pyunicode(PyObject *str); * If conversion cannot occur, cursor->ptr will be NULL and a python exception is set */ struct aws_byte_cursor aws_byte_cursor_from_pybytes(PyObject *py_bytes); +/** + * Check if a PyObject is an instance of datetime.datetime using stable ABI. + * + * @param obj PyObject to check + * @param out_is_datetime Pointer to store result (true if datetime, false otherwise) + * @return AWS_OP_SUCCESS on success, AWS_OP_ERR on error (Python exception set) + */ +int aws_py_is_datetime_instance(PyObject *obj, bool *out_is_datetime); + /* Set current thread's error indicator based on aws_last_error() */ void PyErr_SetAwsLastError(void); diff --git a/test/test_cbor.py b/test/test_cbor.py index 3ab65b569..e2070d4b5 100644 --- a/test/test_cbor.py +++ b/test/test_cbor.py @@ -151,6 +151,127 @@ def on_epoch_time(epoch_secs): exception = e self.assertIsNotNone(exception) + def test_cbor_encode_unsupported_type(self): + """Test that encoding unsupported types raises ValueError""" + # Create a custom class that's not supported by CBOR encoder + class CustomClass: + def __init__(self, value): + self.value = value + + # Try to encode an unsupported type + encoder = AwsCborEncoder() + unsupported_obj = CustomClass(42) + + # Should raise ValueError with message about unsupported type + with self.assertRaises(ValueError) as context: + encoder.write_data_item(unsupported_obj) + # Verify the error message mentions "Not supported type" + self.assertIn("Not supported type", str(context.exception)) + + # Test unsupported type in a list (should also fail) + encoder2 = AwsCborEncoder() + with self.assertRaises(ValueError) as context2: + encoder2.write_data_item([1, 2, unsupported_obj, 3]) + + self.assertIn("Not supported type", str(context2.exception)) + + # Test unsupported type as dict key (should also fail) + encoder3 = AwsCborEncoder() + with self.assertRaises(ValueError) as context3: + encoder3.write_data_item({unsupported_obj: "value"}) + + self.assertIn("Not supported type", str(context3.exception)) + + def test_cbor_encode_decode_special_floats(self): + """Test encoding and decoding special float values: inf, -inf, and NaN""" + import math + + # Test positive infinity + encoder = AwsCborEncoder() + pos_inf = float('inf') + encoder.write_float(pos_inf) + print(encoder.get_encoded_data()) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + decoded_pos_inf = decoder.pop_next_data_item() + print(decoded_pos_inf) + self.assertTrue(math.isinf(decoded_pos_inf)) + self.assertTrue(decoded_pos_inf > 0) + + # Test negative infinity + encoder = AwsCborEncoder() + neg_inf = float('-inf') + encoder.write_float(neg_inf) + + decoder = AwsCborDecoder(encoder.get_encoded_data()) + decoded_neg_inf = decoder.pop_next_data_item() + self.assertTrue(math.isinf(decoded_neg_inf)) + self.assertTrue(decoded_neg_inf < 0) + + # Test NaN + encoder = AwsCborEncoder() + nan_val = float('nan') + encoder.write_float(nan_val) + + decoder = AwsCborDecoder(encoder.get_encoded_data()) + decoded_nan = decoder.pop_next_data_item() + self.assertTrue(math.isnan(decoded_nan)) + + # Test special floats in a list using write_data_item + encoder = AwsCborEncoder() + special_floats_list = [float('inf'), float('-inf'), float('nan'), 42.0, -100.5] + encoder.write_data_item(special_floats_list) + + decoder = AwsCborDecoder(encoder.get_encoded_data()) + decoded_list = decoder.pop_next_data_item() + self.assertEqual(len(decoded_list), 5) + self.assertTrue(math.isinf(decoded_list[0]) and decoded_list[0] > 0) + self.assertTrue(math.isinf(decoded_list[1]) and decoded_list[1] < 0) + self.assertTrue(math.isnan(decoded_list[2])) + self.assertEqual(decoded_list[3], 42.0) + self.assertEqual(decoded_list[4], -100.5) + + # Test special floats in a dictionary + encoder = AwsCborEncoder() + special_floats_dict = { + "positive_infinity": float('inf'), + "negative_infinity": float('-inf'), + "not_a_number": float('nan'), + "normal": 3.14 + } + encoder.write_data_item(special_floats_dict) + + decoder = AwsCborDecoder(encoder.get_encoded_data()) + decoded_dict = decoder.pop_next_data_item() + self.assertTrue(math.isinf(decoded_dict["positive_infinity"]) and decoded_dict["positive_infinity"] > 0) + self.assertTrue(math.isinf(decoded_dict["negative_infinity"]) and decoded_dict["negative_infinity"] < 0) + self.assertTrue(math.isnan(decoded_dict["not_a_number"])) + self.assertEqual(decoded_dict["normal"], 3.14) + + # Test special floats in nested structures + encoder = AwsCborEncoder() + nested_structure = { + "data": [ + {"value": float('inf'), "type": "infinity"}, + {"value": float('-inf'), "type": "negative_infinity"}, + {"value": float('nan'), "type": "nan"} + ], + "metadata": { + "max": float('inf'), + "min": float('-inf') + } + } + encoder.write_data_item(nested_structure) + + decoder = AwsCborDecoder(encoder.get_encoded_data()) + decoded_nested = decoder.pop_next_data_item() + self.assertTrue(math.isinf(decoded_nested["data"][0]["value"]) and decoded_nested["data"][0]["value"] > 0) + self.assertTrue(math.isinf(decoded_nested["data"][1]["value"]) and decoded_nested["data"][1]["value"] < 0) + self.assertTrue(math.isnan(decoded_nested["data"][2]["value"])) + self.assertTrue(math.isinf(decoded_nested["metadata"]["max"]) and decoded_nested["metadata"]["max"] > 0) + self.assertTrue(math.isinf(decoded_nested["metadata"]["min"]) and decoded_nested["metadata"]["min"] < 0) + + self.assertEqual(decoder.get_remaining_bytes_len(), 0) + def _ieee754_bits_to_float(self, bits): return struct.unpack('>f', struct.pack('>I', bits))[0] @@ -261,3 +382,496 @@ def test_cbor_decode_errors(self): tag_data = decoder.pop_next_data_item() else: decoded_data = decoder.pop_next_data_item() + + # Helper shape classes for testing write_data_item_shaped + class SimpleShape(ShapeBase): + """Simple shape implementation for testing""" + + def __init__(self, type_name): + self._type_name = type_name + + @property + def type_name(self): + return self._type_name + + class ListShape(ShapeBase): + """List shape implementation for testing""" + + def __init__(self, member_shape): + self._member_shape = member_shape + + @property + def type_name(self): + return "list" + + @property + def member(self): + return self._member_shape + + class MapShape(ShapeBase): + """Map shape implementation for testing""" + + def __init__(self, key_shape, value_shape): + self._key_shape = key_shape + self._value_shape = value_shape + + @property + def type_name(self): + return "map" + + @property + def key(self): + return self._key_shape + + @property + def value(self): + return self._value_shape + + class StructureShape(ShapeBase): + """Structure shape implementation for testing""" + + def __init__(self, members_dict, serialization_names=None): + self._members_dict = members_dict + self._serialization_names = serialization_names or {} + + @property + def type_name(self): + return "structure" + + @property + def members(self): + return self._members_dict + + def get_serialization_name(self, member_name): + return self._serialization_names.get(member_name, member_name) + + def test_write_data_item_shaped_basic_types(self): + """Test write_data_item_shaped with basic scalar types""" + + # Test integer + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(42, self.SimpleShape("integer")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), 42) + + # Test long (same as integer in Python 3) + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(2**63, self.SimpleShape("long")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), 2**63) + + # Test float + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(3.14, self.SimpleShape("float")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertAlmostEqual(decoder.pop_next_data_item(), 3.14, places=5) + + # Test double + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(2.718281828, self.SimpleShape("double")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertAlmostEqual(decoder.pop_next_data_item(), 2.718281828, places=9) + + # Test boolean - True + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(True, self.SimpleShape("boolean")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertTrue(decoder.pop_next_data_item()) + + # Test boolean - False + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(False, self.SimpleShape("boolean")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertFalse(decoder.pop_next_data_item()) + + # Test string + encoder = AwsCborEncoder() + encoder.write_data_item_shaped("hello world", self.SimpleShape("string")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), "hello world") + + # Test empty string + encoder = AwsCborEncoder() + encoder.write_data_item_shaped("", self.SimpleShape("string")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), "") + + # Test blob + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(b"binary data", self.SimpleShape("blob")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), b"binary data") + + # Test empty blob + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(b"", self.SimpleShape("blob")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), b"") + + def test_write_data_item_shaped_list(self): + """Test write_data_item_shaped with list types""" + + # Test list of integers + encoder = AwsCborEncoder() + int_list_shape = self.ListShape(self.SimpleShape("integer")) + encoder.write_data_item_shaped([1, 2, 3, 4, 5], int_list_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), [1, 2, 3, 4, 5]) + + # Test list of strings + encoder = AwsCborEncoder() + string_list_shape = self.ListShape(self.SimpleShape("string")) + encoder.write_data_item_shaped(["a", "b", "c"], string_list_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), ["a", "b", "c"]) + + # Test empty list + encoder = AwsCborEncoder() + encoder.write_data_item_shaped([], int_list_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), []) + + # Test nested list + encoder = AwsCborEncoder() + nested_list_shape = self.ListShape(self.ListShape(self.SimpleShape("integer"))) + encoder.write_data_item_shaped([[1, 2], [3, 4], [5]], nested_list_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), [[1, 2], [3, 4], [5]]) + + def test_write_data_item_shaped_map(self): + """Test write_data_item_shaped with map types""" + + # Test map with string keys and integer values + encoder = AwsCborEncoder() + map_shape = self.MapShape(self.SimpleShape("string"), self.SimpleShape("integer")) + data = {"a": 1, "b": 2, "c": 3} + encoder.write_data_item_shaped(data, map_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), data) + + # Test map with integer keys and string values + encoder = AwsCborEncoder() + map_shape = self.MapShape(self.SimpleShape("integer"), self.SimpleShape("string")) + data = {1: "one", 2: "two", 3: "three"} + encoder.write_data_item_shaped(data, map_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), data) + + # Test empty map + encoder = AwsCborEncoder() + encoder.write_data_item_shaped({}, map_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), {}) + + # Test nested map + encoder = AwsCborEncoder() + nested_map_shape = self.MapShape( + self.SimpleShape("string"), + self.MapShape(self.SimpleShape("string"), self.SimpleShape("integer")) + ) + data = {"outer1": {"inner1": 1, "inner2": 2}, "outer2": {"inner3": 3}} + encoder.write_data_item_shaped(data, nested_map_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), data) + + def test_write_data_item_shaped_structure(self): + """Test write_data_item_shaped with structure types""" + + # Test simple structure + encoder = AwsCborEncoder() + struct_shape = self.StructureShape({ + "id": self.SimpleShape("integer"), + "name": self.SimpleShape("string"), + "active": self.SimpleShape("boolean") + }) + data = {"id": 123, "name": "Alice", "active": True} + encoder.write_data_item_shaped(data, struct_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), data) + + # Test structure with None values (should be filtered out) + encoder = AwsCborEncoder() + data_with_none = {"id": 456, "name": None, "active": False} + encoder.write_data_item_shaped(data_with_none, struct_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + # None values should be filtered out + self.assertEqual(decoder.pop_next_data_item(), {"id": 456, "active": False}) + + # Test structure with custom serialization names + # Note: Custom serialization names may not be fully implemented yet in C binding + # This test documents the expected behavior when it's implemented + encoder = AwsCborEncoder() + struct_shape_custom = self.StructureShape( + { + "user_id": self.SimpleShape("integer"), + "user_name": self.SimpleShape("string") + }, + serialization_names={"user_id": "UserId", "user_name": "UserName"} + ) + data = {"user_id": 789, "user_name": "Bob"} + encoder.write_data_item_shaped(data, struct_shape_custom) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + # For now, it uses the original names (not custom serialization names) + # TODO: Update this test when custom serialization is implemented + self.assertEqual(decoder.pop_next_data_item(), {"user_id": 789, "user_name": "Bob"}) + + def test_write_data_item_shaped_nested_structure(self): + """Test write_data_item_shaped with nested structures""" + + # Create nested structure: User with Address + encoder = AwsCborEncoder() + address_shape = self.StructureShape({ + "street": self.SimpleShape("string"), + "city": self.SimpleShape("string"), + "zip": self.SimpleShape("integer") + }) + user_shape = self.StructureShape({ + "id": self.SimpleShape("integer"), + "name": self.SimpleShape("string"), + "address": address_shape + }) + + data = { + "id": 1, + "name": "Alice", + "address": { + "street": "123 Main St", + "city": "Springfield", + "zip": 12345 + } + } + encoder.write_data_item_shaped(data, user_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), data) + + # Test with None in nested structure + encoder = AwsCborEncoder() + data_with_none = { + "id": 2, + "name": "Bob", + "address": { + "street": "456 Elm St", + "city": None, # This should be filtered + "zip": 67890 + } + } + encoder.write_data_item_shaped(data_with_none, user_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + expected = { + "id": 2, + "name": "Bob", + "address": { + "street": "456 Elm St", + "zip": 67890 + } + } + self.assertEqual(decoder.pop_next_data_item(), expected) + + def test_write_data_item_shaped_structure_with_list(self): + """Test write_data_item_shaped with structures containing lists""" + + encoder = AwsCborEncoder() + struct_shape = self.StructureShape({ + "id": self.SimpleShape("integer"), + "tags": self.ListShape(self.SimpleShape("string")), + "scores": self.ListShape(self.SimpleShape("integer")) + }) + + data = { + "id": 100, + "tags": ["python", "aws", "cbor"], + "scores": [85, 90, 95] + } + encoder.write_data_item_shaped(data, struct_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), data) + + def test_write_data_item_shaped_structure_with_map(self): + """Test write_data_item_shaped with structures containing maps""" + + encoder = AwsCborEncoder() + struct_shape = self.StructureShape({ + "id": self.SimpleShape("integer"), + "metadata": self.MapShape(self.SimpleShape("string"), self.SimpleShape("string")), + "counts": self.MapShape(self.SimpleShape("string"), self.SimpleShape("integer")) + }) + + data = { + "id": 200, + "metadata": {"author": "Alice", "version": "1.0"}, + "counts": {"views": 100, "likes": 50} + } + encoder.write_data_item_shaped(data, struct_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), data) + + def test_write_data_item_shaped_timestamp(self): + """Test write_data_item_shaped with timestamp type""" + + # Test timestamp with converter + encoder = AwsCborEncoder() + timestamp_shape = self.SimpleShape("timestamp") + + # Create a mock datetime-like object + class MockDateTime: + def timestamp(self): + return 1609459200.0 # 2021-01-01 00:00:00 UTC + + mock_dt = MockDateTime() + + # Converter function + def timestamp_converter(dt): + return dt.timestamp() + + encoder.write_data_item_shaped(mock_dt, timestamp_shape, timestamp_converter) + + # Decode and verify it's encoded as epoch time with tag + decoder = AwsCborDecoder(encoder.get_encoded_data()) + # Should have tag 1 followed by the float value + self.assertEqual(decoder.peek_next_type(), AwsCborType.Tag) + tag_id = decoder.pop_next_tag_val() + self.assertEqual(tag_id, 1) + timestamp_value = decoder.pop_next_data_item() + self.assertAlmostEqual(timestamp_value, 1609459200.0, places=5) + + # Test timestamp without converter (already numeric) + # When converter is None, pass a simple converter that returns the value as-is + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(1609459200.5, timestamp_shape, lambda x: x) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.peek_next_type(), AwsCborType.Tag) + tag_id = decoder.pop_next_tag_val() + self.assertEqual(tag_id, 1) + timestamp_value = decoder.pop_next_data_item() + self.assertAlmostEqual(timestamp_value, 1609459200.5, places=5) + + def test_write_data_item_shaped_complex_nested(self): + """Test write_data_item_shaped with complex nested structures""" + + encoder = AwsCborEncoder() + + # Create a complex shape: Organization with departments, each with employees + employee_shape = self.StructureShape({ + "id": self.SimpleShape("integer"), + "name": self.SimpleShape("string"), + "email": self.SimpleShape("string"), + "skills": self.ListShape(self.SimpleShape("string")) + }) + + department_shape = self.StructureShape({ + "name": self.SimpleShape("string"), + "budget": self.SimpleShape("integer"), + "employees": self.ListShape(employee_shape) + }) + + org_shape = self.StructureShape({ + "org_name": self.SimpleShape("string"), + "founded": self.SimpleShape("integer"), + "departments": self.ListShape(department_shape), + "metadata": self.MapShape(self.SimpleShape("string"), self.SimpleShape("string")) + }) + + data = { + "org_name": "TechCorp", + "founded": 2020, + "departments": [ + { + "name": "Engineering", + "budget": 1000000, + "employees": [ + { + "id": 1, + "name": "Alice", + "email": "alice@techcorp.com", + "skills": ["Python", "Go", "AWS"] + }, + { + "id": 2, + "name": "Bob", + "email": "bob@techcorp.com", + "skills": ["Java", "Kubernetes"] + } + ] + }, + { + "name": "Sales", + "budget": 500000, + "employees": [ + { + "id": 3, + "name": "Charlie", + "email": "charlie@techcorp.com", + "skills": ["Negotiation", "CRM"] + } + ] + } + ], + "metadata": {"industry": "Technology", "location": "Seattle"} + } + + encoder.write_data_item_shaped(data, org_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + decoded = decoder.pop_next_data_item() + self.assertEqual(decoded, data) + + def test_write_data_item_shaped_empty_collections(self): + """Test write_data_item_shaped with empty collections""" + + # Empty structure + encoder = AwsCborEncoder() + struct_shape = self.StructureShape({}) + encoder.write_data_item_shaped({}, struct_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), {}) + + # Structure with all None values (should result in empty map) + encoder = AwsCborEncoder() + struct_shape = self.StructureShape({ + "a": self.SimpleShape("string"), + "b": self.SimpleShape("integer") + }) + encoder.write_data_item_shaped({"a": None, "b": None}, struct_shape) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), {}) + + def test_write_data_item_shaped_special_values(self): + """Test write_data_item_shaped with special float values""" + import math + + # Test positive infinity + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(float('inf'), self.SimpleShape("float")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + result = decoder.pop_next_data_item() + self.assertTrue(math.isinf(result) and result > 0) + + # Test negative infinity + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(float('-inf'), self.SimpleShape("double")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + result = decoder.pop_next_data_item() + self.assertTrue(math.isinf(result) and result < 0) + + # Test NaN + encoder = AwsCborEncoder() + encoder.write_data_item_shaped(float('nan'), self.SimpleShape("float")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + result = decoder.pop_next_data_item() + self.assertTrue(math.isnan(result)) + + def test_write_data_item_shaped_large_numbers(self): + """Test write_data_item_shaped with large numbers""" + + # Test large positive integer + encoder = AwsCborEncoder() + large_int = 2**63 - 1 + encoder.write_data_item_shaped(large_int, self.SimpleShape("long")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), large_int) + + # Test large negative integer + encoder = AwsCborEncoder() + large_neg = -2**63 + encoder.write_data_item_shaped(large_neg, self.SimpleShape("long")) + decoder = AwsCborDecoder(encoder.get_encoded_data()) + self.assertEqual(decoder.pop_next_data_item(), large_neg)