Skip to content

Commit 1fb3c1b

Browse files
authored
DPL Analysis: generalized grouping (#7378)
1 parent 12f7dfa commit 1fb3c1b

File tree

5 files changed

+312
-119
lines changed

5 files changed

+312
-119
lines changed

Framework/Core/include/Framework/ASoA.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2279,6 +2279,19 @@ struct IndexTable : Table<soa::Index<>, H, Ts...> {
22792279

22802280
template <typename T>
22812281
using is_soa_index_table_t = typename framework::is_base_of_template<soa::IndexTable, T>;
2282+
2283+
template <typename T>
2284+
struct SmallGroups : Filtered<T> {
2285+
SmallGroups(std::vector<std::shared_ptr<arrow::Table>>&& tables, SelectionVector&& selection, uint64_t offset = 0)
2286+
: Filtered<T>(std::move(tables), std::forward<SelectionVector>(selection), offset) {}
2287+
2288+
SmallGroups(std::vector<std::shared_ptr<arrow::Table>>&& tables, framework::expressions::Selection selection, uint64_t offset = 0)
2289+
: Filtered<T>(std::move(tables), selection, offset) {}
2290+
2291+
SmallGroups(std::vector<std::shared_ptr<arrow::Table>>&& tables, gandiva::NodePtr const& tree, uint64_t offset = 0)
2292+
: Filtered<T>(std::move(tables), tree, offset) {}
2293+
};
2294+
22822295
} // namespace o2::soa
22832296

22842297
#endif // O2_FRAMEWORK_ASOA_H_

Framework/Core/include/Framework/AnalysisTask.h

Lines changed: 197 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,85 @@ namespace o2::framework
5252
struct AnalysisTask {
5353
};
5454

55+
namespace
56+
{
57+
template <typename B, typename C>
58+
constexpr static bool isIndexTo()
59+
{
60+
if constexpr (soa::is_type_with_binding_v<C>) {
61+
if constexpr (soa::is_soa_index_table_t<std::decay_t<B>>::value) {
62+
using T = typename std::decay_t<B>::first_t;
63+
if constexpr (soa::is_type_with_originals_v<std::decay_t<T>>) {
64+
using TT = typename framework::pack_element_t<0, typename std::decay_t<T>::originals>;
65+
return std::is_same_v<typename C::binding_t, TT>;
66+
} else {
67+
using TT = std::decay_t<T>;
68+
return std::is_same_v<typename C::binding_t, TT>;
69+
}
70+
} else {
71+
if constexpr (soa::is_type_with_originals_v<std::decay_t<B>>) {
72+
using TT = typename framework::pack_element_t<0, typename std::decay_t<B>::originals>;
73+
return std::is_same_v<typename C::binding_t, TT>;
74+
} else {
75+
using TT = std::decay_t<B>;
76+
return std::is_same_v<typename C::binding_t, TT>;
77+
}
78+
}
79+
}
80+
return false;
81+
}
82+
83+
template <typename B, typename C>
84+
constexpr static bool isSortedIndexTo()
85+
{
86+
if constexpr (soa::is_type_with_binding_v<C>) {
87+
if constexpr (soa::is_soa_index_table_t<std::decay_t<B>>::value) {
88+
using T = typename std::decay_t<B>::first_t;
89+
if constexpr (soa::is_type_with_originals_v<std::decay_t<T>>) {
90+
using TT = typename framework::pack_element_t<0, typename std::decay_t<T>::originals>;
91+
return std::is_same_v<typename C::binding_t, TT> && C::sorted;
92+
} else {
93+
using TT = std::decay_t<T>;
94+
return std::is_same_v<typename C::binding_t, TT> && C::sorted;
95+
}
96+
} else {
97+
if constexpr (soa::is_type_with_originals_v<std::decay_t<B>>) {
98+
using TT = typename framework::pack_element_t<0, typename std::decay_t<B>::originals>;
99+
return std::is_same_v<typename C::binding_t, TT> && C::sorted;
100+
} else {
101+
using TT = std::decay_t<B>;
102+
return std::is_same_v<typename C::binding_t, TT> && C::sorted;
103+
}
104+
}
105+
}
106+
return false;
107+
}
108+
109+
template <typename B, typename... C>
110+
constexpr static bool hasIndexTo(framework::pack<C...>&&)
111+
{
112+
return (isIndexTo<B, C>() || ...);
113+
}
114+
115+
template <typename B, typename... C>
116+
constexpr static bool hasSortedIndexTo(framework::pack<C...>&&)
117+
{
118+
return (isSortedIndexTo<B, C>() || ...);
119+
}
120+
121+
template <typename B, typename Z>
122+
constexpr static bool relatedByIndex()
123+
{
124+
return hasIndexTo<B>(typename Z::persistent_columns_t{});
125+
}
126+
127+
template <typename B, typename Z>
128+
constexpr static bool relatedBySortedIndex()
129+
{
130+
return hasSortedIndexTo<B>(typename Z::persistent_columns_t{});
131+
}
132+
} // namespace
133+
55134
// Helper struct which builds a DataProcessorSpec from
56135
// the contents of an AnalysisTask...
57136

@@ -119,7 +198,7 @@ struct AnalysisDataProcessorBuilder {
119198
static void appendSomethingWithMetadata(const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos, size_t hash)
120199
{
121200
using dT = std::decay_t<T>;
122-
if constexpr (framework::is_specialization<dT, soa::Filtered>::value) {
201+
if constexpr (soa::is_soa_filtered_t<dT>::value) {
123202
eInfos.push_back({AI, hash, dT::hashes(), o2::soa::createSchemaFromColumns(typename dT::table_t::persistent_columns_t{}), nullptr});
124203
} else if constexpr (soa::is_soa_iterator_t<dT>::value) {
125204
if constexpr (std::is_same_v<typename dT::policy_t, soa::FilteredIndexPolicy>) {
@@ -129,12 +208,6 @@ struct AnalysisDataProcessorBuilder {
129208
doAppendInputWithMetadata(soa::make_originals_from_type<dT>(), name, value, inputs);
130209
}
131210

132-
// template <typename... T>
133-
// static void inputsFromArgsTuple(std::tuple<T...>& processTuple, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos)
134-
// {
135-
// (inputsFromArgs<o2::framework::has_type_at_v<T>(pack<T...>{})>(std::get<T>(processTuple), inputs, eInfos), ...);
136-
// }
137-
138211
template <typename R, typename C, typename... Args>
139212
static void inputsFromArgs(R (C::*)(Args...), const char* name, bool value, std::vector<InputSpec>& inputs, std::vector<ExpressionInfo>& eInfos)
140213
{
@@ -189,9 +262,15 @@ struct AnalysisDataProcessorBuilder {
189262
static auto extractFilteredFromRecord(InputRecord& record, ExpressionInfo const& info, pack<Os...> const&)
190263
{
191264
if constexpr (soa::is_soa_iterator_t<T>::value) {
192-
return typename T::parent_t(std::vector<std::shared_ptr<arrow::Table>>{extractTableFromRecord<Os>(record)...}, info.tree);
265+
if (info.tree != nullptr) {
266+
return typename T::parent_t(std::vector<std::shared_ptr<arrow::Table>>{extractTableFromRecord<Os>(record)...}, info.tree);
267+
}
268+
return typename T::parent_t(std::vector<std::shared_ptr<arrow::Table>>{extractTableFromRecord<Os>(record)...}, soa::SelectionVector{});
193269
} else {
194-
return T(std::vector<std::shared_ptr<arrow::Table>>{extractTableFromRecord<Os>(record)...}, info.tree);
270+
if (info.tree != nullptr) {
271+
return T(std::vector<std::shared_ptr<arrow::Table>>{extractTableFromRecord<Os>(record)...}, info.tree);
272+
}
273+
return T(std::vector<std::shared_ptr<arrow::Table>>{extractTableFromRecord<Os>(record)...}, soa::SelectionVector{});
195274
}
196275
}
197276

@@ -275,95 +354,82 @@ struct AnalysisDataProcessorBuilder {
275354
}
276355
}
277356

357+
template <typename T>
358+
auto splittingFunction(T&& table)
359+
{
360+
constexpr auto index = framework::has_type_at_v<std::decay_t<T>>(associated_pack_t{});
361+
if constexpr (relatedByIndex<std::decay_t<G>, std::decay_t<T>>()) {
362+
auto name = getLabelFromType<std::decay_t<T>>();
363+
if constexpr (!framework::is_specialization<std::decay_t<T>, soa::SmallGroups>::value) {
364+
if (table.size() == 0) {
365+
return;
366+
}
367+
// use presorted splitting approach
368+
auto result = o2::framework::sliceByColumn(mIndexColumnName.c_str(),
369+
name.c_str(),
370+
table.asArrowTable(),
371+
static_cast<int32_t>(mGt->tableSize()),
372+
&groups[index],
373+
&offsets[index],
374+
&sizes[index]);
375+
if (result.ok() == false) {
376+
throw runtime_error("Cannot split collection");
377+
}
378+
if (groups[index].size() > mGt->tableSize()) {
379+
throw runtime_error_f("Splitting collection %s resulted in a larger group number (%d) than there is rows in the grouping table (%d).", name.c_str(), groups[index].size(), mGt->tableSize());
380+
};
381+
} else {
382+
if (table.tableSize() == 0) {
383+
return;
384+
}
385+
// use generic splitting approach
386+
o2::framework::sliceByColumnGeneric(mIndexColumnName.c_str(),
387+
name.c_str(),
388+
table.asArrowTable(),
389+
static_cast<int32_t>(mGt->tableSize()),
390+
&filterGroups[index]);
391+
}
392+
}
393+
}
394+
395+
template <typename T>
396+
auto extractingFunction(T&& table)
397+
{
398+
if constexpr (soa::is_soa_filtered_t<std::decay_t<T>>::value) {
399+
constexpr auto index = framework::has_type_at_v<std::decay_t<T>>(associated_pack_t{});
400+
selections[index] = &table.getSelectedRows();
401+
starts[index] = selections[index]->begin();
402+
offsets[index].push_back(table.tableSize());
403+
}
404+
}
405+
278406
GroupSlicerIterator(G& gt, std::tuple<A...>& at)
279-
: mAt{&at},
407+
: mIndexColumnName{std::string("fIndex") + getLabelFromType<G>()},
408+
mGt{&gt},
409+
mAt{&at},
280410
mGroupingElement{gt.begin()},
281411
position{0}
282412
{
283413
if constexpr (soa::is_soa_filtered_t<std::decay_t<G>>::value) {
284-
groupSelection = &gt.getSelectedRows();
414+
groupSelection = &mGt->getSelectedRows();
285415
}
286-
auto indexColumnName = std::string("fIndex") + getLabelFromType<G>();
416+
287417
/// prepare slices and offsets for all associated tables that have index
288418
/// to grouping table
289419
///
290-
auto splitter = [&](auto&& x) {
291-
using xt = std::decay_t<decltype(x)>;
292-
constexpr auto index = framework::has_type_at_v<std::decay_t<decltype(x)>>(associated_pack_t{});
293-
if (hasIndexTo<std::decay_t<G>>(typename xt::persistent_columns_t{})) {
294-
if (x.size() != 0) {
295-
auto name = getLabelFromType<decltype(x)>();
296-
auto result = o2::framework::sliceByColumn(indexColumnName.c_str(),
297-
name.c_str(),
298-
x.asArrowTable(),
299-
static_cast<int32_t>(gt.tableSize()),
300-
&groups[index],
301-
&offsets[index],
302-
&sizes[index]);
303-
if (result.ok() == false) {
304-
throw runtime_error("Cannot split collection");
305-
}
306-
if (groups[index].size() > gt.tableSize()) {
307-
throw runtime_error_f("Splitting collection %s resulted in a larger group number (%d) than there is rows in the grouping table (%d).", name.c_str(), groups[index].size(), gt.tableSize());
308-
};
309-
}
310-
}
311-
};
312-
313420
std::apply(
314421
[&](auto&&... x) -> void {
315-
(splitter(x), ...);
422+
(splittingFunction(x), ...);
316423
},
317424
at);
318425
/// extract selections from filtered associated tables
319-
auto extractor = [&](auto&& x) {
320-
using xt = std::decay_t<decltype(x)>;
321-
if constexpr (soa::is_soa_filtered_t<xt>::value) {
322-
constexpr auto index = framework::has_type_at_v<std::decay_t<decltype(x)>>(associated_pack_t{});
323-
selections[index] = &x.getSelectedRows();
324-
starts[index] = selections[index]->begin();
325-
offsets[index].push_back(std::get<xt>(at).tableSize());
326-
}
327-
};
328426
std::apply(
329427
[&](auto&&... x) -> void {
330-
(extractor(x), ...);
428+
(extractingFunction(x), ...);
331429
},
332430
at);
333431
}
334432

335-
template <typename B, typename... C>
336-
constexpr static bool hasIndexTo(framework::pack<C...>&&)
337-
{
338-
return (isIndexTo<B, C>() || ...);
339-
}
340-
341-
template <typename B, typename C>
342-
constexpr static bool isIndexTo()
343-
{
344-
if constexpr (soa::is_type_with_binding_v<C>) {
345-
if constexpr (soa::is_soa_index_table_t<std::decay_t<B>>::value) {
346-
using T = typename std::decay_t<B>::first_t;
347-
if constexpr (soa::is_type_with_originals_v<std::decay_t<T>>) {
348-
using TT = typename framework::pack_element_t<0, typename std::decay_t<T>::originals>;
349-
return std::is_same_v<typename C::binding_t, TT>;
350-
} else {
351-
using TT = std::decay_t<T>;
352-
return std::is_same_v<typename C::binding_t, TT>;
353-
}
354-
} else {
355-
if constexpr (soa::is_type_with_originals_v<std::decay_t<B>>) {
356-
using TT = typename framework::pack_element_t<0, typename std::decay_t<B>::originals>;
357-
return std::is_same_v<typename C::binding_t, TT>;
358-
} else {
359-
using TT = std::decay_t<B>;
360-
return std::is_same_v<typename C::binding_t, TT>;
361-
}
362-
}
363-
}
364-
return false;
365-
}
366-
367433
GroupSlicerIterator& operator++()
368434
{
369435
++position;
@@ -400,47 +466,75 @@ struct AnalysisDataProcessorBuilder {
400466
auto prepareArgument()
401467
{
402468
constexpr auto index = framework::has_type_at_v<A1>(associated_pack_t{});
403-
if (std::get<A1>(*mAt).size() == 0) {
404-
return std::get<A1>(*mAt);
405-
}
406-
if (hasIndexTo<std::decay_t<G>>(typename std::decay_t<A1>::persistent_columns_t{})) {
469+
auto& originalTable = std::get<A1>(*mAt);
470+
471+
if constexpr (relatedByIndex<std::decay_t<G>, std::decay_t<A1>>()) {
407472
uint64_t pos;
408473
if constexpr (soa::is_soa_filtered_t<std::decay_t<G>>::value) {
409474
pos = (*groupSelection)[position];
410475
} else {
411476
pos = position;
412477
}
413-
if constexpr (soa::is_soa_filtered_t<std::decay_t<A1>>::value) {
414-
auto groupedElementsTable = arrow::util::get<std::shared_ptr<arrow::Table>>(((groups[index])[pos]).value);
415-
416-
// for each grouping element we need to slice the selection vector
417-
auto start_iterator = std::lower_bound(starts[index], selections[index]->end(), (offsets[index])[pos]);
418-
auto stop_iterator = std::lower_bound(start_iterator, selections[index]->end(), (offsets[index])[pos] + (sizes[index])[pos]);
419-
starts[index] = stop_iterator;
420-
soa::SelectionVector slicedSelection{start_iterator, stop_iterator};
421-
std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(),
422-
[&](int64_t idx) {
423-
return idx - static_cast<int64_t>((offsets[index])[pos]);
424-
});
425-
426-
std::decay_t<A1> typedTable{{groupedElementsTable}, std::move(slicedSelection), (offsets[index])[pos]};
427-
typedTable.bindInternalIndicesTo(&std::get<A1>(*mAt));
428-
return typedTable;
478+
if constexpr (!framework::is_specialization<std::decay_t<A1>, soa::SmallGroups>::value) {
479+
if (originalTable.size() == 0) {
480+
return originalTable;
481+
}
482+
// optimized split
483+
if constexpr (soa::is_soa_filtered_t<std::decay_t<A1>>::value) {
484+
auto groupedElementsTable = arrow::util::get<std::shared_ptr<arrow::Table>>(((groups[index])[pos]).value);
485+
486+
// for each grouping element we need to slice the selection vector
487+
auto start_iterator = std::lower_bound(starts[index], selections[index]->end(), (offsets[index])[pos]);
488+
auto stop_iterator = std::lower_bound(start_iterator, selections[index]->end(), (offsets[index])[pos] + (sizes[index])[pos]);
489+
starts[index] = stop_iterator;
490+
soa::SelectionVector slicedSelection{start_iterator, stop_iterator};
491+
std::transform(slicedSelection.begin(), slicedSelection.end(), slicedSelection.begin(),
492+
[&](int64_t idx) {
493+
return idx - static_cast<int64_t>((offsets[index])[pos]);
494+
});
495+
496+
std::decay_t<A1> typedTable{{groupedElementsTable}, std::move(slicedSelection), (offsets[index])[pos]};
497+
typedTable.bindInternalIndicesTo(&originalTable);
498+
return typedTable;
499+
} else {
500+
auto groupedElementsTable = arrow::util::get<std::shared_ptr<arrow::Table>>(((groups[index])[pos]).value);
501+
std::decay_t<A1> typedTable{{groupedElementsTable}, (offsets[index])[pos]};
502+
typedTable.bindInternalIndicesTo(&originalTable);
503+
return typedTable;
504+
}
429505
} else {
430-
auto groupedElementsTable = arrow::util::get<std::shared_ptr<arrow::Table>>(((groups[index])[pos]).value);
431-
std::decay_t<A1> typedTable{{groupedElementsTable}, (offsets[index])[pos]};
432-
typedTable.bindInternalIndicesTo(&std::get<A1>(*mAt));
433-
return typedTable;
506+
//generic split
507+
if constexpr (soa::is_soa_filtered_t<std::decay_t<A1>>::value) {
508+
if (originalTable.tableSize() == 0) {
509+
return originalTable;
510+
}
511+
// intersect selections
512+
o2::soa::SelectionVector s;
513+
if (selections[index]->empty()) {
514+
std::copy((filterGroups[index])[pos].begin(), (filterGroups[index])[pos].end(), std::back_inserter(s));
515+
} else {
516+
std::set_intersection((filterGroups[index])[pos].begin(), (filterGroups[index])[pos].end(), selections[index]->begin(), selections[index]->end(), std::back_inserter(s));
517+
}
518+
std::decay_t<A1> typedTable{{originalTable.asArrowTable()}, std::move(s)};
519+
typedTable.bindInternalIndicesTo(&originalTable);
520+
return typedTable;
521+
} else {
522+
throw runtime_error("Unsorted grouped table needs to be used with soa::SmallGroups<>");
523+
}
434524
}
525+
} else {
526+
return std::get<A1>(*mAt);
435527
}
436-
return std::get<A1>(*mAt);
437528
}
438529

530+
std::string mIndexColumnName;
531+
G const* mGt;
439532
std::tuple<A...>* mAt;
440533
typename grouping_t::iterator mGroupingElement;
441534
uint64_t position = 0;
442535
soa::SelectionVector const* groupSelection = nullptr;
443536
std::array<std::vector<arrow::Datum>, sizeof...(A)> groups;
537+
std::array<ListVector, sizeof...(A)> filterGroups;
444538
std::array<std::vector<uint64_t>, sizeof...(A)> offsets;
445539
std::array<std::vector<int>, sizeof...(A)> sizes;
446540
std::array<soa::SelectionVector const*, sizeof...(A)> selections;

0 commit comments

Comments
 (0)