Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 54 additions & 19 deletions PWGDQ/Core/MuonMatchingMlResponse.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <Framework/Logger.h>

#include <concepts>
#include <cstdint>
#include <string>
#include <vector>
Expand All @@ -35,13 +36,47 @@
// Check if the index of mCachedIndices (index associated to a FEATURE)
// matches the entry in EnumInputFeatures associated to this FEATURE
// if so, the inputFeatures vector is filled with the FEATURE's value
// by calling the corresponding GETTER=FEATURE from track
// by calling the corresponding GETTER expression
#define CHECK_AND_FILL_FEATURE(FEATURE, GETTER) \
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
inputFeature = (GETTER); \
break; \
}

// Check if the index of mCachedIndices (index associated to a FEATURE)
// matches the entry in EnumInputFeatures associated to this FEATURE
// if so, and if OBJECT.GETTER() is a valid function invocation,
// the inputFeatures vector is filled with the FEATURE's value
// by calling the corresponding GETTER function
#define CHECK_AND_FILL_FEATURE_OPTIONAL_NO_EXPR(FEATURE, OBJECT, GETTER) \
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
if constexpr (requires(decltype(OBJECT) t) { { t.GETTER() } -> std::convertible_to<float>; }) { \
inputFeature = (OBJECT.GETTER()); \
} else { \
inputFeature = 0; \
} \
break; \
}

// Check if the index of mCachedIndices (index associated to a FEATURE)
// matches the entry in EnumInputFeatures associated to this FEATURE
// if so, and if OBJECT.FUNC() is a valid function invocation,
// the inputFeatures vector is filled with the FEATURE's value
// by calling the corresponding GETTER expression
#define CHECK_AND_FILL_FEATURE_OPTIONAL_WITH_EXPR(FEATURE, OBJECT, FUNC, GETTER) \
case static_cast<uint8_t>(InputFeaturesMFTMuonMatch::FEATURE): { \
if constexpr (requires(decltype(OBJECT) t) { { t.FUNC() } -> std::convertible_to<float>; }) { \
inputFeature = (GETTER); \
} else { \
inputFeature = 0; \
} \
break; \
}

#define __EXPAND(x) x
#define __GET_MACRO(_1, _2, _3, _4, name, ...) name
#define CHECK_AND_FILL_FEATURE_OPTIONAL(...) __EXPAND(__GET_MACRO(__VA_ARGS__, CHECK_AND_FILL_FEATURE_OPTIONAL_WITH_EXPR, CHECK_AND_FILL_FEATURE_OPTIONAL_NO_EXPR)(__VA_ARGS__))

namespace o2::analysis
{
// possible input features for ML
Expand Down Expand Up @@ -287,27 +322,27 @@ class MlResponseMFTMuonMatch : public MlResponse<TypeOutputScore>
CHECK_AND_FILL_FEATURE(posX, collision.posX());
CHECK_AND_FILL_FEATURE(posY, collision.posY());
CHECK_AND_FILL_FEATURE(posZ, collision.posZ());
CHECK_AND_FILL_FEATURE(numContrib, collision.numContrib());
CHECK_AND_FILL_FEATURE(trackOccupancyInTimeRange, collision.trackOccupancyInTimeRange());
CHECK_AND_FILL_FEATURE(ft0cOccupancyInTimeRange, collision.ft0cOccupancyInTimeRange());
CHECK_AND_FILL_FEATURE(multMFT, collision.mftNtracks());
CHECK_AND_FILL_FEATURE(multFT0A, collision.multFT0A());
CHECK_AND_FILL_FEATURE(multFT0C, collision.multFT0C());
CHECK_AND_FILL_FEATURE(multNTracksPV, collision.multNTracksPV());
CHECK_AND_FILL_FEATURE(multNTracksPVeta1, collision.multNTracksPVeta1());
CHECK_AND_FILL_FEATURE(multNTracksPVetaHalf, collision.multNTracksPVetaHalf());
CHECK_AND_FILL_FEATURE(isInelGt0, collision.isInelGt0());
CHECK_AND_FILL_FEATURE(isInelGt1, collision.isInelGt1());
CHECK_AND_FILL_FEATURE(multFT0M, collision.multFT0M());
CHECK_AND_FILL_FEATURE(centFT0M, collision.centFT0M());
CHECK_AND_FILL_FEATURE(centFT0A, collision.centFT0A());
CHECK_AND_FILL_FEATURE(centFT0C, collision.centFT0C());
CHECK_AND_FILL_FEATURE_OPTIONAL(numContrib, collision, numContrib);
CHECK_AND_FILL_FEATURE_OPTIONAL(trackOccupancyInTimeRange, collision, trackOccupancyInTimeRange);
CHECK_AND_FILL_FEATURE_OPTIONAL(ft0cOccupancyInTimeRange, collision, ft0cOccupancyInTimeRange);
CHECK_AND_FILL_FEATURE_OPTIONAL(multMFT, collision, mftNtracks);
CHECK_AND_FILL_FEATURE_OPTIONAL(multFT0A, collision, multFT0A);
CHECK_AND_FILL_FEATURE_OPTIONAL(multFT0C, collision, multFT0C);
CHECK_AND_FILL_FEATURE_OPTIONAL(multNTracksPV, collision, multNTracksPV);
CHECK_AND_FILL_FEATURE_OPTIONAL(multNTracksPVeta1, collision, multNTracksPVeta1);
CHECK_AND_FILL_FEATURE_OPTIONAL(multNTracksPVetaHalf, collision, multNTracksPVetaHalf);
CHECK_AND_FILL_FEATURE_OPTIONAL(isInelGt0, collision, isInelGt0);
CHECK_AND_FILL_FEATURE_OPTIONAL(isInelGt1, collision, isInelGt1);
CHECK_AND_FILL_FEATURE_OPTIONAL(multFT0M, collision, multFT0M);
CHECK_AND_FILL_FEATURE_OPTIONAL(centFT0M, collision, centFT0M);
CHECK_AND_FILL_FEATURE_OPTIONAL(centFT0A, collision, centFT0A);
CHECK_AND_FILL_FEATURE_OPTIONAL(centFT0C, collision, centFT0C);
// global forward track parameters
CHECK_AND_FILL_FEATURE(chi2MCHMFT, muon.chi2MatchMCHMFT());
CHECK_AND_FILL_FEATURE(chi2GlobMUON, muon.chi2());
CHECK_AND_FILL_FEATURE(dcaX, muon.fwdDcaX());
CHECK_AND_FILL_FEATURE(dcaY, muon.fwdDcaX());
CHECK_AND_FILL_FEATURE(isAmbig, (muon.compatibleCollIds().size() == 1) ? 0 : 1);
CHECK_AND_FILL_FEATURE_OPTIONAL(dcaX, muon, fwdDcaX);
CHECK_AND_FILL_FEATURE_OPTIONAL(dcaY, muon, fwdDcaX);
CHECK_AND_FILL_FEATURE_OPTIONAL(isAmbig, muon, compatibleCollIds, (muon.compatibleCollIds().size() == 1) ? 0 : 1);
}
return inputFeature;
}
Expand Down
14 changes: 10 additions & 4 deletions PWGDQ/TableProducer/tableMakerMC_withAssoc.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1041,12 +1041,16 @@ struct TableMakerMC {
if (static_cast<int>(muon.trackType()) < 2) {
auto muonID = muon.matchMCHTrackId();
auto chi2 = muon.chi2MatchMCHMFT();
/* TODO: the getInputFeaturesTest() fuction has been removed
* Moreover, it is not foreseen to run ML models using only the information
* from the global muon track.
* Can this part be safely removed?
if (fConfigVariousOptions.fUseML.value) {
std::vector<float> output;
std::vector<float> inputML = matchingMlResponse.getInputFeaturesTest(muon);
matchingMlResponse.isSelectedMl(inputML, 0, output);
chi2 = output[0];
}
}*/
if (mCandidates.find(muonID) == mCandidates.end()) {
mCandidates[muonID] = {chi2, muon.globalIndex()};
} else {
Expand All @@ -1062,8 +1066,10 @@ struct TableMakerMC {
}

template <typename TMuons, typename TMFTTracks, typename TMFTCovs, typename TEvent>
void skimBestMuonMatchesML(TMuons const& muons, TMFTTracks const& /*mfttracks*/, TMFTCovs const& mfCovs, TEvent const& collision)
void skimBestMuonMatchesML(TMuons const& /*muons*/, TMFTTracks const& /*mfttracks*/, TMFTCovs const& /*mfCovs*/, TEvent const& /*collision*/)
{
return;
/* TODO: add missing tables in the tracks and events definitions
std::unordered_map<int, std::pair<float, int>> mCandidates;
for (const auto& muon : muons) {
if (static_cast<int>(muon.trackType()) < 2) {
Expand All @@ -1078,7 +1084,7 @@ struct TableMakerMC {
muonprop = VarManager::PropagateMuon(muontrack, collision, VarManager::kToMatching);
}
std::vector<float> output;
std::vector<float> inputML = matchingMlResponse.getInputFeaturesGlob(muon, muonprop, mftprop, collision);
std::vector<float> inputML = matchingMlResponse.getInputFeatures(muon, mfttrack, muontrack, mftprop, muonprop, collision);
matchingMlResponse.isSelectedMl(inputML, 0, output);
float score = output[0];
if (mCandidates.find(muonID) == mCandidates.end()) {
Expand All @@ -1092,7 +1098,7 @@ struct TableMakerMC {
}
for (auto& pairCand : mCandidates) {
fBestMatch[pairCand.second.second] = true;
}
}*/
}

template <uint32_t TMuonFillMap, uint32_t TMFTFillMap, typename TEvent, typename TMuons, typename TMFTTracks, typename TMFTCovs>
Expand Down
8 changes: 5 additions & 3 deletions PWGDQ/TableProducer/tableMaker_withAssoc.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -1478,8 +1478,10 @@ struct TableMaker {
}

template <typename TMuons, typename TMFTTracks, typename TMFTCovs, typename TEvent>
void skimBestMuonMatchesML(TMuons const& muons, TMFTTracks const& /*mfttracks*/, TMFTCovs const& mfCovs, TEvent const& collision)
void skimBestMuonMatchesML(TMuons const& /*muons*/, TMFTTracks const& /*mfttracks*/, TMFTCovs const& /*mfCovs*/, TEvent const& /*collision*/)
{
return;
/* TODO: add missing tables in the tracks and events definitions
std::unordered_map<int, std::pair<float, int>> mCandidates;
for (const auto& muon : muons) {
if (static_cast<int>(muon.trackType()) < 2) {
Expand All @@ -1494,7 +1496,7 @@ struct TableMaker {
muonprop = VarManager::PropagateMuon(muontrack, collision, VarManager::kToMatching);
}
std::vector<float> output;
std::vector<float> inputML = matchingMlResponse.getInputFeaturesGlob(muon, muonprop, mftprop, collision);
std::vector<float> inputML = matchingMlResponse.getInputFeatures(muon, mfttrack, muontrack, mftprop, muonprop, collision);
matchingMlResponse.isSelectedMl(inputML, 0, output);
float score = output[0];
if (mCandidates.find(muonID) == mCandidates.end()) {
Expand All @@ -1508,7 +1510,7 @@ struct TableMaker {
}
for (auto& pairCand : mCandidates) {
fBestMatch[pairCand.second.second] = true;
}
}*/
}

template <uint32_t TMuonFillMap, uint32_t TMFTFillMap, typename TEvent, typename TBCs, typename TMuons, typename TMFTTracks, typename TMFTCovs>
Expand Down
Loading