Skip to content

Commit 0752b3d

Browse files
authored
[PWGDQ] updates to the ML training data producer task (#15758)
1 parent 58d53ae commit 0752b3d

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

PWGDQ/Tasks/mftMchMatcher.cxx

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
#include <cstdint>
5252
#include <map>
5353
#include <memory>
54+
#include <random>
5455
#include <string>
5556
#include <unordered_map>
5657
#include <utility>
@@ -120,6 +121,7 @@ DECLARE_SOA_COLUMN(TimeResMFT, timeResMFT, float);
120121
DECLARE_SOA_COLUMN(Chi2MFT, chi2MFT, float);
121122
DECLARE_SOA_COLUMN(McMaskMFT, mcMaskMFT, int);
122123
DECLARE_SOA_COLUMN(MftClusterSizesAndTrackFlags, mftClusterSizesAndTrackFlags, uint64_t);
124+
DECLARE_SOA_COLUMN(TrackTypeMFT, trackTypeMFT, int);
123125

124126
DECLARE_SOA_COLUMN(CXXMFT, cXXMFT, float);
125127
DECLARE_SOA_COLUMN(CYYMFT, cYYMFT, float);
@@ -183,6 +185,7 @@ DECLARE_SOA_TABLE(FwdMatchMLCandidates, "AOD", "FWDMLCAND",
183185
fwdmatchcandidates::TimeResMFT,
184186
fwdmatchcandidates::Chi2MFT,
185187
fwdmatchcandidates::MftClusterSizesAndTrackFlags,
188+
fwdmatchcandidates::TrackTypeMFT,
186189
fwdmatchcandidates::CXXMFT,
187190
fwdmatchcandidates::CYYMFT,
188191
fwdmatchcandidates::CPhiPhiMFT,
@@ -236,6 +239,8 @@ struct mftMchMatcher {
236239
Configurable<bool> fKeepBestMatch{"cfgKeepBestMatch", false, "Keep only the best match global muons in the skimming"};
237240
Configurable<float> fzMatching{"cfgzMatching", -77.5f, "Plane for MFT-MCH matching"};
238241

242+
Configurable<float> fSamplingFraction{"cfgSamplingFraction", 1.f, "Fraction of randomly selected events to be processed"};
243+
239244
//// Variables for ccdb
240245
Configurable<std::string> ccdburl{"ccdb-url", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
241246
Configurable<std::string> grpPath{"grpPath", "GLO/GRP/GRP", "Path of the grp file"};
@@ -263,6 +268,9 @@ struct mftMchMatcher {
263268

264269
o2::parameters::GRPMagField* fGrpMag = nullptr;
265270

271+
std::uniform_real_distribution<double> mDistribution{0.0, 1.0};
272+
std::mt19937 mGenerator;
273+
266274
o2::globaltracking::MatchGlobalFwd mMatching;
267275

268276
std::unordered_map<int64_t, int32_t> mftCovIndexes;
@@ -403,8 +411,15 @@ struct mftMchMatcher {
403411
ccdbManager->get<TGeoManager>(geoPath);
404412
}
405413

406-
// int matchTypeMax = static_cast<int>(kMatchTypeUndefined);
407-
AxisSpec matchTypeAxis = {static_cast<int>(kMatchTypeUndefined), 0, static_cast<double>(kMatchTypeUndefined), ""};
414+
if (fSamplingFraction < 1.0) {
415+
std::random_device rd;
416+
mGenerator = std::mt19937(rd());
417+
}
418+
auto hAcceptedEvents = std::get<std::shared_ptr<TH1>>(registry.add("acceptedEvents", "Accepted events", {HistType::kTH1F, {{2, 0, 2.f, ""}}}));
419+
hAcceptedEvents->GetXaxis()->SetBinLabel(1, "total");
420+
hAcceptedEvents->GetXaxis()->SetBinLabel(2, "accepted");
421+
422+
AxisSpec matchTypeAxis = {static_cast<int>(kMatchTypeUndefined) + 1, 0, static_cast<double>(kMatchTypeUndefined) + 1, ""};
408423
auto hMatchType = std::get<std::shared_ptr<TH1>>(registry.add("matchType", "Match type", {HistType::kTH1F, {matchTypeAxis}}));
409424
hMatchType->GetXaxis()->SetBinLabel(1, "true (leading)");
410425
hMatchType->GetXaxis()->SetBinLabel(2, "wrong (leading)");
@@ -414,6 +429,7 @@ struct mftMchMatcher {
414429
hMatchType->GetXaxis()->SetBinLabel(6, "wrong (non leading)");
415430
hMatchType->GetXaxis()->SetBinLabel(7, "decay (non leading)");
416431
hMatchType->GetXaxis()->SetBinLabel(8, "fake (non leading)");
432+
hMatchType->GetXaxis()->SetBinLabel(9, "undefined");
417433
}
418434

419435
template <typename TMuons>
@@ -584,6 +600,16 @@ struct mftMchMatcher {
584600
VarManager::SetMatchingPlane(fzMatching.value);
585601
}
586602

603+
registry.get<TH1>(HIST("acceptedEvents"))->Fill(0);
604+
// reject a randomly selected fraction of events
605+
if (fSamplingFraction < 1.0) {
606+
double rnd = mDistribution(mGenerator);
607+
if (rnd > fSamplingFraction) {
608+
return;
609+
}
610+
}
611+
registry.get<TH1>(HIST("acceptedEvents"))->Fill(1);
612+
587613
fillBestMuonMatches(muonTracks);
588614

589615
std::vector<std::pair<int64_t, int64_t>> matchablePairs;
@@ -684,6 +710,7 @@ struct mftMchMatcher {
684710
mfttrack.trackTimeRes(),
685711
mfttrack.chi2(),
686712
mfttrack.mftClusterSizesAndTrackFlags(),
713+
(mfttrack.isCA() ? 1 : 0),
687714
mftpropCov(0, 0),
688715
mftpropCov(1, 1),
689716
mftpropCov(2, 2),

0 commit comments

Comments
 (0)