diff --git a/math/mathcore/inc/TRandom.h b/math/mathcore/inc/TRandom.h index fbbcb69295134..4ab98bb6b36c4 100644 --- a/math/mathcore/inc/TRandom.h +++ b/math/mathcore/inc/TRandom.h @@ -23,6 +23,7 @@ ////////////////////////////////////////////////////////////////////////// #include "TNamed.h" +#include class TRandom : public TNamed, public ROOT::Math::TRandomEngine { @@ -56,6 +57,16 @@ class TRandom : public TNamed, public ROOT::Math::TRandomEngine { virtual Double_t Uniform(Double_t x1, Double_t x2); virtual void WriteRandom(const char *filename) const; + // std::UniformRandomBitGenerator interface -- makes TRandom usable directly + // with std::shuffle, std::uniform_int_distribution and similar. + using result_type = UInt_t; + static constexpr result_type min() { return 0; } + static constexpr result_type max() { return std::numeric_limits::max(); } + /// \note Rndm() returns a double in ]0,1[, so converting back to UInt_t + /// has a small precision loss. Subclasses with access to raw integer output + /// should override this for better accuracy. + virtual result_type operator()() { return static_cast(Rndm() * (static_cast(max()) + 1.0)); } + ClassDefOverride(TRandom,3) //Simple Random number generator (periodicity = 10**9) }; diff --git a/math/mathcore/inc/TRandom2.h b/math/mathcore/inc/TRandom2.h index 41968af132988..1396924273a3f 100644 --- a/math/mathcore/inc/TRandom2.h +++ b/math/mathcore/inc/TRandom2.h @@ -39,6 +39,7 @@ class TRandom2 : public TRandom { void RndmArray(Int_t n, Double_t *array) override; void SetSeed(ULong_t seed=0) override; UInt_t GetSeed() const override; + result_type operator()() override; ClassDefOverride(TRandom2, 1) // Random number generator with periodicity of 10**26 }; diff --git a/math/mathcore/inc/TRandom3.h b/math/mathcore/inc/TRandom3.h index e60e60a2992ee..726b68bf50099 100644 --- a/math/mathcore/inc/TRandom3.h +++ b/math/mathcore/inc/TRandom3.h @@ -43,6 +43,7 @@ class TRandom3 : public TRandom { void RndmArray(Int_t n, Double_t *array) override; void SetSeed(ULong_t seed=0) override; virtual const UInt_t *GetState() const { return fMt; } + result_type operator()() override; ClassDefOverride(TRandom3,2) //Random number generator: Mersenne Twister }; diff --git a/math/mathcore/src/TRandom2.cxx b/math/mathcore/src/TRandom2.cxx index 543e5e1850835..ee9952df24239 100644 --- a/math/mathcore/src/TRandom2.cxx +++ b/math/mathcore/src/TRandom2.cxx @@ -26,6 +26,8 @@ The publications are available online at #include "TRandom2.h" #include "TUUID.h" +#define TAUSWORTHE(s,a,b,c,d) (((s &c) <>b) + //////////////////////////////////////////////////////////////////////////////// @@ -53,17 +55,11 @@ TRandom2::~TRandom2() Double_t TRandom2::Rndm() { -#define TAUSWORTHE(s,a,b,c,d) (((s &c) <>b) - // scale by 1./(Max + 1) = 1./4294967296 const double kScale = 2.3283064365386963e-10; // range in 32 bit ( 1/(2**32) - fSeed = TAUSWORTHE (fSeed, 13, 19, 4294967294UL, 12); - fSeed1 = TAUSWORTHE (fSeed1, 2, 25, 4294967288UL, 4); - fSeed2 = TAUSWORTHE (fSeed2, 3, 11, 4294967280UL, 17); - - UInt_t iy = fSeed ^ fSeed1 ^ fSeed2; - if (iy) return kScale*static_cast(iy); + UInt_t iy = operator()(); + if (iy) return kScale * static_cast(iy); return Rndm(); } @@ -166,3 +162,17 @@ UInt_t TRandom2::GetSeed() const { return fSeed; } + +//////////////////////////////////////////////////////////////////////////////// +/// \brief Return a random 32-bit integer, advancing the generator state by one step. +/// +/// Implements the std::UniformRandomBitGenerator interface. Returns the raw +/// Tausworthe XOR output directly, avoiding the round-trip through double. + +TRandom::result_type TRandom2::operator()() +{ + fSeed = TAUSWORTHE(fSeed, 13, 19, 4294967294UL, 12); + fSeed1 = TAUSWORTHE(fSeed1, 2, 25, 4294967288UL, 4); + fSeed2 = TAUSWORTHE(fSeed2, 3, 11, 4294967280UL, 17); + return fSeed ^ fSeed1 ^ fSeed2; +} diff --git a/math/mathcore/src/TRandom3.cxx b/math/mathcore/src/TRandom3.cxx index 82207332b51af..e47d8f71cd0fe 100644 --- a/math/mathcore/src/TRandom3.cxx +++ b/math/mathcore/src/TRandom3.cxx @@ -105,6 +105,21 @@ TRandom3::~TRandom3() /// Generate number in interval (0,1): 0 and 1 are not included in the interval Double_t TRandom3::Rndm() +{ + // 2.3283064365386963e-10 == 1./(UINT_MAX+1UL) -> then returned value cannot be = 1.0 + UInt_t y = operator()(); + if (y) return ((Double_t) y * 2.3283064365386963e-10); // * Power(2,-32) + return Rndm(); +} + +//////////////////////////////////////////////////////////////////////////////// +/// \brief Return a random 32-bit integer, advancing the generator state by one step. +/// +/// Implements the std::UniformRandomBitGenerator interface. Returns the raw +/// Mersenne Twister output directly, including zero, avoiding the round-trip +/// through double. + +TRandom::result_type TRandom3::operator()() { UInt_t y; @@ -119,12 +134,12 @@ Double_t TRandom3::Rndm() if (fCount624 >= kN) { Int_t i; - for (i=0; i < kN-kM; i++) { + for (i = 0; i < kN-kM; i++) { y = (fMt[i] & kUpperMask) | (fMt[i+1] & kLowerMask); fMt[i] = fMt[i+kM] ^ (y >> 1) ^ ((y & 0x1) ? kMatrixA : 0x0); } - for ( ; i < kN-1 ; i++) { + for (; i < kN-1; i++) { y = (fMt[i] & kUpperMask) | (fMt[i+1] & kLowerMask); fMt[i] = fMt[i+kM-kN] ^ (y >> 1) ^ ((y & 0x1) ? kMatrixA : 0x0); } @@ -136,13 +151,11 @@ Double_t TRandom3::Rndm() y = fMt[fCount624++]; y ^= (y >> 11); - y ^= ((y << 7 ) & kTemperingMaskB ); - y ^= ((y << 15) & kTemperingMaskC ); + y ^= ((y << 7 ) & kTemperingMaskB); + y ^= ((y << 15) & kTemperingMaskC); y ^= (y >> 18); - // 2.3283064365386963e-10 == 1./(UINT_MAX+1UL) -> then returned value cannot be = 1.0 - if (y) return ( (Double_t) y * 2.3283064365386963e-10); // * Power(2,-32) - return Rndm(); + return y; } //////////////////////////////////////////////////////////////////////////////// diff --git a/math/mathcore/test/testMathRandom.cxx b/math/mathcore/test/testMathRandom.cxx index 9ac1655a33a93..2ff300aaabdaf 100644 --- a/math/mathcore/test/testMathRandom.cxx +++ b/math/mathcore/test/testMathRandom.cxx @@ -15,9 +15,11 @@ //#include "TRandomNew3.h" #include "TStopwatch.h" +#include #include - +#include #include +#include using namespace ROOT::Math; @@ -193,21 +195,80 @@ bool test4() { } +bool test5() { + std::cout << "\nTesting TRandom std::UniformRandomBitGenerator interface" << std::endl; + + TRandom3 rng(42); + + static_assert(std::is_same::value, "result_type must be UInt_t"); + + if (TRandom::min() != 0) { + std::cout << "TRandom::min() is not 0" << std::endl; + return false; + } + if (TRandom::max() != std::numeric_limits::max()) { + std::cout << "TRandom::max() is wrong" << std::endl; + return false; + } + + for (int i = 0; i < 10000; i++) { + auto v = rng(); + if (v < TRandom::min() || v > TRandom::max()) { + std::cout << "operator() returned out-of-range value: " << v << std::endl; + return false; + } + } + + std::vector vec(10); + std::iota(vec.begin(), vec.end(), 1); + std::shuffle(vec.begin(), vec.end(), rng); + std::sort(vec.begin(), vec.end()); + for (int i = 0; i < 10; i++) { + if (vec[i] != i + 1) { + std::cout << "std::shuffle corrupted the elements" << std::endl; + return false; + } + } + + std::uniform_int_distribution dist(0, 99); + for (int i = 0; i < 10000; i++) { + int v = dist(rng); + if (v < 0 || v > 99) { + std::cout << "std::uniform_int_distribution produced out-of-range value: " << v << std::endl; + return false; + } + } + + // verify TRandom2 override returns raw integers (no double round-trip) + TRandom2 rng2(42); + for (int i = 0; i < 10000; i++) { + auto v2 = rng2(); + if (v2 < TRandom::min() || v2 > TRandom::max()) { + std::cout << "TRandom2::operator() returned out-of-range value: " << v2 << std::endl; + return false; + } + } + + std::cout << "TRandom std interface: OK" << std::endl; + return true; +} + bool testMathRandom() { - + bool ret = true; std::cout << "testing generating " << NR << " numbers " << std::endl; - ret &= test1(); - ret &= test2(); - ret &= test3(); - ret &= test4(); + ret &= test1(); + ret &= test2(); + ret &= test3(); + ret &= test4(); + ret &= test5(); if (!ret) Error("testMathRandom","Test Failed"); else std::cout << "\nTestMathRandom: OK \n"; - return ret; + return ret; } int main() {