单变量的线性回归(用梯度下降法实现,Python语言,MATLAB语言)

单变量的线性回归实现:

Python版:

import numpy as np

import matplotlib.pyplot as plt

a = np.loadtxt('ex1data1.txt')
m=a.shape[0]
print(m)
print(type(a))
x=a[:,0]
y=a[:,1]
plt.scatter(x,y,marker='*',color='r',s=20)
theta0=0
theta1=0
iterations = 1500
alpha = 0.01
def gradientdescent(x,y,theta0,theta1,iterations,alpha):
    J_h=np.zeros( (iterations,1) )
    for i in range(0,iterations):
        y_hat=theta0+theta1*x
        temp0=theta0-alpha*((1/m)*sum(y_hat-y))
        temp1=theta1-alpha*(1/m)*sum((y_hat-y)*x)
        theta0=temp0
        theta1=temp1
        y_hat2=theta0+theta1*x
        aa=sum((y_hat2-y)**2)
        J=aa*(1/(2*m))
        J_h[i,:]=J
    return theta0,theta1,J_h


(theta0,theta1,J_h) = gradientdescent(x,y,theta0,theta1,iterations,alpha)
print(theta1)
print(theta0)
plt.plot(x,theta0+theta1*x)
plt.title("fittingcurve")
plt.show()
x2=np.arange(iterations)
plt.plot(x2,J_h)
plt.title("costfunction")

plt.show()

Matlab版:

data = load('ex1data1.txt');
X = data(:, 1);
y = data(:, 2);
m = length(y); % number of training examples
hold on
plot(X,y, 'r*', 'MarkerSize', 10)
iterations = 1500;
alpha = 0.01;
theta0=0;
theta1=0;
computeCost1(X, y, theta0,theta1)
[theta0,theta1,J_history] = gradientDescent1(X, y, theta0,theta1, alpha, iterations);
y_hat=theta0+theta1*X;
plot(X,y_hat)
hold off

figure,plot(1:iterations,J_history)

function J = computeCost1(X, y, theta0,theta1)
m = length(y); % number of training examples
y_hat1=theta0+theta1*X;
J = sum((y_hat1 - y).^2) / (2*m);     % X(79,2)  theta(2,1)
end

function [theta0,theta1, J_history] = gradientDescent1(X, y, theta0,theta1, alpha, num_iters)
m = length(y); % number of training examples
J_history = zeros(num_iters, 1);
    for iter = 1:num_iters
        y_hat2=theta0+theta1*X;
        temp0 = theta0 - alpha / m * sum(y_hat2 - y);       
        temp1 = theta1 - alpha / m * sum((y_hat2 - y) .* X(:,1));   
        theta0=temp0;
        theta1=temp1;  
        J_history(iter) = computeCost1(X, y, theta0,theta1);
    end

end

测试数据集:https://download.csdn.net/download/qq_20406597/10370009

猜你喜欢

转载自blog.csdn.net/qq_20406597/article/details/80020528