Introduction
In earlier articles we demonstrated how simple distance metrics and a nearest-neighbour recommendation system could be implemented in ruby. We will stick with this machine learning (ML) theme in the present article and discuss the implementation of a simple decision tree algorithm, again in ruby.
We will implement a decsion tree ML algorithm and we will then train our model with data for a classification task. Once trained, our decision tree model should be able to predict the classification for previously unseen data. We will use a simple dataset as we develop the algorithm, but we will also demonstrate how the algorithm can then be applied to a more challenging problem.
The algorithms implemented in this article are inspired by the python versions presented by Toby Segaran in Collective Intelligence (Amazon), which I highly recommend for anyone interested in learning the fundamentals of classic machine learning methods.
Decision Trees
The venerable decision tree has been at the basis of human classification systems from the beginning of time, and permeates many disciplines. Many of you will have come across these ideas in biology class as a mechanism for discriminating different living oganisms, for example:
This approach is characterised by a series of binary decision nodes, where each node represents a single question with two branches emanating from it: the true branch branch and the false branch. At the end of a series of decision nodes we should hopefully have a leaf node, which defines the categorization that should apply.
Once we have a decision tree, such as the one depicted, using it to classify a new case is pretty intuitive. The question is, how do we generate this decision tree?
In machine learning we want the algorithm to build this tree for us, using existing data. Let's introduce a simple dataset which we can use as a reference as we tease out the algorithm. We will take the same dataset presented in [1]:
module SimpleData
DATA = [
['slashdot', 'USA', 'yes', 18, 'None'],
['google', 'France', 'yes', 23, 'Premium'],
['digg', 'USA', 'yes', 24, 'Basic'],
['kiwitobes', 'France', 'yes', 23, 'Basic'],
['google', 'UK', 'no', 21, 'Premium'],
['(direct)', 'New Zealand', 'no', 12, 'None'],
['(direct)', 'UK', 'no', 21, 'Basic'],
['google', 'USA', 'no', 24, 'Premium'],
['slashdot', 'France', 'yes', 19, 'None'],
['digg', 'USA', 'no', 18, 'None'],
['google', 'UK', 'no', 18, 'None'],
['kiwitobes', 'UK', 'no', 19, 'None'],
['digg', 'New Zealand', 'yes', 12, 'Basic'],
['slashdot', 'UK', 'no', 21, 'None'],
['google', 'UK', 'no', 18, 'Basic'],
['kiwitobes', 'France', 'yes', 19, 'Basic']
]
…
end
Each row in the dataset represents a visitor to some particular website, and the columns of each row include the following data
- The referrering website, from which the visitor has arrived
- The geographic location of the visitor
- Whether the visitor has read the FAQ
- The number of pages viewed by the visitor
- The service to which the visitor has subscribed
As a brief aside, you can see that we have wrapped our DATA
in a SimpleData
module; the only other thing exposed by this module
is a singleton method get_train_and_test_data
:
def self.get_train_and_test_data(split: 0.1)
train_data = DATA.dup
# Let's remove :split of data for testing
full_size = train_data.size
test_data =(0..(full_size*split)).to_a.map do |i|
train_data.delete_at(rand(full_size-i))
end
[train_data, test_data]
end
This method provides an interface to return our dataset as two separate arrays, a training set and a test set.
This a very important concept in building machine learning models, but I will not be able to
do it justice in this post.
In brief, we use our training set to train (or build) our decision tree model, and then we use our test set to evaluate the accuracy of the model we have produced.
The method get_train_and_test_data
will randomly pull rows out of the training data to populate the test set, in the ratio specified by split
.
We aim to find a series of conditions which, when applied to our data, will split it into pure sets containing only a single label. As an example, let's consider splitting our dataset by the values in the fourth column, i.e. the number of pages viewed. Suppose our first node splits the data depending if number of pages viewed > 18, how would our data look if we applied this condition:
SimpleData::DATA.partition{ |row| row[3] > 18}
=>
[[["google", "France", "yes", 23, "Premium"],
["digg", "USA", "yes", 24, "Basic"],
["kiwitobes", "France", "yes", 23, "Basic"],
["google", "UK", "no", 21, "Premium"],
["(direct)", "UK", "no", 21, "Basic"],
["google", "USA", "no", 24, "Premium"],
["slashdot", "France", "yes", 19, "None"],
["kiwitobes", "UK", "no", 19, "None"],
["slashdot", "UK", "no", 21, "None"],
["kiwitobes", "France", "yes", 19, "Basic"]],
[["slashdot", "USA", "yes", 18, "None"],
["(direct)", "New Zealand", "no", 12, "None"],
["digg", "USA", "no", 18, "None"],
["google", "UK", "no", 18, "None"],
["digg", "New Zealand", "yes", 12, "Basic"],
["google", "UK", "no", 18, "Basic"]]]
In this snippet we are using the a very neat Enumerable#partition
method.
Take a look at the docs and you will see that this little beauty is tailor-made for our decision tree
implementation; it allows us to split an array in two, based on whether the given block evaluates to true of false.
Using this criterion, i.e num_pageviews > 18
, we have gathered all of the Premium visitors into our first partition. This definitely feels like progress, like
the two partitions are somehow purer. But is there anyway that we can evaluate this perceived improvement? This is where we need to introduce the idea of
entropy.
Entropy a quantity that has been defined by humans because it turns out to be very useful to help us describe different physical systems. It is often described as a measure of disorder in a system, but in our case we will use it to tell us how mixed our arrays are. We want to apply criteria that will tend to reduce the average entropy in our partitions, i.e. will make them less mixed. The entropy in any given partition of our data can be defined to be: $$ \begin{aligned} U = - \sum_{i=1}^{N} p_{i} \log_{2} p_{i} \end{aligned} $$ where $p_{i}$ is the probability of obtaining label $i$ in the partition, and $N$ is 3, as we have 3 labels to sum over.
Things will be easier if we consider an example. In our original dataset (containing 16 rows) we had a mixture of labels: 6 visitors were Basic (b), 3 visitors were Premium (p) and 7 visitors were None (n). We can calculate the entropy of this set as follows:
$$ \begin{aligned} U_{\rm{orig}} &= - \sum_{i=\{\rm{b}, \rm{p}, \rm{n}\}} p_{i} \log_{2} p_{i} \\ &= - p_{\rm{b}} \log_{2} p_{\rm{b}} - p_{\rm{p}} \log_{2} p_{\rm{p}} - p_{\rm{n}} \log_{2} p_{\rm{n}} \\ &= - \frac{6}{16} \log_{2} \left( \frac{6}{16} \right) - \frac{3}{16} \log_{2} \left( \frac{3}{16} \right) - \frac{7}{16} \log_{2} \left( \frac{7}{16} \right) \\ &= 1.5052 \end{aligned} $$
We then use our condition, num_pageviews > 18
, to split the data into two separate partitions of 10 and 6 rows. Each partition has its own entropy:
$$
\begin{aligned}
U_{1} &= - p_{\rm{b}} \log_{2} p_{\rm{b}} - p_{\rm{p}} \log_{2} p_{\rm{p}} - p_{\rm{n}} \log_{2} p_{\rm{n}} \\
&= - \frac{4}{10} \log_{2} \left( \frac{4}{10} \right) - \frac{3}{10} \log_{2} \left( \frac{3}{10} \right) - \frac{3}{10} \log_{2} \left( \frac{3}{10} \right) \\
&= 1.5709 \\
\\
U_{2} &= - p_{\rm{b}} \log_{2} p_{\rm{b}} - p_{\rm{p}} \log_{2} p_{\rm{p}} - p_{\rm{n}} \log_{2} p_{\rm{n}} \\
&= - \frac{2}{6} \log_{2} \left( \frac{2}{6} \right) - \frac{0}{6} \log_{2} \left( \frac{0}{6} \right) - \frac{4}{6} \log_{2} \left( \frac{4}{6} \right) \\
&= 0.9183
\end{aligned}
$$
To compare the entropy before and after the split we need to take the weighted average of the entropy for each partition, i.e. $\frac{10}{16} U_{1} + \frac{6}{16} U_{2} = 1.3255$. So our split criterion has succeeded in reducing the entropy by $1.5052 - 1.3255 = 0.1797$. This decrease in entropy can also be referred to as the information gain.
Now that we understand what the entropy is, the implementation in ruby should be straightforward. Our entropy
method will take a set of rows
, it will
group the rows according to the label in the last column, getting a count for each label. The count for a given label, divided by the total number of rows, will give the probablility for that label.
We then combine these probabilities using the formula already presented:
def entropy(rows)
rows.group_by(&:last).reduce(0) do |entropy,(key, value)|
prob = value.length.to_f / rows.size
entropy -= prob * Math.log(prob, 2)
end
end
We will see this entropy enter our algorithm later in the article.
Before that, let's recall that our decision to split our data on this particular attribute and value was selected quasi-randomly. We could have chosen a different pageview limit in our criterion, or we could have split by a completely different attribute. Our decision tree algoritm must explore all the possible ways to split the data, and then choose the criterion which gives the best information gain before-after the split. Let's look at how we can write this algorithm in ruby.
Representing a Decision Tree
The basis of our decision tree is the individual node. Each node is either a decision node or a leaf node, we could represent these as separate classes, but for our implementation we will handle both cases within a single class:
class TreeNode
attr_reader :column_index, :value, :true_branch, :false_branch, :results
def initialize(column_index: nil, value: nil, true_branch: nil, false_branch: nil, results: nil)
@column_index = column_index
@value = value
@true_branch = true_branch
@false_branch = false_branch
@results = results
end
…
end
A decision node will be defined by a column_index
, for the column upon which the node is partitioning the data, along with the value
of that column
which defines the threshold for the split. The decision node will also hold a reference to the true_branch
and false_branch
; these are references to the next
TreeNode
to be applied to the to the true and false partitions, respectively. For a decision node the results
attribute will be nil
.
By contrast, if we are dealing with a leaf node then results
will be populated with the training instances (or rows) that were grouped into this leaf.
As the leaf node is at the end of a tree branch the other attributes of column_index
, value
, true_branch
and false_branch
will not be
populated for such a node.
With a decision tree composed in this manner, if we are given the root node we should be able to classify a previously unseen item. We do this by recursively navigating our tree
nodes, applying the decision criteria at each node to decide which branch should be followed next, until we reach a leaf node. This classification process for a new item (or row) is captured in the
TreeNode#classify
method, which leans on the classify_prob
method:
class TreeNode
…
def classify(row)
classify_prob(row).max_by{ |_,v| v }[0]
end
def classify_prob(row)
return summary_results if results
if value.is_a?(Numeric)
if row[column_index] >= value
true_branch.classify_prob(row)
else
false_branch.classify_prob(row)
end
elsif row[column_index] == value
true_branch.classify_prob(row)
else
false_branch.classify_prob(row)
end
end
private
def summary_results
results.group_by(&:last).reduce({}) do |memo, (key, value)|
(memo[key] = value.size) && memo
end
end
…
end
If the current node is a decision node (i.e. where results
are not present), then the classify_prob
method will take the new item (or row)
and will apply the descision encapsulated in this current node. We apply a slightly different threshold check depending upon whether the TreeNode#value
is
numeric or non-numberic. However, the recursive pattern is the same in both cases, so let's focus on the case where the current TreeNode#value
is a number:
if value.is_a?(Numeric)
if row[column_index] >= value
true_branch.classify_prob(row)
else
false_branch.classify_prob(row)
end
else …
This decision node tells us that we are interested in the attribute in column column_index
, so we extract this element from row
.
We then compare this extracted element with the threshold value
defined by our decision node. If our row
element equals or exceeds the threshold value we
pass the row
to classify_prob
on the true_branch
, otherwise we pass the row
to false_branch.classify_prob
.
Thus we are recursively applying the classify_prob
method to the row
, as we navigate from node to node. This recursion will exit once we reach a leaf node, which
has the results
populated.
If the current node is a leaf node (i.e. with results
present) then the method will return a summary of the trianing rows which have been stored in that leaf.
This summary will give a count of the different classes within the leaf node's results, for instance if the leaf node has the following results
stored:
@results = [
["kiwitobes", "UK", "no", 19, "None"],
["slashdot", "UK", "no", 21, "None"],
["kiwitobes", "France", "yes", 19, "Basic"],
["slashdot", "USA", "yes", 18, "None"]
]
Then the summary_results
will return a hash that looks like this: { "None" => 3, "Basic" => 1 }
, reflecting that this leaf node has categorized 3 cases with a
label of 'None' and one case of 'Basic'. If our row is categorized to this leaf node, we can conclude that our row should be classified as 'None', with a probability of 75%. The
classify
method is simply responsible for pulling out this label that has the maximum probability:
def classify(row)
classify_prob(row).max_by{ |_,v| v }[0]
end
Algorithm Implementation
We have seen how the TreeNode
can be used to represent our decision tree, and how we can use an existing tree to classify new data. We now introduce our
DecisionTreeTrainer
, which will use the training data to build our tree of nodes:
class DecisionTreeTrainer
def self.train(rows, max_depth: nil)
self.new(rows, max_depth: max_depth).train
end
attr_reader :all_rows, :total_rows, :num_attributes, :root_node, :max_depth
def initialize(rows, max_depth: 10)
@all_rows = rows
@num_attributes = all_rows.first.size - 1 # Never split on last attribute
@total_rows = all_rows.size
@root_node = nil
@max_depth = 10
end
# Returns a TreeNode, which is the root of the tree we have trained
def train
return @root_node if @root_node
puts "Starting training ..."
start_time = Time.now
@root_node = build_decision_node(self.all_rows)
puts "Completed training in #{Time.now - start_time} seconds"
@root_node
end
…
The DecisionTreeTrainer
is initialized with our training data, which is stored in the rows
attribute and we can also specify a max_depth
parameter. Let's summarize the instance variables on the DecisionTreeTrainer
class:
-
rows
: holds our training data. This is the data that will help us to build our decision tree model. -
max_depth
: prevents our recursive algorithm from building a tree that is too deep (defaults to a value of 10). -
num_attributes
: captures the number of attributes that makes up each item in our dataset , i.e. the number of columns in each row. -
total_rows
: total number of items in our training data. -
root_node
: the ultimate result of our training is the root node of the tree we have constructed. With this node we can traverse the tree to classify new items.
The DecisionTreeTrainer#train
method will return this root_node
if it is defined, otherwise it will calculate the root_node
by calling
build_decision_node
passing the full dataset. This is where all the magic happens:
…
def build_decision_node(rows, depth = 0)
max_info_gain = 0
split_index = -1
split_value = nil
initial_entropy = entropy(rows)
return TreeNode.new(results: rows) if (initial_entropy == 0 || depth >= max_depth)
num_attributes.times do |i|
rows.map{|r| r[i] }.uniq.each do |value|
new_rows = divide_set(rows, i, value)
new_entropy = new_rows.reduce(0) do |memo, branch|
memo+=entropy(branch)*branch.size/rows.size
end
info_gain = (initial_entropy - new_entropy)
if info_gain > max_info_gain
max_info_gain = info_gain
split_index = i
split_value = value
end
end
end
# OK if we have an info gain, lets split according to best criteria found
if max_info_gain <= 0
TreeNode.new(results: rows)
else
true_rows, false_rows = divide_set(rows, split_index, split_value)
TreeNode.new(
column_index: split_index,
value: split_value,
true_branch: build_decision_node(true_rows, depth + 1),
false_branch: build_decision_node(false_rows, depth + 1),
)
end
end
The build_decision_node
method, again, takes our training rows and an optional depth
argument, which defaults to 0. We start by calculating the
initial_entropy
of our training set, using the entropy
method defined earlier. If the initial_entropy
evaluates to zero (meaning
all the rows have the same label) then we just return a leaf node with the results
set accordingly. We also terminate the process in the same way if the
depth
value exceeds the max_depth
parameter which we previously defined.
Presuming that neither of these conditions have been met, we proceed to try and find a candidate column and value, which we can use to split our training dataset. We start
by looping over each column index, i
in turn. Then for each column we look at all the distinct values in that column, and try to split our training set on each value, using the
divide_set
method.
The splitting of the training dataset by a particular value
in a particular column (column_index
) can be achieved pretty neatly in ruby using the
Enumerable#partition
method as follows:
def divide_set(rows, column_index, value)
split_function = if value.is_a? Numeric
lambda{ |row| row[column_index] >= value }
else
lambda{ |row| row[column_index] == value }
end
rows.partition(&split_function)
end
If the value
in the column is numeric then we check each item (or row) to see if the corresponding element is greater-than-or-equal to this threshold value
.
If it is, then the item/row will be returned within the true
partition, otherwise the item/row will be returned in the false
partition.
If the value
is non-numeric then the true
partition will include all rows where the matching element equals the value
, otherwise the
row will be returned in the false
partition.
The resulting new_rows
will contain our split dataset, with some items in the first array (or partition), new_rows[0]
,and the remaining items in the
second partition, new_rows[1]
. We calculate the average entropy of these partitions, new_entropy
, and compare this to the original_entropy
to determine the info_gain
for this particular split.
By looping over each candidate value
in each column, i
, we eventually find the particular column and value which give us the greatest information gain
These are represented by split_index
and split_value
, respectively. Having defined our best split criteria, we are now in a position to build our
TreeNode
:
if max_info_gain <= 0
TreeNode.new(results: rows)
else
true_rows, false_rows = divide_set(rows, split_index, split_value)
TreeNode.new(
column_index: split_index,
value: split_value,
true_branch: build_decision_node(true_rows, depth + 1),
false_branch: build_decision_node(false_rows, depth + 1),
)
end
One special case which we need to handle here is when our best split doesn't actually improve the information gain. In this case no split
can improve the separation of our examples, so we just return a leaf node with the results
populated with the full set of rows.
Presuming our best split does, indeed, provide a positive information gain we then proceed to split our dataset using these optimal criteria, split_index
and
split_value
, to retrieve the two new partitions: true_rows
and false_rows
. We then proceed to build our TreeNode
with these
optimal criteria, but how do we know the next node in the tree for the true_branch
and false_branch
?
The answer is that we don't know these nodes yet, we need to calculate them. We calculate these nodes by calling the build_decision_node
method again, recursively.
For the true_branch
we need to invoke the method to be trained on the true_rows
partition, whilst the false_branch
should be trained
on the false_rows
partition. In each case we increment the depth
parameter by 1.
If you are anything like me, this recursive invocation will hurt your head! But understanding this part is key to understanding the whole algorithm, so take your time to let these few lines sink in.
As we initialize our root TreeNode
, it will partition the training data using the best criteria it can find and then it will use the partitioned data to build two new nodes for
the true and false branches. As those nodes are constructed they will also partition the data they have received and build subsequent nodes and so on. This will happen until we try to build a
node for which one of the exiting criteria are triggered. These exiting criteria are
original_entropy
for the rows is zero (i.e. they all have the same label),- we have reached our
max_depth
or - there is no split which will result in an information gain
rows
which we hold at that point.
Training and running the model
And with that we have all the pieces we need to train our decision tree model. We can initialize our training data using the SimpleData
module and pass
the data into DecisionTreeTrainer#train
to build our model. With the model in hand we can then test how accurate it is against our test data:
require './decision_tree_trainer'
require './simple_data'
train_rows, test_rows = SimpleData.get_train_and_test_data(split: 0.1)
tree = DecisionTreeTrainer.train(train_rows)
tree.print
correct_count = test_rows.reduce(0) do |count, row|
count+=1 if tree.classify(row) == row[-1]
count
end
puts "Accuracy: #{correct_count/test_rows.size.to_f}"
You can see that we also make use of a print
function which we have defined on the TreeNode
. This gives a basic visual representation of our tree hierarchy, but
I will refer you to the source to see how this is implemented.
Running this script we see the following output:
Starting training ...
Completed training in 0.000835017 seconds
(column: 0, value: google)
true:
(column: 3, value: 21)
true:
["google", "France", "yes", 23, "Premium"]
["google", "UK", "no", 21, "Premium"]
false:
["google", "UK", "no", 18, "None"]
["google", "UK", "no", 18, "Basic"]
false:
(column: 2, value: yes)
true:
(column: 0, value: slashdot)
true:
["slashdot", "France", "yes", 19, "None"]
false:
["digg", "USA", "yes", 24, "Basic"]
["kiwitobes", "France", "yes", 23, "Basic"]
["digg", "New Zealand", "yes", 12, "Basic"]
["kiwitobes", "France", "yes", 19, "Basic"]
false:
(column: 0, value: (direct))
true:
(column: 1, value: New Zealand)
true:
["(direct)", "New Zealand", "no", 12, "None"]
false:
["(direct)", "UK", "no", 21, "Basic"]
false:
["digg", "USA", "no", 18, "None"]
["kiwitobes", "UK", "no", 19, "None"]
["slashdot", "UK", "no", 21, "None"]
Accuracy: 1.0
For the run displayed above the accuracy was 1.0, but in other runs we are getting accuracy of 0.5 or 0.0. This is basically owing to the fact that our dummy dataset is very small. We can better evaluate our algorithm by applying it to a more meaningful dataset.
To this end, we will consider the mushroom dataset published by UC Irvine. This dataset contains thousands of samples of gilled mushrooms, with each sample described by a set of categorical features, along with a classification of whether the sample was poisonous (p) or edible (e). The first 10 rows from the dataset are show here:
>head mushroom/agaricus-lepiota.data
p,x,s,n,t,p,f,c,n,k,e,e,s,s,w,w,p,w,o,p,k,s,u
e,x,s,y,t,a,f,c,b,k,e,c,s,s,w,w,p,w,o,p,n,n,g
e,b,s,w,t,l,f,c,b,n,e,c,s,s,w,w,p,w,o,p,n,n,m
p,x,y,w,t,p,f,c,n,n,e,e,s,s,w,w,p,w,o,p,k,s,u
e,x,s,g,f,n,f,w,b,k,t,e,s,s,w,w,p,w,o,e,n,a,g
e,x,y,y,t,a,f,c,b,n,e,c,s,s,w,w,p,w,o,p,k,n,g
e,b,s,w,t,a,f,c,b,g,e,c,s,s,w,w,p,w,o,p,k,n,m
e,b,y,w,t,l,f,c,b,n,e,c,s,s,w,w,p,w,o,p,n,s,m
p,x,y,w,t,p,f,c,n,p,e,e,s,s,w,w,p,w,o,p,k,v,g
e,b,s,y,t,a,f,c,b,g,e,c,s,s,w,w,p,w,o,p,k,s,m
The first column contains the classification, p or e. The subsequent columns represent different features for the sample. For example column 2 refers to the cap-shape and can take one of the values of bell (b), conical (c), convex (x), flat (f), knobbed (k) or sunken (s). The next column is the cap-surface, which can be fibrous (f), grooves (g), scaly (y) or smooth (s). The actual details of these features is not important to our decision tree algorithm, it will find the best way to separate these samples based on their values.
We will introduce a MushroomData
module to take care of loading the data from the file and putting the data in the correct format expected by our algorithm, i.e. we need to
place the label data (p or e) at the end of the row rather than the start. In addition, we will not train our model on the full set of 22 features for each sample, instead we will just use the first 7
features to build our model:
require 'csv'
module MushroomData
def self.get_train_and_test_data(max_rows: 10_000, split: 0.1)
raw_data = CSV.read('mushroom/agaricus-lepiota.data')[0..max_rows]
classifications = []
train_data = raw_data.map{ |row| row[1..7].append(row[0]) }
# Let's remove :split of data for testing
full_size = train_data.size
test_data =(0..(full_size*split)).to_a.map do |i|
train_data.delete_at(rand(full_size-i))
end
[train_data, test_data]
end
end
With this we can adapt our existing script to load the mushroom data and train our decision tree classifier with the new dataset:
require './decision_tree_trainer'
require './simple_data'
require './mushroom_data'
train_rows, test_rows = if ARGV.first == "mushroom"
MushroomData.get_train_and_test_data(max_rows: 8000, split: 0.1)
else
SimpleData.get_train_and_test_data(split: 0.1)
end
tree = DecisionTreeTrainer.train(train_rows)
tree.print
correct_count = test_rows.reduce(0) do |count, row|
count+=1 if tree.classify(row) == row[-1]
count
end
puts "Accuracy: #{correct_count/test_rows.size.to_f}"
Running this script we will reuse our decision tree algorithm to train a model on this completely new, and unrelated dataset. The representation of the tree produced will be output to the terminal, but we truncate it here for brevity:
> ruby runner.rb mushroom
Starting training ...
Completed training in 0.385516171 seconds
(column: 4, value: n)
true:
(column: 2, value: y)
true:
{"p"=>22}
false:
(column: 0, value: b)
true:
(column: 3, value: t)
…
false:
(column: 3, value: t)
true:
(column: 4, value: f)
true:
{"p"=>256}
false:
(column: 4, value: p)
true:
{"p"=>231}
false:
{"e"=>721}
false:
{"p"=>2900}
Accuracy: 0.9975031210986267
An accuracy of 99.75%, not too bad for about 200 lines of ruby!
Conclusion
We have discussed the fundamental ideas behind the decision tree model, along with an algorithm for training a decision tree using labelled data. We showed how this algorithm could be implemented in ruby and tested our implementation on a realistic data set.
Whilst there are other machine learning techniques which receive a lot more attention, the decision tree remains a simple but powerful technique. It is simple to implement and intuitive to interpret. It can deal with both numeric and categorical data and requires very little data preparation.
Notwithstanding these benefits, it is not all roses for the decision tree model. Unchecked, the model has a tendency to overfit the data, i.e. generate very specific rules that only apply to the precise training data, but do not apply more generally. In addition the technique can be quite computationally expensive and does not scale easily to very large datasets.
To address some of these shortcomings, the accuracy and performance of decision trees can often be improved by employing extensions such a pruning, random forests, boosting and bagging. These techniques warrant their own discussion so that will need to wait for a future post.
I hope you enjoyed this article and found it useful. f you would like to be kept up-to-date when we publish new articles please subscribe.
References
- Intro to the CART algorithm
- Collective Intelligence by Toby Sagaran, available on Amazon
- chapter 7: Modeling with Decision Trees from Collective Intelligence by Toby Sagaran
- Article on differentiation of the biological kingoms
- Blog post on the virtues of the decision tree model
- Docs on Ruby's
Enumerable#partition
method - Machine Learning Mastery blog post covering some background theory on decision trees
- GitHub repo with the code presented in this blog post
- Ruby docs for
Array#partition
method - Wiki entry on training and test data for machine learning models
- UCI mushroom dataset
- Wikipedia entry for entropy
Comments
There are no existing comments
Got your own view or feedback? Share it with us below …