gtsam  4.1.0
gtsam
WeightedSampler.h
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
19 #pragma once
20 
21 #include <cmath>
22 #include <queue>
23 #include <random>
24 #include <stdexcept>
25 #include <utility>
26 #include <vector>
27 
28 namespace gtsam {
29 /*
30  * Fast sampling without replacement.
31  * Example usage:
32  * std::mt19937 rng(42);
33  * WeightedSampler<std::mt19937> sampler(&rng);
34  * auto samples = sampler.sampleWithoutReplacement(5, weights);
35  */
36 template <class Engine = std::mt19937>
38  private:
39  Engine* engine_; // random number generation engine
40 
41  public:
46  explicit WeightedSampler(Engine* engine) : engine_(engine) {}
47 
48  std::vector<size_t> sampleWithoutReplacement(
49  size_t numSamples, const std::vector<double>& weights) {
50  // Implementation adapted from code accompanying paper at
51  // https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf
52  const size_t n = weights.size();
53  if (n < numSamples) {
54  throw std::runtime_error(
55  "numSamples must be smaller than weights.size()");
56  }
57 
58  // Return empty array if numSamples==0
59  std::vector<size_t> result(numSamples);
60  if (numSamples == 0) return result;
61 
62  // Step 1: The first m items of V are inserted into reservoir
63  // Step 2: For each item v_i ∈ reservoir: Calculate a key k_i = u_i^(1/w),
64  // where u_i = random(0, 1)
65  // (Modification: Calculate and store -log k_i = e_i / w where e_i = exp(1),
66  // reservoir is a priority queue that pops the *maximum* elements)
67  std::priority_queue<std::pair<double, size_t> > reservoir;
68 
69  static const double kexp1 = std::exp(1.0);
70  for (auto it = weights.begin(); it != weights.begin() + numSamples; ++it) {
71  const double k_i = kexp1 / *it;
72  reservoir.push(std::make_pair(k_i, it - weights.begin() + 1));
73  }
74 
75  // Step 4: Repeat Steps 5–10 until the population is exhausted
76  {
77  // Step 3: The threshold T_w is the minimum key of reservoir
78  // (Modification: This is now the logarithm)
79  // Step 10: The new threshold T w is the new minimum key of reservoir
80  const std::pair<double, size_t>& T_w = reservoir.top();
81 
82  // Incrementing it is part of Step 7
83  for (auto it = weights.begin() + numSamples; it != weights.end(); ++it) {
84  // Step 5: Let r = random(0, 1) and X_w = log(r) / log(T_w)
85  // (Modification: Use e = -exp(1) instead of log(r))
86  const double X_w = kexp1 / T_w.first;
87 
88  // Step 6: From the current item v_c skip items until item v_i, such
89  // that:
90  double w = 0.0;
91 
92  // Step 7: w_c + w_{c+1} + ··· + w_{i−1} < X_w <= w_c + w_{c+1} + ··· +
93  // w_{i−1} + w_i
94  for (; it != weights.end(); ++it) {
95  w += *it;
96  if (X_w <= w) break;
97  }
98 
99  // Step 7: No such item, terminate
100  if (it == weights.end()) break;
101 
102  // Step 9: Let t_w = T_w^{w_i}, r_2 = random(t_w, 1) and v_i’s key: k_i
103  // = (r_2)^{1/w_i} (Mod: Let t_w = log(T_w) * {w_i}, e_2 =
104  // log(random(e^{t_w}, 1)) and v_i’s key: k_i = -e_2 / w_i)
105  const double t_w = -T_w.first * *it;
106  std::uniform_real_distribution<double> randomAngle(std::exp(t_w), 1.0);
107  const double e_2 = std::log(randomAngle(*engine_));
108  const double k_i = -e_2 / *it;
109 
110  // Step 8: The item in reservoir with the minimum key is replaced by
111  // item v_i
112  reservoir.pop();
113  reservoir.push(std::make_pair(k_i, it - weights.begin() + 1));
114  }
115  }
116 
117  for (auto iret = result.end(); iret != result.begin();) {
118  --iret;
119 
120  if (reservoir.empty()) {
121  throw std::runtime_error(
122  "Reservoir empty before all elements have been filled");
123  }
124 
125  *iret = reservoir.top().second - 1;
126  reservoir.pop();
127  }
128 
129  if (!reservoir.empty()) {
130  throw std::runtime_error(
131  "Reservoir not empty after all elements have been filled");
132  }
133 
134  return result;
135  }
136 }; // namespace gtsam
137 } // namespace gtsam
gtsam::WeightedSampler::WeightedSampler
WeightedSampler(Engine *engine)
Construct from random number generation engine We only store a pointer to it.
Definition: WeightedSampler.h:46
gtsam
Global functions in a separate testing namespace.
Definition: chartTesting.h:28
gtsam::WeightedSampler
Definition: WeightedSampler.h:37