diff --git a/include/xtensor/io/xcsv.hpp b/include/xtensor/io/xcsv.hpp index 2fd347745..70472625e 100644 --- a/include/xtensor/io/xcsv.hpp +++ b/include/xtensor/io/xcsv.hpp @@ -274,6 +274,47 @@ namespace xt } }; + template + void dump_csv(std::ostream& stream, const xexpression& e, const xcsv_config& config) + { + using size_type = typename E::size_type; + const E& ex = e.derived_cast(); + if (ex.dimension() == 1) + { + const size_type n = ex.shape()[0]; + for (size_type i = 0; i != n; ++i) + { + stream << ex(i); + if (i != n - 1) + { + stream << config.delimiter; + } + } + stream << std::endl; + } + else if (ex.dimension() == 2) + { + const size_type nbrows = ex.shape()[0]; + const size_type nbcols = ex.shape()[1]; + for (size_type r = 0; r != nbrows; ++r) + { + for (size_type c = 0; c != nbcols; ++c) + { + stream << ex(r, c); + if (c != nbcols - 1) + { + stream << config.delimiter; + } + } + stream << std::endl; + } + } + else + { + XTENSOR_THROW(std::runtime_error, "Only 1-D and 2-D expressions can be serialized to CSV"); + } + } + template void load_file(std::istream& stream, xexpression& e, const xcsv_config& config) { @@ -287,9 +328,9 @@ namespace xt } template - void dump_file(std::ostream& stream, const xexpression& e, const xcsv_config&) + void dump_file(std::ostream& stream, const xexpression& e, const xcsv_config& config) { - dump_csv(stream, e); + dump_csv(stream, e, config); } } diff --git a/test/test_xcsv.cpp b/test/test_xcsv.cpp index 148b40cd3..00fcd4519 100644 --- a/test/test_xcsv.cpp +++ b/test/test_xcsv.cpp @@ -162,4 +162,28 @@ namespace xt XT_EXPECT_THROW(dump_csv(res, data), std::runtime_error); } + + TEST(xcsv, dump_with_config) + { + xtensor data{{1.0, 2.0, 3.0, 4.0}, {10.0, 12.0, 15.0, 18.0}}; + + std::stringstream res; + + xcsv_config config; + config.delimiter = ' '; + dump_csv(res, data, config); + ASSERT_EQ("1 2 3 4\n10 12 15 18\n", res.str()); + } + + TEST(xcsv, dump_file_with_config) + { + xtensor data{{1.0, 2.0, 3.0, 4.0}, {10.0, 12.0, 15.0, 18.0}}; + + std::stringstream res; + + xcsv_config config; + config.delimiter = ';'; + dump_file(res, data, config); + ASSERT_EQ("1;2;3;4\n10;12;15;18\n", res.str()); + } }