Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
840 views
in Technique[技术] by (71.8m points)

scikit learn - Translate Python export_text decision rules to SAS IF THEN DO; END code

I am attempting to translate Python's sklearn.tree export_text output into SAS conditions. There has been a solution given in Python for this problem here (at the very end of the page) : How to extract the decision rules from scikit-learn decision-tree?

I tried to adapt the code for generating SAS code, but I have an issue with handling nested DO; END;

Here's my code (I create a DATA STEP) :

def get_sas_from_text(tree, tree_id, features, text, spacing=2):
    # tree is a decision tree from a RandomForestClassifier for instance
    # tree id is a number I use for naming the table I create
    # features is a list of features names
    # text is the output of the export_text function from sklearn.tree
    # spacing is used to handle the file size
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'data decision_tree_' + str(tree_id) + ';'
    code += ' set input_data; '
    n_end_last = 0 # Number of 'END;' to add at the end of the data step
    splitted_text = text.split('
')
    text_list = []
    for i, line in enumerate(splitted_text):
        line = line.rstrip().replace('|',' ')

        # Handling rows for IF conditions
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g} THEN DO;'.format(line, float(val))
            n_end_last += 1 # need to add an END;
            if i > 0 and 'PREDICTED_VALUE' in text_list[i-1]: # If there's a PREDICTED_VALUE in line above, then condition is ELSE DO
                line = "ELSE DO; " + line
                n_end_last += 1 # add another END
        # Handling rows for PREDICTED_VALUE
        else:
            line = line.replace(' {} class:'.format(dash), 'PREDICTED_VALUE =')
            line += ';'
            line += '
 end;' # Immediately add END after PREDICTED_VALUE = .. ;
            n_end_last -= 1
        text_list.append(line)
        code += skip + line + '
'
    code = code[:-1] 
    code += 'end; + '
''* n_end_last # add END;
    code += 'run;'
    return code

Here's an example with the iris data set :

import pandas as pd
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import _tree
import string
from sklearn.tree import export_text

iris = datasets.load_iris()

data=pd.DataFrame({
    'sepal length':iris.data[:,0],
    'sepal width':iris.data[:,1],
    'petal length':iris.data[:,2],
    'petal width':iris.data[:,3],
    'species':iris.target
})

X=data[['sepal length', 'sepal width', 'petal length', 'petal width']]  # Features
y=data['species']  # Labels

# Split dataset into training set and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3) # 70% training and 30% test

clf=RandomForestClassifier(n_estimators=100)

#Train the model using the training sets y_pred=clf.predict(X_test)
clf.fit(X_train,y_train)

# Function to export the tree rules' code : in Python (works) and in SAS (issue with DO; END;
def export_code(tree, tree_id, feature_names, max_depth=100, spacing=2):
    if spacing < 2:
        raise ValueError('spacing must be > 1')

    # Clean up feature names (for correctness)
    nums = string.digits
    alnums = string.ascii_letters + nums
    clean = lambda s: ''.join(c if c in alnums else '_' for c in s)
    features = [clean(x) for x in feature_names]
    features = ['_'+x if x[0] in nums else x for x in features if x]
    if len(set(features)) != len(feature_names):
        raise ValueError('invalid feature names')

    # First: export tree to text
    res = export_text(tree, feature_names=features, 
                        max_depth=max_depth,
                        decimals=6,
                        spacing=spacing-1)

    code_sas = get_sas_from_text(tree, tree_id, features, res, spacing)
    code_py = get_py_from_text(tree, tree_id, features, res, spacing)
    return res, code_sas, code_py # to take a look at the different code outputs

# Python function
def get_py_from_text(tree, tree_id, features, text, spacing):
    skip, dash = ' '*spacing, '-'*(spacing-1)
    code = 'def decision_tree_'+ str(tree_id) + '({}):
'.format(', '.join(features))
    for line in repr(tree).split('
'):
        code += skip + "# " + line + '
'
    for line in text.split('
'):
        line = line.rstrip().replace('|',' ')
        if '<' in line or '>' in line:
            line, val = line.rsplit(maxsplit=1)
            line = line.replace(' ' + dash, 'if')
            line = '{} {:g}:'.format(line, float(val))
        else:
            line = line.replace(' {} class:'.format(dash), 'return')
        code += skip + line + '
'

    return code

I then get the generated codes in SAS and Python for a tree's decision rules :

# Rules for first decision tree (there are 100 of them)
exported_text, sas_text, py_text = export_code(clf[0], 0, iris.feature_names)

Here are the decision rules in Python for first tree :

def decision_tree_0(sepal_length__cm_, sepal_width__cm_, petal_length__cm_, petal_width__cm_):
  # DecisionTreeClassifier(max_features='auto', random_state=1087864992)
  if sepal_length__cm_ <= 5.35:
    if sepal_width__cm_ <= 2.7:
      return 1.0
    if sepal_width__cm_ > 2.7:
      return 0.0
  if sepal_length__cm_ > 5.35:
    if petal_width__cm_ <= 1.75:
      if petal_length__cm_ <= 2.5:
        return 0.0
      if petal_length__cm_ > 2.5:
        if sepal_length__cm_ <= 7.1:
          if petal_width__cm_ <= 1.45:
            if petal_length__cm_ <= 5.15:
              return 1.0
            if petal_length__cm_ > 5.15:
              return 2.0
          if petal_width__cm_ > 1.45:
            return 1.0
        if sepal_length__cm_ > 7.1:
          return 2.0
    if petal_width__cm_ > 1.75:
      return 2.0

The issue I have with my program get_sas_from_text is that the 'END;' statements are not all well placed within the code, and this does not correspond to the decision rules given as an input (when compared to the corresponding Python function :

data decision_tree_0;

     set input_data;

 

     if sepal_length__cm_ <= 5.35 THEN

          DO;

                if sepal_width__cm_ <= 2.7 THEN

                     DO;

                          PREDICTED_VALUE = 1.0;

                     end;

                ELSE

                     DO;

                          if sepal_width__cm_ > 2.7 THEN

                               DO;

                                     PREDICTED_VALUE = 0.0;

                               end;

                          ELSE

                               DO;

                                     if sepal_length__cm_ > 5.35 THEN

                                          DO;

                                               if petal_width__cm_ <= 1.75 THEN

                                                    DO;

                                                          if petal_length__cm_ <= 2.5 THEN

                                                               DO;

                                                                    PREDICTED_VALUE = 0.0;

                                                               end;

                                                          ELSE

                                                               DO;

                                                                    if petal_length__cm_ > 2.5 THEN

                                                                         DO;

                                                                              if sepal_length__cm_ <= 7.1 THEN

                                                                                    DO;

                                                                                         if petal_width__cm_ <= 1.45 THEN

                                                                                               DO;

                                                                                                    if petal_length__cm_ <= 5.15 THEN

                                                                                                         DO;

                                                                                                              PREDICTED_VALUE = 1.0;

                                                                                                         end;

                                                                                                    ELSE

                                                                                                         DO;

                                                                                                              if petal_length__cm_ > 5.15 THEN

                                                                                                                    DO;

                                                                                                                         PREDICTED_VALUE = 2.0;

                                                                                                                    end;

                                                                                                              ELSE

                                                                                                                    DO;

                                                                                                                         if petal_width__cm_ > 1.45 THEN

                                                                                                                              DO;

                                                                                                                                   PREDICTED_VALUE = 1.0;

                                                                                                                              end;

                                                                                                                         ELSE

                                                                                                                              DO;

                                                                                                                                  if sepal_length__cm_ > 7.1 THEN

                                                                                                                                         DO;

                                                                                                                                              PREDICTED_VALUE = 2.0;

                                                                                                                                         end;

                                                                                                                                   ELSE

                                                                                                                                         DO;

                                                                                             

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)
Waitting for answers

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...