gtsam 4.1.1
gtsam
WeightedSampler.h
Go to the documentation of this file.
1/* ----------------------------------------------------------------------------
2
3 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7
8 * See LICENSE for the license information
9
10 * -------------------------------------------------------------------------- */
11
19#pragma once
20
21#include <cmath>
22#include <queue>
23#include <random>
24#include <stdexcept>
25#include <utility>
26#include <vector>
27
28namespace gtsam {
29/*
30 * Fast sampling without replacement.
31 * Example usage:
32 * std::mt19937 rng(42);
33 * WeightedSampler<std::mt19937> sampler(&rng);
34 * auto samples = sampler.sampleWithoutReplacement(5, weights);
35 */
36template <class Engine = std::mt19937>
38 private:
39 Engine* engine_; // random number generation engine
40
41 public:
46 explicit WeightedSampler(Engine* engine) : engine_(engine) {}
47
48 std::vector<size_t> sampleWithoutReplacement(
49 size_t numSamples, const std::vector<double>& weights) {
50 // Implementation adapted from code accompanying paper at
51 // https://www.ethz.ch/content/dam/ethz/special-interest/baug/ivt/ivt-dam/vpl/reports/1101-1200/ab1141.pdf
52 const size_t n = weights.size();
53 if (n < numSamples) {
54 throw std::runtime_error(
55 "numSamples must be smaller than weights.size()");
56 }
57
58 // Return empty array if numSamples==0
59 std::vector<size_t> result(numSamples);
60 if (numSamples == 0) return result;
61
62 // Step 1: The first m items of V are inserted into reservoir
63 // Step 2: For each item v_i ∈ reservoir: Calculate a key k_i = u_i^(1/w),
64 // where u_i = random(0, 1)
65 // (Modification: Calculate and store -log k_i = e_i / w where e_i = exp(1),
66 // reservoir is a priority queue that pops the *maximum* elements)
67 std::priority_queue<std::pair<double, size_t> > reservoir;
68
69 static const double kexp1 = std::exp(1.0);
70 for (auto it = weights.begin(); it != weights.begin() + numSamples; ++it) {
71 const double k_i = kexp1 / *it;
72 reservoir.push(std::make_pair(k_i, it - weights.begin() + 1));
73 }
74
75 // Step 4: Repeat Steps 5–10 until the population is exhausted
76 {
77 // Step 3: The threshold T_w is the minimum key of reservoir
78 // (Modification: This is now the logarithm)
79 // Step 10: The new threshold T w is the new minimum key of reservoir
80 const std::pair<double, size_t>& T_w = reservoir.top();
81
82 // Incrementing it is part of Step 7
83 for (auto it = weights.begin() + numSamples; it != weights.end(); ++it) {
84 // Step 5: Let r = random(0, 1) and X_w = log(r) / log(T_w)
85 // (Modification: Use e = -exp(1) instead of log(r))
86 const double X_w = kexp1 / T_w.first;
87
88 // Step 6: From the current item v_c skip items until item v_i, such
89 // that:
90 double w = 0.0;
91
92 // Step 7: w_c + w_{c+1} + ··· + w_{i−1} < X_w <= w_c + w_{c+1} + ··· +
93 // w_{i−1} + w_i
94 for (; it != weights.end(); ++it) {
95 w += *it;
96 if (X_w <= w) break;
97 }
98
99 // Step 7: No such item, terminate
100 if (it == weights.end()) break;
101
102 // Step 9: Let t_w = T_w^{w_i}, r_2 = random(t_w, 1) and v_i’s key: k_i
103 // = (r_2)^{1/w_i} (Mod: Let t_w = log(T_w) * {w_i}, e_2 =
104 // log(random(e^{t_w}, 1)) and v_i’s key: k_i = -e_2 / w_i)
105 const double t_w = -T_w.first * *it;
106 std::uniform_real_distribution<double> randomAngle(std::exp(t_w), 1.0);
107 const double e_2 = std::log(randomAngle(*engine_));
108 const double k_i = -e_2 / *it;
109
110 // Step 8: The item in reservoir with the minimum key is replaced by
111 // item v_i
112 reservoir.pop();
113 reservoir.push(std::make_pair(k_i, it - weights.begin() + 1));
114 }
115 }
116
117 for (auto iret = result.end(); iret != result.begin();) {
118 --iret;
119
120 if (reservoir.empty()) {
121 throw std::runtime_error(
122 "Reservoir empty before all elements have been filled");
123 }
124
125 *iret = reservoir.top().second - 1;
126 reservoir.pop();
127 }
128
129 if (!reservoir.empty()) {
130 throw std::runtime_error(
131 "Reservoir not empty after all elements have been filled");
132 }
133
134 return result;
135 }
136}; // namespace gtsam
137} // namespace gtsam
Global functions in a separate testing namespace.
Definition: chartTesting.h:28
Definition: WeightedSampler.h:37
WeightedSampler(Engine *engine)
Construct from random number generation engine We only store a pointer to it.
Definition: WeightedSampler.h:46