Machine Learning (2) Logistic Regression

The difference between logistic regression and linear regression:

(1) The function fitting of linear regression is used for numerical prediction, and logistic regression is a binary classification algorithm for classification;

(2) Linear regression model:

        

Logistic regression model:

        

That is to say, logistic regression is actually based on linear regression, and an excitation function mapping is added. Because logistic regression is a binary classification algorithm, there are only two values ​​for training data, 1 and 0 , representing two classifications. When used for prediction classification, if the input value is greater than 0.5 , it will be classified as class 1 , otherwise are classified as category 0 . Therefore, the training data needs to satisfy the following probability formula:


Our training process is to train the parameter θ, so that the above two probabilities are as 1 as possible;

(3) The commonly used cost function of linear regression is defined as:

        

The logistic regression cost function is :

        

In fact, the above formula can be written separately. For class 1 , the total cost function is: :

        

For class 0 , the total cost function is:

        

therefore:

        

Our purpose is to minimize the value of the cost function J( θ ) and make it as close to 0 as possible .

(4) Gradient descent method to solve.

Cost function simplification:

        

OK , after the simplification of the formula is completed, the partial derivative is then calculated:

        

After the derivation is completed, the next step is to directly use the formula of the gradient descent method:

        

Then write the matlab code and train it. Only after you have written the code yourself can you be really familiar with this algorithm:

close all;
clear;
clc;
% Generate test data
mu = [2 3];% test data 1
SIGMA = [1 0; 0 2];
r1 = mvnrnd(mu,SIGMA,100);
plot(r1(:,1),r1(:,2),'.');
hold on;
mu = [10 10];% test data 2
SIGMA = [ 1 0; 0 2];
r2 = mvnrnd (in, SIGMA, 100);
plot (r2 (:, 1), r2 (:, 2), '.');
data(:,2:3)=[r1;r2];
data(:,1)=1;

% training data label
flag=[ones(100,1);zeros(100,1)];
[m,n]=size(data);
w=zeros(n,1);
% gradient descent
sigma=0.05;
i=1;
while i<10000
   for j=1:n
    % Calculate the excitation function value first
    pp=data*w;
    pp=exp(-data*w);
    gx=1./(1+exp(-data*w));
    % Calculate partial derivative value
     r=-1/m*sum((flag-gx).*data(:,j));
     w(j)=w(j)-sigma*r;
   end
   i=i+1;
end
% draw the classification results
figure(2);
hold on;
for i=1:m
    if gx(i)>0.5
        plot(data(i,2),data(i,3),'.b');
    else
        plot(data(i,2),data(i,3),'.y');
    end
end
% draw the decision boundary line
w(2)=w(2)/sqrt(w(2)*w(2)+w(3)*w(3));
w(3)=w(3)/sqrt(w(2)*w(2)+w(3)*w(3));
line([4,9],[(4*w(2)+w(1))/(-w(3)),(9*w(2)+w(1))/(-w(3))]);
Original image classification result:

          

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325643687&siteId=291194637