View on GitHub

data-science

Notebooks and Python about data science

If you like this project please add your Star

Multiclass classification with 2 features using trees

Classification trees are a wide category of algorithms to process classification (and linear regression).

In this notebook, we will reuse the generated dataset of the multiclass linear regression with Keras (HTML / Jupyter) : the Czech and Norvegian flags.

Learning goals:

  • Theory of trees
  • Custom code for trees
  • Understanding of the tree partitioning and performance
  • Decision tree using Scikit-Learn model
In [27]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as pltcolors
from sklearn import metrics, tree
import graphviz
from IPython.display import SVG
import seaborn as sns

Helpers

In [32]:
def flagPlot(X, y, title, colors, ax=None):
    """ Plot the flag as a 2D parametric label"""
    if ax is None:
        ax = plt.gca()
    ax.set_facecolor((0.8, 0.8, 0.8))
    f = ax.scatter(X[:,0], X[:,1], c=y, cmap=pltcolors.ListedColormap(colors), marker='x', alpha=0.5);
    ax.set_xlabel('$x_0$')
    ax.set_ylabel('$x_1$')
    ax.grid()
    ax.set_title(title)
    cb = plt.colorbar(f, ax=ax)
    loc = np.arange(0,2.1,1)
    cb.set_ticks(loc)
    cb.set_ticklabels([0,1,2]);
    
def plotHeatMap(X, classes, title=None, fmt='.2g', ax=None, xlabel=None, ylabel=None):
    """ Fix heatmap plot from Seaborn with pyplot 3.1.0, 3.1.1
        https://stackoverflow.com/questions/56942670/matplotlib-seaborn-first-and-last-row-cut-in-half-of-heatmap-plot
    """
    ax = sns.heatmap(X, xticklabels=classes, yticklabels=classes, annot=True, \
                     fmt=fmt, cmap=plt.cm.Blues, ax=ax) #notation: "annot" not "annote"
    bottom, top = ax.get_ylim()
    ax.set_ylim(bottom + 0.5, top - 0.5)
    if title:
        ax.set_title(title)
    if xlabel:
        ax.set_xlabel(xlabel)
    if ylabel:
        ax.set_ylabel(ylabel)
        
def plotConfusionMatrix(yTrue, yEst, classes, title=None, fmt='.2g', ax=None):
    plotHeatMap(metrics.confusion_matrix(yTrue, yEst), classes, title, fmt, ax, xlabel='Estimations', \
                ylabel='True values');

Data models - Flags

In [3]:
nFeatures = 2
nClasses = 3
In [4]:
# Czech flag colors
czechColors = np.array(['blue', 'red', 'white'])
def generateCzechBatch(n, noiseSigma = 0.1):
    """ Generate a multi class sample based on the Czech banner geometry. 
        Tip of the triangle is set at the origin 
    """
    # xMax adjusted such that the 3 classes are quasi equiprobable
    xMin = np.array([-1, -1])
    xMax = np.array([0.5, 1])
    #
    X = np.random.uniform(xMin, xMax, (n, nFeatures))
    noisyX = X + np.random.normal(0, noiseSigma, X.shape)
    y = np.zeros(n)
    y[noisyX[:,1] > 0]  = 2                                       # White
    y[noisyX[:,1] <= 0] = 1                                       # Red
    y[(noisyX[:,0] <= 0) & (np.abs(noisyX[:,1]) < np.abs(noisyX[:,0]))] = 0 # Blue triangle
    return X, y
In [34]:
# Norway flag colors
norwayColors = ['red', 'white', 'navy']
def generateNorwayBatch(n, noiseSigma = 0.1):
    """ Generate a multi class sample based on the Norway banner geometry """
    # xMax adjusted such that the 3 classes are quasi equiprobable
    xMin = np.array([-1, -1])
    xMax = np.array([1, 1])
    #
    X = np.random.uniform(xMin, xMax, (n, nFeatures))
    a1 = 0.43
    a2 = 0.18
    noisyX = X + np.random.normal(0, noiseSigma, X.shape)
    y = np.zeros(n)                                                                # Red = background
    y[((noisyX[:,0] > -a1) & (noisyX[:,0] < a1)) | ((noisyX[:,1] > -a1) & (noisyX[:,1] < a1))]  = 1    # White cross
    y[((noisyX[:,0] > -a2) & (noisyX[:,0] < a2)) | ((noisyX[:,1] > -a2) & (noisyX[:,1] < a2))]  = 2    # Navy cross over white
    return X, y
In [33]:
N = 1000
xTrainC, yTrainC = generateCzechBatch(N)
fig, axes = plt.subplots(1, 2, figsize=(15,4))
flagPlot(xTrainC, yTrainC, 'Generated Czech flag', czechColors, axes[0])
xTrainN, yTrainN = generateNorwayBatch(N)
flagPlot(xTrainN, yTrainN, 'Generated Norway flag', norwayColors, axes[1])

Test data

In [7]:
xTestC, yTestC = generateCzechBatch(N)
xTestN, yTestN = generateNorwayBatch(N)

Tree classifier

Tree classifiers are a very large family of non-parametric classifiers, and have many refinements. They are splitting recursively a portion of the space with an hyperplan.

In our case, the space is 2 dimension, the splitting hyperplan is thus 1D : a line.

We will first implement a tree classifier from scratch, then use the Scikit-Learn implementation.

Trees main advantages are :

  • Better ability to approximate complex functions than linear regression
  • Explanability of the decisions since the boundaries are known. However, the decisions might have no human understanding
  • Far more efficient than the kNN classifier (HTML / Jupyter) since the model is fitted only once

Their main drawbacks :

  • Cost and difficulty to reach the optimum. We will see later the Random Forest method that is mitigating this issue

Algorithm

We will use a simple recursive split algorithm similar to CART.

At each iteration of the split :

  • Select the feature (co-variable) bringing the best improvement on the impurity (also known as information gain)
  • Split the space with an hyperplan orthogonal to that feature's axis

For simplification purpose, following algo choices are set:

  • Impurity improvement is measured as entropy minimization, other metrics are often in use like Gini
  • Split is at the middle point of the axis, finer grain split is possible
  • Partitioning stop condition is on the leaf minimum item number and the minimum entropy

Impurity is calculated as the average of the entropies of the two partitions created by the split, weighted by the number of items in each partition.

Entropy is defined using the empirical probabilities of the $Y$ labels within each classes : $$ - \sum_{c \in C} \mathbb{P}(y=c) log(\mathbb{P}(y=c))$$

$C$ is the set of classes in $Y$, in our case $ C = \{0, 1, 2\}$ or the set of three colors.

References

Home made implementation

During training, the classifier algorithm is partitioning the original X space and producing a tree representing this partitioning.

Two Python classes are designed :

  • The classifier MyTreeClassifier
  • The tree element TreeNode that may be intermediate or terminal. Intermediate nodes have two children

The algorithm used for the partitioning is quite naive as it heavily uses numpy slicing based on range inequalities.

In [8]:
class TreeNode:
    """ Tree node element
        Either an intermediate node with two children (left-right) or a leaf node
    """
    
    def __init__(self, depth, parent, boundary, entropy):
        self.depth = depth
        self.parent = parent
        self.boundary = boundary
        self.splitOnFeature = None
        self.halfPoint = None
        self.entropy = entropy
        self.leftChild = None
        self.rightChild = None
        self.valueFrequencies = None
        self.isLeaf = False
        
    def split(self, splitOnFeature, halfPoint, childL, childR, valueFrequencies=None):
        """ Split node into to two child nodes """
        self.splitOnFeature = splitOnFeature
        self.halfPoint = halfPoint
        self.leftChild = childL
        self.rightChild = childR
        self.valueFrequencies = valueFrequencies
   
    def setAsLeaf(self, valueFrequencies):
        """ Set this node to be a terminal (a leaf) """
        self.valueFrequencies = valueFrequencies
        self.isLeaf = True
        
    def predict(self, x):
        """ Recursively predict value """
        if len(x.shape) > 1:
            if self.isLeaf:
                return self.getValue() * np.ones(x.shape[0])
            else:
                result = np.zeros(x.shape[0])
                left = (x[:,self.splitOnFeature] <= self.halfPoint)
                result[left] = self.leftChild.predict(x[left])
                result[~left] = self.rightChild.predict(x[~left])
                return result
        else:
            if self.isLeaf:
                return self.getValue()
            elif(x[self.splitOnFeature] > self.halfPoint):
                return self.rightChild.predict(x)
            else:
                return self.leftChild.predict(x)

    def getValue(self):
        """ Take the class with highest frequency as value """
        return np.argmax(self.valueFrequencies)
In [9]:
class MyTreeClassifier:
    """ Tree classifier using CART recursive algorithm 
        Stop criteria are the leaf minimum number of items, and the minimum entropy
        Entropy is used to measure the impurity
    """

    def __init__(self, leafMinSize=5, entropyMin=1e-5):
        self.leafMinSize = leafMinSize
        self.entropyMin = entropyMin
        self.root = None

    def fit(self, X, y, yClasses=None):
        """ Partition X partition in order to predict y """
        if not yClasses:
            yClasses = np.unique(y)
        p = X.shape[1]
        # Stack X and y in order to extract slices
        Xy = np.hstack((X, y.reshape(-1, 1)))
        # Get overall X bounding box
        epsilon = 1e-4
        originalBoundary = [[X[:,j].min() - epsilon, X[:,j].max()] for j in range(p)]
        
        # Initialize FIFO of sub-partitions to split, 
        #  it avoids recursive calls with large sub-partitions in memory
        self.root = TreeNode(0, None, originalBoundary, self.getEntropy(y, yClasses))
        splitList = [(Xy, self.root)]
        processedList = []

        # Run until splitList FIFO empty
        while splitList:
            partition, node = splitList.pop(0)
            selectedFeature, halfPoint, newBoundary, entropies = \
                self.splitDecision(partition, node.boundary, yClasses)
            partitionValue = self.getValueFrequencies(partition[:,nFeatures], len(yClasses))
            
            # Split
            children  = [TreeNode(node.depth + 1, node, newBoundary[i], entropies[i]) for i in range(2)]
            node.split(selectedFeature, halfPoint, children[0], children[1], partitionValue)
            processedList.append(node)
            
            # (Schedule) process of children
            for child, bound, entropy in zip(children, newBoundary, entropies):
                subpartition = self.selectSubPartition(partition, bound)
                if (len(subpartition) > self.leafMinSize) and (entropy > self.entropyMin):
                    splitList.append((subpartition, child))
                else:
                    child.setAsLeaf(self.getValueFrequencies(subpartition[:,nFeatures], len(yClasses)))
                    processedList.append(child)
        return processedList
      
    def predict(self, X):
        """ Predict y values given X """
        assert(self.root != None)
        return self.root.predict(X)   

    def splitDecision(self, Xy, xBoundary, yClasses):
        """ Select variable axis to be split computing impurity """
        p = len(xBoundary)
        nTotal = len(Xy)
        # Distribution after split
        halfPoints = [(xBoundary[i][0] + xBoundary[i][1])/2 for i in range(nFeatures)]
        impurity = []
        entropies = []
        for xIndex in range(nFeatures):
            xRangeLeft  = (Xy[:,xIndex] > xBoundary[xIndex][0]) & (Xy[:,xIndex] <= halfPoints[xIndex])
            xRangeRight = (Xy[:,xIndex] > halfPoints[xIndex]) & (Xy[:,xIndex] <= xBoundary[xIndex][1])
            weightedEntropies = []
            for ra in [xRangeLeft, xRangeRight]:
                partitionHalf  = Xy[ra]
                nPartition  = len(partitionHalf)
                if nPartition > 0:
                    entropy = self.getEntropy(partitionHalf[:,nFeatures], yClasses)
                    entropies.append(entropy)
                    weightedEntropies.append(entropy * nPartition)
                else:
                    entropies.append(0)
            impurity.append(np.sum(weightedEntropies) / nTotal)
        # Select the feature to split that is leading to the lowest impurity
        selectedFeature = np.argmin(impurity)
        newBoundary = [[[xBoundary[i][0], halfPoints[i] if i == selectedFeature else xBoundary[i][1]] \
                        for i in range(nFeatures)],
                      [[halfPoints[i] if i == selectedFeature else xBoundary[i][0], xBoundary[i][1]] \
                        for i in range(nFeatures)]]
        return selectedFeature, halfPoints[selectedFeature], newBoundary, \
            entropies[selectedFeature*2:(selectedFeature+1)*2]
    
    def getEntropy(self, y, yClasses):
        """ Compute entropy given the empirical probabilities for each class"""
        nPartition = len(y)
        frequencies = np.array([np.sum(y == c) / nPartition for c in yClasses])
        # epsilon added to avoid log of 0
        epsilon = 1e-10
        return - np.sum(np.dot(frequencies, np.log(frequencies + epsilon)))

    def selectSubPartition(self, Xy, xBoundary):
        """ Select a rectangular sub-partition """
        p = len(xBoundary)
        xRanges = [(Xy[:,j] > xBoundary[j][0]) & (Xy[:,j] <= xBoundary[j][1]) for j in range(p)]
        return Xy[xRanges[0] & xRanges[1]]
    
    def getValueFrequencies(self, y, nClasses):
        binEdges = np.arange(nClasses + 1) - 0.5
        return np.histogram(y, bins=binEdges)[0]

Model Fit on Czech and Norway flags

In [35]:
model0Czech = MyTreeClassifier()
processedListC = model0Czech.fit(xTrainC, yTrainC, range(3))
In [11]:
model0Norway = MyTreeClassifier()
processedListN = model0Norway.fit(xTrainN, yTrainN, range(3))

Analysis and evaluation

Plot tree partitioning

In [12]:
def plotNodeArea(node, palette, ax):
    """ Plot the area and split line of a tree node """
    color = palette[int(np.round(node.getValue()))]
    ax.fill_between(node.boundary[0], node.boundary[1][0], node.boundary[1][1], color=color)
    if node.halfPoint:
        if node.splitOnFeature == 0:
            return ax.plot(np.ones(2) * node.halfPoint, node.boundary[1], color='k')
        else:
            return ax.plot(node.boundary[0], np.ones(2) * node.halfPoint, color='k')
In [13]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for ax, processedList, palette in zip(axes, [processedListC, processedListN], [czechColors, norwayColors]):
    ax.set_facecolor((0.8, 0.8, 0.8))
    for targetDepth in range(1, 10):
        for n in filter(lambda n: n.depth == targetDepth, processedList):
            plotNodeArea(n, palette, ax)

axes[0].set_title("Czech flag classification tree")
axes[0].set_xlim(-1, 0.5)
axes[0].set_ylim(-1, 1);
axes[1].set_title("Norway flag classification tree")
axes[1].set_xlim(-1, 1)
axes[1].set_ylim(-1, 1);

The Czech flag is quite homogeneous within the three color areas. Challenges are the color area boundaries where the gaussian noise is mixing color points, and the diagonal boundaries that mus be approximated by small rectangles.

The Norway flag is more complex because of the small width of the white area. Noise is also adding some challenges. We observe that the partitioning is not similar on the four quadrants.

Weighted entropy plot for Norway flag

All along the tree fit, the weighted entropy is computed for each split area (left / right).

In [24]:
def getMinMaxEntropies(nodeList, depthRange):
    weithedEntropiesMin = []
    weithedEntropiesMax = []
    for targetDepth in depthRange:
        weithedEntropies = np.array([node.entropy * np.sum(node.valueFrequencies) for node in filter(lambda n: n.depth == targetDepth, nodeList)])
        weithedEntropiesMin.append(weithedEntropies.min())
        weithedEntropiesMax.append(weithedEntropies.max())
    return weithedEntropiesMin, weithedEntropiesMax

depthRange = range(1, 10)
entropiesMin, entropiesMax = getMinMaxEntropies(processedListN, depthRange)
fig, ax = plt.subplots(1, 1, figsize=(12, 5))
ax.fill_between(depthRange, entropiesMin, entropiesMax)
ax.set_xlabel("Depth")
ax.set_ylabel("Weighted entropy")
ax.set_title("Min and max entropy weighted by partition size - Norway flag")
ax.grid()

Czech flag classifier performance

In [15]:
yEstC = model0Czech.predict(xTestC)
plotConfusionMatrix(yTestC, yEstC, czechColors, title="Czech flag confusion matrix")
print(metrics.classification_report(yTestC, yEstC))
              precision    recall  f1-score   support

         0.0       0.86      0.89      0.88       336
         1.0       0.93      0.91      0.92       328
         2.0       0.92      0.90      0.91       336

    accuracy                           0.90      1000
   macro avg       0.90      0.90      0.90      1000
weighted avg       0.90      0.90      0.90      1000

Norway flag classifier performance

In [16]:
yEstN = model0Norway.predict(xTestN)
plotConfusionMatrix(yTestN, yEstN, norwayColors, title="Norway flag confusion matrix")
print(metrics.classification_report(yTestN, yEstN))
              precision    recall  f1-score   support

         0.0       0.80      0.84      0.82       319
         1.0       0.59      0.62      0.61       366
         2.0       0.71      0.64      0.67       315

    accuracy                           0.70      1000
   macro avg       0.70      0.70      0.70      1000
weighted avg       0.70      0.70      0.70      1000

Norway flag problem is a little more challenging as the there are more borders between the three colors.

Let's investigate the best parameters using accuracy

Hyper-parameter search using grid on Norway flag

A grid search is performed on the two tree partitioner stop conditions : on minimum impurity, and on minimum leave size

In [17]:
leafSizeRange = range(4, 21)
entropyRange = np.logspace(-3, 0, 10)
accuracy = np.zeros((len(leafSizeRange), len(entropyRange)))
for i,leafMinSize in enumerate(leafSizeRange):
    for j,entropy in enumerate(entropyRange): 
        model1Norway = MyTreeClassifier(leafMinSize=leafMinSize, entropyMin=entropy)
        model1Norway.fit(xTrainN, yTrainN, range(3))
        yEstN1 = model1Norway.predict(xTestN)
        accuracy[i,j] = np.mean(yEstN1 == yTestN)

bestParams = np.argmax(accuracy)
In [18]:
fig, ax = plt.subplots(1, figsize=(10, 5))
for j,h in enumerate(entropyRange):
    ax.plot(leafSizeRange, accuracy[:,j], label = "Min entropy %.2e" % h)
ax.legend()
ax.set_title("Accuracy as function of min leaf size and min entropy")
ax.set_xlabel("Leaf minimum size")
ax.set_ylabel("Accuracy")
ax.grid()
In [19]:
leafMinSize = leafSizeRange[bestParams // len(leafSizeRange)]
entropyMin = entropyRange[bestParams % len(leafSizeRange)]
model1Norway = MyTreeClassifier(leafMinSize=leafMinSize, entropyMin=entropyMin)
processedListN1 = model1Norway.fit(xTrainN, yTrainN, range(3))
yEstN1 = model1Norway.predict(xTestN)

print("Norway flag classification with leaf min size = %d, min entropy = %.3f" % (leafMinSize, entropyMin))
print(metrics.classification_report(yTestN, yEstN1))

fig, axes = plt.subplots(1, 3, figsize=(16, 4))
plotConfusionMatrix(yTestN, yEstN1, norwayColors, ax=axes[0], title="Norway flag confusion matrix")
axes[1].set_title("Partitions of the decision tree")
axes[1].set_facecolor((0.8, 0.8, 0.8))
for targetDepth in range(1, 10):
    for n in filter(lambda n: n.depth == targetDepth, processedListN1):
        plotNodeArea(n, norwayColors, axes[1])
flagPlot(xTestN, yTestN, "Norway flag with true test samples", norwayColors, ax=axes[2])
Norway flag classification with leaf min size = 13, min entropy = 0.215
              precision    recall  f1-score   support

         0.0       0.84      0.81      0.82       319
         1.0       0.60      0.67      0.63       366
         2.0       0.73      0.65      0.69       315

    accuracy                           0.71      1000
   macro avg       0.72      0.71      0.71      1000
weighted avg       0.71      0.71      0.71      1000

The minimum entropy stop criteria has the advantage of avoid very fine and small partitions and thus decrease the need of consolidating the leafs (also called pruning). However it may also prevent useful splits at the early stages as seen on the large white rectangles at right hand of above partition map.

Comparing the partition map based on the training samples and the plot of the test samples, we see that there is some overfitting.

Classification tree with Scikit-Learn

In [20]:
model2Norway = tree.DecisionTreeClassifier(min_samples_leaf=8, criterion='entropy')
model2Norway.fit(xTrainN, yTrainN)
Out[20]:
DecisionTreeClassifier(class_weight=None, criterion='entropy', max_depth=None,
                       max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=8, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort=False,
                       random_state=None, splitter='best')
In [21]:
yEstN2 = model2Norway.predict(xTestN)
print("Norway flag classification with leaf min size = 8")
print(metrics.classification_report(yTestN, yEstN2))

plotConfusionMatrix(yTestN, yEstN2, norwayColors)
Norway flag classification with leaf min size = 8
              precision    recall  f1-score   support

         0.0       0.87      0.82      0.84       319
         1.0       0.62      0.60      0.61       366
         2.0       0.67      0.74      0.71       315

    accuracy                           0.71      1000
   macro avg       0.72      0.72      0.72      1000
weighted avg       0.72      0.71      0.71      1000

Visualization of the decision tree

Scikit-Learn provides a graph view of the Tree... not that easy to visualize

In [31]:
dot_data = tree.export_graphviz(model2Norway, out_file=None, 
                    feature_names=['x_1','x_2'],  
                    class_names=norwayColors,  
                    filled=True, rounded=False,  
                    special_characters=True)  
graph = graphviz.Source(dot_data)  
display(SVG(graph.pipe(format='svg')))
Tree 0 x_1 ≤ -0.418 entropy = 1.583 samples = 1000 value = [310, 351, 339] class = white 1 x_2 ≤ -0.375 entropy = 1.401 samples = 281 value = [161, 71, 49] class = red 0->1 True 38 x_1 ≤ 0.372 entropy = 1.529 samples = 719 value = [149, 280, 290] class = navy 0->38 False 2 x_1 ≤ -0.582 entropy = 0.473 samples = 89 value = [80, 9, 0] class = red 1->2 11 x_2 ≤ 0.341 entropy = 1.555 samples = 192 value = [81, 62, 49] class = red 1->11 3 x_1 ≤ -0.884 entropy = 0.201 samples = 64 value = [62, 2, 0] class = red 2->3 8 x_1 ≤ -0.497 entropy = 0.855 samples = 25 value = [18, 7, 0] class = red 2->8 4 x_1 ≤ -0.94 entropy = 0.469 samples = 20 value = [18, 2, 0] class = red 3->4 7 entropy = 0.0 samples = 44 value = [44, 0, 0] class = red 3->7 5 entropy = 0.0 samples = 12 value = [12, 0, 0] class = red 4->5 6 entropy = 0.811 samples = 8 value = [6, 2, 0] class = red 4->6 9 entropy = 0.75 samples = 14 value = [11, 3, 0] class = red 8->9 10 entropy = 0.946 samples = 11 value = [7, 4, 0] class = red 8->10 12 x_2 ≤ -0.166 entropy = 1.236 samples = 100 value = [5, 46, 49] class = navy 11->12 29 x_1 ≤ -0.535 entropy = 0.667 samples = 92 value = [76, 16, 0] class = red 11->29 13 x_1 ≤ -0.818 entropy = 0.817 samples = 24 value = [2, 20, 2] class = white 12->13 18 x_2 ≤ 0.134 entropy = 1.142 samples = 76 value = [3, 26, 47] class = navy 12->18 14 entropy = 0.811 samples = 8 value = [2, 6, 0] class = white 13->14 15 x_2 ≤ -0.25 entropy = 0.544 samples = 16 value = [0, 14, 2] class = white 13->15 16 entropy = 0.544 samples = 8 value = [0, 7, 1] class = white 15->16 17 entropy = 0.544 samples = 8 value = [0, 7, 1] class = white 15->17 19 x_1 ≤ -0.766 entropy = 0.433 samples = 45 value = [0, 4, 41] class = navy 18->19 24 x_1 ≤ -0.774 entropy = 1.136 samples = 31 value = [3, 22, 6] class = white 18->24 20 entropy = 0.0 samples = 19 value = [0, 0, 19] class = navy 19->20 21 x_1 ≤ -0.596 entropy = 0.619 samples = 26 value = [0, 4, 22] class = navy 19->21 22 entropy = 0.837 samples = 15 value = [0, 4, 11] class = navy 21->22 23 entropy = 0.0 samples = 11 value = [0, 0, 11] class = navy 21->23 25 entropy = 0.353 samples = 15 value = [1, 14, 0] class = white 24->25 26 x_2 ≤ 0.244 entropy = 1.406 samples = 16 value = [2, 8, 6] class = white 24->26 27 entropy = 1.299 samples = 8 value = [1, 5, 2] class = white 26->27 28 entropy = 1.406 samples = 8 value = [1, 3, 4] class = navy 26->28 30 x_2 ≤ 0.561 entropy = 0.52 samples = 77 value = [68, 9, 0] class = red 29->30 37 entropy = 0.997 samples = 15 value = [8, 7, 0] class = red 29->37 31 x_1 ≤ -0.644 entropy = 0.894 samples = 29 value = [20, 9, 0] class = red 30->31 36 entropy = 0.0 samples = 48 value = [48, 0, 0] class = red 30->36 32 x_1 ≤ -0.782 entropy = 0.959 samples = 21 value = [13, 8, 0] class = red 31->32 35 entropy = 0.544 samples = 8 value = [7, 1, 0] class = red 31->35 33 entropy = 0.779 samples = 13 value = [10, 3, 0] class = red 32->33 34 entropy = 0.954 samples = 8 value = [3, 5, 0] class = white 32->34 39 x_1 ≤ -0.24 entropy = 1.104 samples = 414 value = [8, 174, 232] class = navy 38->39 110 x_2 ≤ -0.389 entropy = 1.5 samples = 305 value = [141, 106, 58] class = red 38->110 40 x_2 ≤ 0.357 entropy = 1.143 samples = 88 value = [7, 61, 20] class = white 39->40 55 x_1 ≤ 0.173 entropy = 0.959 samples = 326 value = [1, 113, 212] class = navy 39->55 41 x_2 ≤ -0.162 entropy = 1.19 samples = 65 value = [4, 41, 20] class = white 40->41 52 x_2 ≤ 0.597 entropy = 0.559 samples = 23 value = [3, 20, 0] class = white 40->52 42 x_1 ≤ -0.375 entropy = 0.879 samples = 43 value = [4, 35, 4] class = white 41->42 49 x_2 ≤ -0.026 entropy = 0.845 samples = 22 value = [0, 6, 16] class = navy 41->49 43 entropy = 0.918 samples = 12 value = [4, 8, 0] class = white 42->43 44 x_2 ≤ -0.492 entropy = 0.555 samples = 31 value = [0, 27, 4] class = white 42->44 45 x_1 ≤ -0.311 entropy = 0.742 samples = 19 value = [0, 15, 4] class = white 44->45 48 entropy = 0.0 samples = 12 value = [0, 12, 0] class = white 44->48 46 entropy = 0.544 samples = 8 value = [0, 7, 1] class = white 45->46 47 entropy = 0.845 samples = 11 value = [0, 8, 3] class = white 45->47 50 entropy = 0.0 samples = 8 value = [0, 0, 8] class = navy 49->50 51 entropy = 0.985 samples = 14 value = [0, 6, 8] class = navy 49->51 53 entropy = 0.811 samples = 8 value = [2, 6, 0] class = white 52->53 54 entropy = 0.353 samples = 15 value = [1, 14, 0] class = white 52->54 56 x_1 ≤ -0.175 entropy = 0.688 samples = 218 value = [0, 40, 178] class = navy 55->56 91 x_1 ≤ 0.303 entropy = 0.969 samples = 108 value = [1, 73, 34] class = white 55->91 57 x_2 ≤ -0.337 entropy = 0.999 samples = 31 value = [0, 15, 16] class = navy 56->57 62 x_2 ≤ 0.889 entropy = 0.567 samples = 187 value = [0, 25, 162] class = navy 56->62 58 entropy = 0.811 samples = 8 value = [0, 6, 2] class = white 57->58 59 x_1 ≤ -0.193 entropy = 0.966 samples = 23 value = [0, 9, 14] class = navy 57->59 60 entropy = 0.684 samples = 11 value = [0, 2, 9] class = navy 59->60 61 entropy = 0.98 samples = 12 value = [0, 7, 5] class = white 59->61 63 x_1 ≤ 0.059 entropy = 0.596 samples = 173 value = [0, 25, 148] class = navy 62->63 90 entropy = 0.0 samples = 14 value = [0, 0, 14] class = navy 62->90 64 x_2 ≤ 0.392 entropy = 0.485 samples = 114 value = [0, 12, 102] class = navy 63->64 79 x_1 ≤ 0.1 entropy = 0.761 samples = 59 value = [0, 13, 46] class = navy 63->79 65 x_2 ≤ -0.314 entropy = 0.368 samples = 85 value = [0, 6, 79] class = navy 64->65 74 x_1 ≤ -0.119 entropy = 0.736 samples = 29 value = [0, 6, 23] class = navy 64->74 66 x_1 ≤ -0.075 entropy = 0.551 samples = 47 value = [0, 6, 41] class = navy 65->66 73 entropy = 0.0 samples = 38 value = [0, 0, 38] class = navy 65->73 67 x_1 ≤ -0.122 entropy = 0.811 samples = 16 value = [0, 4, 12] class = navy 66->67 70 x_1 ≤ -0.006 entropy = 0.345 samples = 31 value = [0, 2, 29] class = navy 66->70 68 entropy = 0.811 samples = 8 value = [0, 2, 6] class = navy 67->68 69 entropy = 0.811 samples = 8 value = [0, 2, 6] class = navy 67->69 71 entropy = 0.0 samples = 16 value = [0, 0, 16] class = navy 70->71 72 entropy = 0.567 samples = 15 value = [0, 2, 13] class = navy 70->72 75 entropy = 0.918 samples = 9 value = [0, 3, 6] class = navy 74->75 76 x_2 ≤ 0.64 entropy = 0.61 samples = 20 value = [0, 3, 17] class = navy 74->76 77 entropy = 0.845 samples = 11 value = [0, 3, 8] class = navy 76->77 78 entropy = 0.0 samples = 9 value = [0, 0, 9] class = navy 76->78 80 x_1 ≤ 0.087 entropy = 0.964 samples = 18 value = [0, 7, 11] class = navy 79->80 83 x_1 ≤ 0.136 entropy = 0.601 samples = 41 value = [0, 6, 35] class = navy 79->83 81 entropy = 0.722 samples = 10 value = [0, 2, 8] class = navy 80->81 82 entropy = 0.954 samples = 8 value = [0, 5, 3] class = white 80->82 84 x_1 ≤ 0.115 entropy = 0.297 samples = 19 value = [0, 1, 18] class = navy 83->84 87 x_2 ≤ -0.154 entropy = 0.773 samples = 22 value = [0, 5, 17] class = navy 83->87 85 entropy = 0.544 samples = 8 value = [0, 1, 7] class = navy 84->85 86 entropy = 0.0 samples = 11 value = [0, 0, 11] class = navy 84->86 88 entropy = 0.918 samples = 12 value = [0, 4, 8] class = navy 87->88 89 entropy = 0.469 samples = 10 value = [0, 1, 9] class = navy 87->89 92 x_2 ≤ -0.853 entropy = 1.064 samples = 71 value = [1, 42, 28] class = white 91->92 105 x_2 ≤ 0.227 entropy = 0.639 samples = 37 value = [0, 31, 6] class = white 91->105 93 entropy = 1.406 samples = 8 value = [1, 4, 3] class = white 92->93 94 x_1 ≤ 0.262 entropy = 0.969 samples = 63 value = [0, 38, 25] class = white 92->94 95 x_1 ≤ 0.232 entropy = 0.91 samples = 40 value = [0, 27, 13] class = white 94->95 102 x_2 ≤ -0.025 entropy = 0.999 samples = 23 value = [0, 11, 12] class = navy 94->102 96 x_1 ≤ 0.204 entropy = 0.987 samples = 30 value = [0, 17, 13] class = white 95->96 101 entropy = 0.0 samples = 10 value = [0, 10, 0] class = white 95->101 97 x_1 ≤ 0.182 entropy = 0.874 samples = 17 value = [0, 12, 5] class = white 96->97 100 entropy = 0.961 samples = 13 value = [0, 5, 8] class = navy 96->100 98 entropy = 0.764 samples = 9 value = [0, 7, 2] class = white 97->98 99 entropy = 0.954 samples = 8 value = [0, 5, 3] class = white 97->99 103 entropy = 0.954 samples = 8 value = [0, 5, 3] class = white 102->103 104 entropy = 0.971 samples = 15 value = [0, 6, 9] class = navy 102->104 106 x_2 ≤ -0.209 entropy = 0.881 samples = 20 value = [0, 14, 6] class = white 105->106 109 entropy = 0.0 samples = 17 value = [0, 17, 0] class = white 105->109 107 entropy = 0.0 samples = 11 value = [0, 11, 0] class = white 106->107 108 entropy = 0.918 samples = 9 value = [0, 3, 6] class = navy 106->108 111 x_2 ≤ -0.566 entropy = 0.735 samples = 92 value = [73, 19, 0] class = red 110->111 120 x_2 ≤ 0.309 entropy = 1.565 samples = 213 value = [68, 87, 58] class = white 110->120 112 x_1 ≤ 0.561 entropy = 0.483 samples = 67 value = [60, 7, 0] class = red 111->112 117 x_2 ≤ -0.496 entropy = 0.999 samples = 25 value = [13, 12, 0] class = red 111->117 113 x_2 ≤ -0.81 entropy = 0.989 samples = 16 value = [9, 7, 0] class = red 112->113 116 entropy = 0.0 samples = 51 value = [51, 0, 0] class = red 112->116 114 entropy = 0.811 samples = 8 value = [6, 2, 0] class = red 113->114 115 entropy = 0.954 samples = 8 value = [3, 5, 0] class = white 113->115 118 entropy = 0.946 samples = 11 value = [4, 7, 0] class = white 117->118 119 entropy = 0.94 samples = 14 value = [9, 5, 0] class = red 117->119 121 x_2 ≤ -0.182 entropy = 1.214 samples = 112 value = [5, 49, 58] class = navy 120->121 142 x_2 ≤ 0.462 entropy = 0.955 samples = 101 value = [63, 38, 0] class = red 120->142 122 x_2 ≤ -0.237 entropy = 1.041 samples = 36 value = [3, 27, 6] class = white 121->122 129 x_2 ≤ 0.218 entropy = 1.03 samples = 76 value = [2, 22, 52] class = navy 121->129 123 x_1 ≤ 0.549 entropy = 0.889 samples = 28 value = [1, 22, 5] class = white 122->123 128 entropy = 1.299 samples = 8 value = [2, 5, 1] class = white 122->128 124 entropy = 0.0 samples = 8 value = [0, 8, 0] class = white 123->124 125 x_2 ≤ -0.298 entropy = 1.076 samples = 20 value = [1, 14, 5] class = white 123->125 126 entropy = 0.922 samples = 10 value = [1, 8, 1] class = white 125->126 127 entropy = 0.971 samples = 10 value = [0, 6, 4] class = white 125->127 130 x_2 ≤ 0.105 entropy = 0.781 samples = 63 value = [1, 11, 51] class = navy 129->130 141 entropy = 0.773 samples = 13 value = [1, 11, 1] class = white 129->141 131 x_2 ≤ 0.013 entropy = 0.574 samples = 46 value = [1, 4, 41] class = navy 130->131 138 x_1 ≤ 0.75 entropy = 0.977 samples = 17 value = [0, 7, 10] class = navy 130->138 132 x_1 ≤ 0.755 entropy = 0.754 samples = 31 value = [1, 4, 26] class = navy 131->132 137 entropy = 0.0 samples = 15 value = [0, 0, 15] class = navy 131->137 133 x_1 ≤ 0.464 entropy = 0.937 samples = 22 value = [1, 4, 17] class = navy 132->133 136 entropy = 0.0 samples = 9 value = [0, 0, 9] class = navy 132->136 134 entropy = 0.0 samples = 8 value = [0, 0, 8] class = navy 133->134 135 entropy = 1.198 samples = 14 value = [1, 4, 9] class = navy 133->135 139 entropy = 0.918 samples = 9 value = [0, 3, 6] class = navy 138->139 140 entropy = 1.0 samples = 8 value = [0, 4, 4] class = white 138->140 143 x_1 ≤ 0.711 entropy = 0.65 samples = 24 value = [4, 20, 0] class = white 142->143 146 x_1 ≤ 0.482 entropy = 0.785 samples = 77 value = [59, 18, 0] class = red 142->146 144 entropy = 0.414 samples = 12 value = [1, 11, 0] class = white 143->144 145 entropy = 0.811 samples = 12 value = [3, 9, 0] class = white 143->145 147 x_1 ≤ 0.425 entropy = 0.896 samples = 16 value = [5, 11, 0] class = white 146->147 150 x_1 ≤ 0.753 entropy = 0.514 samples = 61 value = [54, 7, 0] class = red 146->150 148 entropy = 0.811 samples = 8 value = [2, 6, 0] class = white 147->148 149 entropy = 0.954 samples = 8 value = [3, 5, 0] class = white 147->149 151 x_2 ≤ 0.598 entropy = 0.711 samples = 36 value = [29, 7, 0] class = red 150->151 156 entropy = 0.0 samples = 25 value = [25, 0, 0] class = red 150->156 152 entropy = 0.991 samples = 9 value = [5, 4, 0] class = red 151->152 153 x_1 ≤ 0.544 entropy = 0.503 samples = 27 value = [24, 3, 0] class = red 151->153 154 entropy = 0.881 samples = 10 value = [7, 3, 0] class = red 153->154 155 entropy = 0.0 samples = 17 value = [17, 0, 0] class = red 153->155

Conclusion

Tree classifier provides a simple algorithm that is able to cope with non linearly separable problems like Norway flag.

Its performance, as measured by the F1-score or accuracy, is very close to the one achieved by neural networks (see HTML / Jupyter) but the design cost is way lower. The prediction cost might also be lower as well : it requires maximum $d = log_2(N)$ comparisons where $d$ is the depth and $N$ the number of leaves of the tree.

Where to go from here

  • Same multi-class problem but using neural networks and the Keras framework (HTML / Jupyter)
  • Binary classification with Logistic regression (HTML / Jupyter), or k Nearest Neighbors (HTML / Jupyter)