|
23 | 23 | #include <arrow/array.h> |
24 | 24 | #include <arrow/util/variant.h> |
25 | 25 | #include <arrow/compute/kernel.h> |
| 26 | +#include <arrow/compute/api_aggregate.h> |
26 | 27 | #include <gandiva/selection_vector.h> |
27 | 28 | #include <cassert> |
28 | 29 | #include <fmt/format.h> |
@@ -864,6 +865,44 @@ auto select(T const& t, framework::expressions::Filter&& f) |
864 | 865 | t.asArrowTable()->schema())); |
865 | 866 | } |
866 | 867 |
|
| 868 | +namespace |
| 869 | +{ |
| 870 | +auto getSliceFor(int value, char const* key, std::shared_ptr<arrow::Table> const& input, std::shared_ptr<arrow::Table>& output, uint64_t& offset) |
| 871 | +{ |
| 872 | + arrow::Datum value_counts; |
| 873 | + auto options = arrow::compute::CountOptions::Defaults(); |
| 874 | + ARROW_ASSIGN_OR_RAISE(value_counts, |
| 875 | + arrow::compute::CallFunction("value_counts", {input->GetColumnByName(key)}, |
| 876 | + &options)); |
| 877 | + auto pair = static_cast<arrow::StructArray>(value_counts.array()); |
| 878 | + auto values = static_cast<arrow::NumericArray<arrow::Int32Type>>(pair.field(0)->data()); |
| 879 | + auto counts = static_cast<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data()); |
| 880 | + |
| 881 | + int slice; |
| 882 | + for (slice = 0; slice < values.length(); ++slice) { |
| 883 | + if (values.Value(slice) == value) { |
| 884 | + offset = slice; |
| 885 | + output = input->Slice(slice, counts.Value(slice)); |
| 886 | + return arrow::Status::OK(); |
| 887 | + } |
| 888 | + } |
| 889 | + output = input->Slice(0, 0); |
| 890 | + return arrow::Status::OK(); |
| 891 | +} |
| 892 | +} // namespace |
| 893 | + |
| 894 | +template <typename T> |
| 895 | +auto sliceBy(T const& t, framework::expressions::BindingNode const& node, int value) |
| 896 | +{ |
| 897 | + uint64_t offset = 0; |
| 898 | + std::shared_ptr<arrow::Table> result = nullptr; |
| 899 | + auto status = getSliceFor(value, node.name.c_str(), t.asArrowTable(), result, offset); |
| 900 | + if (status.ok()) { |
| 901 | + return T({result}, offset); |
| 902 | + } |
| 903 | + throw std::runtime_error("Failed to slice table"); |
| 904 | +} |
| 905 | + |
867 | 906 | /// A Table class which observes an arrow::Table and provides |
868 | 907 | /// It is templated on a set of Column / DynamicColumn types. |
869 | 908 | template <typename... C> |
@@ -1081,6 +1120,13 @@ class Table |
1081 | 1120 | return t; |
1082 | 1121 | } |
1083 | 1122 |
|
| 1123 | + auto sliceBy(framework::expressions::BindingNode const& node, int value) const |
| 1124 | + { |
| 1125 | + auto t = o2::soa::sliceBy(*this, node, value); |
| 1126 | + copyIndexBindings(t); |
| 1127 | + return t; |
| 1128 | + } |
| 1129 | + |
1084 | 1130 | private: |
1085 | 1131 | template <typename T> |
1086 | 1132 | arrow::ChunkedArray* lookupColumn() |
|
0 commit comments