gtsam  4.0.0
gtsam
BayesTree-inst.h
1 /* ----------------------------------------------------------------------------
2 
3  * GTSAM Copyright 2010, Georgia Tech Research Corporation,
4  * Atlanta, Georgia 30332-0415
5  * All Rights Reserved
6  * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
7 
8  * See LICENSE for the license information
9 
10  * -------------------------------------------------------------------------- */
11 
21 #pragma once
22 
26 #include <gtsam/base/timing.h>
27 
28 #include <boost/optional.hpp>
29 #include <boost/assign/list_of.hpp>
30 #include <fstream>
31 
32 using boost::assign::cref_list_of;
33 
34 namespace gtsam {
35 
36  /* ************************************************************************* */
37  template<class CLIQUE>
40  for(const sharedClique& root: roots_)
41  getCliqueData(data, root);
42  return data;
43  }
44 
45  /* ************************************************************************* */
46  template<class CLIQUE>
48  data.conditionalSizes.push_back(clique->conditional()->nrFrontals());
49  data.separatorSizes.push_back(clique->conditional()->nrParents());
50  for(sharedClique c: clique->children) {
51  getCliqueData(data, c);
52  }
53  }
54 
55  /* ************************************************************************* */
56  template<class CLIQUE>
58  size_t count = 0;
59  for(const sharedClique& root: roots_)
60  count += root->numCachedSeparatorMarginals();
61  return count;
62  }
63 
64  /* ************************************************************************* */
65  template<class CLIQUE>
66  void BayesTree<CLIQUE>::saveGraph(const std::string &s, const KeyFormatter& keyFormatter) const {
67  if (roots_.empty()) throw std::invalid_argument("the root of Bayes tree has not been initialized!");
68  std::ofstream of(s.c_str());
69  of<< "digraph G{\n";
70  for(const sharedClique& root: roots_)
71  saveGraph(of, root, keyFormatter);
72  of<<"}";
73  of.close();
74  }
75 
76  /* ************************************************************************* */
77  template<class CLIQUE>
78  void BayesTree<CLIQUE>::saveGraph(std::ostream &s, sharedClique clique, const KeyFormatter& indexFormatter, int parentnum) const {
79  static int num = 0;
80  bool first = true;
81  std::stringstream out;
82  out << num;
83  std::string parent = out.str();
84  parent += "[label=\"";
85 
86  for (Key index : clique->conditional_->frontals()) {
87  if (!first) parent += ",";
88  first = false;
89  parent += indexFormatter(index);
90  }
91 
92  if (clique->parent()) {
93  parent += " : ";
94  s << parentnum << "->" << num << "\n";
95  }
96 
97  first = true;
98  for (Key sep : clique->conditional_->parents()) {
99  if (!first) parent += ",";
100  first = false;
101  parent += indexFormatter(sep);
102  }
103  parent += "\"];\n";
104  s << parent;
105  parentnum = num;
106 
107  for (sharedClique c : clique->children) {
108  num++;
109  saveGraph(s, c, indexFormatter, parentnum);
110  }
111  }
112 
113  /* ************************************************************************* */
114  template<class CLIQUE>
115  size_t BayesTree<CLIQUE>::size() const {
116  size_t size = 0;
117  for(const sharedClique& clique: roots_)
118  size += clique->treeSize();
119  return size;
120  }
121 
122  /* ************************************************************************* */
123  template<class CLIQUE>
124  void BayesTree<CLIQUE>::addClique(const sharedClique& clique, const sharedClique& parent_clique) {
125  for(Key j: clique->conditional()->frontals())
126  nodes_[j] = clique;
127  if (parent_clique != NULL) {
128  clique->parent_ = parent_clique;
129  parent_clique->children.push_back(clique);
130  } else {
131  roots_.push_back(clique);
132  }
133  }
134 
135  /* ************************************************************************* */
136  // TODO: Clean up
137  namespace {
138  template<class FACTOR, class CLIQUE>
139  int _pushClique(FactorGraph<FACTOR>& fg, const boost::shared_ptr<CLIQUE>& clique) {
140  fg.push_back(clique->conditional_);
141  return 0;
142  }
143 
144  template<class FACTOR, class CLIQUE>
145  struct _pushCliqueFunctor {
146  _pushCliqueFunctor(FactorGraph<FACTOR>& graph_) : graph(graph_) {}
147  FactorGraph<FACTOR>& graph;
148  int operator()(const boost::shared_ptr<CLIQUE>& clique, int dummy) {
149  graph.push_back(clique->conditional_);
150  return 0;
151  }
152  };
153  }
154 
155  /* ************************************************************************* */
156  template<class CLIQUE>
158  {
159  // Traverse the BayesTree and add all conditionals to this graph
160  int data = 0; // Unused
161  _pushCliqueFunctor<FactorType,CLIQUE> functor(graph);
162  treeTraversal::DepthFirstForest(*this, data, functor); // FIXME: sort of works?
163 // treeTraversal::DepthFirstForest(*this, data, boost::bind(&_pushClique<FactorType,CLIQUE>, boost::ref(graph), _1));
164  }
165 
166  /* ************************************************************************* */
167  template<class CLIQUE>
169  *this = other;
170  }
171 
172  /* ************************************************************************* */
173  namespace {
174  template<typename NODE>
175  boost::shared_ptr<NODE>
176  BayesTreeCloneForestVisitorPre(const boost::shared_ptr<NODE>& node, const boost::shared_ptr<NODE>& parentPointer)
177  {
178  // Clone the current node and add it to its cloned parent
179  boost::shared_ptr<NODE> clone = boost::make_shared<NODE>(*node);
180  clone->children.clear();
181  clone->parent_ = parentPointer;
182  parentPointer->children.push_back(clone);
183  return clone;
184  }
185  }
186 
187  /* ************************************************************************* */
188  template<class CLIQUE>
190  this->clear();
191  boost::shared_ptr<Clique> rootContainer = boost::make_shared<Clique>();
192  treeTraversal::DepthFirstForest(other, rootContainer, BayesTreeCloneForestVisitorPre<Clique>);
193  for(const sharedClique& root: rootContainer->children) {
194  root->parent_ = typename Clique::weak_ptr(); // Reset the parent since it's set to the dummy clique
195  insertRoot(root);
196  }
197  return *this;
198  }
199 
200  /* ************************************************************************* */
201  template<class CLIQUE>
202  void BayesTree<CLIQUE>::print(const std::string& s, const KeyFormatter& keyFormatter) const {
203  std::cout << s << ": cliques: " << size() << ", variables: " << nodes_.size() << std::endl;
204  treeTraversal::PrintForest(*this, s, keyFormatter);
205  }
206 
207  /* ************************************************************************* */
208  // binary predicate to test equality of a pair for use in equals
209  template<class CLIQUE>
210  bool check_sharedCliques(
211  const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v1,
212  const std::pair<Key, typename BayesTree<CLIQUE>::sharedClique>& v2
213  ) {
214  return v1.first == v2.first &&
215  ((!v1.second && !v2.second) || (v1.second && v2.second && v1.second->equals(*v2.second)));
216  }
217 
218  /* ************************************************************************* */
219  template<class CLIQUE>
220  bool BayesTree<CLIQUE>::equals(const BayesTree<CLIQUE>& other, double tol) const {
221  return size()==other.size() &&
222  std::equal(nodes_.begin(), nodes_.end(), other.nodes_.begin(), &check_sharedCliques<CLIQUE>);
223  }
224 
225  /* ************************************************************************* */
226  template<class CLIQUE>
227  template<class CONTAINER>
228  Key BayesTree<CLIQUE>::findParentClique(const CONTAINER& parents) const {
229  typename CONTAINER::const_iterator lowestOrderedParent = min_element(parents.begin(), parents.end());
230  assert(lowestOrderedParent != parents.end());
231  return *lowestOrderedParent;
232  }
233 
234  /* ************************************************************************* */
235  template<class CLIQUE>
237  // Add each frontal variable of this root node
238  for(const Key& j: subtree->conditional()->frontals()) {
239  bool inserted = nodes_.insert(std::make_pair(j, subtree)).second;
240  assert(inserted); (void)inserted;
241  }
242  // Fill index for each child
244  for(const sharedClique& child: subtree->children) {
245  fillNodesIndex(child); }
246  }
247 
248  /* ************************************************************************* */
249  template<class CLIQUE>
251  roots_.push_back(subtree); // Add to roots
252  fillNodesIndex(subtree); // Populate nodes index
253  }
254 
255  /* ************************************************************************* */
256  // First finds clique marginal then marginalizes that
257  /* ************************************************************************* */
258  template<class CLIQUE>
259  typename BayesTree<CLIQUE>::sharedConditional
260  BayesTree<CLIQUE>::marginalFactor(Key j, const Eliminate& function) const
261  {
262  gttic(BayesTree_marginalFactor);
263 
264  // get clique containing Key j
265  sharedClique clique = this->clique(j);
266 
267  // calculate or retrieve its marginal P(C) = P(F,S)
268  FactorGraphType cliqueMarginal = clique->marginal2(function);
269 
270  // Now, marginalize out everything that is not variable j
271  BayesNetType marginalBN = *cliqueMarginal.marginalMultifrontalBayesNet(
272  Ordering(cref_list_of<1,Key>(j)), boost::none, function);
273 
274  // The Bayes net should contain only one conditional for variable j, so return it
275  return marginalBN.front();
276  }
277 
278  /* ************************************************************************* */
279  // Find two cliques, their joint, then marginalizes
280  /* ************************************************************************* */
281  template<class CLIQUE>
282  typename BayesTree<CLIQUE>::sharedFactorGraph
283  BayesTree<CLIQUE>::joint(Key j1, Key j2, const Eliminate& function) const
284  {
285  gttic(BayesTree_joint);
286  return boost::make_shared<FactorGraphType>(*jointBayesNet(j1, j2, function));
287  }
288 
289  /* ************************************************************************* */
290  template<class CLIQUE>
291  typename BayesTree<CLIQUE>::sharedBayesNet
292  BayesTree<CLIQUE>::jointBayesNet(Key j1, Key j2, const Eliminate& function) const
293  {
294  gttic(BayesTree_jointBayesNet);
295  // get clique C1 and C2
296  sharedClique C1 = (*this)[j1], C2 = (*this)[j2];
297 
298  gttic(Lowest_common_ancestor);
299  // Find lowest common ancestor clique
300  sharedClique B; {
301  // Build two paths to the root
302  FastList<sharedClique> path1, path2; {
303  sharedClique p = C1;
304  while(p) {
305  path1.push_front(p);
306  p = p->parent();
307  }
308  } {
309  sharedClique p = C2;
310  while(p) {
311  path2.push_front(p);
312  p = p->parent();
313  }
314  }
315  // Find the path intersection
316  typename FastList<sharedClique>::const_iterator p1 = path1.begin(), p2 = path2.begin();
317  if(*p1 == *p2)
318  B = *p1;
319  while(p1 != path1.end() && p2 != path2.end() && *p1 == *p2) {
320  B = *p1;
321  ++p1;
322  ++p2;
323  }
324  }
325  gttoc(Lowest_common_ancestor);
326 
327  // Build joint on all involved variables
328  FactorGraphType p_BC1C2;
329 
330  if(B)
331  {
332  // Compute marginal on lowest common ancestor clique
333  gttic(LCA_marginal);
334  FactorGraphType p_B = B->marginal2(function);
335  gttoc(LCA_marginal);
336 
337  // Compute shortcuts of the requested cliques given the lowest common ancestor
338  gttic(Clique_shortcuts);
339  BayesNetType p_C1_Bred = C1->shortcut(B, function);
340  BayesNetType p_C2_Bred = C2->shortcut(B, function);
341  gttoc(Clique_shortcuts);
342 
343  // Factor the shortcuts to be conditioned on the full root
344  // Get the set of variables to eliminate, which is C1\B.
345  gttic(Full_root_factoring);
346  boost::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C1_B; {
347  KeyVector C1_minus_B; {
348  KeySet C1_minus_B_set(C1->conditional()->beginParents(), C1->conditional()->endParents());
349  for(const Key j: *B->conditional()) {
350  C1_minus_B_set.erase(j); }
351  C1_minus_B.assign(C1_minus_B_set.begin(), C1_minus_B_set.end());
352  }
353  // Factor into C1\B | B.
354  sharedFactorGraph temp_remaining;
355  boost::tie(p_C1_B, temp_remaining) =
356  FactorGraphType(p_C1_Bred).eliminatePartialMultifrontal(Ordering(C1_minus_B), function);
357  }
358  boost::shared_ptr<typename EliminationTraitsType::BayesTreeType> p_C2_B; {
359  KeyVector C2_minus_B; {
360  KeySet C2_minus_B_set(C2->conditional()->beginParents(), C2->conditional()->endParents());
361  for(const Key j: *B->conditional()) {
362  C2_minus_B_set.erase(j); }
363  C2_minus_B.assign(C2_minus_B_set.begin(), C2_minus_B_set.end());
364  }
365  // Factor into C2\B | B.
366  sharedFactorGraph temp_remaining;
367  boost::tie(p_C2_B, temp_remaining) =
368  FactorGraphType(p_C2_Bred).eliminatePartialMultifrontal(Ordering(C2_minus_B), function);
369  }
370  gttoc(Full_root_factoring);
371 
372  gttic(Variable_joint);
373  p_BC1C2 += p_B;
374  p_BC1C2 += *p_C1_B;
375  p_BC1C2 += *p_C2_B;
376  if(C1 != B)
377  p_BC1C2 += C1->conditional();
378  if(C2 != B)
379  p_BC1C2 += C2->conditional();
380  gttoc(Variable_joint);
381  }
382  else
383  {
384  // The nodes have no common ancestor, they're in different trees, so they're joint is just the
385  // product of their marginals.
386  gttic(Disjoint_marginals);
387  p_BC1C2 += C1->marginal2(function);
388  p_BC1C2 += C2->marginal2(function);
389  gttoc(Disjoint_marginals);
390  }
391 
392  // now, marginalize out everything that is not variable j1 or j2
393  return p_BC1C2.marginalMultifrontalBayesNet(Ordering(cref_list_of<2,Key>(j1)(j2)), boost::none, function);
394  }
395 
396  /* ************************************************************************* */
397  template<class CLIQUE>
399  // Remove all nodes and clear the root pointer
400  nodes_.clear();
401  roots_.clear();
402  }
403 
404  /* ************************************************************************* */
405  template<class CLIQUE>
407  for(const sharedClique& root: roots_) {
408  root->deleteCachedShortcuts();
409  }
410  }
411 
412  /* ************************************************************************* */
413  template<class CLIQUE>
415  {
416  if (clique->isRoot()) {
417  typename Roots::iterator root = std::find(roots_.begin(), roots_.end(), clique);
418  if(root != roots_.end())
419  roots_.erase(root);
420  } else { // detach clique from parent
421  sharedClique parent = clique->parent_.lock();
422  typename Roots::iterator child = std::find(parent->children.begin(), parent->children.end(), clique);
423  assert(child != parent->children.end());
424  parent->children.erase(child);
425  }
426 
427  // orphan my children
428  for(sharedClique child: clique->children)
429  child->parent_ = typename Clique::weak_ptr();
430 
431  for(Key j: clique->conditional()->frontals()) {
432  nodes_.unsafe_erase(j);
433  }
434  }
435 
436  /* ************************************************************************* */
437  template<class CLIQUE>
438  void BayesTree<CLIQUE>::removePath(sharedClique clique, BayesNetType& bn, Cliques& orphans)
439  {
440  // base case is NULL, if so we do nothing and return empties above
441  if (clique) {
442 
443  // remove the clique from orphans in case it has been added earlier
444  orphans.remove(clique);
445 
446  // remove me
447  this->removeClique(clique);
448 
449  // remove path above me
450  this->removePath(typename Clique::shared_ptr(clique->parent_.lock()), bn, orphans);
451 
452  // add children to list of orphans (splice also removed them from clique->children_)
453  orphans.insert(orphans.begin(), clique->children.begin(), clique->children.end());
454  clique->children.clear();
455 
456  bn.push_back(clique->conditional_);
457 
458  }
459  }
460 
461  /* ************************************************************************* */
462  template<class CLIQUE>
463  void BayesTree<CLIQUE>::removeTop(const KeyVector& keys, BayesNetType& bn, Cliques& orphans)
464  {
465  // process each key of the new factor
466  for(const Key& j: keys)
467  {
468  // get the clique
469  // TODO: Nodes will be searched again in removeClique
470  typename Nodes::const_iterator node = nodes_.find(j);
471  if(node != nodes_.end()) {
472  // remove path from clique to root
473  this->removePath(node->second, bn, orphans);
474  }
475  }
476 
477  // Delete cachedShortcuts for each orphan subtree
478  //TODO: Consider Improving
479  for(sharedClique& orphan: orphans)
480  orphan->deleteCachedShortcuts();
481  }
482 
483  /* ************************************************************************* */
484  template<class CLIQUE>
486  const sharedClique& subtree)
487  {
488  // Result clique list
489  Cliques cliques;
490  cliques.push_back(subtree);
491 
492  // Remove the first clique from its parents
493  if(!subtree->isRoot())
494  subtree->parent()->children.erase(std::find(
495  subtree->parent()->children.begin(), subtree->parent()->children.end(), subtree));
496  else
497  roots_.erase(std::find(roots_.begin(), roots_.end(), subtree));
498 
499  // Add all subtree cliques and erase the children and parent of each
500  for(typename Cliques::iterator clique = cliques.begin(); clique != cliques.end(); ++clique)
501  {
502  // Add children
503  for(const sharedClique& child: (*clique)->children) {
504  cliques.push_back(child); }
505 
506  // Delete cached shortcuts
507  (*clique)->deleteCachedShortcutsNonRecursive();
508 
509  // Remove this node from the nodes index
510  for(Key j: (*clique)->conditional()->frontals()) {
511  nodes_.unsafe_erase(j); }
512 
513  // Erase the parent and children pointers
514  (*clique)->parent_.reset();
515  (*clique)->children.clear();
516  }
517 
518  return cliques;
519  }
520 
521 }
size_t numCachedSeparatorMarginals() const
Collect number of cliques with cached separator marginals.
Definition: BayesTree-inst.h:57
void DepthFirstForest(FOREST &forest, DATA &rootData, VISITOR_PRE &visitorPre, VISITOR_POST &visitorPost)
Traverse a forest depth-first with pre-order and post-order visits.
Definition: treeTraversal-inst.h:77
sharedConditional marginalFactor(Key j, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
Return marginal on any variable.
Definition: BayesTree-inst.h:260
void clear()
Remove all nodes.
Definition: BayesTree-inst.h:398
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:57
boost::shared_ptr< Clique > sharedClique
Shared pointer to a clique.
Definition: BayesTree.h:72
void saveGraph(const std::string &s, const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
Read only with side effects.
Definition: BayesTree-inst.h:66
Variable ordering for the elimination algorithm.
void addFactorsToGraph(FactorGraph< FactorType > &graph) const
Add all cliques in this BayesTree to the specified factor graph.
Definition: BayesTree-inst.h:157
bool equals(const This &other, double tol=1e-9) const
check equality
Definition: BayesTree-inst.h:220
void removeClique(sharedClique clique)
remove a clique: warning, can result in a forest
Definition: BayesTree-inst.h:414
void addClique(const sharedClique &clique, const sharedClique &parent_clique=sharedClique())
add a clique (top down)
Definition: BayesTree-inst.h:124
void removePath(sharedClique clique, BayesNetType &bn, Cliques &orphans)
Remove path from clique to root and return that path as factors plus a list of orphaned subtree roots...
Definition: BayesTree-inst.h:438
boost::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:33
Definition: FastList.h:38
Key findParentClique(const CONTAINER &parents) const
Find parent clique of a conditional.
Definition: BayesTree-inst.h:228
sharedFactorGraph joint(Key j1, Key j2, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
return joint on two variables Limitation: can only calculate joint if cliques are disjoint or one of ...
Definition: BayesTree-inst.h:283
FastVector< Key > KeyVector
Define collection type once and for all - also used in wrappers.
Definition: Key.h:56
store all the sizes
Definition: BayesTree.h:46
A factor graph is a bipartite graph with factor nodes connected to variable nodes.
Definition: BayesTree.h:32
void PrintForest(const FOREST &forest, std::string str, const KeyFormatter &keyFormatter)
Print a tree, prefixing each line with str, and formatting keys using keyFormatter.
Definition: treeTraversal-inst.h:220
sharedBayesNet jointBayesNet(Key j1, Key j2, const Eliminate &function=EliminationTraitsType::DefaultEliminate) const
return joint on two variables as a BayesNet Limitation: can only calculate joint if cliques are disjo...
Definition: BayesTree-inst.h:292
This & operator=(const This &other)
Assignment operator.
Definition: BayesTree-inst.h:189
void removeTop(const KeyVector &keys, BayesNetType &bn, Cliques &orphans)
Given a list of indices, turn "contaminated" part of the tree back into a factor graph.
Definition: BayesTree-inst.h:463
Timing utilities.
Global functions in a separate testing namespace.
Definition: chartTesting.h:28
bool equal(const T &obj1, const T &obj2, double tol)
Call equal on the object.
Definition: Testable.h:83
size_t size() const
number of cliques
Definition: BayesTree-inst.h:115
Nodes nodes_
Map from indices to Clique.
Definition: BayesTree.h:95
void deleteCachedShortcuts()
Clear all shortcut caches - use before timing on marginal calculation to avoid residual cache data.
Definition: BayesTree-inst.h:406
BayesTreeCliqueData getCliqueData() const
Gather data on all cliques.
Definition: BayesTree-inst.h:38
Cliques removeSubtree(const sharedClique &subtree)
Remove the requested subtree.
Definition: BayesTree-inst.h:485
void fillNodesIndex(const sharedClique &subtree)
Fill the nodes index for a subtree.
Definition: BayesTree-inst.h:236
void insertRoot(const sharedClique &subtree)
Insert a new subtree with known parent clique.
Definition: BayesTree-inst.h:250
Bayes Tree is a tree of cliques of a Bayes Chain.
Definition: Ordering.h:34
BayesTree()
Create an empty Bayes Tree.
Definition: BayesTree.h:107
void print(const std::string &s="", const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
print
Definition: BayesTree-inst.h:202
std::enable_if< std::is_base_of< FactorType, DERIVEDFACTOR >::value >::type push_back(boost::shared_ptr< DERIVEDFACTOR > factor)
Add a factor directly using a shared_ptr.
Definition: FactorGraph.h:159