diff --git a/montecarlo/resampler.hpp b/montecarlo/resampler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8497dbbce98e1f4c275792248d651601b2e78c99 --- /dev/null +++ b/montecarlo/resampler.hpp @@ -0,0 +1,100 @@ +#include "fenwicktree.hpp" +#include +#include +#include + +enum VState { BAD = 0, GOOD = 1 }; + +// Base class for resampler +// Contains code for storing n vertices and efficiently picking +// a random BAD using a fenwick tree data structure. +// The graph structure should be specified by the subclass, +// by implementing a function +// ``void resampleNeighbors(int i) { +// resampleVertex(i); +// resampleVertex(j1); +// // ... +// resampleVertex(jk); +// }`` +template +class ResamplerBase { + public: + // p - probability of sampling a BAD + // n - total number of vertices + // rng - random number generator, usually std::mt19937 + // Default initial state is all-GOOD + ResamplerBase(float p_, size_t n, RNG& rng_) + : p(p_), state(n, GOOD), fwTree(n), nBads(0), rng(rng_), distr(p) { + assert(p >= 0.0f && p <= 1.0f); + } + + // Returns the chosen site or -1 if in all-good state + int doMove() { + if (nBads == 0) + return -1; + assert(fwTree.sum(state.size()) == nBads); + std::uniform_int_distribution<> intDist(1, nBads); + unsigned int a = intDist(rng); + // There are 'nBads' 1's in state + // Select the a'th Bad using a binary search on the fenwicktree + // Where a is 1-based, as well as the fenwick tree + int i = 0; + { + // The a-th Bad is in the range + // [x_lower, ..., x_upper-1] + // where lower and upper are 0-based + int lower = 0; + int upper = state.size(); + + while (upper - lower > 1) { + int cur = (upper + lower) / 2; + + // Check the number of 1's in [x0,...,x_cur-1] + // fwTree is 1-based so no -1 needed + if (fwTree.sum(cur) >= a) { + // The a-th 1 is in [x_lower,...,x_cur-1] + upper = cur; + } else { + // The a-th 1 is in the range [x_cur,...,x_upper-1] + lower = cur; + } + } + i = lower; + } + + // state[i] is now the a'th Bad + ((Derived*)this)->resampleNeighbors(i); + return i; + } + + size_t getN() const { return state.size(); }; + size_t numBads() const { return nBads; } + + const std::vector& getState() const { return state; } + + protected: + void setVertex(int i, VState value) { + if (state[i] == BAD) { + fwTree.update(i + 1, -1); + nBads--; + } + state[i] = value; + if (state[i] == BAD) { + fwTree.update(i + 1, 1); + nBads++; + } + } + + // Used by Derived::resampleNeighbors + void resampleVertex(int i) { setVertex(i, (distr(rng) ? BAD : GOOD)); } + + private: + float p; + std::vector state; + BIT fwTree; // fenwicktree + size_t nBads; + + RNG& rng; + std::bernoulli_distribution distr; +}; +