Skip to content

Commit 2be0482

Browse files
authored
DPL Analysis: table slicing helper (#4999)
1 parent f2719b5 commit 2be0482

File tree

2 files changed

+49
-3
lines changed

2 files changed

+49
-3
lines changed

Analysis/Tutorials/src/associatedExample.cxx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ struct ZTask {
9696

9797
LOGF(INFO, "Bin 0-10");
9898
for (auto& col : multbin0_10) {
99-
auto groupedTracks = tracks.select(aod::track::collisionId == col.globalIndex());
99+
auto groupedTracks = tracks.sliceBy(aod::track::collisionId, col.globalIndex());
100100
LOGF(INFO, "Collision %d; Ntrk = %d vs %d", col.globalIndex(), col.mult(), groupedTracks.size());
101101
if (groupedTracks.size() > 0) {
102102
auto track = groupedTracks.begin();
@@ -106,13 +106,13 @@ struct ZTask {
106106

107107
LOGF(INFO, "Bin 10-30");
108108
for (auto& col : multbin10_30) {
109-
auto groupedTracks = tracks.select(aod::track::collisionId == col.globalIndex());
109+
auto groupedTracks = tracks.sliceBy(aod::track::collisionId, col.globalIndex());
110110
LOGF(INFO, "Collision %d; Ntrk = %d vs %d", col.globalIndex(), col.mult(), groupedTracks.size());
111111
}
112112

113113
LOGF(INFO, "Bin 30-100");
114114
for (auto& col : multbin30_100) {
115-
auto groupedTracks = tracks.select(aod::track::collisionId == col.globalIndex());
115+
auto groupedTracks = tracks.sliceBy(aod::track::collisionId, col.globalIndex());
116116
LOGF(INFO, "Collision %d; Ntrk = %d vs %d", col.globalIndex(), col.mult(), groupedTracks.size());
117117
}
118118
}

Framework/Core/include/Framework/ASoA.h

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <arrow/array.h>
2424
#include <arrow/util/variant.h>
2525
#include <arrow/compute/kernel.h>
26+
#include <arrow/compute/api_aggregate.h>
2627
#include <gandiva/selection_vector.h>
2728
#include <cassert>
2829
#include <fmt/format.h>
@@ -864,6 +865,44 @@ auto select(T const& t, framework::expressions::Filter&& f)
864865
t.asArrowTable()->schema()));
865866
}
866867

868+
namespace
869+
{
870+
auto getSliceFor(int value, char const* key, std::shared_ptr<arrow::Table> const& input, std::shared_ptr<arrow::Table>& output, uint64_t& offset)
871+
{
872+
arrow::Datum value_counts;
873+
auto options = arrow::compute::CountOptions::Defaults();
874+
ARROW_ASSIGN_OR_RAISE(value_counts,
875+
arrow::compute::CallFunction("value_counts", {input->GetColumnByName(key)},
876+
&options));
877+
auto pair = static_cast<arrow::StructArray>(value_counts.array());
878+
auto values = static_cast<arrow::NumericArray<arrow::Int32Type>>(pair.field(0)->data());
879+
auto counts = static_cast<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data());
880+
881+
int slice;
882+
for (slice = 0; slice < values.length(); ++slice) {
883+
if (values.Value(slice) == value) {
884+
offset = slice;
885+
output = input->Slice(slice, counts.Value(slice));
886+
return arrow::Status::OK();
887+
}
888+
}
889+
output = input->Slice(0, 0);
890+
return arrow::Status::OK();
891+
}
892+
} // namespace
893+
894+
template <typename T>
895+
auto sliceBy(T const& t, framework::expressions::BindingNode const& node, int value)
896+
{
897+
uint64_t offset = 0;
898+
std::shared_ptr<arrow::Table> result = nullptr;
899+
auto status = getSliceFor(value, node.name.c_str(), t.asArrowTable(), result, offset);
900+
if (status.ok()) {
901+
return T({result}, offset);
902+
}
903+
throw std::runtime_error("Failed to slice table");
904+
}
905+
867906
/// A Table class which observes an arrow::Table and provides
868907
/// It is templated on a set of Column / DynamicColumn types.
869908
template <typename... C>
@@ -1081,6 +1120,13 @@ class Table
10811120
return t;
10821121
}
10831122

1123+
auto sliceBy(framework::expressions::BindingNode const& node, int value) const
1124+
{
1125+
auto t = o2::soa::sliceBy(*this, node, value);
1126+
copyIndexBindings(t);
1127+
return t;
1128+
}
1129+
10841130
private:
10851131
template <typename T>
10861132
arrow::ChunkedArray* lookupColumn()

0 commit comments

Comments
 (0)