Files
@ 82182a2f068a
Branch filter:
Location: AENC/resampling_chain/montecarlo/resampler.hpp - annotation
82182a2f068a
2.9 KiB
text/x-c++hdr
Initial commit with Monte Carlo sampler
82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a 82182a2f068a | #include "fenwicktree.hpp"
#include <cassert>
#include <random>
#include <vector>
enum VState { BAD = 0, GOOD = 1 };
// Base class for resampler
// Contains code for storing n vertices and efficiently picking
// a random BAD using a fenwick tree data structure.
// The graph structure should be specified by the subclass,
// by implementing a function
// ``void resampleNeighbors(int i) {
// resampleVertex(i);
// resampleVertex(j1);
// // ...
// resampleVertex(jk);
// }``
template <typename Derived, typename RNG>
class ResamplerBase {
public:
// p - probability of sampling a BAD
// n - total number of vertices
// rng - random number generator, usually std::mt19937
// Default initial state is all-GOOD
ResamplerBase(float p_, size_t n, RNG& rng_)
: p(p_), state(n, GOOD), fwTree(n), nBads(0), rng(rng_), distr(p) {
assert(p >= 0.0f && p <= 1.0f);
}
// Returns the chosen site or -1 if in all-good state
int doMove() {
if (nBads == 0)
return -1;
assert(fwTree.sum(state.size()) == nBads);
std::uniform_int_distribution<> intDist(1, nBads);
unsigned int a = intDist(rng);
// There are 'nBads' 1's in state
// Select the a'th Bad using a binary search on the fenwicktree
// Where a is 1-based, as well as the fenwick tree
int i = 0;
{
// The a-th Bad is in the range
// [x_lower, ..., x_upper-1]
// where lower and upper are 0-based
int lower = 0;
int upper = state.size();
while (upper - lower > 1) {
int cur = (upper + lower) / 2;
// Check the number of 1's in [x0,...,x_cur-1]
// fwTree is 1-based so no -1 needed
if (fwTree.sum(cur) >= a) {
// The a-th 1 is in [x_lower,...,x_cur-1]
upper = cur;
} else {
// The a-th 1 is in the range [x_cur,...,x_upper-1]
lower = cur;
}
}
i = lower;
}
// state[i] is now the a'th Bad
((Derived*)this)->resampleNeighbors(i);
return i;
}
size_t getN() const { return state.size(); };
size_t numBads() const { return nBads; }
const std::vector<VState>& getState() const { return state; }
protected:
void setVertex(int i, VState value) {
if (state[i] == BAD) {
fwTree.update(i + 1, -1);
nBads--;
}
state[i] = value;
if (state[i] == BAD) {
fwTree.update(i + 1, 1);
nBads++;
}
}
// Used by Derived::resampleNeighbors
void resampleVertex(int i) { setVertex(i, (distr(rng) ? BAD : GOOD)); }
private:
float p;
std::vector<VState> state;
BIT<unsigned int> fwTree; // fenwicktree
size_t nBads;
RNG& rng;
std::bernoulli_distribution distr;
};
|