fit_transform() function and transform() function

When typing the code on "Python Machine Learning and Practice", the difference between the fit_transform() function and the transform() function involved in data preprocessing is very vague. After consulting a lot of information, here is a summary:


The code involved in these two functions is as follows:

[python] view plain copy
print ?
  1. # Import StandardScaler from sklearn.preprocessing  
  2. from sklearn.preprocessing import StandardScaler  
  3. # Standardize the data to ensure that the variance of the feature data in each dimension is 1 and the mean is 0, so that the prediction results will not be dominated by the eigenvalues ​​of some dimensions that are too large  
  4. ss = StandardScaler()  
  5. # fit_transform() Fit the data first, then standardize  
  6. X_train = ss.fit_transform(X_train)  
  7. # transform() data normalization  
  8. X_test = ss.transform(X_test)  
# Import StandardScaler from sklearn.preprocessing
from sklearn.preprocessing import StandardScaler




Standardize the data to ensure that the variance of the feature data of each dimension is 1 and the mean is 0, so that the prediction result will not be dominated by the eigenvalues ​​of some dimensions that are too large

ss = StandardScaler()

fit_transform() first fits the data, then normalizes

X_train = ss.fit_transform(X_train)

transform() data normalization

X_test = ss.transform(X_test)
Let's first take a look at the API and parameter meanings of these two functions:

1. fit_transform() function


That is, the role of fit_transform() is to fit the data first, and then transform it into a standard form

2. transform() function


That is, the role of tranform() is to achieve standardization by finding the center and scaling, etc.


At this point, we seem to know some differences between the two, just like the difference in name, the former has an additional step of fitting data, so why is the fit_transform() function not applicable when standardizing the data?

The reasons are as follows:

In order to normalize the data (to make the variance of the feature data 1 and the mean to be 0), we need to calculate the mean μ and variance σ^2 of the feature data, and then use the following formula to normalize:


We call fit_transform() on the training set, and actually find the mean μ and variance σ^2, that is, we have found the transformation rule , we use this rule on the training set, and similarly, we can directly apply it to the test set (even the cross-validation set), so on the test set, we only need to normalize the data without fitting the data again. Shown with a picture as follows:


(Image source: Click to open the link )



        </div>
            </div>

When typing the code on "Python Machine Learning and Practice", the difference between the fit_transform() function and the transform() function involved in data preprocessing is very vague. After consulting a lot of information, here is a summary:


The code involved in these two functions is as follows:

[python] view plain copy
print ?
  1. # Import StandardScaler from sklearn.preprocessing  
  2. from sklearn.preprocessing import StandardScaler  
  3. # Standardize the data to ensure that the variance of the feature data in each dimension is 1 and the mean is 0, so that the prediction results will not be dominated by the eigenvalues ​​of some dimensions that are too large  
  4. ss = StandardScaler()  
  5. # fit_transform() Fit the data first, then standardize  
  6. X_train = ss.fit_transform(X_train)  
  7. # transform() data normalization  
  8. X_test = ss.transform(X_test)  
# Import StandardScaler from sklearn.preprocessing
from sklearn.preprocessing import StandardScaler




Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325745525&siteId=291194637