Skip to content

Commit 7af85a2

Browse files
squash
1 parent e7b7990 commit 7af85a2

26 files changed

+243
-145
lines changed

GPU/Common/GPUCommonAlgorithm.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#define GPUCOMMONALGORITHM_H
1717

1818
#include "GPUCommonDef.h"
19-
#include "MemLayout.h"
2019

2120
#if !defined(GPUCA_GPUCODE) // Could also enable custom search on the CPU, but it is not always faster, so we stick to std::sort
2221
#include <algorithm>

GPU/Common/GPUCommonAlgorithmThrust.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
#pragma GCC diagnostic push
2020
#pragma GCC diagnostic ignored "-Wshadow"
2121
#include <thrust/sort.h>
22-
#include <thrust/iterator/iterator_traits.h>
2322
#include <thrust/execution_policy.h>
2423
#include <thrust/device_ptr.h>
2524
#pragma GCC diagnostic pop

GPU/Common/MemLayout.h

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@ template <class T> using const_reference = const T&;
1212
template <class T> using pointer = T*;
1313
template <class T> using const_pointer = const T*;
1414

15+
template <class T> using reference_restrict = T& __restrict__;
16+
template <class T> using const_reference_restrict = const T& __restrict__;
17+
template <class T> using pointer_restrict = T* __restrict__;
18+
template <class T> using const_pointer_restrict = const T* __restrict__;
19+
1520
template <class SF>
1621
struct RandomAccessAt {
1722
MemLayout::size_t i;
@@ -79,6 +84,8 @@ struct wrapper : public S<F> {
7984
constexpr wrapper(Base b) : Base{static_cast<Base&&>(b)} {}
8085
template <template <class> class F_other>
8186
constexpr wrapper(S<F_other>& other) : Base{other.apply(AggregateConstructor<Base>{})} {}
87+
template <template <class> class F_other>
88+
constexpr wrapper(const S<F_other>& other) : Base{other.apply(AggregateConstructor<Base>{})} {}
8289

8390
constexpr wrapper<S, reference> operator[] (size_t i) { return Base::apply(RandomAccessAt<S<reference>>{i}); }
8491
constexpr wrapper<S, const_reference> operator[] (size_t i) const { return Base::apply(RandomAccessAt<S<const_reference>>{i}); }
@@ -94,7 +101,7 @@ struct wrapper<S, value> : public S<value> {
94101
constexpr wrapper() = default;
95102
constexpr wrapper(Base b) : Base{static_cast<Base&&>(b)} {}
96103
constexpr wrapper(const S<reference>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
97-
constexpr wrapper(const S<const_reference> other) : Base(other.apply(AggregateConstructor<Base>{})) {}
104+
constexpr wrapper(const S<const_reference>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
98105

99106
//constexpr S<MemLayout::pointer> operator& () { return apply(GetPointer<S<MemLayout::pointer>>{}); }
100107
//constexpr S<MemLayout::const_pointer> operator& () const { return apply(GetPointer<S<MemLayout::const_pointer>>{}); }
@@ -107,9 +114,14 @@ struct wrapper<S, reference> : public S<reference> {
107114
constexpr wrapper() = delete;
108115
constexpr wrapper(Base b) : Base{static_cast<Base&&>(b)} {}
109116
constexpr wrapper(S<value>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
117+
constexpr wrapper(S<reference_restrict> other) : Base(other.apply(AggregateConstructor<Base>{})) {}
110118

111119
constexpr wrapper(const wrapper& other) = default;
112120

121+
constexpr wrapper& operator=(const wrapper<S, value>& other) {
122+
Base::apply(other, CopyAssignment{});
123+
return *this;
124+
}
113125
constexpr wrapper& operator=(const wrapper& other) {
114126
Base::apply(other, CopyAssignment{});
115127
return *this;
@@ -118,17 +130,63 @@ struct wrapper<S, reference> : public S<reference> {
118130
Base::apply(other, CopyAssignment{});
119131
return *this;
120132
}
133+
constexpr wrapper& operator=(const wrapper<S, reference_restrict>& other) {
134+
Base::apply(other, CopyAssignment{});
135+
return *this;
136+
}
137+
constexpr wrapper& operator=(const wrapper<S, const_reference_restrict>& other) {
138+
Base::apply(other, CopyAssignment{});
139+
return *this;
140+
}
141+
142+
constexpr wrapper(wrapper&& other) = default;
143+
144+
constexpr wrapper& operator=(wrapper&& other) { return operator=(other); }
145+
146+
constexpr wrapper<S, pointer> operator&() { return Base::apply(GetPointer<S<pointer>>{}); }
147+
//constexpr wrapper<S, const_pointer> operator&() const { return Base::apply(GetPointer<S<const_pointer>>{}); }
148+
constexpr pointer<wrapper<S, reference>> operator->() { return this; }
149+
};
150+
151+
template <template <template <class> class> class S>
152+
struct wrapper<S, reference_restrict> : public S<reference_restrict> {
153+
using Base = S<reference_restrict>;
154+
155+
constexpr wrapper() = delete;
156+
constexpr wrapper(Base b) : Base{static_cast<Base&&>(b)} {}
157+
constexpr wrapper(S<value>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
158+
constexpr wrapper(S<reference> other) : Base(other.apply(AggregateConstructor<Base>{})) {}
159+
160+
constexpr wrapper(const wrapper& other) = default;
161+
121162
constexpr wrapper& operator=(const wrapper<S, value>& other) {
122163
Base::apply(other, CopyAssignment{});
123164
return *this;
124165
}
166+
constexpr wrapper& operator=(const wrapper& other) {
167+
Base::apply(other, CopyAssignment{});
168+
return *this;
169+
}
170+
constexpr wrapper& operator=(const wrapper<S, reference>& other) {
171+
Base::apply(other, CopyAssignment{});
172+
return *this;
173+
}
174+
constexpr wrapper& operator=(const wrapper<S, const_reference>& other) {
175+
Base::apply(other, CopyAssignment{});
176+
return *this;
177+
}
178+
constexpr wrapper& operator=(const wrapper<S, const_reference_restrict>& other) {
179+
Base::apply(other, CopyAssignment{});
180+
return *this;
181+
}
125182

126183
constexpr wrapper(wrapper&& other) = default;
127184

128185
constexpr wrapper& operator=(wrapper&& other) { return operator=(other); }
129186

130-
constexpr wrapper<S, pointer> operator& () { return Base::apply(GetPointer<S<pointer>>{}); }
131-
//constexpr wrapper<S, const_pointer> operator& () const { return Base::apply(GetPointer<S<const_pointer>>{}); }
187+
constexpr wrapper<S, pointer> operator&() { return Base::apply(GetPointer<S<pointer>>{}); }
188+
//constexpr wrapper<S, const_pointer> operator&() const { return Base::apply(GetPointer<S<const_pointer>>{}); }
189+
constexpr pointer<wrapper<S, reference>> operator->() { return this; }
132190
};
133191

134192
template <template <template <class> class> class S>
@@ -139,8 +197,26 @@ struct wrapper<S, const_reference> : public S<const_reference> {
139197
constexpr wrapper(Base b) : Base{static_cast<Base&&>(b)} {}
140198
constexpr wrapper(const S<value>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
141199
constexpr wrapper(const S<reference>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
200+
constexpr wrapper(const S<reference_restrict>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
201+
constexpr wrapper(const S<const_reference_restrict>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
142202

143203
constexpr wrapper<S, const_pointer> operator&() const { return Base::apply(GetPointer<S<const_pointer>>{}); }
204+
constexpr const_pointer<wrapper<S, const_reference>> operator->() const { return this; }
205+
};
206+
207+
template <template <template <class> class> class S>
208+
struct wrapper<S, const_reference_restrict> : public S<const_reference_restrict> {
209+
using Base = S<const_reference_restrict>;
210+
211+
constexpr wrapper() = delete;
212+
constexpr wrapper(Base b) : Base{static_cast<Base&&>(b)} {}
213+
constexpr wrapper(const S<value>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
214+
constexpr wrapper(const S<reference>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
215+
constexpr wrapper(const S<reference_restrict>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
216+
constexpr wrapper(const S<const_reference>& other) : Base(other.apply(AggregateConstructor<Base>{})) {}
217+
218+
constexpr wrapper<S, const_pointer> operator&() const { return Base::apply(GetPointer<S<const_pointer>>{}); }
219+
constexpr const_pointer<wrapper<S, const_reference>> operator->() const { return this; }
144220
};
145221

146222
template <template <template <class> class> class S>
@@ -155,6 +231,8 @@ struct wrapper<S, pointer> : public S<pointer> {
155231

156232
constexpr wrapper<S, reference> operator*() { return operator[](0); }
157233
constexpr wrapper<S, const_reference> operator*() const { return operator[](0); }
234+
constexpr wrapper<S, reference> operator->() { return operator[](0); }
235+
constexpr wrapper<S, const_reference> operator->() const { return operator[](0); }
158236

159237
constexpr bool operator==(const wrapper& other) const { return Base::apply(FirstMember{}) == other.apply(FirstMember{}); }
160238
constexpr bool operator!=(const wrapper& other) const { return !this->operator==(other); }
@@ -180,6 +258,20 @@ struct wrapper<S, const_pointer> : public S<const_pointer> {
180258

181259
constexpr wrapper<S, const_reference> operator[] (size_t i) const { return Base::apply(RandomAccessAt<S<const_reference>>{i}); }
182260
constexpr wrapper<S, const_reference> operator*() const { return operator[](0); }
261+
constexpr wrapper<S, const_reference> operator->() const { return operator[](0); }
262+
263+
constexpr bool operator==(const wrapper& other) const { return Base::apply(FirstMember{}) == other.apply(FirstMember{}); }
264+
constexpr bool operator!=(const wrapper& other) const { return !this->operator==(other); }
265+
constexpr bool operator<(const wrapper& other) const { return Base::apply(FirstMember{}) < other.apply(FirstMember{}); }
266+
267+
constexpr wrapper operator+(ptrdiff_t i) const { return Base::apply(Advance<Base>{i}); }
268+
constexpr wrapper operator-(ptrdiff_t i) const { return operator+(-i); }
269+
constexpr ptrdiff_t operator-(const wrapper& other) const { return Base::apply(FirstMember{}) - other.apply(FirstMember{}); }
270+
271+
constexpr wrapper& operator++() { Base::apply(PreIncrement<Base>{}); return *this; }
272+
constexpr wrapper& operator+=(ptrdiff_t i) { return *this = *this + i; }
273+
constexpr wrapper& operator--() { Base::apply(PreDecrement<Base>{}); return *this; }
274+
constexpr wrapper& operator-=(ptrdiff_t i) { return *this = *this - i; }
183275
};
184276

185277
enum Flag { soa, aos };
@@ -204,6 +296,10 @@ struct interface<S, F, Flag::soa> { using type = wrapper<S, F>; };
204296
#define MEMLAYOUT_EXPAND(m) f(m, other.m)
205297

206298
#define MEMLAYOUT_APPLY_BINARY(STRUCT_NAME, ...)\
299+
template <template <class> class F_other, class Function>\
300+
constexpr STRUCT_NAME apply(STRUCT_NAME<F_other>& other, Function&& f) { return {__VA_ARGS__}; }\
301+
template <template <class> class F_other, class Function>\
302+
constexpr STRUCT_NAME apply(STRUCT_NAME<F_other>& other, Function&& f) const { return {__VA_ARGS__}; }\
207303
template <template <class> class F_other, class Function>\
208304
constexpr STRUCT_NAME apply(const STRUCT_NAME<F_other>& other, Function&& f) { return {__VA_ARGS__}; }\
209305
template <template <class> class F_other, class Function>\

GPU/GPUTracking/DataCompression/GPUTPCCompressionTrackModel.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
#elif defined(GPUCA_COMPRESSION_TRACK_MODEL_SECTORTRACKER)
2929
#include "GPUTPCTrackParam.h"
30+
#include "MemLayout.h"
3031

3132
#else // Default internal track model for compression
3233
#endif
@@ -121,7 +122,7 @@ class GPUTPCCompressionTrackModel
121122
const GPUParam* mParam;
122123

123124
#elif defined(GPUCA_COMPRESSION_TRACK_MODEL_SECTORTRACKER)
124-
GPUTPCTrackParam mTrk;
125+
GPUTPCTrackParamSkeleton<MemLayout::value> mTrk;
125126
float mAlpha;
126127
const GPUParam* mParam;
127128

GPU/GPUTracking/DataTypes/GPUDataTypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ struct GPUTrackingInOutPointers {
225225
const AliHLTTPCRawCluster* rawClusters[NSECTORS] = {nullptr};
226226
uint32_t nRawClusters[NSECTORS] = {0};
227227
const o2::tpc::ClusterNativeAccess* clustersNative = nullptr;
228-
MemLayout::wrapper<GPUTPCTrackSkeleton, MemLayout::const_pointer> sectorTracks[NSECTORS];// = {{nullptr, nullptr, nullptr, {nullptr, nullptr, nullptr, nullptr}}};
228+
MemLayout::interface<GPUTPCTrackSkeleton, MemLayout::const_pointer, GPUTPCTrackLayout>::type sectorTracks[NSECTORS];
229229
uint32_t nSectorTracks[NSECTORS] = {0};
230230
const GPUTPCHitId* sectorClusters[NSECTORS] = {nullptr};
231231
uint32_t nSectorClusters[NSECTORS] = {0};

GPU/GPUTracking/Global/GPUChainTracking.cxx

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -549,23 +549,33 @@ void GPUChainTracking::ClearIOPointers()
549549
new (&mIOMem) InOutMemory;
550550
}
551551

552+
namespace {
553+
554+
template <class Function>
555+
struct ApplyRecursive {
556+
Function f;
557+
558+
template <class T>
559+
const T * operator()(const T * & aosIOPtr, GPUChainTracking::unique_ptr_array<T>& aosIOMem) const { return f(aosIOPtr, aosIOMem); }
560+
561+
template <template <template <class> class> class S>
562+
S<MemLayout::const_pointer> operator()(S<MemLayout::const_pointer>& soaIOPtr, S<GPUChainTracking::unique_ptr_array>& soaIOMem) const {
563+
return soaIOPtr.apply(soaIOMem, ApplyRecursive{f});
564+
}
565+
};
566+
567+
}
568+
552569
void GPUChainTracking::AllocateIOMemory()
553570
{
554571
for (uint32_t i = 0; i < NSECTORS; i++) {
555572
AllocateIOMemoryHelper(mIOPtrs.nClusterData[i], mIOPtrs.clusterData[i], mIOMem.clusterData[i]);
556573
AllocateIOMemoryHelper(mIOPtrs.nRawClusters[i], mIOPtrs.rawClusters[i], mIOMem.rawClusters[i]);
557-
558-
//AllocateIOMemoryHelper(mIOPtrs.nSectorTracks[i], mIOPtrs.sectorTracks[i], mIOMem.sectorTracks[i]);
559-
560-
AllocateIOMemoryHelper(mIOPtrs.nSectorTracks[i], mIOPtrs.sectorTracks[i].mFirstHitID, mIOMem.sectorTracks[i].mFirstHitID);
561-
AllocateIOMemoryHelper(mIOPtrs.nSectorTracks[i], mIOPtrs.sectorTracks[i].mNHits, mIOMem.sectorTracks[i].mNHits);
562-
AllocateIOMemoryHelper(mIOPtrs.nSectorTracks[i], mIOPtrs.sectorTracks[i].mLocalTrackId, mIOMem.sectorTracks[i].mLocalTrackId);
563-
564-
AllocateIOMemoryHelper(mIOPtrs.nSectorTracks[i], mIOPtrs.sectorTracks[i].mParam.mX, mIOMem.sectorTracks[i].mParam.mX);
565-
AllocateIOMemoryHelper(mIOPtrs.nSectorTracks[i], mIOPtrs.sectorTracks[i].mParam.mC, mIOMem.sectorTracks[i].mParam.mC);
566-
AllocateIOMemoryHelper(mIOPtrs.nSectorTracks[i], mIOPtrs.sectorTracks[i].mParam.mZOffset, mIOMem.sectorTracks[i].mParam.mZOffset);
567-
AllocateIOMemoryHelper(mIOPtrs.nSectorTracks[i], mIOPtrs.sectorTracks[i].mParam.mP, mIOMem.sectorTracks[i].mParam.mP);
568-
574+
auto sectorTrackAllocator = [this, nSectorTrack = this->mIOPtrs.nSectorTracks[i]](auto& IOPtrsTrack, auto& mIOMemTrack) {
575+
AllocateIOMemoryHelper(nSectorTrack, IOPtrsTrack, mIOMemTrack);
576+
return IOPtrsTrack;
577+
};
578+
ApplyRecursive{sectorTrackAllocator}(mIOPtrs.sectorTracks[i], mIOMem.sectorTracks[i]);
569579
AllocateIOMemoryHelper(mIOPtrs.nSectorClusters[i], mIOPtrs.sectorClusters[i], mIOMem.sectorClusters[i]);
570580
}
571581
mIOMem.clusterNativeAccess.reset(new ClusterNativeAccess);

GPU/GPUTracking/Global/GPUChainTracking.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717

1818
#include "GPUChain.h"
1919
#include "GPUDataTypes.h"
20+
#include "GPUTPCTrack.h"
2021
#include "MemLayout.h"
22+
2123
#include <atomic>
2224
#include <mutex>
2325
#include <functional>
@@ -91,6 +93,9 @@ class GPUChainTracking : public GPUChain
9193
// Structures for input and output data
9294
GPUTrackingInOutPointers& mIOPtrs;
9395

96+
template <class T>
97+
using unique_ptr_array = std::unique_ptr<T[]>;
98+
9499
struct InOutMemory {
95100
InOutMemory();
96101
~InOutMemory();
@@ -108,12 +113,7 @@ class GPUChainTracking : public GPUChain
108113
std::unique_ptr<AliHLTTPCRawCluster[]> rawClusters[NSECTORS];
109114
std::unique_ptr<o2::tpc::ClusterNative[]> clustersNative;
110115
std::unique_ptr<o2::tpc::ClusterNativeAccess> clusterNativeAccess;
111-
112-
template <class T>
113-
using unique_ptr_array = std::unique_ptr<T[]>;
114-
115-
MemLayout::wrapper<GPUTPCTrackSkeleton, unique_ptr_array> sectorTracks[NSECTORS];
116-
116+
MemLayout::interface<GPUTPCTrackSkeleton, unique_ptr_array, GPUTPCTrackLayout>::type sectorTracks[NSECTORS];
117117
std::unique_ptr<GPUTPCHitId[]> sectorClusters[NSECTORS];
118118
std::unique_ptr<AliHLTTPCClusterMCLabel[]> mcLabelsTPC;
119119
std::unique_ptr<GPUTPCMCInfo[]> mcInfosTPC;

0 commit comments

Comments
 (0)