gtsam 4.1.1
gtsam
ClusterTree-inst.h
1
10#pragma once
11
12#include <gtsam/inference/ClusterTree.h>
15#include <gtsam/base/timing.h>
17
18namespace gtsam {
19
20/* ************************************************************************* */
21template<class GRAPH>
22void ClusterTree<GRAPH>::Cluster::print(const std::string& s,
23 const KeyFormatter& keyFormatter) const {
24 std::cout << s << " (" << problemSize_ << ")";
26}
27
28/* ************************************************************************* */
29template <class GRAPH>
31 std::vector<size_t> nrFrontals;
32 nrFrontals.reserve(nrChildren());
33 for (const sharedNode& child : children)
34 nrFrontals.push_back(child->nrFrontals());
35 return nrFrontals;
36}
37
38/* ************************************************************************* */
39template <class GRAPH>
40void ClusterTree<GRAPH>::Cluster::merge(const boost::shared_ptr<Cluster>& cluster) {
41 // Merge keys. For efficiency, we add keys in reverse order at end, calling reverse after..
42 orderedFrontalKeys.insert(orderedFrontalKeys.end(), cluster->orderedFrontalKeys.rbegin(),
43 cluster->orderedFrontalKeys.rend());
44 factors.push_back(cluster->factors);
45 children.insert(children.end(), cluster->children.begin(), cluster->children.end());
46 // Increment problem size
47 problemSize_ = std::max(problemSize_, cluster->problemSize_);
48}
49
50/* ************************************************************************* */
51template<class GRAPH>
53 const std::vector<bool>& merge) {
54 gttic(Cluster_mergeChildren);
55 assert(merge.size() == this->children.size());
56
57 // Count how many keys, factors and children we'll end up with
58 size_t nrKeys = orderedFrontalKeys.size();
59 size_t nrFactors = factors.size();
60 size_t nrNewChildren = 0;
61 // Loop over children
62 size_t i = 0;
63 for(const sharedNode& child: this->children) {
64 if (merge[i]) {
65 nrKeys += child->orderedFrontalKeys.size();
66 nrFactors += child->factors.size();
67 nrNewChildren += child->nrChildren();
68 } else {
69 nrNewChildren += 1; // we keep the child
70 }
71 ++i;
72 }
73
74 // now reserve space, and really merge
75 auto oldChildren = this->children;
76 this->children.clear();
77 this->children.reserve(nrNewChildren);
78 orderedFrontalKeys.reserve(nrKeys);
79 factors.reserve(nrFactors);
80 i = 0;
81 for (const sharedNode& child : oldChildren) {
82 if (merge[i]) {
83 this->merge(child);
84 } else {
85 this->addChild(child); // we keep the child
86 }
87 ++i;
88 }
89 std::reverse(orderedFrontalKeys.begin(), orderedFrontalKeys.end());
90}
91
92/* ************************************************************************* */
93template <class GRAPH>
94void ClusterTree<GRAPH>::print(const std::string& s, const KeyFormatter& keyFormatter) const {
95 treeTraversal::PrintForest(*this, s, keyFormatter);
96}
98/* ************************************************************************* */
99template <class GRAPH>
101 // Start by duplicating the tree.
103 return *this;
104}
105
106/* ************************************************************************* */
107// Elimination traversal data - stores a pointer to the parent data and collects
108// the factors resulting from elimination of the children. Also sets up BayesTree
109// cliques with parent and child pointers.
110template<class CLUSTERTREE>
112 // Typedefs
113 typedef typename CLUSTERTREE::sharedFactor sharedFactor;
114 typedef typename CLUSTERTREE::FactorType FactorType;
115 typedef typename CLUSTERTREE::FactorGraphType FactorGraphType;
116 typedef typename CLUSTERTREE::ConditionalType ConditionalType;
117 typedef typename CLUSTERTREE::BayesTreeType::Node BTNode;
118
119 EliminationData* const parentData;
120 size_t myIndexInParent;
121 FastVector<sharedFactor> childFactors;
122 boost::shared_ptr<BTNode> bayesTreeNode;
123
124 EliminationData(EliminationData* _parentData, size_t nChildren) :
125 parentData(_parentData), bayesTreeNode(boost::make_shared<BTNode>()) {
126 if (parentData) {
127 myIndexInParent = parentData->childFactors.size();
128 parentData->childFactors.push_back(sharedFactor());
129 } else {
130 myIndexInParent = 0;
131 }
132 // Set up BayesTree parent and child pointers
133 if (parentData) {
134 if (parentData->parentData) // If our parent is not the dummy node
135 bayesTreeNode->parent_ = parentData->bayesTreeNode;
136 parentData->bayesTreeNode->children.push_back(bayesTreeNode);
137 }
139
140 // Elimination pre-order visitor - creates the EliminationData structure for the visited node.
141 static EliminationData EliminationPreOrderVisitor(
142 const typename CLUSTERTREE::sharedNode& node,
143 EliminationData& parentData) {
144 assert(node);
145 EliminationData myData(&parentData, node->nrChildren());
146 myData.bayesTreeNode->problemSize_ = node->problemSize();
147 return myData;
148 }
149
150 // Elimination post-order visitor - combine the child factors with our own factors, add the
151 // resulting conditional to the BayesTree, and add the remaining factor to the parent.
153 const typename CLUSTERTREE::Eliminate& eliminationFunction_;
154 typename CLUSTERTREE::BayesTreeType::Nodes& nodesIndex_;
155
156 public:
157 // Construct functor
159 const typename CLUSTERTREE::Eliminate& eliminationFunction,
160 typename CLUSTERTREE::BayesTreeType::Nodes& nodesIndex) :
161 eliminationFunction_(eliminationFunction), nodesIndex_(nodesIndex) {
162 }
163
164 // Function that does the HEAVY lifting
165 void operator()(const typename CLUSTERTREE::sharedNode& node, EliminationData& myData) {
166 assert(node);
167
168 // Gather factors
169 FactorGraphType gatheredFactors;
170 gatheredFactors.reserve(node->factors.size() + node->nrChildren());
171 gatheredFactors += node->factors;
172 gatheredFactors += myData.childFactors;
173
174 // Check for Bayes tree orphan subtrees, and add them to our children
175 // TODO(frank): should this really happen here?
176 for (const sharedFactor& factor: node->factors) {
177 auto asSubtree = dynamic_cast<const BayesTreeOrphanWrapper<BTNode>*>(factor.get());
178 if (asSubtree) {
179 myData.bayesTreeNode->children.push_back(asSubtree->clique);
180 asSubtree->clique->parent_ = myData.bayesTreeNode;
181 }
182 }
183
184 // >>>>>>>>>>>>>> Do dense elimination step >>>>>>>>>>>>>>>>>>>>>>>>>>>>>
185 auto eliminationResult = eliminationFunction_(gatheredFactors, node->orderedFrontalKeys);
186 // <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
187
188 // Store conditional in BayesTree clique, and in the case of ISAM2Clique also store the
189 // remaining factor
190 myData.bayesTreeNode->setEliminationResult(eliminationResult);
191
192 // Fill nodes index - we do this here instead of calling insertRoot at the end to avoid
193 // putting orphan subtrees in the index - they'll already be in the index of the ISAM2
194 // object they're added to.
195 for (const Key& j: myData.bayesTreeNode->conditional()->frontals())
196 nodesIndex_.insert(std::make_pair(j, myData.bayesTreeNode));
197
198 // Store remaining factor in parent's gathered factors
199 if (!eliminationResult.second->empty())
200 myData.parentData->childFactors[myData.myIndexInParent] = eliminationResult.second;
201 }
202 };
203};
204
205/* ************************************************************************* */
206template<class BAYESTREE, class GRAPH>
208 const This& other) {
210
211 // Assign the remaining factors - these are pointers to factors in the original factor graph and
212 // we do not clone them.
213 remainingFactors_ = other.remainingFactors_;
214
215 return *this;
216}
217
218/* ************************************************************************* */
219template <class BAYESTREE, class GRAPH>
220std::pair<boost::shared_ptr<BAYESTREE>, boost::shared_ptr<GRAPH> >
222 gttic(ClusterTree_eliminate);
223 // Do elimination (depth-first traversal). The rootsContainer stores a 'dummy' BayesTree node
224 // that contains all of the roots as its children. rootsContainer also stores the remaining
225 // un-eliminated factors passed up from the roots.
226 boost::shared_ptr<BayesTreeType> result = boost::make_shared<BayesTreeType>();
227
228 typedef EliminationData<This> Data;
229 Data rootsContainer(0, this->nrRoots());
230
231 typename Data::EliminationPostOrderVisitor visitorPost(function, result->nodes_);
232 {
233 TbbOpenMPMixedScope threadLimiter; // Limits OpenMP threads since we're mixing TBB and OpenMP
234 treeTraversal::DepthFirstForestParallel(*this, rootsContainer, Data::EliminationPreOrderVisitor,
235 visitorPost, 10);
236 }
237
238 // Create BayesTree from roots stored in the dummy BayesTree node.
239 result->roots_.insert(result->roots_.end(), rootsContainer.bayesTreeNode->children.begin(),
240 rootsContainer.bayesTreeNode->children.end());
241
242 // Add remaining factors that were not involved with eliminated variables
243 boost::shared_ptr<FactorGraphType> remaining = boost::make_shared<FactorGraphType>();
244 remaining->reserve(remainingFactors_.size() + rootsContainer.childFactors.size());
245 remaining->push_back(remainingFactors_.begin(), remainingFactors_.end());
246 for (const sharedFactor& factor : rootsContainer.childFactors) {
247 if (factor)
248 remaining->push_back(factor);
249 }
250
251 // Return result
252 return std::make_pair(result, remaining);
253}
254
255} // namespace gtsam
Timing utilities.
Bayes Tree is a tree of cliques of a Bayes Chain.
Variable ordering for the elimination algorithm.
Global functions in a separate testing namespace.
Definition: chartTesting.h:28
void PrintKeyVector(const KeyVector &keys, const string &s, const KeyFormatter &keyFormatter)
Utility function to print sets of keys with optional prefix.
Definition: Key.cpp:77
std::uint64_t Key
Integer nonlinear key type.
Definition: types.h:69
std::function< std::string(Key)> KeyFormatter
Typedef for a function to format a key, i.e. to convert it to a string.
Definition: Key.h:35
FastVector< boost::shared_ptr< typename FOREST::Node > > CloneForest(const FOREST &forest)
Clone a tree, copy-constructing new nodes (calling boost::make_shared) and setting up child pointers ...
Definition: treeTraversal-inst.h:189
void PrintForest(const FOREST &forest, std::string str, const KeyFormatter &keyFormatter)
Print a tree, prefixing each line with str, and formatting keys using keyFormatter.
Definition: treeTraversal-inst.h:219
void DepthFirstForestParallel(FOREST &forest, DATA &rootData, VISITOR_PRE &visitorPre, VISITOR_POST &visitorPost, int problemSizeThreshold=10)
Traverse a forest depth-first with pre-order and post-order visits.
Definition: treeTraversal-inst.h:154
An object whose scope defines a block where TBB and OpenMP parallelism are mixed.
Definition: types.h:161
A cluster-tree that eliminates to a Bayes tree.
Definition: ClusterTree.h:184
This & operator=(const This &other)
Assignment operator - makes a deep copy of the tree structure, but only pointers to factors are copie...
Definition: ClusterTree-inst.h:207
std::pair< boost::shared_ptr< BayesTreeType >, boost::shared_ptr< FactorGraphType > > eliminate(const Eliminate &function) const
Eliminate the factors to a Bayes tree and remaining factor graph.
Definition: ClusterTree-inst.h:221
boost::shared_ptr< FactorType > sharedFactor
Shared pointer to a factor.
Definition: ClusterTree.h:197
GRAPH::Eliminate Eliminate
Typedef for an eliminate subroutine.
Definition: ClusterTree.h:195
Definition: BayesTree.h:270
Definition: ClusterTree-inst.h:111
Definition: ClusterTree-inst.h:152
A cluster-tree is associated with a factor graph and is defined as in Koller-Friedman: each node k re...
Definition: ClusterTree.h:25
This & operator=(const This &other)
Assignment operator - makes a deep copy of the tree structure, but only pointers to factors are copie...
Definition: ClusterTree-inst.h:100
FastVector< sharedNode > roots_
concept check
Definition: ClusterTree.h:116
GRAPH::FactorType FactorType
The type of factors.
Definition: ClusterTree.h:31
GRAPH FactorGraphType
The factor graph type.
Definition: ClusterTree.h:27
void print(const std::string &s="", const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
Print the cluster tree.
Definition: ClusterTree-inst.h:94
boost::shared_ptr< FactorType > sharedFactor
Shared pointer to a factor.
Definition: ClusterTree.h:32
virtual void print(const std::string &s="", const KeyFormatter &keyFormatter=DefaultKeyFormatter) const
print this node
Definition: ClusterTree-inst.h:22
Keys orderedFrontalKeys
Frontal keys of this node.
Definition: ClusterTree.h:41
void mergeChildren(const std::vector< bool > &merge)
Merge all children for which bit is set into this node.
Definition: ClusterTree-inst.h:52
void merge(const boost::shared_ptr< Cluster > &cluster)
Merge in given cluster.
Definition: ClusterTree-inst.h:40
std::vector< size_t > nrFrontalsOfChildren() const
Return a vector with nrFrontal keys for each child.
Definition: ClusterTree-inst.h:30