Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 29 additions & 2 deletions PWGJE/Core/JetTaggingUtilities.h
Original file line number Diff line number Diff line change
Expand Up @@ -1095,9 +1095,9 @@ void analyzeJetTrackInfo4MLnoSV(AnalysisJet const& analysisJet, AnyTracks const&
std::sort(tracksParams.begin(), tracksParams.end(), compare);
}

// Looping over the track info and putting them in the input vector (for GNN b-jet tagging)
// Looping over the track info and putting them in the input vector, with extra input features (for GNN b-jet tagging)
template <typename AnalysisJet, typename AnyTracks, typename AnyOriginalTracks>
void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, AnyOriginalTracks const& /*origTracks*/, std::vector<std::vector<float>>& tracksParams, float trackPtMin = 0.5, float trackDcaXYMax = 10.0, float trackDcaZMax = 10.0, int64_t nMaxConstit = 40)
void analyzeJetTrackInfo4GNNwExtra(AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, AnyOriginalTracks const& /*origTracks*/, std::vector<std::vector<float>>& tracksParams, float trackPtMin = 0.5, float trackDcaXYMax = 10.0, float trackDcaZMax = 10.0, int64_t nMaxConstit = 40)
{
for (const auto& constituent : analysisJet.template tracks_as<AnyTracks>()) {

Expand All @@ -1124,6 +1124,33 @@ void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*
}
}

// Looping over the track info and putting them in the input vector (for GNN b-jet tagging)
template <typename AnalysisJet, typename AnyTracks>
void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, std::vector<std::vector<float>>& tracksParams, float trackPtMin = 0.5, float trackDcaXYMax = 10.0, float trackDcaZMax = 10.0, int64_t nMaxConstit = 40)
{
for (const auto& constituent : analysisJet.template tracks_as<AnyTracks>()) {

if (constituent.pt() < trackPtMin || !trackAcceptanceWithDca(constituent, trackDcaXYMax, trackDcaZMax)) {
continue;
}

int sign = getGeoSign(analysisJet, constituent);

if (static_cast<int64_t>(tracksParams.size()) < nMaxConstit) {
tracksParams.emplace_back(std::vector<float>{constituent.pt(), constituent.phi(), constituent.eta(), static_cast<float>(constituent.sign()), std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaZ()) * sign, constituent.sigmadcaZ()});
} else {
// If there are more than nMaxConstit constituents in the jet, select only nMaxConstit constituents with the highest DCA_XY significance.
size_t minIdx = 0;
for (size_t i = 0; i < tracksParams.size(); ++i) {
if (tracksParams[i][4] / tracksParams[i][5] < tracksParams[minIdx][4] / tracksParams[minIdx][5])
minIdx = i;
}
if (std::abs(constituent.dcaXY()) * sign / constituent.sigmadcaXY() > tracksParams[minIdx][4] / tracksParams[minIdx][5])
tracksParams[minIdx] = std::vector<float>{constituent.pt(), constituent.phi(), constituent.eta(), static_cast<float>(constituent.sign()), std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaZ()) * sign, constituent.sigmadcaZ()};
}
}
}

// Discriminant value for GNN b-jet tagging
template <typename T>
T getDb(const std::vector<T>& logits, double fC = 0.018)
Expand Down
57 changes: 46 additions & 11 deletions PWGJE/TableProducer/jetTaggerHF.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ struct JetTaggerHFTask {
}
}
}
if (doprocessAlgorithmGNN) {
if (doprocessAlgorithmGNN || doprocessAlgorithmGNNwExtra) {
if (jet.pt() >= jetpTMin) {
float dbRange;
if (scoreML[jet.globalIndex()] < dbMin) {
Expand Down Expand Up @@ -513,7 +513,7 @@ struct JetTaggerHFTask {
}
}

if (doprocessAlgorithmML || doprocessAlgorithmGNN) {
if (doprocessAlgorithmML || doprocessAlgorithmGNN || doprocessAlgorithmGNNwExtra) {
bMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
if (loadModelsFromCCDB) {
ccdbApi.init(ccdbUrl);
Expand All @@ -525,7 +525,7 @@ struct JetTaggerHFTask {
bMlResponse.init();
}

if (doprocessAlgorithmGNN) {
if (doprocessAlgorithmGNN || doprocessAlgorithmGNNwExtra) {
tensorAlloc = o2::analysis::GNNBjetAllocator(nJetFeat.value, nTrkFeat.value, nClassesMl.value, nTrkOrigin.value, transformFeatureJetMean.value, transformFeatureJetStdev.value, transformFeatureTrkMean.value, transformFeatureTrkStdev.value, nJetConst, tfFuncTypeGNN.value);

registry.add("h2_count_db", "#it{D}_{b} underflow/overflow;Jet flavour;#it{D}_{b} range", {HistType::kTH2F, {{4, 0., 4.}, {3, 0., 3.}}});
Expand All @@ -538,10 +538,10 @@ struct JetTaggerHFTask {
h2CountDb->GetYaxis()->SetBinLabel(2, "in range");
h2CountDb->GetYaxis()->SetBinLabel(3, "overflow");

registry.add("h_db_b", "#it{D}_{b} b-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}});
registry.add("h_db_c", "#it{D}_{b} c-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}});
registry.add("h_db_lf", "#it{D}_{b} lf-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}});
registry.add("h2_pt_db", "#it{p}_{T} vs. #it{D}_{b};#it{p}_{T}^{ch jet} (GeV/#it{c}^{2});#it{D}_{b}", {HistType::kTH2F, {{100, 0., 200.}, {50, -10., 35.}}});
registry.add("h_db_b", "#it{D}_{b} b-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -20., 30.}}});
registry.add("h_db_c", "#it{D}_{b} c-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -20., 30.}}});
registry.add("h_db_lf", "#it{D}_{b} lf-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -20., 30.}}});
registry.add("h2_pt_db", "#it{p}_{T} vs. #it{D}_{b};#it{p}_{T}^{ch jet} (GeV/#it{c}^{2});#it{D}_{b}", {HistType::kTH2F, {{100, 0., 200.}, {50, -20., 30.}}});
}
}

Expand Down Expand Up @@ -612,11 +612,40 @@ struct JetTaggerHFTask {
}

template <typename AnyJets, typename AnyTracks, typename AnyOriginalTracks>
void analyzeJetAlgorithmGNN(AnyJets const& jets, AnyTracks const& tracks, AnyOriginalTracks const& origTracks)
void analyzeJetAlgorithmGNNwExtra(AnyJets const& jets, AnyTracks const& tracks, AnyOriginalTracks const& origTracks)
{
for (const auto& jet : jets) {
std::vector<std::vector<float>> trkFeat;
jettaggingutilities::analyzeJetTrackInfo4GNN(jet, tracks, origTracks, trkFeat, trackPtMin, trackDcaXYMax, trackDcaZMax, nJetConst);
jettaggingutilities::analyzeJetTrackInfo4GNNwExtra(jet, tracks, origTracks, trkFeat, trackPtMin, trackDcaXYMax, trackDcaZMax, nJetConst);

std::vector<float> jetFeat{jet.pt(), jet.phi(), jet.eta(), jet.mass()};

if (trkFeat.size() > 0) {
std::vector<float> feat;
std::vector<Ort::Value> gnnInput;
tensorAlloc.getGNNInput(jetFeat, trkFeat, feat, gnnInput);

auto modelOutput = bMlResponse.getModelOutput(gnnInput, 0);
float db = jettaggingutilities::getDb(modelOutput, fC);
if (!std::isnan(db)) {
scoreML[jet.globalIndex()] = db;
} else {
scoreML[jet.globalIndex()] = 999.;
LOGF(debug, "doprocessAlgorithmGNNwExtra, Db is NaN (%d)", jet.globalIndex());
}
} else {
scoreML[jet.globalIndex()] = -999.;
LOGF(debug, "doprocessAlgorithmGNNwExtra, trkFeat.size() <= 0 (%d)", jet.globalIndex());
}
}
}

template <typename AnyJets, typename AnyTracks>
void analyzeJetAlgorithmGNN(AnyJets const& jets, AnyTracks const& tracks)
{
for (const auto& jet : jets) {
std::vector<std::vector<float>> trkFeat;
jettaggingutilities::analyzeJetTrackInfo4GNN(jet, tracks, trkFeat, trackPtMin, trackDcaXYMax, trackDcaZMax, nJetConst);

std::vector<float> jetFeat{jet.pt(), jet.phi(), jet.eta(), jet.mass()};

Expand Down Expand Up @@ -684,9 +713,15 @@ struct JetTaggerHFTask {
}
PROCESS_SWITCH(JetTaggerHFTask, processAlgorithmMLnoSV, "Fill ML evaluation score for charged jets but without using SVs", false);

void processAlgorithmGNN(JetTable const& jets, JetTracksExt const& jtracks, OriginalTracks const& origTracks)
void processAlgorithmGNNwExtra(JetTable const& jets, JetTracksExt const& jtracks, OriginalTracks const& origTracks)
{
analyzeJetAlgorithmGNNwExtra(jets, jtracks, origTracks);
}
PROCESS_SWITCH(JetTaggerHFTask, processAlgorithmGNNwExtra, "Fill GNN evaluation score (D_b) for charged jets with extra input features", false);

void processAlgorithmGNN(JetTable const& jets, JetTracksExt const& jtracks)
{
analyzeJetAlgorithmGNN(jets, jtracks, origTracks);
analyzeJetAlgorithmGNN(jets, jtracks);
}
PROCESS_SWITCH(JetTaggerHFTask, processAlgorithmGNN, "Fill GNN evaluation score (D_b) for charged jets", false);

Expand Down
66 changes: 61 additions & 5 deletions PWGJE/Tasks/bjetTaggingGnn.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
/// \author Changhwan Choi <changhwan.choi@cern.ch>, Pusan National University

#include "PWGJE/Core/JetDerivedDataUtilities.h"
#include "PWGJE/Core/JetFindingUtilities.h"
#include "PWGJE/Core/JetTaggingUtilities.h"
#include "PWGJE/DataModel/Jet.h"
#include "PWGJE/DataModel/JetReducedData.h"
#include "PWGJE/DataModel/JetTagging.h"

#include "Common/CCDB/TriggerAliases.h"
#include "Common/Core/RecoDecay.h"
#include "Common/Core/Zorro.h"
#include "Common/Core/ZorroSummary.h"
#include "Common/DataModel/EventSelection.h"
Expand Down Expand Up @@ -126,6 +128,7 @@ struct BjetTaggingGnn {
Configurable<float> pTHatExponent{"pTHatExponent", 6.0, "exponent of the event weight for the calculation of pTHat"};

// track level configurables
Configurable<std::string> trackSelections{"trackSelections", "QualityTracks", "set track selections"};
Configurable<float> trackPtMin{"trackPtMin", 0.15, "minimum track pT"};
Configurable<float> trackPtMax{"trackPtMax", 1000.0, "maximum track pT"};
Configurable<float> trackEtaMin{"trackEtaMin", -0.9, "minimum track eta"};
Expand Down Expand Up @@ -175,6 +178,8 @@ struct BjetTaggingGnn {
std::vector<int> eventSelectionBitsSelMC;
std::vector<int> eventSelectionBitsSel8;

int trackSelectionBits;

std::vector<double> jetRadiiValues;

void init(InitContext const&)
Expand All @@ -186,6 +191,8 @@ struct BjetTaggingGnn {
eventSelectionBitsSel8 = jetderiveddatautilities::initialiseEventSelectionBits("sel8");
eventSelectionBitsSelMC = jetderiveddatautilities::initialiseEventSelectionBits("selMC");

trackSelectionBits = jetderiveddatautilities::initialiseTrackSelection(static_cast<std::string>(trackSelections));

if (doprocessDataJetsTrig) {
zorroSummary.setObject(zorro.getZorroSummary());
}
Expand Down Expand Up @@ -236,7 +243,7 @@ struct BjetTaggingGnn {

const AxisSpec axisTrackpT{200, 0., 200., "#it{p}_{T} (GeV/#it{c})"};
const AxisSpec axisTrackpTFine{1000, 0., 10., "#it{p}_{T} (GeV/#it{c})"};
const AxisSpec axisJetpT{200, 0., 200., "#it{p}_{T} (GeV/#it{c})"};
const AxisSpec axisJetpT{250, 0., 250., "#it{p}_{T} (GeV/#it{c})"};
const AxisSpec axisJetEta{200, -0.8, 0.8, "#it{#eta}_{jet}"};
const AxisSpec axisDb{200, dbMin, dbMax, "#it{D}_{b}"};
const AxisSpec axisDbFine{dbNbins, dbMin, dbMax, "#it{D}_{b}"};
Expand All @@ -250,6 +257,7 @@ struct BjetTaggingGnn {
registry.add("h_jetMass", "", {HistType::kTH1F, {axisJetMass}});
registry.add("h_Db", "", {HistType::kTH1F, {axisDbFine}});
registry.add("h2_jetpT_Db", "", {HistType::kTH2F, {axisJetpT, axisDb}});
registry.add("h2_nTracks_Db", "", {HistType::kTH2F, {axisNTracks, axisDb}});

registry.add("h_gnnfeat_trackpT", "", {HistType::kTH1F, {{200, 0., 100., "#it{p}_{T} (GeV/#it{c})"}}});
registry.add("h_gnnfeat_trackPhi", "", {HistType::kTH1F, {{200, 0., 2. * M_PI, "#it{#phi}"}}});
Expand Down Expand Up @@ -278,6 +286,18 @@ struct BjetTaggingGnn {
registry.add("h_partpT_matched_fine", "", {HistType::kTH1F, {axisTrackpTFine}}, callSumw2);
registry.add("h_partpT", "", {HistType::kTH1F, {axisTrackpT}}, callSumw2);
registry.add("h_partpT_fine", "", {HistType::kTH1F, {axisTrackpTFine}}, callSumw2);
registry.add("h_dcaXY_coll_matched", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{xy}| (cm)"}}}, callSumw2);
registry.add("h_dcaXY_coll_matched_b", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{xy}| (cm)"}}}, callSumw2);
registry.add("h_dcaXY_coll_matched_c", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{xy}| (cm)"}}}, callSumw2);
registry.add("h_dcaXY_coll_matched_lf", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{xy}| (cm)"}}}, callSumw2);
registry.add("h_dcaXY_coll_mismatched", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{xy}| (cm)"}}}, callSumw2);
registry.add("h_dcaXY_npp", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{xy}| (cm)"}}}, callSumw2);
registry.add("h_dcaZ_coll_matched", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{z}| (cm)"}}}, callSumw2);
registry.add("h_dcaZ_coll_matched_b", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{z}| (cm)"}}}, callSumw2);
registry.add("h_dcaZ_coll_matched_c", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{z}| (cm)"}}}, callSumw2);
registry.add("h_dcaZ_coll_matched_lf", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{z}| (cm)"}}}, callSumw2);
registry.add("h_dcaZ_coll_mismatched", "", {HistType::kTH1F, {{200, 0., 4., "|DCA_#it{z}| (cm)"}}}, callSumw2);
registry.add("h_dcaZ_npp", "", {HistType::kTH1F, {{200, 0., 5., "|DCA_#it{z}| (cm)"}}}, callSumw2);
}

if (doprocessDataJetsSel || doprocessMCDJetsSel) {
Expand All @@ -302,6 +322,9 @@ struct BjetTaggingGnn {
registry.add("h2_jetpT_Db_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisDb}});
registry.add("h2_jetpT_Db_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisDb}});
registry.add("h2_jetpT_Db_lf", "lf-jet", {HistType::kTH2F, {axisJetpT, axisDb}});
registry.add("h2_nTracks_Db_b", "b-jet", {HistType::kTH2F, {axisNTracks, axisDb}});
registry.add("h2_nTracks_Db_c", "c-jet", {HistType::kTH2F, {axisNTracks, axisDb}});
registry.add("h2_nTracks_Db_lf", "lf-jet", {HistType::kTH2F, {axisNTracks, axisDb}});
registry.add("h2_Response_DetjetpT_PartjetpT", "", {HistType::kTH2F, {axisJetpT, axisJetpT}}, callSumw2);
registry.add("h2_Response_DetjetpT_PartjetpT_b", "b-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}}, callSumw2);
registry.add("h2_Response_DetjetpT_PartjetpT_c", "c-jet", {HistType::kTH2F, {axisJetpT, axisJetpT}}, callSumw2);
Expand Down Expand Up @@ -513,6 +536,7 @@ struct BjetTaggingGnn {
registry.fill(HIST("h_jetMass"), analysisJet.mass());
registry.fill(HIST("h_Db"), analysisJet.scoreML());
registry.fill(HIST("h2_jetpT_Db"), analysisJet.pt(), analysisJet.scoreML());
registry.fill(HIST("h2_nTracks_Db"), nTracks, analysisJet.scoreML());

if (doDataDriven) {
if (doDataDrivenExtra) {
Expand Down Expand Up @@ -567,19 +591,23 @@ struct BjetTaggingGnn {
registry.fill(HIST("h_jetMass"), analysisJet.mass(), weightEvt);
registry.fill(HIST("h_Db"), analysisJet.scoreML(), weightEvt);
registry.fill(HIST("h2_jetpT_Db"), analysisJet.pt(), analysisJet.scoreML(), weightEvt);
registry.fill(HIST("h2_nTracks_Db"), nTracks, analysisJet.scoreML(), weightEvt);

if (jetFlavor == JetTaggingSpecies::beauty) {
registry.fill(HIST("h_jetpT_b"), analysisJet.pt(), weightEvt);
registry.fill(HIST("h_Db_b"), analysisJet.scoreML(), weightEvt);
registry.fill(HIST("h2_jetpT_Db_b"), analysisJet.pt(), analysisJet.scoreML(), weightEvt);
registry.fill(HIST("h2_nTracks_Db_b"), nTracks, analysisJet.scoreML(), weightEvt);
} else if (jetFlavor == JetTaggingSpecies::charm) {
registry.fill(HIST("h_jetpT_c"), analysisJet.pt(), weightEvt);
registry.fill(HIST("h_Db_c"), analysisJet.scoreML(), weightEvt);
registry.fill(HIST("h2_jetpT_Db_c"), analysisJet.pt(), analysisJet.scoreML(), weightEvt);
registry.fill(HIST("h2_nTracks_Db_c"), nTracks, analysisJet.scoreML(), weightEvt);
} else {
registry.fill(HIST("h_jetpT_lf"), analysisJet.pt(), weightEvt);
registry.fill(HIST("h_Db_lf"), analysisJet.scoreML(), weightEvt);
registry.fill(HIST("h2_jetpT_Db_lf"), analysisJet.pt(), analysisJet.scoreML(), weightEvt);
registry.fill(HIST("h2_nTracks_Db_lf"), nTracks, analysisJet.scoreML(), weightEvt);
if (jetFlavor == JetTaggingSpecies::none) {
registry.fill(HIST("h2_jetpT_Db_lf_none"), analysisJet.pt(), analysisJet.scoreML(), weightEvt);
} else {
Expand Down Expand Up @@ -1061,7 +1089,7 @@ struct BjetTaggingGnn {
bool matchedMcColl = collision.has_mcCollision() && std::fabs(collision.template mcCollision_as<FilteredCollisionsMCP>().posZ()) < vertexZCut;

for (const auto& track : tracks) {
if (track.eta() <= trackEtaMin || track.eta() >= trackEtaMax) {
if (!jetderiveddatautilities::selectTrack(track, trackSelectionBits) || track.eta() <= trackEtaMin || track.eta() >= trackEtaMax) {
continue;
}
registry.fill(HIST("h_trackpT"), track.pt(), weightEvt);
Expand All @@ -1072,9 +1100,37 @@ struct BjetTaggingGnn {
continue;
}
auto particle = track.template mcParticle_as<aod::JetParticles>();
if (particle.isPhysicalPrimary() && particle.eta() > trackEtaMin && particle.eta() < trackEtaMax) {
registry.fill(HIST("h2_trackpT_partpT"), track.pt(), particle.pt(), weightEvt);
registry.fill(HIST("h_partpT_matched_fine"), particle.pt(), weightEvt);
if (particle.isPhysicalPrimary()) {
if (particle.eta() > trackEtaMin && particle.eta() < trackEtaMax) {
registry.fill(HIST("h2_trackpT_partpT"), track.pt(), particle.pt(), weightEvt);
registry.fill(HIST("h_partpT_matched_fine"), particle.pt(), weightEvt);
// Track association accuracy as a function of DCA
if (track.pt() >= trackPtMin) {
if (particle.mcCollisionId() == collision.mcCollisionId()) {
registry.fill(HIST("h_dcaXY_coll_matched"), std::fabs(track.dcaXY()), weightEvt); // Matched to particle from the same MC collision
registry.fill(HIST("h_dcaZ_coll_matched"), std::fabs(track.dcaZ()), weightEvt);
int origin = RecoDecay::getParticleOrigin(allParticles, particle, false);
if (origin == RecoDecay::OriginType::NonPrompt) {
registry.fill(HIST("h_dcaXY_coll_matched_b"), std::fabs(track.dcaXY()), weightEvt);
registry.fill(HIST("h_dcaZ_coll_matched_b"), std::fabs(track.dcaZ()), weightEvt);
} else if (origin == RecoDecay::OriginType::Prompt) {
registry.fill(HIST("h_dcaXY_coll_matched_c"), std::fabs(track.dcaXY()), weightEvt);
registry.fill(HIST("h_dcaZ_coll_matched_c"), std::fabs(track.dcaZ()), weightEvt);
} else {
registry.fill(HIST("h_dcaXY_coll_matched_lf"), std::fabs(track.dcaXY()), weightEvt);
registry.fill(HIST("h_dcaZ_coll_matched_lf"), std::fabs(track.dcaZ()), weightEvt);
}
} else {
registry.fill(HIST("h_dcaXY_coll_mismatched"), std::fabs(track.dcaXY()), weightEvt); // Matched to particle from a different MC collision
registry.fill(HIST("h_dcaZ_coll_mismatched"), std::fabs(track.dcaZ()), weightEvt);
}
}
}
} else {
if (particle.eta() > trackEtaMin && particle.eta() < trackEtaMax && track.pt() >= trackPtMin) {
registry.fill(HIST("h_dcaXY_npp"), std::fabs(track.dcaXY()), weightEvt);
registry.fill(HIST("h_dcaZ_npp"), std::fabs(track.dcaZ()), weightEvt);
}
}
}

Expand Down
Loading