25#include <boost/format.hpp>
26#include <boost/optional.hpp>
27#include <boost/tuple/tuple.hpp>
28#include <boost/assign/std/vector.hpp>
29using boost::assign::operator+=;
30#include <boost/unordered_set.hpp>
31#include <boost/noncopyable.hpp>
44 template<
typename L,
typename Y>
45 int DecisionTree<L, Y>::Node::nrNodes = 0;
51 template<
typename L,
typename Y>
61 constant_(constant) {}
64 const Y& constant()
const {
70 return constant_ == q.constant_;
75 return (q.isLeaf() && q.sameLeaf(*
this));
80 const Leaf* other =
dynamic_cast<const Leaf*
> (&q);
81 if (!other)
return false;
82 return std::abs(
double(this->constant_ - other->constant_)) < tol;
86 void print(
const std::string& s)
const override {
88 if (showZero || constant_) std::cout << s <<
" Leaf " << constant_ << std::endl;
92 void dot(std::ostream& os,
bool showZero)
const override {
93 if (showZero || constant_) os <<
"\"" << this->id() <<
"\" [label=\""
94 << boost::format(
"%4.2g") % constant_
95 <<
"\", shape=box, rank=sink, height=0.35, fixedsize=true]\n";
114 NodePtr apply_f_op_g(
const Node& g,
const Binary& op)
const override {
115 return g.apply_g_op_fL(*
this, op);
119 NodePtr apply_g_op_fL(
const Leaf& fL,
const Binary& op)
const override {
120 NodePtr h(
new Leaf(op(fL.constant_, constant_)));
125 NodePtr apply_g_op_fC(
const Choice& fC,
const Binary& op)
const override {
126 return fC.apply_fC_op_gL(*
this, op);
134 bool isLeaf()
const override {
return true; }
141 template<
typename L,
typename Y>
148 std::vector<NodePtr> branches_;
154 typedef boost::shared_ptr<const Choice> ChoicePtr;
159#ifdef DT_DEBUG_MEMORY
160 std::std::cout << Node::nrNodes <<
" destructing (Choice) " << this->id() << std::std::endl;
168 assert(f->branches().size() > 0);
170 assert(f0->isLeaf());
171 NodePtr newLeaf(
new Leaf(boost::dynamic_pointer_cast<const Leaf>(f0)->constant()));
178 bool isLeaf()
const override {
return false; }
182 label_(label), allSame_(true) {
183 branches_.reserve(count);
193 if (f.label() > g.label()) {
196 size_t count = f.nrChoices();
197 branches_.reserve(count);
198 for (
size_t i = 0; i < count; i++)
199 push_back(f.branches_[i]->apply_f_op_g(g, op));
200 }
else if (g.label() > f.label()) {
203 size_t count = g.nrChoices();
204 branches_.reserve(count);
205 for (
size_t i = 0; i < count; i++)
206 push_back(g.branches_[i]->apply_g_op_fC(f, op));
210 size_t count = f.nrChoices();
211 branches_.reserve(count);
212 for (
size_t i = 0; i < count; i++)
213 push_back(f.branches_[i]->apply_f_op_g(*g.branches_[i], op));
217 const L& label()
const {
221 size_t nrChoices()
const {
222 return branches_.size();
225 const std::vector<NodePtr>& branches()
const {
232 if (allSame_ && !branches_.empty()) {
233 allSame_ = node->sameLeaf(*branches_.back());
235 branches_.push_back(node);
239 void print(
const std::string& s)
const override {
240 std::cout << s <<
" Choice(";
242 std::cout << label_ <<
") " << std::endl;
243 for (
size_t i = 0; i < branches_.size(); i++)
244 branches_[i]->
print((boost::format(
"%s %d") % s % i).str());
248 void dot(std::ostream& os,
bool showZero)
const override {
249 os <<
"\"" << this->id() <<
"\" [shape=circle, label=\"" << label_
251 for (
size_t i = 0; i < branches_.size(); i++) {
256 const Leaf* leaf =
dynamic_cast<const Leaf*
> (branch.get());
257 if (leaf && !leaf->
constant())
continue;
260 os <<
"\"" << this->id() <<
"\" -> \"" << branch->id() <<
"\"";
261 if (i == 0) os <<
" [style=dashed]";
262 if (i > 1) os <<
" [style=bold]";
264 branch->dot(os, showZero);
275 return (q.isLeaf() && q.sameLeaf(*
this));
281 if (!other)
return false;
282 if (this->label_ != other->label_)
return false;
283 if (branches_.size() != other->branches_.size())
return false;
285 for (
size_t i = 0; i < branches_.size(); i++)
286 if (!(branches_[i]->
equals(*(other->branches_[i]), tol)))
return false;
295 std::cout <<
"Trying to find value for " << label_ << std::endl;
296 throw std::invalid_argument(
297 "DecisionTree::operator(): value undefined for a label");
300 size_t index = x.at(label_);
301 NodePtr child = branches_[index];
309 label_(label), allSame_(true) {
311 branches_.reserve(f.branches_.size());
312 for (
const NodePtr& branch: f.branches_)
313 push_back(branch->apply(op));
318 boost::shared_ptr<Choice> r(
new Choice(label_, *
this, op));
327 NodePtr apply_f_op_g(
const Node& g,
const Binary& op)
const override {
328 return g.apply_g_op_fC(*
this, op);
332 NodePtr apply_g_op_fL(
const Leaf& fL,
const Binary& op)
const override {
333 boost::shared_ptr<Choice> h(
new Choice(label(), nrChoices()));
335 h->push_back(fL.apply_f_op_g(*branch, op));
340 NodePtr apply_g_op_fC(
const Choice& fC,
const Binary& op)
const override {
341 boost::shared_ptr<Choice> h(
new Choice(fC, *
this, op));
346 template<
typename OP>
347 NodePtr apply_fC_op_gL(
const Leaf& gL, OP op)
const {
348 boost::shared_ptr<Choice> h(
new Choice(label(), nrChoices()));
349 for(
const NodePtr& branch: branches_)
350 h->push_back(branch->apply_f_op_g(gL, op));
357 return branches_[index];
360 boost::shared_ptr<Choice> r(
new Choice(label_, branches_.size()));
361 for(
const NodePtr& branch: branches_)
362 r->push_back(branch->choose(label, index));
371 template<
typename L,
typename Y>
375 template<
typename L,
typename Y>
381 template<
typename L,
typename Y>
387 template<
typename L,
typename Y>
389 const L& label,
const Y& y1,
const Y& y2) {
390 boost::shared_ptr<Choice> a(
new Choice(label, 2));
394 root_ = Choice::Unique(a);
398 template<
typename L,
typename Y>
400 const LabelC& labelC,
const Y& y1,
const Y& y2) {
401 if (labelC.second != 2)
throw std::invalid_argument(
402 "DecisionTree: binary constructor called with non-binary label");
403 boost::shared_ptr<Choice> a(
new Choice(labelC.first, 2));
407 root_ = Choice::Unique(a);
411 template<
typename L,
typename Y>
413 const std::vector<Y>& ys) {
415 root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
419 template<
typename L,
typename Y>
421 const std::string& table) {
425 std::istringstream iss(table);
426 copy(std::istream_iterator<Y>(iss), std::istream_iterator<Y>(),
430 root_ = create(labelCs.begin(), labelCs.end(), ys.begin(), ys.end());
434 template<
typename L,
typename Y>
436 Iterator begin, Iterator end,
const L& label) {
437 root_ = compose(begin, end, label);
441 template<
typename L,
typename Y>
444 std::vector<DecisionTree> functions;
446 root_ = compose(functions.begin(), functions.end(), label);
450 template<
typename L,
typename Y>
451 template<
typename M,
typename X>
453 const std::map<M, L>& map, std::function<Y(
const X&)> op) {
454 root_ = convert(other.root_, map, op);
463 template<
typename L,
typename Y>
template<
typename Iterator>
465 Iterator end,
const L& label)
const {
468 boost::optional<L> highestLabel;
469 size_t nrChoices = 0;
470 for (Iterator it = begin; it != end; it++) {
471 if (it->root_->isLeaf())
473 boost::shared_ptr<const Choice> c =
474 boost::dynamic_pointer_cast<const Choice>(it->root_);
475 if (!highestLabel || c->label() > *highestLabel) {
476 highestLabel.reset(c->label());
477 nrChoices = c->nrChoices();
482 if (!nrChoices || !highestLabel || label > *highestLabel) {
483 boost::shared_ptr<Choice> choiceOnLabel(
new Choice(label, end - begin));
484 for (Iterator it = begin; it != end; it++)
485 choiceOnLabel->push_back(it->root_);
486 return Choice::Unique(choiceOnLabel);
489 boost::shared_ptr<Choice> choiceOnHighestLabel(
new Choice(*highestLabel, nrChoices));
491 for (
size_t index = 0; index < nrChoices; index++) {
494 std::vector<DecisionTree> functions;
495 for (Iterator it = begin; it != end; it++) {
497 DecisionTree chosen = it->choose(*highestLabel, index);
498 functions.push_back(chosen);
501 NodePtr fi = compose(functions.begin(), functions.end(), label);
502 choiceOnHighestLabel->push_back(fi);
504 return Choice::Unique(choiceOnHighestLabel);
529 template<
typename L,
typename Y>
530 template<
typename It,
typename ValueIt>
532 It begin, It end, ValueIt beginY, ValueIt endY)
const {
535 size_t nrChoices = begin->second;
536 size_t size = endY - beginY;
539 It labelC = begin + 1;
543 if (size != nrChoices) {
544 std::cout <<
"Trying to create DD on " << begin->first << std::endl;
545 std::cout << boost::format(
"DecisionTree::create: expected %d values but got %d instead") % nrChoices % size << std::endl;
546 throw std::invalid_argument(
"DecisionTree::create invalid argument");
548 boost::shared_ptr<Choice> choice(
new Choice(begin->first, endY - beginY));
549 for (ValueIt y = beginY; y != endY; y++)
551 return Choice::Unique(choice);
557 std::vector<DecisionTree> functions;
558 size_t split = size / nrChoices;
559 for (
size_t i = 0; i < nrChoices; i++, beginY +=
split) {
560 NodePtr f = create<It, ValueIt>(labelC, end, beginY, beginY +
split);
563 return compose(functions.begin(), functions.end(), begin->first);
567 template<
typename L,
typename Y>
568 template<
typename M,
typename X>
571 std::function<Y(
const X&)> op) {
574 typedef typename MX::Leaf MXLeaf;
575 typedef typename MX::Choice MXChoice;
576 typedef typename MX::NodePtr MXNodePtr;
581 const MXLeaf* leaf =
dynamic_cast<const MXLeaf*
> (f.get());
582 if (leaf)
return NodePtr(
new Leaf(op(leaf->constant())));
585 boost::shared_ptr<const MXChoice> choice = boost::dynamic_pointer_cast<const MXChoice> (f);
586 if (!choice)
throw std::invalid_argument(
587 "DecisionTree::Convert: Invalid NodePtr");
590 M oldLabel = choice->label();
591 L newLabel = map.at(oldLabel);
594 std::vector<LY> functions;
595 for(
const MXNodePtr& branch: choice->branches()) {
596 LY converted(convert<M, X>(branch, map, op));
597 functions += converted;
599 return LY::compose(functions.begin(), functions.end(), newLabel);
603 template<
typename L,
typename Y>
605 return root_->equals(*other.root_, tol);
608 template<
typename L,
typename Y>
613 template<
typename L,
typename Y>
615 return root_->equals(*other.root_);
618 template<
typename L,
typename Y>
620 return root_->operator ()(x);
623 template<
typename L,
typename Y>
629 template<
typename L,
typename Y>
631 const Binary& op)
const {
633 NodePtr h = root_->apply_f_op_g(*g.root_, op);
648 template<
typename L,
typename Y>
650 size_t cardinality,
const Binary& op)
const {
652 for (
size_t index = 1; index < cardinality; index++) {
654 result = result.apply(chosen, op);
660 template<
typename L,
typename Y>
662 os <<
"digraph G {\n";
663 root_->dot(os, showZero);
664 os <<
" [ordering=out]}" << std::endl;
667 template<
typename L,
typename Y>
669 std::ofstream os((name +
".dot").c_str());
672 (
"dot -Tpdf " + name +
".dot -o " + name +
".pdf >& /dev/null").c_str());
673 if (result==-1)
throw std::runtime_error(
"DecisionTree::dot system call failed");
Concept check for values that can be used in unit tests.
Decision Tree for use in DiscreteFactors.
Global functions in a separate testing namespace.
Definition: chartTesting.h:28
void split(const G &g, const PredecessorMap< KEY > &tree, G &Ab1, G &Ab2)
Split the graph into two parts: one corresponds to the given spanning tree, and the other corresponds...
Definition: graph-inl.h:255
double dot(const V1 &a, const V2 &b)
Dot product.
Definition: Vector.h:194
Template to create a binary predicate.
Definition: Testable.h:111
An assignment from labels to value index (size_t).
Definition: Assignment.h:34
Definition: DecisionTree-inl.h:52
NodePtr choose(const L &label, size_t index) const override
choose a branch, create new memory !
Definition: DecisionTree-inl.h:130
bool equals(const Node &q, double tol) const override
equality up to tolerance
Definition: DecisionTree-inl.h:79
NodePtr apply(const Unary &op) const override
apply unary operator
Definition: DecisionTree-inl.h:104
void print(const std::string &s) const override
print
Definition: DecisionTree-inl.h:86
Leaf(const Y &constant)
Constructor from constant.
Definition: DecisionTree-inl.h:60
bool sameLeaf(const Leaf &q) const override
Leaf-Leaf equality.
Definition: DecisionTree-inl.h:69
void dot(std::ostream &os, bool showZero) const override
to graphviz file
Definition: DecisionTree-inl.h:92
bool sameLeaf(const Node &q) const override
polymorphic equality: is q is a leaf, could be
Definition: DecisionTree-inl.h:74
const Y & constant() const
return the constant
Definition: DecisionTree-inl.h:64
Definition: DecisionTree-inl.h:142
NodePtr apply(const Unary &op) const override
apply unary operator
Definition: DecisionTree-inl.h:317
bool equals(const Node &q, double tol) const override
equality up to tolerance
Definition: DecisionTree-inl.h:279
void print(const std::string &s) const override
print (as a tree)
Definition: DecisionTree-inl.h:239
bool sameLeaf(const Node &q) const override
polymorphic equality: if q is a leaf, could be...
Definition: DecisionTree-inl.h:274
Choice(const Choice &f, const Choice &g, const Binary &op)
Construct from applying binary op to two Choice nodes.
Definition: DecisionTree-inl.h:189
void push_back(const NodePtr &node)
add a branch: TODO merge into constructor
Definition: DecisionTree-inl.h:230
void dot(std::ostream &os, bool showZero) const override
output to graphviz (as a a graph)
Definition: DecisionTree-inl.h:248
Choice(const L &label, size_t count)
Constructor, given choice label and mandatory expected branch count.
Definition: DecisionTree-inl.h:181
NodePtr choose(const L &label, size_t index) const override
choose a branch, recursively
Definition: DecisionTree-inl.h:355
Choice(const L &label, const Choice &f, const Unary &op)
Construct from applying unary op to a Choice node.
Definition: DecisionTree-inl.h:308
bool sameLeaf(const Leaf &q) const override
Choice-Leaf equality: always false.
Definition: DecisionTree-inl.h:269
static NodePtr Unique(const ChoicePtr &f)
If all branches of a choice node f are the same, just return a branch.
Definition: DecisionTree-inl.h:165
Decision Tree L = label for variables Y = function range (any algebra), e.g., bool,...
Definition: DecisionTree.h:38
DecisionTree apply(const Unary &op) const
apply Unary operation "op" to f
Definition: DecisionTree-inl.h:624
NodePtr create(It begin, It end, ValueIt beginY, ValueIt endY) const
Internal recursive function to create from keys, cardinalities, and Y values.
Definition: DecisionTree-inl.h:531
void dot(std::ostream &os, bool showZero=true) const
output to graphviz format, stream version
Definition: DecisionTree-inl.h:661
std::pair< L, size_t > LabelC
A label annotated with cardinality.
Definition: DecisionTree.h:47
NodePtr convert(const typename DecisionTree< M, X >::NodePtr &f, const std::map< M, L > &map, std::function< Y(const X &)> op)
Convert to a different type.
Definition: DecisionTree-inl.h:569
DecisionTree combine(const L &label, size_t cardinality, const Binary &op) const
combine subtrees on key with binary operation "op"
Definition: DecisionTree-inl.h:649
const Y & operator()(const Assignment< L > &x) const
evaluate
Definition: DecisionTree-inl.h:619
Node::Ptr NodePtr
---------------------— Node base class ------------------------—
Definition: DecisionTree.h:98
bool operator==(const DecisionTree &q) const
equality
Definition: DecisionTree-inl.h:614
void print(const std::string &s="DecisionTree") const
GTSAM-style print.
Definition: DecisionTree-inl.h:609
DecisionTree()
Default constructor.
Definition: DecisionTree-inl.h:372
std::function< Y(const Y &)> Unary
Handy typedefs for unary and binary function types.
Definition: DecisionTree.h:43
---------------------— Node base class ------------------------—
Definition: DecisionTree.h:54