Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
308 views
in Technique[技术] by (71.8m points)

python - Pruning Decision Trees

Below is a snippet of the decision tree as it is pretty huge.

enter image description here

How to make the tree stop growing when the lowest value in a node is under 5. Here is the code to produce the decision tree. On SciKit - Decission Tree we can see the only way to do so is by min_impurity_decrease but I am not sure how it specifically works.

import numpy as np
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier


X, y = make_classification(n_samples=1000,
                           n_features=6,
                           n_informative=3,
                           n_classes=2,
                           random_state=0,
                           shuffle=False)

# Creating a dataFrame
df = pd.DataFrame({'Feature 1':X[:,0],
                                  'Feature 2':X[:,1],
                                  'Feature 3':X[:,2],
                                  'Feature 4':X[:,3],
                                  'Feature 5':X[:,4],
                                  'Feature 6':X[:,5],
                                  'Class':y})


y_train = df['Class']
X_train = df.drop('Class',axis = 1)

dt = DecisionTreeClassifier( random_state=42)                
dt.fit(X_train, y_train)

from IPython.display import display, Image
import pydotplus
from sklearn import tree
from sklearn.tree import _tree
from sklearn import tree
import collections
import drawtree
import os  

os.environ["PATH"] += os.pathsep + 'C:\Anaconda3\Library\bin\graphviz'

dot_data = tree.export_graphviz(dt, out_file = 'thisIsTheImagetree.dot',
                                 feature_names=X_train.columns, filled   = True
                                    , rounded  = True
                                    , special_characters = True)

graph = pydotplus.graph_from_dot_file('thisIsTheImagetree.dot')  

thisIsTheImage = Image(graph.create_png())
display(thisIsTheImage)
#print(dt.tree_.feature)

from subprocess import check_call
check_call(['dot','-Tpng','thisIsTheImagetree.dot','-o','thisIsTheImagetree.png'])

Update

I think min_impurity_decrease can in a way help reach the goal. As tweaking min_impurity_decrease does actually prune the tree. Can anyone kindly explain min_impurity_decrease.

I am trying to understand the equation in scikit learn but I am not sure what is the value of right_impurity and left_impurity.

N = 256
N_t = 256
impurity = ??
N_t_R = 242
N_t_L = 14
right_impurity = ??
left_impurity = ??

New_Value = N_t / N * (impurity - ((N_t_R / N_t) * right_impurity)
                    - ((N_t_L / N_t) * left_impurity))
New_Value

Update 2

Instead of pruning at a certain value, we prune under a certain condition. such as We do split at 6/4 and 5/5 but not at 6000/4 or 5000/5. Let's say if one value is under a certain percentage in comparison with its adjacent value in the node, rather than a certain value.

      11/9
   /       
  6/4       5/5
 /        /   
6/0  0/4  2/2  3/3
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Directly restricting the lowest value (number of occurences of a particular class) of a leaf cannot be done with min_impurity_decrease or any other built-in stopping criteria.

I think the only way you can accomplish this without changing the source code of scikit-learn is to post-prune your tree. To accomplish this, you can just traverse the tree and remove all children of the nodes with minimum class count less that 5 (or any other condition you can think of). I will continue your example:

from sklearn.tree._tree import TREE_LEAF

def prune_index(inner_tree, index, threshold):
    if inner_tree.value[index].min() < threshold:
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
    # if there are shildren, visit them as well
    if inner_tree.children_left[index] != TREE_LEAF:
        prune_index(inner_tree, inner_tree.children_left[index], threshold)
        prune_index(inner_tree, inner_tree.children_right[index], threshold)

print(sum(dt.tree_.children_left < 0))
# start pruning from the root
prune_index(dt.tree_, 0, 5)
sum(dt.tree_.children_left < 0)

this code will print first 74, and then 91. It means that the code has created 17 new leaf nodes (by practically removing links to their ancestors). The tree, which has looked before like

enter image description here

now looks like

enter image description here

so you can see that is indeed has decreased a lot.


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...