gtsam 4.1.1
gtsam
BayesTree-inst.h
1/* ----------------------------------------------------------------------------
2
3 * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4 * Atlanta, Georgia 30332-0415
5 * All Rights Reserved
6 * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7
8 * See LICENSE for the license information
9
10 * -------------------------------------------------------------------------- */
11
21#pragma once
22
26#include <gtsam/base/timing.h>
27
28#include <boost/optional.hpp>
29#include <boost/assign/list_of.hpp>
30#include <fstream>
31
32using boost::assign::cref_list_of;
33
34namespace gtsam {
35
36 /* ************************************************************************* */
37 template<class CLIQUE>
40 for (const sharedClique& root : roots_) getCliqueData(root, &stats);
41 return stats;
42 }
43
44 /* ************************************************************************* */
45 template <class CLIQUE>
47 BayesTreeCliqueData* stats) const {
48 const auto conditional = clique->conditional();
49 stats->conditionalSizes.push_back(conditional->nrFrontals());
50 stats->separatorSizes.push_back(conditional->nrParents());
51 for (sharedClique c : clique->children) {
52 getCliqueData(c, stats);
53 }
54 }
55
56 /* ************************************************************************* */
57 template<class CLIQUE>
59 size_t count = 0;
60 for(const sharedClique& root: roots_)
61 count += root->numCachedSeparatorMarginals();
62 return count;
63 }
64
65 /* ************************************************************************* */
66 template<class CLIQUE>
67 void BayesTree<CLIQUE>::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const {
68 if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!");
69 std::ofstream of(s.c_str());
70 of<< "digraph G{\n";
71 for(const sharedClique& root: roots_)
72 saveGraph(of, root, keyFormatter);
73 of<<"}";
74 of.close();
75 }
76
77 /* ************************************************************************* */
78 template<class CLIQUE>
79 void BayesTree<CLIQUE>::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const {
80 static int num = 0;
81 bool first = true;
82 std::stringstream out;
83 out << num;
84 std::string parent = out.str();
85 parent += "[label=\"";
86
87 for (Key index : clique->conditional_->frontals()) {
88 if (!first) parent += ",";
89 first = false;
90 parent += indexFormatter(index);
91 }
92
93 if (clique->parent()) {
94 parent += " : ";
95 s << parentnum << "->" << num << "\n";
96 }
97
98 first = true;
99 for (Key sep : clique->conditional_->parents()) {
100 if (!first) parent += ",";
101 first = false;
102 parent += indexFormatter(sep);
103 }
104 parent += "\"];\n";
105 s << parent;
106 parentnum = num;
107
108 for (sharedClique c : clique->children) {
109 num++;
110 saveGraph(s, c, indexFormatter, parentnum);
111 }
113
114 /* ************************************************************************* */
115 template<class CLIQUE>
116 size_t BayesTree<CLIQUE>::size() const {
117 size_t size = 0;
118 for(const sharedClique& clique: roots_)
119 size += clique->treeSize();
120 return size;
121 }
122
123 /* ************************************************************************* */
124 template<class CLIQUE>
125 void BayesTree<CLIQUE>::addClique(const sharedClique& clique, const sharedClique& parent_clique) {
126 for(Key j: clique->conditional()->frontals())
127 nodes_[j] = clique;
128 if (parent_clique != nullptr) {
129 clique->parent_ = parent_clique;
130 parent_clique->children.push_back(clique);
131 } else {
132 roots_.push_back(clique);
133 }
134 }
136 /* ************************************************************************* */
137 namespace {
138 template <class FACTOR, class CLIQUE>
139 struct _pushCliqueFunctor {
140 _pushCliqueFunctor(FactorGraph<FACTOR>* graph_) : graph(graph_) {}
141 FactorGraph<FACTOR>* graph;
142 int operator()(const boost::shared_ptr<CLIQUE>& clique, int dummy) {
143 graph->push_back(clique->conditional_);
144 return 0;
145 }
146 };
147 } // namespace
148
149 /* ************************************************************************* */
150 template <class CLIQUE>
152 FactorGraph<FactorType>* graph) const {
153 // Traverse the BayesTree and add all conditionals to this graph
154 int data = 0; // Unused
155 _pushCliqueFunctor<FactorType, CLIQUE> functor(graph);
156 treeTraversal::DepthFirstForest(*this, data, functor);
157 }
158
159 /* ************************************************************************* */
160 template<class CLIQUE>
162 *this = other;
163 }
165 /* ************************************************************************* */
166 namespace {
167 template<typename NODE>
168 boost::shared_ptr<NODE>
169 BayesTreeCloneForestVisitorPre(const boost::shared_ptr<NODE>& node, const boost::shared_ptr<NODE>& parentPointer)
170 {
171 // Clone the current node and add it to its cloned parent
172 boost::shared_ptr<NODE> clone = boost::make_shared<NODE>(*node);
173 clone->children.clear();
174 clone->parent_ = parentPointer;
175 parentPointer->children.push_back(clone);
176 return clone;
178 }
179
180 /* ************************************************************************* */
181 template<class CLIQUE>
183 this->clear();
184 boost::shared_ptr<Clique> rootContainer = boost::make_shared<Clique>();
185 treeTraversal::DepthFirstForest(other, rootContainer, BayesTreeCloneForestVisitorPre<Clique>);
186 for(const sharedClique& root: rootContainer->children) {
187 root->parent_ = typename Clique::weak_ptr(); // Reset the parent since it's set to the dummy clique
188 insertRoot(root);
189 }
190 return *this;
191 }
192
193 /* ************************************************************************* */
194 template<class CLIQUE>
195 void BayesTree<CLIQUE>::print(const std::string& s, const KeyFormatter& keyFormatter) const {
196 std::cout << s << ": cliques: " << size() << ", variables: " << nodes_.size() << std::endl;
197 treeTraversal::PrintForest(*this, s, keyFormatter);
198 }
199
200 /* ************************************************************************* */
201 // binary predicate to test equality of a pair for use in equals
202 template<class CLIQUE>
203 bool check_sharedCliques(
204 const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v1,
205 const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v2
206 ) {
207 return v1.first == v2.first &&
208 ((!v1.second && !v2.second) || (v1.second && v2.second && v1.second->equals(*v2.second)));
209 }
210
211 /* ************************************************************************* */
212 template<class CLIQUE>
213 bool BayesTree<CLIQUE>::equals(const BayesTree<CLIQUE>& other, double tol) const {
214 return size()==other.size() &&
215 std::equal(nodes_.begin(), nodes_.end(), other.nodes_.begin(), &check_sharedCliques<CLIQUE>);
216 }
217
218 /* ************************************************************************* */
219 template<class CLIQUE>
220 template<class CONTAINER>
221 Key BayesTree<CLIQUE>::findParentClique(const CONTAINER& parents) const {
222 typename CONTAINER::const_iterator lowestOrderedParent = min_element(parents.begin(), parents.end());
223 assert(lowestOrderedParent != parents.end());
224 return *lowestOrderedParent;
225 }
226
227 /* ************************************************************************* */
228 template<class CLIQUE>
230 // Add each frontal variable of this root node
231 for(const Key& j: subtree->conditional()->frontals()) {
232 bool inserted = nodes_.insert(std::make_pair(j, subtree)).second;
233 assert(inserted); (void)inserted;
235 // Fill index for each child
236 typedef typename BayesTree<CLIQUE>::sharedClique sharedClique;
237 for(const sharedClique& child: subtree->children) {
238 fillNodesIndex(child); }
240
241 /* ************************************************************************* */
242 template<class CLIQUE>
244 roots_.push_back(subtree); // Add to roots
245 fillNodesIndex(subtree); // Populate nodes index
247
248 /* ************************************************************************* */
249 // First finds clique marginal then marginalizes that
250 /* ************************************************************************* */
251 template<class CLIQUE>
252 typename BayesTree<CLIQUE>::sharedConditional
253 BayesTree<CLIQUE>::marginalFactor(Key j, const Eliminate& function) const
254 {
255 gttic(BayesTree_marginalFactor);
256
257 // get clique containing Key j
258 sharedClique clique = this->clique(j);
259
260 // calculate or retrieve its marginal P(C) = P(F,S)
261 FactorGraphType cliqueMarginal = clique->marginal2(function);
262
263 // Now, marginalize out everything that is not variable j
264 BayesNetType marginalBN = *cliqueMarginal.marginalMultifrontalBayesNet(
265 Ordering(cref_list_of<1,Key>(j)), function);
266
267 // The Bayes net should contain only one conditional for variable j, so return it
268 return marginalBN.front();
269 }
270
271 /* ************************************************************************* */
272 // Find two cliques, their joint, then marginalizes
273 /* ************************************************************************* */
274 template<class CLIQUE>
275 typename BayesTree<CLIQUE>::sharedFactorGraph
276 BayesTree<CLIQUE>::joint(Key j1, Key j2, const Eliminate& function) const
277 {
278 gttic(BayesTree_joint);
279 return boost::make_shared<FactorGraphType>(*jointBayesNet(j1, j2, function));
280 }
281
282 /* ************************************************************************* */
283 template<class CLIQUE>
284 typename BayesTree<CLIQUE>::sharedBayesNet
285 BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
286 {
287 gttic(BayesTree_jointBayesNet);
288 // get clique C1 and C2
289 sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
290
291 gttic(Lowest_common_ancestor);
292 // Find lowest common ancestor clique
293 sharedClique B; {
294 // Build two paths to the root
295 FastList<sharedClique> path1, path2; {
296 sharedClique p = C1;
297 while(p) {
298 path1.push_front(p);
299 p = p->parent();
300 }
301 } {
302 sharedClique p = C2;
303 while(p) {
304 path2.push_front(p);
305 p = p->parent();
306 }
307 }
308 // Find the path intersection
309 typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
310 if(*p1 == *p2)
311 B = *p1;
312 while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
313 B = *p1;
314 ++p1;
315 ++p2;
316 }
317 }
318 gttoc(Lowest_common_ancestor);
319
320 // Build joint on all involved variables
321 FactorGraphType p_BC1C2;
322
323 if(B)
324 {
325 // Compute marginal on lowest common ancestor clique
326 gttic(LCA_marginal);
327 FactorGraphType p_B = B->marginal2(function);
328 gttoc(LCA_marginal);
329
330 // Compute shortcuts of the requested cliques given the lowest common ancestor
331 gttic(Clique_shortcuts);
332 BayesNetType p_C1_Bred = C1->shortcut(B, function);
333 BayesNetType p_C2_Bred = C2->shortcut(B, function);
334 gttoc(Clique_shortcuts);
335
336 // Factor the shortcuts to be conditioned on the full root
337 // Get the set of variables to eliminate, which is C1\B.
338 gttic(Full_root_factoring);
339 boost::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
340 KeyVector C1_minus_B; {
341 KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
342 for(const Key j: *B->conditional()) {
343 C1_minus_B_set.erase(j); }
344 C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
345 }
346 // Factor into C1\B | B.
347 sharedFactorGraph temp_remaining;
348 boost::tie(p_C1_B, temp_remaining) =
349 FactorGraphType(p_C1_Bred).eliminatePartialMultifrontal(Ordering(C1_minus_B), function);
350 }
351 boost::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
352 KeyVector C2_minus_B; {
353 KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
354 for(const Key j: *B->conditional()) {
355 C2_minus_B_set.erase(j); }
356 C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
357 }
358 // Factor into C2\B | B.
359 sharedFactorGraph temp_remaining;
360 boost::tie(p_C2_B, temp_remaining) =
361 FactorGraphType(p_C2_Bred).eliminatePartialMultifrontal(Ordering(C2_minus_B), function);
362 }
363 gttoc(Full_root_factoring);
364
365 gttic(Variable_joint);
366 p_BC1C2 += p_B;
367 p_BC1C2 += *p_C1_B;
368 p_BC1C2 += *p_C2_B;
369 if(C1 != B)
370 p_BC1C2 += C1->conditional();
371 if(C2 != B)
372 p_BC1C2 += C2->conditional();
373 gttoc(Variable_joint);
374 }
375 else
376 {
377 // The nodes have no common ancestor, they're in different trees, so they're joint is just the
378 // product of their marginals.
379 gttic(Disjoint_marginals);
380 p_BC1C2 += C1->marginal2(function);
381 p_BC1C2 += C2->marginal2(function);
382 gttoc(Disjoint_marginals);
383 }
384
385 // now, marginalize out everything that is not variable j1 or j2
386 return p_BC1C2.marginalMultifrontalBayesNet(Ordering(cref_list_of<2,Key>(j1)(j2)), function);
387 }
388
389 /* ************************************************************************* */
390 template<class CLIQUE>
392 // Remove all nodes and clear the root pointer
393 nodes_.clear();
394 roots_.clear();
395 }
396
397 /* ************************************************************************* */
398 template<class CLIQUE>
400 for(const sharedClique& root: roots_) {
401 root->deleteCachedShortcuts();
402 }
403 }
404
405 /* ************************************************************************* */
406 template<class CLIQUE>
408 {
409 if (clique->isRoot()) {
410 typename Roots::iterator root = std::find(roots_.begin(), roots_.end(), clique);
411 if(root != roots_.end())
412 roots_.erase(root);
413 } else { // detach clique from parent
414 sharedClique parent = clique->parent_.lock();
415 typename Roots::iterator child = std::find(parent->children.begin(), parent->children.end(), clique);
416 assert(child != parent->children.end());
417 parent->children.erase(child);
418 }
419
420 // orphan my children
421 for(sharedClique child: clique->children)
422 child->parent_ = typename Clique::weak_ptr();
423
424 for(Key j: clique->conditional()->frontals()) {
425 nodes_.unsafe_erase(j);
426 }
427 }
428
429 /* ************************************************************************* */
430 template <class CLIQUE>
431 void BayesTree<CLIQUE>::removePath(sharedClique clique, BayesNetType* bn,
432 Cliques* orphans) {
433 // base case is nullptr, if so we do nothing and return empties above
434 if (clique) {
435 // remove the clique from orphans in case it has been added earlier
436 orphans->remove(clique);
437
438 // remove me
439 this->removeClique(clique);
440
441 // remove path above me
442 this->removePath(typename Clique::shared_ptr(clique->parent_.lock()), bn,
443 orphans);
444
445 // add children to list of orphans (splice also removed them from
446 // clique->children_)
447 orphans->insert(orphans->begin(), clique->children.begin(),
448 clique->children.end());
449 clique->children.clear();
450
451 bn->push_back(clique->conditional_);
452 }
453 }
454
455 /* *************************************************************************
456 */
457 template <class CLIQUE>
458 void BayesTree<CLIQUE>::removeTop(const KeyVector& keys, BayesNetType* bn,
459 Cliques* orphans) {
460 gttic(removetop);
461 // process each key of the new factor
462 for (const Key& j : keys) {
463 // get the clique
464 // TODO(frank): Nodes will be searched again in removeClique
465 typename Nodes::const_iterator node = nodes_.find(j);
466 if (node != nodes_.end()) {
467 // remove path from clique to root
468 this->removePath(node->second, bn, orphans);
469 }
470 }
471
472 // Delete cachedShortcuts for each orphan subtree
473 // TODO(frank): Consider Improving
474 for (sharedClique& orphan : *orphans) orphan->deleteCachedShortcuts();
475 }
476
477 /* ************************************************************************* */
478 template<class CLIQUE>
480 const sharedClique& subtree)
481 {
482 // Result clique list
483 Cliques cliques;
484 cliques.push_back(subtree);
485
486 // Remove the first clique from its parents
487 if(!subtree->isRoot())
488 subtree->parent()->children.erase(std::find(
489 subtree->parent()->children.begin(), subtree->parent()->children.end(), subtree));
490 else
491 roots_.erase(std::find(roots_.begin(), roots_.end(), subtree));
492
493 // Add all subtree cliques and erase the children and parent of each
494 for(typename Cliques::iterator clique = cliques.begin(); clique != cliques.end(); ++clique)
495 {
496 // Add children
497 for(const sharedClique& child: (*clique)->children) {
498 cliques.push_back(child); }
499
500 // Delete cached shortcuts
501 (*clique)->deleteCachedShortcutsNonRecursive();
502
503 // Remove this node from the nodes index
504 for(Key j: (*clique)->conditional()->frontals()) {
505 nodes_.unsafe_erase(j); }
506
507 // Erase the parent and children pointers
508 (*clique)->parent_.reset();
509 (*clique)->children.clear();
510 }
511
512 return cliques;
513 }
514
515}
Timing utilities.
Bayes Tree is a tree of cliques of a Bayes Chain.
Variable ordering for the elimination algorithm.
Global functions in a separate testing namespace.
Definition: chartTesting.h:28
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:86
bool equal(const T &obj1, const T &obj2, double tol)
Call equal on the object.
Definition: Testable.h:84
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:69
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
void DepthFirstForest(FOREST &forest, DATA &rootData, VISITOR_PRE &visitorPre, VISITOR_POST &visitorPost)
Traverse a forest depth-first with pre-order and post-order visits.
Definition: treeTraversal-inst.h:77
void PrintForest(const FOREST &forest, std::string str, const KeyFormatter &keyFormatter)
Print a tree, prefixing each line with str, and formatting keys using keyFormatter.
Definition: treeTraversal-inst.h:219
Definition: FastList.h:40
A factor graph is a bipartite graph with factor nodes connected to variable nodes.
Definition: FactorGraph.h:93
store all the sizes
Definition: BayesTree.h:48
Definition: BayesTree.h:67
Nodes nodes_
Map from indices to Clique.
Definition: BayesTree.h:100
void removeClique(sharedClique clique)
remove a clique: warning, can result in a forest
Definition: BayesTree-inst.h:407
sharedFactorGraph joint(Key j1, Key j2, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
return joint on two variables Limitation: can only calculate joint if cliques are disjoint or one of ...
Definition: BayesTree-inst.h:276
void fillNodesIndex(const sharedClique &subtree)
Fill the nodes index for a subtree.
Definition: BayesTree-inst.h:229
void addFactorsToGraph(FactorGraph< FactorType > *graph) const
Add all cliques in this BayesTree to the specified factor graph.
Definition: BayesTree-inst.h:151
bool equals(const This &other, double tol=1e-9) const
check equality
Definition: BayesTree-inst.h:213
void saveGraph(const std::string &s, const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
Read only with side effects.
Definition: BayesTree-inst.h:67
This & operator=(const This &other)
Assignment operator.
Definition: BayesTree-inst.h:182
boost::shared_ptr< Clique > sharedClique
Shared pointer to a clique.
Definition: BayesTree.h:74
BayesTree()
Create an empty Bayes Tree.
Definition: BayesTree.h:109
void clear()
Remove all nodes.
Definition: BayesTree-inst.h:391
void addClique(const sharedClique &clique, const sharedClique &parent_clique=sharedClique())
add a clique (top down)
Definition: BayesTree-inst.h:125
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
return joint on two variables as a BayesNet Limitation: can only calculate joint if cliques are disjo...
Definition: BayesTree-inst.h:285
Key findParentClique(const CONTAINER &parents) const
Find parent clique of a conditional.
Definition: BayesTree-inst.h:221
size_t size() const
number of cliques
Definition: BayesTree-inst.h:116
void deleteCachedShortcuts()
Clear all shortcut caches - use before timing on marginal calculation to avoid residual cache data.
Definition: BayesTree-inst.h:399
void removePath(sharedClique clique, BayesNetType *bn, Cliques *orphans)
Remove path from clique to root and return that path as factors plus a list of orphaned subtree roots...
Definition: BayesTree-inst.h:431
sharedConditional marginalFactor(Key j, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
Return marginal on any variable.
Definition: BayesTree-inst.h:253
size_t numCachedSeparatorMarginals() const
Collect number of cliques with cached separator marginals.
Definition: BayesTree-inst.h:58
BayesTreeCliqueData getCliqueData() const
Gather data on all cliques.
Definition: BayesTree-inst.h:38
Cliques removeSubtree(const sharedClique &subtree)
Remove the requested subtree.
Definition: BayesTree-inst.h:479
void print(const std::string &s="", const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
print
Definition: BayesTree-inst.h:195
void insertRoot(const sharedClique &subtree)
Insert a new subtree with known parent clique.
Definition: BayesTree-inst.h:243
void removeTop(const KeyVector &keys, BayesNetType *bn, Cliques *orphans)
Given a list of indices, turn "contaminated" part of the tree back into a factor graph.
Definition: BayesTree-inst.h:458
Definition: Ordering.h:34