Files @ a31c9fa6dae1
Branch filter:

Location: AENC/resampling_chain/montecarlo/resampler.hpp

Tom Bannink
Merge remote-tracking branch 'cwigit/master'
#include "fenwicktree.hpp"
#include <cassert>
#include <random>
#include <vector>

enum VState { BAD = 0, GOOD = 1 };

// Base class for resampler
// Contains code for storing n vertices and efficiently picking
// a random BAD using a fenwick tree data structure.
// The graph structure should be specified by the subclass,
// by implementing a function
//     ``void resampleNeighbors(int i) {
//         resampleVertex(i);
//         resampleVertex(j1);
//         // ...
//         resampleVertex(jk);
//     }``
template <typename Derived, typename RNG>
class ResamplerBase {
  public:
      // p - probability of sampling a BAD
      // n - total number of vertices
      // rng - random number generator, usually std::mt19937
      // Default initial state is all-GOOD
    ResamplerBase(float p_, size_t n, RNG& rng_)
        : p(p_), state(n, GOOD), fwTree(n), nBads(0), rng(rng_), distr(p) {
        assert(p >= 0.0f && p <= 1.0f);
    }

    // Returns the chosen site or -1 if in all-good state
    int doMove() {
        if (nBads == 0)
            return -1;
        assert(fwTree.sum(state.size()) == nBads);
        std::uniform_int_distribution<> intDist(1, nBads);
        unsigned int a = intDist(rng);
        // There are 'nBads' 1's in state
        // Select the a'th Bad using a binary search on the fenwicktree
        // Where a is 1-based, as well as the fenwick tree
        int i = 0;
        {
            // The a-th Bad is in the range
            // [x_lower, ..., x_upper-1]
            // where lower and upper are 0-based
            int lower = 0;
            int upper = state.size();

            while (upper - lower > 1) {
                int cur = (upper + lower) / 2;

                // Check the number of 1's in [x0,...,x_cur-1]
                // fwTree is 1-based so no -1 needed
                if (fwTree.sum(cur) >= a) {
                    // The a-th 1 is in [x_lower,...,x_cur-1]
                    upper = cur;
                } else {
                    // The a-th 1 is in the range [x_cur,...,x_upper-1]
                    lower = cur;
                }
            }
            i = lower;
        }

        // state[i] is now the a'th Bad
        ((Derived*)this)->resampleNeighbors(i);
        return i;
    }

    size_t getN() const { return state.size(); };
    size_t numBads() const { return nBads; }

    const std::vector<VState>& getState() const { return state; }

  protected:
    void setVertex(int i, VState value) {
        if (state[i] == BAD) {
            fwTree.update(i + 1, -1);
            nBads--;
        }
        state[i] = value;
        if (state[i] == BAD) {
            fwTree.update(i + 1, 1);
            nBads++;
        }
    }

    // Used by Derived::resampleNeighbors
    void resampleVertex(int i) { setVertex(i, (distr(rng) ? BAD : GOOD)); }

  private:
    float p;
    std::vector<VState> state;
    BIT<unsigned int> fwTree; // fenwicktree
    size_t nBads;

    RNG& rng;
    std::bernoulli_distribution distr;
};