diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index 53f92c8f23d2..87cdf9daee79 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -534,13 +534,16 @@ class ScalarMemoTable : public MemoTable { // Merge entries from `other_table` into `this->hash_table_`. Status MergeTable(const ScalarMemoTable& other_table) { const HashTableType& other_hashtable = other_table.hash_table_; + Status status = Status::OK(); - other_hashtable.VisitEntries([this](const HashTableEntry* other_entry) { + other_hashtable.VisitEntries([this, &status](const HashTableEntry* other_entry) { + if (ARROW_PREDICT_FALSE(!status.ok())) { + return; + } int32_t unused; - ARROW_DCHECK_OK(this->GetOrInsert(other_entry->payload.value, &unused)); + status = this->GetOrInsert(other_entry->payload.value, &unused); }); - // TODO: ARROW-17074 - implement proper error handling - return Status::OK(); + return status; } }; @@ -899,11 +902,15 @@ class BinaryMemoTable : public MemoTable { public: Status MergeTable(const BinaryMemoTable& other_table) { - other_table.VisitValues(0, [this](std::string_view other_value) { + Status status = Status::OK(); + other_table.VisitValues(0, [this, &status](std::string_view other_value) { + if (ARROW_PREDICT_FALSE(!status.ok())) { + return; + } int32_t unused; - ARROW_DCHECK_OK(this->GetOrInsert(other_value, &unused)); + status = this->GetOrInsert(other_value, &unused); }); - return Status::OK(); + return status; } }; diff --git a/cpp/src/arrow/util/hashing_test.cc b/cpp/src/arrow/util/hashing_test.cc index f6ada0acd2d0..1c5223900563 100644 --- a/cpp/src/arrow/util/hashing_test.cc +++ b/cpp/src/arrow/util/hashing_test.cc @@ -30,6 +30,7 @@ #include "arrow/array/builder_primitive.h" #include "arrow/array/concatenate.h" +#include "arrow/memory_pool.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/bit_util.h" #include "arrow/util/hashing.h" @@ -376,6 +377,32 @@ TEST(ScalarMemoTable, StressInt64) { ASSERT_EQ(table.size(), map.size()); } +TEST(ScalarMemoTable, MergeTablePropagatesInsertError) { + int64_t bytes_allocated_limit = 0; + { + ProxyMemoryPool probe(default_memory_pool()); + ScalarMemoTable target(&probe, 0); + for (int64_t value = 0; value < 15; ++value) { + AssertGetOrInsert(target, value, static_cast(value)); + } + bytes_allocated_limit = probe.bytes_allocated(); + } + ASSERT_GT(bytes_allocated_limit, 0); + + ScalarMemoTable source(default_memory_pool(), 0); + AssertGetOrInsert(source, 15, 0); + + ProxyMemoryPool proxy(default_memory_pool()); + CappedMemoryPool pool(&proxy, bytes_allocated_limit); + ScalarMemoTable target(&pool, 0); + for (int64_t value = 0; value < 15; ++value) { + AssertGetOrInsert(target, value, static_cast(value)); + } + ASSERT_EQ(proxy.bytes_allocated(), bytes_allocated_limit); + + ASSERT_RAISES(OutOfMemory, target.MergeTable(source)); +} + TEST(BinaryMemoTable, Basics) { std::string A = "", B = "a", C = "foo", D = "bar", E, F; E += '\0';