Commit 2574864a authored by CARDOSI Paul's avatar CARDOSI Paul
Browse files

Rewrite skip function such that there is no overflow when the nb of numbers to...

Rewrite skip function such that there is no overflow when the nb of numbers to skip is very large. Make operator () call counter unsigned to avoid undefined behavior on overflow.
parent 87ba288b
......@@ -35,18 +35,16 @@ class SpPhiloxGenerator {
philox4x32() = default;
explicit philox4x32(uint64 seed, int cycles = DEFAULT_CYCLES) {
explicit philox4x32(uint64 seed, int cycles = DEFAULT_CYCLES)
: counter_(), temp_results_(), key_(), temp_counter_(0), cycles_(cycles),
force_computation_(true), operatorPPcounter(0)
{
// Splitting the seed in two
key_[0] = static_cast<uint32>(seed);
key_[1] = static_cast<uint32>(seed >> 32);
counter_.fill(0);
temp_results_.fill(0);
cycles_ = cycles;
operatorPPcounter = 0;
}
// Returns the minimum value productible by the engine
......@@ -58,45 +56,58 @@ class SpPhiloxGenerator {
// Skip the specified number of steps
void Skip(uint64 count) {
if(count > 0) {
skipOnWrap();
const auto position = temp_counter_ + count;
if (position > 3) {
force_compute_ = true;
temp_counter_ = position % 4;
const auto nbOfCounterIncrements = position / 4;
const auto count_lo = static_cast<uint32>(nbOfCounterIncrements);
auto count_hi = static_cast<uint32>(nbOfCounterIncrements >> 32);
// 128 bit add
counter_[0] += count_lo;
if (counter_[0] < count_lo) {
++count_hi;
}
counter_[1] += count_hi;
if (counter_[1] < count_hi) {
if (++counter_[2] == 0) {
const auto nbStepsToNextMultipleOf4 = 4 - temp_counter_;
if(count <= nbStepsToNextMultipleOf4) {
temp_counter_ += count;
return;
}
count -= nbStepsToNextMultipleOf4;
temp_counter_ = 0;
const auto nbOfCounterIncrements = count / 4 + 1;
const auto newTempCounter = count % 4;
const auto count_lo = static_cast<uint32>(nbOfCounterIncrements);
auto count_hi = static_cast<uint32>(nbOfCounterIncrements >> 32);
// 128 bit add
counter_[0] += count_lo;
if (counter_[0] < count_lo) {
if(++counter_[1] == 0) {
if(++counter_[2] == 0) {
++counter_[3];
}
}
} else {
temp_counter_ = position;
}
counter_[1] += count_hi;
if (counter_[1] < count_hi) {
if (++counter_[2] == 0) {
++counter_[3];
}
}
temp_counter_ = newTempCounter;
force_computation_ = true;
}
}
// Returns an random number using the philox engine
// Returns a random number using the philox engine
uint32 operator()() {
operatorPPcounter += 1;
operatorPPcounter++;
skipOnWrap();
if(temp_counter_ == 4) {
temp_counter_ = 0;
SkipOne();
force_computation_ = true;
}
if (force_compute_) {
force_compute_ = false;
if(force_computation_) {
force_computation_ = false;
temp_results_ = counter_;
ExecuteRounds();
}
......@@ -107,7 +118,7 @@ class SpPhiloxGenerator {
return value;
}
int getOperatorPPCounter() const{
auto getOperatorPPCounter() const{
return operatorPPcounter;
}
......@@ -134,17 +145,17 @@ class SpPhiloxGenerator {
Key key_;
// To iterate through the temp_results_
uint64 temp_counter_ = 0;
uint64 temp_counter_;
// The number of cycles used to generate randomness
int cycles_;
// To force the engine to compute the rounds to populates temp_results_
bool force_compute_ = true;
bool force_computation_;
// The number of times operator () is called to ensure that the STL
// always call it once
int operatorPPcounter;
uint32 operatorPPcounter;
// Skip one step
void SkipOne() {
......@@ -200,14 +211,6 @@ class SpPhiloxGenerator {
(*key)[0] += kPhiloxW32A;
(*key)[1] += kPhiloxW32B;
}
void skipOnWrap() {
if(temp_counter_ > 3) {
temp_counter_ = 0;
SkipOne();
force_compute_ = true;
}
}
};
philox4x32 phEngine;
......@@ -240,10 +243,10 @@ public:
}
RealType getRand01() {
nbValuesGenerated += 1;
[[maybe_unused]] const int counterOperatorPPBefore = phEngine.getOperatorPPCounter();
nbValuesGenerated++;
[[maybe_unused]] const auto counterOperatorPPBefore = phEngine.getOperatorPPCounter();
const RealType number = dis01(phEngine);
[[maybe_unused]] const int counterOperatorPPAfter = phEngine.getOperatorPPCounter();
[[maybe_unused]] const auto counterOperatorPPAfter = phEngine.getOperatorPPCounter();
assert(counterOperatorPPAfter == counterOperatorPPBefore+1);
return number;
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment