Skip to content

Commit d9e3198

Browse files
authored
Merge branch 'HFTagging' into HFTagging_git_upload
2 parents 949f9b9 + 8022672 commit d9e3198

File tree

4 files changed

+297
-3
lines changed

4 files changed

+297
-3
lines changed

PWGJE/Core/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ o2physics_add_library(PWGJECore
1414
SOURCES FastJetUtilities.cxx
1515
JetFinder.cxx
1616
JetBkgSubUtils.cxx
17-
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore FastJet::FastJet FastJet::Contrib)
17+
PUBLIC_LINK_LIBRARIES O2Physics::AnalysisCore FastJet::FastJet FastJet::Contrib ONNXRuntime::ONNXRuntime)
1818

1919
o2physics_target_root_dictionary(PWGJECore
2020
HEADERS JetFinder.h

PWGJE/Core/JetTaggingUtilities.h

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@
3838
#include "Common/Core/trackUtilities.h"
3939
#include "PWGJE/Core/JetUtilities.h"
4040

41+
#if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
42+
#include <onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>
43+
#else
44+
#include <onnxruntime_cxx_api.h>
45+
#endif
46+
4147
using namespace o2::constants::physics;
4248

4349
enum JetTaggingSpecies {
@@ -102,6 +108,159 @@ struct BJetSVParams {
102108
double mDecayLength3DError = 0.0;
103109
};
104110

111+
// ONNX Runtime tensor (Ort::Value) allocator for using customized inputs of ML models.
112+
class TensorAllocator
113+
{
114+
protected:
115+
#if !__has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
116+
Ort::MemoryInfo mem_info;
117+
#endif
118+
public:
119+
TensorAllocator()
120+
#if !__has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
121+
: mem_info(Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault))
122+
#endif
123+
{
124+
}
125+
~TensorAllocator() = default;
126+
template <typename T>
127+
Ort::Value createTensor(std::vector<T>& input, std::vector<int64_t>& inputShape)
128+
{
129+
#if __has_include(<onnxruntime/core/session/onnxruntime_cxx_api.h>)
130+
return Ort::Experimental::Value::CreateTensor<T>(input.data(), input.size(), inputShape);
131+
#else
132+
return Ort::Value::CreateTensor<T>(mem_info, input.data(), input.size(), inputShape.data(), inputShape.size());
133+
#endif
134+
}
135+
};
136+
137+
// TensorAllocator for GNN b-jet tagger
138+
class GNNBjetAllocator : public TensorAllocator
139+
{
140+
private:
141+
int64_t nJetFeat;
142+
int64_t nTrkFeat;
143+
int64_t nFlav;
144+
int64_t nTrkOrigin;
145+
int64_t maxNNodes;
146+
147+
std::vector<float> tfJetMean;
148+
std::vector<float> tfJetStdev;
149+
std::vector<float> tfTrkMean;
150+
std::vector<float> tfTrkStdev;
151+
152+
std::vector<std::vector<int64_t>> edgesList;
153+
154+
// Jet feature normalization
155+
template <typename T>
156+
T jetFeatureTransform(T feat, int idx) const
157+
{
158+
return (feat - tfJetMean[idx]) / tfJetStdev[idx];
159+
}
160+
161+
// Track feature normalization
162+
template <typename T>
163+
T trkFeatureTransform(T feat, int idx) const
164+
{
165+
return (feat - tfTrkMean[idx]) / tfTrkStdev[idx];
166+
}
167+
168+
// Edge input of GNN (fully-connected graph)
169+
void setEdgesList(void)
170+
{
171+
for (int64_t nNodes = 0; nNodes <= maxNNodes; ++nNodes) {
172+
std::vector<std::pair<int64_t, int64_t>> edges;
173+
// Generate all permutations of (i, j) where i != j
174+
for (int64_t i = 0; i < nNodes; ++i) {
175+
for (int64_t j = 0; j < nNodes; ++j) {
176+
if (i != j) {
177+
edges.emplace_back(i, j);
178+
}
179+
}
180+
}
181+
// Add self-loops (i, i)
182+
for (int64_t i = 0; i < nNodes; ++i) {
183+
edges.emplace_back(i, i);
184+
}
185+
// Flatten
186+
std::vector<int64_t> flattenedEdges;
187+
for (const auto& edge : edges) {
188+
flattenedEdges.push_back(edge.first);
189+
}
190+
for (const auto& edge : edges) {
191+
flattenedEdges.push_back(edge.second);
192+
}
193+
edgesList.push_back(flattenedEdges);
194+
}
195+
}
196+
197+
// Replace NaN in a vector into value
198+
template <typename T>
199+
static int replaceNaN(std::vector<T>& vec, T value)
200+
{
201+
int numNaN = 0;
202+
for (auto& el : vec) {
203+
if (std::isnan(el)) {
204+
el = value;
205+
++numNaN;
206+
}
207+
}
208+
return numNaN;
209+
}
210+
211+
public:
212+
GNNBjetAllocator() : TensorAllocator(), nJetFeat(4), nTrkFeat(13), nFlav(3), nTrkOrigin(5), maxNNodes(40) {}
213+
GNNBjetAllocator(int64_t nJetFeat, int64_t nTrkFeat, int64_t nFlav, int64_t nTrkOrigin, std::vector<float>& tfJetMean, std::vector<float>& tfJetStdev, std::vector<float>& tfTrkMean, std::vector<float>& tfTrkStdev, int64_t maxNNodes = 40)
214+
: TensorAllocator(), nJetFeat(nJetFeat), nTrkFeat(nTrkFeat), nFlav(nFlav), nTrkOrigin(nTrkOrigin), maxNNodes(maxNNodes), tfJetMean(tfJetMean), tfJetStdev(tfJetStdev), tfTrkMean(tfTrkMean), tfTrkStdev(tfTrkStdev)
215+
{
216+
setEdgesList();
217+
}
218+
~GNNBjetAllocator() = default;
219+
220+
// Copy operator for initializing GNNBjetAllocator using Configurable values
221+
GNNBjetAllocator& operator=(const GNNBjetAllocator& other)
222+
{
223+
nJetFeat = other.nJetFeat;
224+
nTrkFeat = other.nTrkFeat;
225+
nFlav = other.nFlav;
226+
nTrkOrigin = other.nTrkOrigin;
227+
maxNNodes = other.maxNNodes;
228+
tfJetMean = other.tfJetMean;
229+
tfJetStdev = other.tfJetStdev;
230+
tfTrkMean = other.tfTrkMean;
231+
tfTrkStdev = other.tfTrkStdev;
232+
setEdgesList();
233+
return *this;
234+
}
235+
236+
// Allocate & Return GNN input tensors (std::vector<Ort::Value>)
237+
template <typename T>
238+
void getGNNInput(std::vector<T>& jetFeat, std::vector<std::vector<T>>& trkFeat, std::vector<T>& feat, std::vector<Ort::Value>& gnnInput)
239+
{
240+
int64_t nNodes = trkFeat.size();
241+
242+
std::vector<int64_t> edgesShape{2, nNodes * nNodes};
243+
gnnInput.emplace_back(createTensor(edgesList[nNodes], edgesShape));
244+
245+
std::vector<int64_t> featShape{nNodes, nJetFeat + nTrkFeat};
246+
247+
int numNaN = replaceNaN(jetFeat, 0.f);
248+
for (auto& aTrkFeat : trkFeat) {
249+
for (size_t i = 0; i < jetFeat.size(); ++i)
250+
feat.push_back(jetFeatureTransform(jetFeat[i], i));
251+
numNaN += replaceNaN(aTrkFeat, 0.f);
252+
for (size_t i = 0; i < aTrkFeat.size(); ++i)
253+
feat.push_back(trkFeatureTransform(aTrkFeat[i], i));
254+
}
255+
256+
gnnInput.emplace_back(createTensor(feat, featShape));
257+
258+
if (numNaN > 0) {
259+
LOGF(info, "NaN found in GNN input feature, number of NaN: %d", numNaN);
260+
}
261+
}
262+
};
263+
105264
//________________________________________________________________________
106265
bool isBHadron(int pc)
107266
{
@@ -1054,6 +1213,63 @@ void analyzeJetTrackInfo4ML(AnalysisJet const& analysisJet, AnyTracks const& /*a
10541213
// Sort the tracks based on their IP significance in descending order
10551214
std::sort(tracksParams.begin(), tracksParams.end(), compare);
10561215
}
1216+
1217+
// Looping over the track info and putting them in the input vector (for GNN b-jet tagging)
1218+
template <typename AnalysisJet, typename AnyTracks, typename AnyOriginalTracks>
1219+
void analyzeJetTrackInfo4GNN(AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, AnyOriginalTracks const& /*origTracks*/, std::vector<std::vector<float>>& tracksParams, float trackPtMin = 0.5, int64_t nMaxConstit = 40)
1220+
{
1221+
for (const auto& constituent : analysisJet.template tracks_as<AnyTracks>()) {
1222+
1223+
if (constituent.pt() < trackPtMin) {
1224+
continue;
1225+
}
1226+
1227+
int sign = jettaggingutilities::getGeoSign(analysisJet, constituent);
1228+
1229+
auto origConstit = constituent.template track_as<AnyOriginalTracks>();
1230+
1231+
if (static_cast<int64_t>(tracksParams.size()) < nMaxConstit) {
1232+
tracksParams.emplace_back(std::vector<float>{constituent.pt(), origConstit.phi(), constituent.eta(), static_cast<float>(constituent.sign()), std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), static_cast<float>(origConstit.itsNCls()), static_cast<float>(origConstit.tpcNClsFound()), static_cast<float>(origConstit.tpcNClsCrossedRows()), origConstit.itsChi2NCl(), origConstit.tpcChi2NCl()});
1233+
} else {
1234+
// If there are more than nMaxConstit constituents in the jet, select only nMaxConstit constituents with the highest DCA_XY significance.
1235+
size_t minIdx = 0;
1236+
for (size_t i = 0; i < tracksParams.size(); ++i) {
1237+
if (tracksParams[i][4] / tracksParams[i][5] < tracksParams[minIdx][4] / tracksParams[minIdx][5])
1238+
minIdx = i;
1239+
}
1240+
if (std::abs(constituent.dcaXY()) * sign / constituent.sigmadcaXY() > tracksParams[minIdx][4] / tracksParams[minIdx][5])
1241+
tracksParams[minIdx] = std::vector<float>{constituent.pt(), origConstit.phi(), constituent.eta(), static_cast<float>(constituent.sign()), std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), static_cast<float>(origConstit.itsNCls()), static_cast<float>(origConstit.tpcNClsFound()), static_cast<float>(origConstit.tpcNClsCrossedRows()), origConstit.itsChi2NCl(), origConstit.tpcChi2NCl()};
1242+
}
1243+
}
1244+
}
1245+
1246+
// Discriminant value for GNN b-jet tagging
1247+
template <typename T>
1248+
T Db(const std::vector<T>& logits, double fC = 0.018)
1249+
{
1250+
auto softmax = [](const std::vector<T>& logits) {
1251+
std::vector<T> res;
1252+
T maxLogit = *std::max_element(logits.begin(), logits.end());
1253+
T sumLogit = 0.;
1254+
for (size_t i = 0; i < logits.size(); ++i) {
1255+
res.push_back(std::exp(logits[i] - maxLogit));
1256+
sumLogit += res[i];
1257+
}
1258+
for (size_t i = 0; i < logits.size(); ++i) {
1259+
res[i] /= sumLogit;
1260+
}
1261+
return res;
1262+
};
1263+
1264+
std::vector<T> softmaxLogits = softmax(logits);
1265+
1266+
if (softmaxLogits[1] == 0. && softmaxLogits[2] == 0.) {
1267+
LOG(debug) << "jettaggingutilities::Db, Divide by zero: softmaxLogits = (" << softmaxLogits[0] << ", " << softmaxLogits[1] << ", " << softmaxLogits[2] << ")";
1268+
}
1269+
1270+
return std::log(softmaxLogits[0] / (fC * softmaxLogits[1] + (1. - fC) * softmaxLogits[2]));
1271+
}
1272+
10571273
}; // namespace jettaggingutilities
10581274

10591275
#endif // PWGJE_CORE_JETTAGGINGUTILITIES_H_

PWGJE/TableProducer/jetTaggerHF.cxx

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,29 @@ struct JetTaggerHFTask {
9393
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"};
9494
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};
9595

96+
9697
Configurable<std::string> IPparameterPathsCCDB{"IPparameterPathsCCDB", "Users/l/leehy/LHC24g4/", "Paths for fitting parameters of resolution functions for IP method on CCDB"};
9798
Configurable<std::vector<int64_t>> IPtimestampCCDB{"IPtimestampCCDB", std::vector<int64_t>{1737027389227, 1737027391774, 1737027393668, 1737027395548, 1737027397505, 1737027399396, 1737027401294}, "timestamp of the resolution function for IP method used to query in CCDB"};
9899
Configurable<bool> usepTcategorize{"usepTcategorize", false, "p_T categorize TF1 function with Inclusive jet"};
99100

101+
// GNN configuration
102+
Configurable<double> fC{"fC", 0.018, "Parameter f_c for D_b calculation"};
103+
Configurable<int64_t> nJetFeat{"nJetFeat", 4, "Number of jet GNN input features"};
104+
Configurable<int64_t> nTrkFeat{"nTrkFeat", 13, "Number of track GNN input features"};
105+
Configurable<int64_t> nTrkOrigin{"nTrkOrigin", 5, "Number of track origin categories"};
106+
Configurable<std::vector<float>> transformFeatureJetMean{"transformFeatureJetMean",
107+
std::vector<float>{3.7093048e+01, 3.1462731e+00, -8.9617318e-04, 4.5036483e+00},
108+
"Mean values for each GNN input feature (jet)"};
109+
Configurable<std::vector<float>> transformFeatureJetStdev{"transformFeatureJetStdev",
110+
std::vector<float>{3.9559139e+01, 1.8156786e+00, 2.8845072e-01, 4.6293869e+00},
111+
"Stdev values for each GNN input feature (jet)"};
112+
Configurable<std::vector<float>> transformFeatureTrkMean{"transformFeatureTrkMean",
113+
std::vector<float>{5.8772368e+00, 3.1470699e+00, -1.4703944e-03, 1.9976571e-03, 1.7700187e-03, 3.5821514e-03, 1.9987826e-03, 7.3673888e-03, 6.6411214e+00, 1.3810074e+02, 1.4888744e+02, 6.5751970e-01, 1.6469173e+00},
114+
"Mean values for each GNN input feature (track)"};
115+
Configurable<std::vector<float>> transformFeatureTrkStdev{"transformFeatureTrkStdev",
116+
std::vector<float>{9.2763824e+00, 1.8162115e+00, 3.1512174e-01, 9.9999982e-01, 5.6147423e-02, 2.3086982e-02, 1.6523319e+00, 4.8507337e-02, 8.1565088e-01, 1.2891182e+01, 1.1064601e+01, 9.5457840e-01, 2.8930053e-01},
117+
"Stdev values for each GNN input feature (track)"};
118+
100119
// axis spec
101120
ConfigurableAxis binTrackProbability{"binTrackProbability", {100, 0.f, 1.f}, ""};
102121
ConfigurableAxis binJetFlavour{"binJetFlavour", {6, -0.5, 5.5}, ""};
@@ -105,6 +124,7 @@ struct JetTaggerHFTask {
105124
o2::ccdb::CcdbApi ccdbApi;
106125

107126
using JetTracksExt = soa::Join<aod::JetTracks, aod::JTrackExtras, aod::JTrackPIs>;
127+
using OriginalTracks = soa::Join<aod::Tracks, aod::TracksCov, aod::TrackSelection, aod::TracksDCA, aod::TracksDCACov, aod::TracksExtra>;
108128

109129
bool useResoFuncFromIncJet = false;
110130
int maxOrder = -1;
@@ -137,6 +157,8 @@ struct JetTaggerHFTask {
137157
std::vector<uint16_t> decisionNonML;
138158
std::vector<float> scoreML;
139159

160+
jettaggingutilities::GNNBjetAllocator tensorAlloc;
161+
140162
template <typename T, typename U>
141163
float calculateJetProbability(int origin, T const& jet, U const& tracks, bool const& isMC = false)
142164
{
@@ -220,6 +242,25 @@ struct JetTaggerHFTask {
220242
}
221243
}
222244
}
245+
if (doprocessAlgorithmGNN) {
246+
if constexpr (isMC) {
247+
switch (origin) {
248+
case 2:
249+
registry.fill(HIST("h_db_b"), scoreML[jet.globalIndex()]);
250+
break;
251+
case 1:
252+
registry.fill(HIST("h_db_c"), scoreML[jet.globalIndex()]);
253+
break;
254+
case 0:
255+
case 3:
256+
registry.fill(HIST("h_db_lf"), scoreML[jet.globalIndex()]);
257+
break;
258+
default:
259+
LOGF(debug, "doprocessAlgorithmGNN, Unexpected origin value: %d (%d)", origin, jet.globalIndex());
260+
}
261+
}
262+
registry.fill(HIST("h2_pt_db"), jet.pt(), scoreML[jet.globalIndex()]);
263+
}
223264
taggingTable(decisionNonML[jet.globalIndex()], jetProb, scoreML[jet.globalIndex()]);
224265
}
225266

@@ -352,7 +393,7 @@ struct JetTaggerHFTask {
352393
}
353394
}
354395

355-
if (doprocessAlgorithmML) {
396+
if (doprocessAlgorithmML || doprocessAlgorithmGNN) {
356397
bMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl);
357398
if (loadModelsFromCCDB) {
358399
ccdbApi.init(ccdbUrl);
@@ -363,6 +404,14 @@ struct JetTaggerHFTask {
363404
// bMlResponse.cacheInputFeaturesIndices(namesInputFeatures);
364405
bMlResponse.init();
365406
}
407+
408+
if (doprocessAlgorithmGNN) {
409+
tensorAlloc = jettaggingutilities::GNNBjetAllocator(nJetFeat.value, nTrkFeat.value, nClassesMl.value, nTrkOrigin.value, transformFeatureJetMean.value, transformFeatureJetStdev.value, transformFeatureTrkMean.value, transformFeatureTrkStdev.value, nJetConst);
410+
registry.add("h_db_b", "#it{D}_{b} b-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}});
411+
registry.add("h_db_c", "#it{D}_{b} c-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}});
412+
registry.add("h_db_lf", "#it{D}_{b} lf-jet;#it{D}_{b}", {HistType::kTH1F, {{50, -10., 35.}}});
413+
registry.add("h2_pt_db", "#it{p}_{T} vs. #it{D}_{b};#it{p}_{T}^{ch jet} (GeV/#it{c}^{2});#it{D}_{b}", {HistType::kTH2F, {{100, 0., 200.}, {50, -10., 35.}}});
414+
}
366415
}
367416

368417
template <typename AnyJets, typename AnyTracks, typename SecondaryVertices>
@@ -392,6 +441,29 @@ struct JetTaggerHFTask {
392441
}
393442
}
394443

444+
template <typename AnyJets, typename AnyTracks, typename AnyOriginalTracks>
445+
void analyzeJetAlgorithmGNN(AnyJets const& jets, AnyTracks const& tracks, AnyOriginalTracks const& origTracks)
446+
{
447+
for (const auto& jet : jets) {
448+
std::vector<std::vector<float>> trkFeat;
449+
jettaggingutilities::analyzeJetTrackInfo4GNN(jet, tracks, origTracks, trkFeat, trackPtMin, nJetConst);
450+
451+
std::vector<float> jetFeat{jet.pt(), jet.phi(), jet.eta(), jet.mass()};
452+
453+
if (trkFeat.size() > 0) {
454+
std::vector<float> feat;
455+
std::vector<Ort::Value> gnnInput;
456+
tensorAlloc.getGNNInput(jetFeat, trkFeat, feat, gnnInput);
457+
458+
auto modelOutput = bMlResponse.getModelOutput(gnnInput, 0);
459+
scoreML[jet.globalIndex()] = jettaggingutilities::Db(modelOutput, fC);
460+
} else {
461+
scoreML[jet.globalIndex()] = -999.;
462+
LOGF(debug, "doprocessAlgorithmGNN, trkFeat.size() <= 0 (%d)", jet.globalIndex());
463+
}
464+
}
465+
}
466+
395467
void processDummy(aod::JetCollisions const&)
396468
{
397469
}
@@ -430,6 +502,12 @@ struct JetTaggerHFTask {
430502
}
431503
PROCESS_SWITCH(JetTaggerHFTask, processAlgorithmML, "Fill ML evaluation score for charged jets", false);
432504

505+
void processAlgorithmGNN(JetTable const& jets, JetTracksExt const& jtracks, OriginalTracks const& origTracks)
506+
{
507+
analyzeJetAlgorithmGNN(jets, jtracks, origTracks);
508+
}
509+
PROCESS_SWITCH(JetTaggerHFTask, processAlgorithmGNN, "Fill GNN evaluation score (D_b) for charged jets", false);
510+
433511
void processFillTables(std::conditional_t<isMCD, soa::Join<JetTable, aod::ChargedMCDetectorLevelJetFlavourDef>, JetTable>::iterator const& jet, JetTracksExt const& tracks)
434512
{
435513
fillTables<isMCD>(jet, tracks);

Tools/ML/MlResponse.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ class MlResponse
158158
LOG(fatal) << "Model index " << nModel << " is out of range! The number of initialised models is " << mModels.size() << ". Please check your configurables.";
159159
}
160160

161-
TypeOutputScore* outputPtr = mModels[nModel].evalModel(input);
161+
TypeOutputScore* outputPtr = mModels[nModel].template evalModel<TypeOutputScore>(input);
162162
return std::vector<TypeOutputScore>{outputPtr, outputPtr + mNClasses};
163163
}
164164

0 commit comments

Comments
 (0)