Decision trees in python with scikit-learn and pandas

In this post I will cover decision trees (for classification) in python, using scikit-learn and pandas. The emphasis will be on the basics and understanding the resulting decision tree. I will cover:

  • Importing a csv file using pandas,
  • Using pandas to prep the data for the scikit-leaarn decision tree code,
  • Drawing the tree, and
  • Producing pseudocode that represents the tree.

The last two parts will go over what the tree has actually found– this is one of the really nice parts of a decision tree: the findings can be inspected and we can learn something about the patterns in our data. If this sounds interesting to you, read on. Also, if you have other ideas about how to do related things please leave comments below!


Before we get going, the code is available as a gist, so you don’t have to copy and paste (unless you want to). I’ll go through the functions and usage from scratch here– usage of the gist code is detailed there in a README file.


So, first we do some imports, including the print_function for python3-style print statements. I also import the usual suspects, using common abbreviations, which I’ll discuss below:

from __future__ import print_function

import os
import subprocess

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_graphviz

data with pandas

Next, we need some data to consider. I’ll use the famous iris data set, that has various measurements for a variety of different iris types. I think both pandas and sckit-learn have easy import options for this data, but I’m going to write a function to import from a csv file, using pandas. The point of this to demonstrate how pandas can be used with scikit-learn. So, we define a function for getting the iris data:

def get_iris_data():
    """Get the iris data, from local csv or pandas repo."""
    if os.path.exists("iris.csv"):
        print("-- iris.csv found locally")
        df = pd.read_csv("iris.csv", index_col=0)
        print("-- trying to download from github")
        fn = "" + \
            df = pd.read_csv(fn)
            exit("-- Unable to download iris.csv")

        with open("iris.csv", 'w') as f:
            print("-- writing to local iris.csv file")

    return df


  • This function first tries to read the data locally, using pandas. This is why I import os above: to make use of the os.path.exists() method. If the iris.csv file is found in the local directory, pandas is used to read the file using pd.read_csv()– note that pandas has been import using import pandas as pd. This is typical usage for the package.
  • If a local iris.csv is not found, pandas is used to grab the data from a url and a local copy is saved for future runs. A try and except are used to exit and provide a note if there are problems– maybe the user is not connected to the internet?

Hopefully the above codes gives a sense of how to load a csv data file, locally as well as from a remote location. The next step is to get the data and use the head() and tail() methods to see what the data is like– these show the start and end of the dataframe, respectively. So, first get the data:

df = get_iris_data()
-- iris.csv found locally

then, head and tail:

print("* df.head()", df.head(), sep="\n", end="\n\n")
print("* df.tail()", df.tail(), sep="\n", end="\n\n")
* df.head()
   SepalLength  SepalWidth  PetalLength  PetalWidth         Name
0          5.1         3.5          1.4         0.2  Iris-setosa
1          4.9         3.0          1.4         0.2  Iris-setosa
2          4.7         3.2          1.3         0.2  Iris-setosa
3          4.6         3.1          1.5         0.2  Iris-setosa
4          5.0         3.6          1.4         0.2  Iris-setosa

* df.tail()
     SepalLength  SepalWidth  PetalLength  PetalWidth            Name
145          6.7         3.0          5.2         2.3  Iris-virginica
146          6.3         2.5          5.0         1.9  Iris-virginica
147          6.5         3.0          5.2         2.0  Iris-virginica
148          6.2         3.4          5.4         2.3  Iris-virginica
149          5.9         3.0          5.1         1.8  Iris-virginica

From this information we can talk about our goal: to predict Name (or, type of iris) given the features SepalLength, SepalWidth, PetalLength and PetalWidth. We can use pandas to show the three iris types:

print("* iris types:", df["Name"].unique(), sep="\n")
* iris types:
['Iris-setosa' 'Iris-versicolor' 'Iris-virginica']


In order to pass this data into scikit-learn we need to encode the Names to integers. To do this we’ll write another function and return the modified data frame as well as a list of the target (class) names:

def encode_target(df, target_column):
    """Add column to df with integers for the target.

    df -- pandas DataFrame.
    target_column -- column to map to int, producing
                     new Target column.

    df_mod -- modified DataFrame.
    targets -- list of target names.
    df_mod = df.copy()
    targets = df_mod[target_column].unique()
    map_to_int = {name: n for n, name in enumerate(targets)}
    df_mod["Target"] = df_mod[target_column].replace(map_to_int)

    return (df_mod, targets)

Let’s see what we have (I’ll show just Name and Target columns to prevent wrapping):

df2, targets = encode_target(df, "Name")
print("* df2.head()", df2[["Target", "Name"]].head(),
      sep="\n", end="\n\n")
print("* df2.tail()", df2[["Target", "Name"]].tail(),
      sep="\n", end="\n\n")
print("* targets", targets, sep="\n", end="\n\n")
* df2.head()
   Target         Name
0       0  Iris-setosa
1       0  Iris-setosa
2       0  Iris-setosa
3       0  Iris-setosa
4       0  Iris-setosa

* df2.tail()
     Target            Name
145       2  Iris-virginica
146       2  Iris-virginica
147       2  Iris-virginica
148       2  Iris-virginica
149       2  Iris-virginica

* targets
['Iris-setosa' 'Iris-versicolor' 'Iris-virginica']

Looks good, Iris-setosa has been mapped to zero, Iris-versicolor to one, and Iris-virginica to three. Next, we get the names of the feature columns:

features = list(df2.columns[:4])
print("* features:", features, sep="\n")
* features:
['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth']

fitting the decision tree with scikit-learn

Now we can fit the decision tree, using the DecisionTreeClassifier imported above, as follows:

y = df2["Target"]
X = df2[features]
dt = DecisionTreeClassifier(min_samples_split=20, random_state=99), y)


  • We pull the X and y data from the pandas dataframe using simple indexing.
  • The decision tree, imported at the start of the post, is initialized with two parameters: min_samples_split=20 requires 20 samples in a node for it to be split (this will make more sense when we see the result) and random_state=99 to seed the random number generator.

visualizing the tree

We can produce a graphic (if graphviz is available on your system– if not check the site and see if you can install) using the following function:

def visualize_tree(tree, feature_names):
    """Create tree png using graphviz.

    tree -- scikit-learn DecsisionTree.
    feature_names -- list of feature names.
    with open("", 'w') as f:
        export_graphviz(tree, out_file=f,

    command = ["dot", "-Tpng", "", "-o", "dt.png"]
        exit("Could not run dot, ie graphviz, to "
             "produce visualization")


  • The export_graphviz method, imported from scikit-learn above, writes a dot file. This file is used to produce the graphic.
  • subprocess, imported above, is used to process the dot file and generate the graphic dt.png– see the example below.

So, running the function:

visualize_tree(dt, features)

results in (click on the figure to see a larger version)


Okay, what does this all mean? Well, we can use this figure to understand the patterns found by the decision tree:

  • Imagine that all data (all rows) start in a single bin at the top of the tree.
  • All features are considered to see how the data can be split in the most informative way– this uses the gini measure by default, but this can be changed to entropy if you prefer; see decision tree classifier documentation.
  • At the top we see the most informative condition is PetalLength <= 2.4500. If this condition is true, take the left branch to get to the 50 samples of value = [50. 0. 0.]. This means there are 50 examples of class/target 0, in this case Iris-setosa. Unfortunately, the default scikit-learn export to graphviz/dot does not seem to be able to include this information (but see below). The other 100 samples, of the 150 total, go to the right bin.
  • This splitting continues until
  1. The split creates a bin with only one class– for example the bin with 50 Iris-setosa is not split again.
  2. Or, the resulting bin has less than 20 samples– this is because we set the min_samples_split=20 when initializing the decision tree. If we had not set this value, the tree would keep splitting until all bins have a single class.

So, that’s it for the visualization– you should be able to trace, from top to bottom, and see how the rules discussed above were applied to the iris data.

psuedocode for the decision tree

Finally let’s consider generation of psuedocode that represents the learned decision tree. In particular, the target names (classes) and feature names should be included in the output so that it is simple to follow the patterns found. The function below is based on the answer to a stackoverflow question. I’ve made some additions to the function to meet the requirements I’ve stated above:

  • The target names can be passed to the function and are included in the output. The output now shows both the features used for branching conditions as well as the class, or classes, found in the resulting node/bin.
  • The if/else structure has indenting, using the spacer_base argument to make the output easier to read (I think).

That said, the function is:

def get_code(tree, feature_names, target_names,
             spacer_base="    "):
    """Produce psuedo-code for decision tree.

    tree -- scikit-leant DescisionTree.
    feature_names -- list of feature names.
    target_names -- list of target (class) names.
    spacer_base -- used for spacing code (default: "    ").

    based on
    left      = tree.tree_.children_left
    right     = tree.tree_.children_right
    threshold = tree.tree_.threshold
    features  = [feature_names[i] for i in tree.tree_.feature]
    value = tree.tree_.value

    def recurse(left, right, threshold, features, node, depth):
        spacer = spacer_base * depth
        if (threshold[node] != -2):
            print(spacer + "if ( " + features[node] + " <= " + \
                  str(threshold[node]) + " ) {")
            if left[node] != -1:
                    recurse(left, right, threshold, features,
                            left[node], depth+1)
            print(spacer + "}\n" + spacer +"else {")
            if right[node] != -1:
                    recurse(left, right, threshold, features,
                            right[node], depth+1)
            print(spacer + "}")
            target = value[node]
            for i, v in zip(np.nonzero(target)[1],
                target_name = target_names[i]
                target_count = int(v)
                print(spacer + "return " + str(target_name) + \
                      " ( " + str(target_count) + " examples )")

    recurse(left, right, threshold, features, 0, 0)

and the resulting output for application to the iris data is:

get_code(dt, features, targets)
if ( PetalLength <= 2.45000004768 ) {
    return Iris-setosa ( 50 examples )
else {
    if ( PetalWidth <= 1.75 ) {
        if ( PetalLength <= 4.94999980927 ) {
            if ( PetalWidth <= 1.65000009537 ) {
                return Iris-versicolor ( 47 examples )
            else {
                return Iris-virginica ( 1 examples )
        else {
            return Iris-versicolor ( 2 examples )
            return Iris-virginica ( 4 examples )
    else {
        if ( PetalLength <= 4.85000038147 ) {
            return Iris-versicolor ( 1 examples )
            return Iris-virginica ( 2 examples )
        else {
            return Iris-virginica ( 43 examples )

This should be compared with the graphic output above– this is just a different representation of the learned decision tree. However, I think the addition of target/classes and features really make this useful.

Okay, that’s it for this post. There are many topics I have not covered, but I think that I’ve provided some useful code for understanding a decision tree learned with scikit-learn. Useful links at the scikit-learn site, to dig deeper include:

Importantly, I have not covered how to set parameters and to avoid over fitting. However, that’s beyond the scope of this post. The best place to start for this is the cross-validation tools in scikit-learn. Check out:

As always, post comments and questions below. Corrections and typos are also welcomed!


api [1]   arduino [1]   audio [2]   audio features [1]   babel [1]   Bayesian [7]   Beta [1]   blog setup [1]   bootstrap [1]   bottleneck [1]   c++ [1]   caret [1]   cmpy [1]   conditional probability [6]   coursera [1]   coursera intro to data science [3]   css [1]   cython [1]   d3 [2]   decision trees [2]   diy [1]   dropbox [1]   dsp [1]   e1071 [1]   essentia [1]   garmin [1]   geojson [1]   ggplot2 [1]   gis [2]   git [1]   gnuplot [1]   graphs [1]   html5 [1]   igraph [1]   ipython [1]   javascript [7]   joint probability [6]   json [1]   LaTeX [2]   LDA [1]   Lea [2]   machine learning [3]   marginal probability [6]   matplotlib [1]   meteor [2]   mir [1]   MongoDB [3]   music [2]   my python setup [5]   my ubuntu setup [10]   mysql [3]   networks [1]   networkx [1]   nodejs [5]   npm [3]   numexpr [1]   numpy [1]   octave [1]   Open Oakland [2]   openpyxl [1]   pandas [3]   patsy [1]   pip [2]   pweave [1]   pygraphviz [1]   pymc [1]   PySoundFile [2]   python [15]   Python [1]   python 2.7 [5]   python 3.4 [2]   pyyaml [1]   qgis [1]   R [1]   randomForest [1]   restview [1]   resume [1]   rpart [1]   running [1]   scikit-learn [3]   scipy [1]   screen [1]   server setup [1]   shapefile [1]   social networks [1]   Socrata [1]   sound [2]   sphinx [1]   sql [4]   sqlite3 [1]   ssh [1]   ssh keys [1]   statsmodels [1]   supervised learning [2]   sympy [1]   tableau [1]   tinkerer [2]   topic models [1]   tree [1]   ubuntu 14.04 [13]   Ubuntu 14.04 [3]   ubuntu 16.04 [4]   vim [2]   virtualbox [1]   virtualenv [4]   virtualenvwrapper [3]   VPS [1]   vundle [1]   webpack [1]   yaml [1]