#include "fenwicktree.hpp"
#include <cassert>
#include <random>
#include <vector>
enum VState { BAD = 0, GOOD = 1 };
// Base class for resampler
// Contains code for storing n vertices and efficiently picking
// a random BAD using a fenwick tree data structure.
// The graph structure should be specified by the subclass,
// by implementing a function
// ``void resampleNeighbors(int i) {
// resampleVertex(i);
// resampleVertex(j1);
// // ...
// resampleVertex(jk);
// }``
template <typename Derived, typename RNG>
class ResamplerBase {
public:
// p - probability of sampling a BAD
// n - total number of vertices
// rng - random number generator, usually std::mt19937
// Default initial state is all-GOOD
ResamplerBase(float p_, size_t n, RNG& rng_)
: p(p_), state(n, GOOD), fwTree(n), nBads(0), rng(rng_), distr(p) {
assert(p >= 0.0f && p <= 1.0f);
}
// Returns the chosen site or -1 if in all-good state
int doMove() {
if (nBads == 0)
return -1;
assert(fwTree.sum(state.size()) == nBads);
std::uniform_int_distribution<> intDist(1, nBads);
unsigned int a = intDist(rng);
// There are 'nBads' 1's in state
// Select the a'th Bad using a binary search on the fenwicktree
// Where a is 1-based, as well as the fenwick tree
int i = 0;
{
// The a-th Bad is in the range
// [x_lower, ..., x_upper-1]
// where lower and upper are 0-based
int lower = 0;
int upper = state.size();
while (upper - lower > 1) {
int cur = (upper + lower) / 2;
// Check the number of 1's in [x0,...,x_cur-1]
// fwTree is 1-based so no -1 needed
if (fwTree.sum(cur) >= a) {
// The a-th 1 is in [x_lower,...,x_cur-1]
upper = cur;
} else {
// The a-th 1 is in the range [x_cur,...,x_upper-1]
lower = cur;
}
}
i = lower;
}
// state[i] is now the a'th Bad
((Derived*)this)->resampleNeighbors(i);
return i;
}
size_t getN() const { return state.size(); };
size_t numBads() const { return nBads; }
const std::vector<VState>& getState() const { return state; }
protected:
void setVertex(int i, VState value) {
if (state[i] == BAD) {
fwTree.update(i + 1, -1);
nBads--;
}
state[i] = value;
if (state[i] == BAD) {
fwTree.update(i + 1, 1);
nBads++;
}
}
// Used by Derived::resampleNeighbors
void resampleVertex(int i) { setVertex(i, (distr(rng) ? BAD : GOOD)); }
private:
float p;
std::vector<VState> state;
BIT<unsigned int> fwTree; // fenwicktree
size_t nBads;
RNG& rng;
std::bernoulli_distribution distr;
};