Point Cloud Library (PCL)  1.9.1
decision_tree_trainer.hpp
1 /*
2  * Software License Agreement (BSD License)
3  *
4  * Point Cloud Library (PCL) - www.pointclouds.org
5  * Copyright (c) 2010-2011, Willow Garage, Inc.
6  *
7  * All rights reserved.
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * * Redistributions of source code must retain the above copyright
14  * notice, this list of conditions and the following disclaimer.
15  * * Redistributions in binary form must reproduce the above
16  * copyright notice, this list of conditions and the following
17  * disclaimer in the documentation and/or other materials provided
18  * with the distribution.
19  * * Neither the name of Willow Garage, Inc. nor the names of its
20  * contributors may be used to endorse or promote products derived
21  * from this software without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
24  * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
25  * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
26  * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
27  * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
28  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
29  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
30  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31  * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
32  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
33  * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
34  * POSSIBILITY OF SUCH DAMAGE.
35  *
36  */
37 
38 #ifndef PCL_ML_DT_DECISION_TREE_TRAINER_HPP_
39 #define PCL_ML_DT_DECISION_TREE_TRAINER_HPP_
40 
41 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
42 template <class FeatureType, class DataSet, class LabelType, class ExampleIndex, class NodeType>
44  : max_tree_depth_ (15)
45  , num_of_features_ (1000)
46  , num_of_thresholds_ (10)
47  , feature_handler_ (NULL)
48  , stats_estimator_ (NULL)
49  , data_set_ ()
50  , label_data_ ()
51  , examples_ ()
52  , decision_tree_trainer_data_provider_ ()
53  , random_features_at_split_node_(false)
54 {
55 
56 }
57 
58 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
59 template <class FeatureType, class DataSet, class LabelType, class ExampleIndex, class NodeType>
61 {
62 
63 }
64 
65 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
66 template <class FeatureType, class DataSet, class LabelType, class ExampleIndex, class NodeType>
67 void
70 {
71  // create random features
72  std::vector<FeatureType> features;
73 
74  if (!random_features_at_split_node_)
75  feature_handler_->createRandomFeatures (num_of_features_, features);
76 
77  // recursively build decision tree
78  NodeType root_node;
79  tree.setRoot (root_node);
80 
81  if (decision_tree_trainer_data_provider_)
82  {
83  std::cerr << "use decision_tree_trainer_data_provider_" << std::endl;
84 
85  decision_tree_trainer_data_provider_->getDatasetAndLabels (data_set_, label_data_, examples_);
86  trainDecisionTreeNode (features, examples_, label_data_, max_tree_depth_, tree.getRoot ());
87  label_data_.clear ();
88  data_set_.clear ();
89  examples_.clear ();
90  }
91  else
92  {
93  trainDecisionTreeNode (features, examples_, label_data_, max_tree_depth_, tree.getRoot ());
94  }
95 }
96 
97 
98 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
99 template <class FeatureType, class DataSet, class LabelType, class ExampleIndex, class NodeType>
100 void
102  std::vector<FeatureType> & features,
103  std::vector<ExampleIndex> & examples,
104  std::vector<LabelType> & label_data,
105  const size_t max_depth,
106  NodeType & node)
107 {
108  const size_t num_of_examples = examples.size ();
109  if (num_of_examples == 0)
110  {
111  PCL_ERROR ("Reached invalid point in decision tree training: Number of examples is 0!");
112  return;
113  };
114 
115  if (max_depth == 0)
116  {
117  stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
118  return;
119  };
120 
121  if(examples.size () < min_examples_for_split_) {
122  stats_estimator_->computeAndSetNodeStats (data_set_, examples, label_data, node);
123  return;
124  }
125 
126  if(random_features_at_split_node_) {
127  features.clear ();
128  feature_handler_->createRandomFeatures (num_of_features_, features);
129  }
130 
131  std::vector<float> feature_results;
132  std::vector<unsigned char> flags;
133 
134  feature_results.reserve (num_of_examples);
135  flags.reserve (num_of_examples);
136 
137  // find best feature for split
138  int best_feature_index = -1;
139  float best_feature_threshold = 0.0f;
140  float best_feature_information_gain = 0.0f;
141 
142  const size_t num_of_features = features.size ();
143  for (size_t feature_index = 0; feature_index < num_of_features; ++feature_index)
144  {
145  // evaluate features
146  feature_handler_->evaluateFeature (features[feature_index],
147  data_set_,
148  examples,
149  feature_results,
150  flags );
151 
152  // get list of thresholds
153  if (thresholds_.size () > 0)
154  {
155  // compute information gain for each threshold and store threshold with highest information gain
156  for (size_t threshold_index = 0; threshold_index < thresholds_.size (); ++threshold_index)
157  {
158 
159  const float information_gain = stats_estimator_->computeInformationGain (data_set_,
160  examples,
161  label_data,
162  feature_results,
163  flags,
164  thresholds_[threshold_index]);
165 
166  if (information_gain > best_feature_information_gain)
167  {
168  best_feature_information_gain = information_gain;
169  best_feature_index = static_cast<int> (feature_index);
170  best_feature_threshold = thresholds_[threshold_index];
171  }
172  }
173  }
174  else
175  {
176  std::vector<float> thresholds;
177  thresholds.reserve (num_of_thresholds_);
178  createThresholdsUniform (num_of_thresholds_, feature_results, thresholds);
179 
180  // compute information gain for each threshold and store threshold with highest information gain
181  for (size_t threshold_index = 0; threshold_index < num_of_thresholds_; ++threshold_index)
182  {
183  const float threshold = thresholds[threshold_index];
184 
185  // compute information gain
186  const float information_gain = stats_estimator_->computeInformationGain (data_set_,
187  examples,
188  label_data,
189  feature_results,
190  flags,
191  threshold);
192 
193  if (information_gain > best_feature_information_gain)
194  {
195  best_feature_information_gain = information_gain;
196  best_feature_index = static_cast<int> (feature_index);
197  best_feature_threshold = threshold;
198  }
199  }
200  }
201  }
202 
203  if (best_feature_index == -1)
204  {
205  stats_estimator_->computeAndSetNodeStats (data_set_, examples, label_data, node);
206  return;
207  }
208 
209  // get branch indices for best feature and best threshold
210  std::vector<unsigned char> branch_indices;
211  branch_indices.reserve (num_of_examples);
212  {
213  feature_handler_->evaluateFeature (features[best_feature_index],
214  data_set_,
215  examples,
216  feature_results,
217  flags );
218 
219  stats_estimator_->computeBranchIndices (feature_results,
220  flags,
221  best_feature_threshold,
222  branch_indices);
223  }
224 
225  stats_estimator_->computeAndSetNodeStats (data_set_, examples, label_data, node);
226 
227  // separate data
228  {
229  const size_t num_of_branches = stats_estimator_->getNumOfBranches ();
230 
231  std::vector<size_t> branch_counts (num_of_branches, 0);
232  for (size_t example_index = 0; example_index < num_of_examples; ++example_index)
233  {
234  ++branch_counts[branch_indices[example_index]];
235  }
236 
237  node.feature = features[best_feature_index];
238  node.threshold = best_feature_threshold;
239  node.sub_nodes.resize (num_of_branches);
240 
241  for (size_t branch_index = 0; branch_index < num_of_branches; ++branch_index)
242  {
243  if (branch_counts[branch_index] == 0)
244  {
245  NodeType branch_node;
246  stats_estimator_->computeAndSetNodeStats (data_set_, examples, label_data, branch_node);
247  //branch_node->num_of_sub_nodes = 0;
248 
249  node.sub_nodes[branch_index] = branch_node;
250 
251  continue;
252  }
253 
254  std::vector<LabelType> branch_labels;
255  std::vector<ExampleIndex> branch_examples;
256  branch_labels.reserve (branch_counts[branch_index]);
257  branch_examples.reserve (branch_counts[branch_index]);
258 
259  for (size_t example_index = 0; example_index < num_of_examples; ++example_index)
260  {
261  if (branch_indices[example_index] == branch_index)
262  {
263  branch_examples.push_back (examples[example_index]);
264  branch_labels.push_back (label_data[example_index]);
265  }
266  }
267 
268  trainDecisionTreeNode (features, branch_examples, branch_labels, max_depth-1, node.sub_nodes[branch_index]);
269  }
270  }
271 }
272 
273 
274 //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
275 template <class FeatureType, class DataSet, class LabelType, class ExampleIndex, class NodeType>
276 void
278  const size_t num_of_thresholds,
279  std::vector<float> & values,
280  std::vector<float> & thresholds)
281 {
282  // estimate range of values
283  float min_value = ::std::numeric_limits<float>::max();
284  float max_value = -::std::numeric_limits<float>::max();
285 
286  const size_t num_of_values = values.size ();
287  for (size_t value_index = 0; value_index < num_of_values; ++value_index)
288  {
289  const float value = values[value_index];
290 
291  if (value < min_value) min_value = value;
292  if (value > max_value) max_value = value;
293  }
294 
295  const float range = max_value - min_value;
296  const float step = range / static_cast<float>(num_of_thresholds+2);
297 
298  // compute thresholds
299  thresholds.resize (num_of_thresholds);
300 
301  for (size_t threshold_index = 0; threshold_index < num_of_thresholds; ++threshold_index)
302  {
303  thresholds[threshold_index] = min_value + step*(static_cast<float>(threshold_index+1));
304  }
305 }
306 
307 #endif
void train(DecisionTree< NodeType > &tree)
Trains a decision tree using the set training data and settings.
Class representing a decision tree.
Definition: decision_tree.h:51
void setRoot(const NodeType &root)
Sets the root node of the tree.
Definition: decision_tree.h:66
NodeType & getRoot()
Returns the root node of the tree.
Definition: decision_tree.h:73
void trainDecisionTreeNode(std::vector< FeatureType > &features, std::vector< ExampleIndex > &examples, std::vector< LabelType > &label_data, size_t max_depth, NodeType &node)
Trains a decision tree node from the specified features, label data, and examples.
static void createThresholdsUniform(const size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformely distrebuted thresholds over the range of the supplied values.
virtual ~DecisionTreeTrainer()
Destructor.