Commit 4935fafd authored by CARDOSI Paul's avatar CARDOSI Paul
Browse files

Merge branch 'philox-gen-clean-up' into 'master'

Philox gen clean up

See merge request bramas/spetabaru!28
parents 87ba288b f92f1958
......@@ -35,18 +35,16 @@ class SpPhiloxGenerator {
philox4x32() = default;
explicit philox4x32(uint64 seed, int cycles = DEFAULT_CYCLES) {
explicit philox4x32(uint64 seed, int cycles = DEFAULT_CYCLES)
: counter_(), temp_results_(), key_(), temp_counter_(0), cycles_(cycles),
force_computation_(true), operatorPPcounter(0)
{
// Splitting the seed in two
key_[0] = static_cast<uint32>(seed);
key_[1] = static_cast<uint32>(seed >> 32);
counter_.fill(0);
temp_results_.fill(0);
cycles_ = cycles;
operatorPPcounter = 0;
}
// Returns the minimum value productible by the engine
......@@ -58,45 +56,60 @@ class SpPhiloxGenerator {
// Skip the specified number of steps
void Skip(uint64 count) {
if(count > 0) {
skipOnWrap();
const auto position = temp_counter_ + count;
if (position > 3) {
force_compute_ = true;
temp_counter_ = position % 4;
const auto nbOfCounterIncrements = position / 4;
const auto count_lo = static_cast<uint32>(nbOfCounterIncrements);
auto count_hi = static_cast<uint32>(nbOfCounterIncrements >> 32);
// 128 bit add
counter_[0] += count_lo;
if (counter_[0] < count_lo) {
++count_hi;
}
const auto nbStepsToNextMultipleOf4 = 4 - temp_counter_;
if(count <= nbStepsToNextMultipleOf4) {
temp_counter_ += count;
return;
}
count -= nbStepsToNextMultipleOf4;
// We need to add 1 to the counter because we have moved past
// all the 4 results from the current temp_results_ array. This
// also includes the special case where we already are on the edge
// (temp_counter_ == 4) but we haven't triggered a counter increment yet.
// We can safely add 1 here (instead of calling SkipOne). I won't cause any
// overfow since we are dividing the value of count by 4 and count has a
// width of 64 bits.
const auto nbOfCounterIncrements = count / 4 + 1;
temp_counter_ = count % 4;
const auto count_lo = static_cast<uint32>(nbOfCounterIncrements);
auto count_hi = static_cast<uint32>(nbOfCounterIncrements >> 32);
// 128 bit add
counter_[0] += count_lo;
if (counter_[0] < count_lo) {
++count_hi;
}
counter_[1] += count_hi;
if (counter_[1] < count_hi) {
if (++counter_[2] == 0) {
++counter_[3];
}
counter_[1] += count_hi;
if (counter_[1] < count_hi) {
if (++counter_[2] == 0) {
++counter_[3];
}
} else {
temp_counter_ = position;
}
force_computation_ = true;
}
}
// Returns an random number using the philox engine
// Returns a random number using the philox engine
uint32 operator()() {
operatorPPcounter += 1;
operatorPPcounter++;
skipOnWrap();
if(temp_counter_ == 4) {
temp_counter_ = 0;
SkipOne();
force_computation_ = true;
}
if (force_compute_) {
force_compute_ = false;
if(force_computation_) {
force_computation_ = false;
temp_results_ = counter_;
ExecuteRounds();
}
......@@ -107,7 +120,7 @@ class SpPhiloxGenerator {
return value;
}
int getOperatorPPCounter() const{
auto getOperatorPPCounter() const{
return operatorPPcounter;
}
......@@ -134,17 +147,17 @@ class SpPhiloxGenerator {
Key key_;
// To iterate through the temp_results_
uint64 temp_counter_ = 0;
uint64 temp_counter_;
// The number of cycles used to generate randomness
int cycles_;
// To force the engine to compute the rounds to populates temp_results_
bool force_compute_ = true;
bool force_computation_;
// The number of times operator () is called to ensure that the STL
// always call it once
int operatorPPcounter;
uint32 operatorPPcounter;
// Skip one step
void SkipOne() {
......@@ -200,14 +213,6 @@ class SpPhiloxGenerator {
(*key)[0] += kPhiloxW32A;
(*key)[1] += kPhiloxW32B;
}
void skipOnWrap() {
if(temp_counter_ > 3) {
temp_counter_ = 0;
SkipOne();
force_compute_ = true;
}
}
};
philox4x32 phEngine;
......@@ -240,10 +245,10 @@ public:
}
RealType getRand01() {
nbValuesGenerated += 1;
[[maybe_unused]] const int counterOperatorPPBefore = phEngine.getOperatorPPCounter();
nbValuesGenerated++;
[[maybe_unused]] const auto counterOperatorPPBefore = phEngine.getOperatorPPCounter();
const RealType number = dis01(phEngine);
[[maybe_unused]] const int counterOperatorPPAfter = phEngine.getOperatorPPCounter();
[[maybe_unused]] const auto counterOperatorPPAfter = phEngine.getOperatorPPCounter();
assert(counterOperatorPPAfter == counterOperatorPPBefore+1);
return number;
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment