diff --git a/lang/c++/include/avro/Specific.hh b/lang/c++/include/avro/Specific.hh index fc28b3f5e4b..edd2c92d867 100644 --- a/lang/c++/include/avro/Specific.hh +++ b/lang/c++/include/avro/Specific.hh @@ -22,6 +22,7 @@ #include "array" #include #include +#include #include #include @@ -29,6 +30,7 @@ #include "Config.hh" #include "Decoder.hh" #include "Encoder.hh" +#include "Exception.hh" /** * A bunch of templates and specializations for encoding and decoding @@ -165,6 +167,43 @@ struct codec_traits { } }; +/** +* codec_traits for Avro optional assumming that the schema is ["null", T]. +*/ +template +struct codec_traits> { + /** + * Encodes a given value. + */ + static void encode(Encoder &e, const std::optional &b) { + if (b) { + e.encodeUnionIndex(1); + avro::encode(e, b.value()); + } else { + e.encodeUnionIndex(0); + e.encodeNull(); + } + } + + /** + * Decodes into a given value. + */ + static void decode(Decoder &d, std::optional &s) { + size_t n = d.decodeUnionIndex(); + if (n >= 2) { throw avro::Exception("Union index too big for optional (expected 0 or 1, got " + std::to_string(n) + ")"); } + switch (n) { + case 0: { + d.decodeNull(); + s.reset(); + } break; + case 1: { + s.emplace(); + avro::decode(d, *s); + } break; + } + } +}; + /** * codec_traits for Avro string. */ diff --git a/lang/c++/test/SpecificTests.cc b/lang/c++/test/SpecificTests.cc index 72f2897e45b..91c5ce6d551 100644 --- a/lang/c++/test/SpecificTests.cc +++ b/lang/c++/test/SpecificTests.cc @@ -24,6 +24,7 @@ using std::array; using std::map; +using std::optional; using std::string; using std::unique_ptr; using std::vector; @@ -127,6 +128,18 @@ void testDouble() { BOOST_CHECK_CLOSE(b, n, 0.00000001); } +void testNonEmptyOptional() { + optional n = -109; + optional b = encodeAndDecode(n); + BOOST_CHECK_EQUAL(b.value(), n.value()); +} + +void testEmptyOptional() { + optional n; + optional b = encodeAndDecode(n); + BOOST_CHECK(!b.has_value()); +} + void testString() { string n = "abc"; string b = encodeAndDecode(n); @@ -191,6 +204,8 @@ init_unit_test_suite(int /*argc*/, char * /*argv*/[]) { ts->add(BOOST_TEST_CASE(avro::specific::testLong)); ts->add(BOOST_TEST_CASE(avro::specific::testFloat)); ts->add(BOOST_TEST_CASE(avro::specific::testDouble)); + ts->add(BOOST_TEST_CASE(avro::specific::testNonEmptyOptional)); + ts->add(BOOST_TEST_CASE(avro::specific::testEmptyOptional)); ts->add(BOOST_TEST_CASE(avro::specific::testString)); ts->add(BOOST_TEST_CASE(avro::specific::testBytes)); ts->add(BOOST_TEST_CASE(avro::specific::testFixed));