Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Welcome To Ask or Share your Answers For Others

Categories

0 votes
344 views
in Technique[技术] by (71.8m points)

python - Scikit Learn OneHotEncoder fit and transform Error: ValueError: X has different shape than during fitting

Below is my code.

I know why the error is occurring during transform. It is because of the feature list mismatch during fit and transform. How can i solve this? How can i get 0 for all the rest features?

After this i want to use this for partial fit of SGD classifier.

Jupyter QtConsole 4.3.1
Python 3.6.2 |Anaconda custom (64-bit)| (default, Sep 21 2017, 18:29:43) 
Type 'copyright', 'credits' or 'license' for more information
IPython 6.1.0 -- An enhanced Interactive Python. Type '?' for help.

import pandas as pd
from sklearn.preprocessing import OneHotEncoder

input_df = pd.DataFrame(dict(fruit=['Apple', 'Orange', 'Pine'], 
                             color=['Red', 'Orange','Green'],
                             is_sweet = [0,0,1],
                             country=['USA','India','Asia']))
input_df
Out[1]: 
    color country   fruit  is_sweet
0     Red     USA   Apple         0
1  Orange   India  Orange         0
2   Green    Asia    Pine         1



filtered_df = input_df.apply(pd.to_numeric, errors='ignore')
filtered_df.info()
# apply one hot encode
refreshed_df = pd.get_dummies(filtered_df)
refreshed_df
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 3 entries, 0 to 2
Data columns (total 4 columns):
color       3 non-null object
country     3 non-null object
fruit       3 non-null object
is_sweet    3 non-null int64
dtypes: int64(1), object(3)
memory usage: 176.0+ bytes


Out[2]: 
   is_sweet  color_Green  color_Orange  color_Red  country_Asia  
0         0            0             0          1             0   
1         0            0             1          0             0   
2         1            1             0          0             1   

   country_India  country_USA  fruit_Apple  fruit_Orange  fruit_Pine  
0              0            1            1             0           0  
1              1            0            0             1           0  
2              0            0            0             0           1  



enc = OneHotEncoder()
enc.fit(refreshed_df)

Out[3]: 
OneHotEncoder(categorical_features='all', dtype=<class 'numpy.float64'>,
       handle_unknown='error', n_values='auto', sparse=True)



new_df = pd.DataFrame(dict(fruit=['Apple'], 
                             color=['Red'],
                             is_sweet = [0],
                             country=['USA']))
new_df


Out[4]: 
  color country  fruit  is_sweet
0   Red     USA  Apple         0



filtered_df1 = new_df.apply(pd.to_numeric, errors='ignore')
filtered_df1.info()
# apply one hot encode
refreshed_df1 = pd.get_dummies(filtered_df1)
refreshed_df1
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 1 entries, 0 to 0
Data columns (total 4 columns):
color       1 non-null object
country     1 non-null object
fruit       1 non-null object
is_sweet    1 non-null int64
dtypes: int64(1), object(3)
memory usage: 112.0+ bytes



Out[5]: 
   is_sweet  color_Red  country_USA  fruit_Apple
0         0          1            1            1

enc.transform(refreshed_df1)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-6-33a6a884ba3f> in <module>()
----> 1 enc.transform(refreshed_df1)

~/anaconda3/lib/python3.6/site-packages/sklearn/preprocessing/data.py in transform(self, X)
   2073         """
   2074         return _transform_selected(X, self._transform,
-> 2075                                    self.categorical_features, copy=True)
   2076 
   2077 

~/anaconda3/lib/python3.6/site-packages/sklearn/preprocessing/data.py in _transform_selected(X, transform, selected, copy)
   1810 
   1811     if isinstance(selected, six.string_types) and selected == "all":
-> 1812         return transform(X)
   1813 
   1814     if len(selected) == 0:

~/anaconda3/lib/python3.6/site-packages/sklearn/preprocessing/data.py in _transform(self, X)
   2030             raise ValueError("X has different shape than during fitting."
   2031                              " Expected %d, got %d."
-> 2032                              % (indices.shape[0] - 1, n_features))
   2033 
   2034         # We use only those categorical features of X that are known using fit.

ValueError: X has different shape than during fitting. Expected 10, got 4.
See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
Welcome To Ask or Share your Answers For Others

1 Reply

0 votes
by (71.8m points)

Instead of using pd.get_dummies() you need LabelEncoder + OneHotEncoder which can store the original values and then use them on the new data.

Changing your code like below will give you required results.

import pandas as pd
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
input_df = pd.DataFrame(dict(fruit=['Apple', 'Orange', 'Pine'], 
                             color=['Red', 'Orange','Green'],
                             is_sweet = [0,0,1],
                             country=['USA','India','Asia']))

filtered_df = input_df.apply(pd.to_numeric, errors='ignore')

# This is what you need
le_dict = {}
for col in filtered_df.columns:
    le_dict[col] = LabelEncoder().fit(filtered_df[col])
    filtered_df[col] = le_dict[col].transform(filtered_df[col])

enc = OneHotEncoder()
enc.fit(filtered_df)
refreshed_df = enc.transform(filtered_df).toarray()

new_df = pd.DataFrame(dict(fruit=['Apple'], 
                             color=['Red'],
                             is_sweet = [0],
                             country=['USA']))
for col in new_df.columns:
    new_df[col] = le_dict[col].transform(new_df[col])

new_refreshed_df = enc.transform(new_df).toarray()

print(filtered_df)
      color  country  fruit  is_sweet
0      2        2      0         0
1      1        1      1         0
2      0        0      2         1

print(refreshed_df)
[[ 0.  0.  1.  0.  0.  1.  1.  0.  0.  1.  0.]
 [ 0.  1.  0.  0.  1.  0.  0.  1.  0.  1.  0.]
 [ 1.  0.  0.  1.  0.  0.  0.  0.  1.  0.  1.]]

print(new_df)
      color  country  fruit  is_sweet
0      2        2      0         0

print(new_refreshed_df)
[[ 0.  0.  1.  0.  0.  1.  1.  0.  0.  1.  0.]]

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
OGeek|极客中国-欢迎来到极客的世界,一个免费开放的程序员编程交流平台!开放,进步,分享!让技术改变生活,让极客改变未来! Welcome to OGeek Q&A Community for programmer and developer-Open, Learning and Share
Click Here to Ask a Question

...