机器学习ex1

single_variable problem:

ex1.m

 1 clear;
 2 clc;
 3 close all;
 4 
 5 %输入数据、变量
 6 data = load ('ex1data1.txt');
 7 X = data(:, 1);
 8 y = data(:, 2);
 9 m = length(y);
10 alpha = 0.01;
11 theta = [0;0];
12 %----------plotdata------------
13 plot(X, y, 'rx', 'MarkerSize', 10);
14 xlabel('populations');
15 ylabel('profits');
16 iterations = 1500;
17 %----------gradientDescent------------
18 [theta,J_history] = gradientDescent(X, y, theta, iterations, alpha);
19 x = [ones(m, 1), X];
20 hold on;
21 plot(X, x * theta);
22 legend('Training data', 'Linear regression');
23 hold off;
24 figure;
25 plot(1:1500, J_history, 'b');
26 legend('gradientDescent');
27 J_vals = visualize(X, y, iterations, theta);

cost.m

function J = cost(X, y, theta)

m = length(y);
X = [ones(m, 1), X];
sum = 0;

for i = 1:m
    sum = sum + (X(i, :) * theta - y(i))^2;
end
J = 1/ (2 * m) * sum ;

end

gradientDescent.m

function [theta, J_vals] = gradientDescent(X, y, theta, iterations, alpha)

m = length(y);
X = [ones(m, 1), X];
J_vals = zeros(iterations, 1);

for iter = 1 : iterations
    sum  = zeros(2, 1);
    for i = 1:2
        for j =1 : m
            sum(i) = sum(i) + (X(j, :) * theta - y(j)) * X(j, i);
        end
    end
    theta = theta - alpha / m * sum;
    J_vals(iter) = cost(X(:, 2), y, theta);
    fprintf('-----%f-----\n', J_vals(iter));
end
    fprintf('\n');
    fprintf('我们得到theta为:\n');
    fprintf('%f\n', theta);
    fprintf('对应代价:\n');
    fprintf('%f\n', J_vals(iterations));
    
end

visualize.m

function  J_vals = visualize(X, y, iterations,theta)

%-------------surf-------------------
theta0_vals = linspace(-10, 10);
theta1_vals = linspace(-1, 4);
% theta0_vals = linspace(-10, 10, 100);
% theta1_vals = linspace(-2, 5, 100);

J_vals = zeros(length(theta0_vals), length(theta1_vals));
for i = 1: length(theta0_vals)
    for j =1: length(theta1_vals)
        J_vals(i, j) = cost(X, y, [theta0_vals(i); theta1_vals(j)]);
    end
end
J_vals = J_vals';
figure;
surf(theta0_vals, theta1_vals, J_vals);
xlabel('theta0');
ylabel('theta1');
zlabel('J');

%--------------wan---------------
figure;
contour(theta0_vals, theta1_vals, J_vals, logspace(-2, 3, 20));
xlabel('\theta_0'); ylabel('\theta_1');
hold on;
plot(theta(1), theta(2), 'rx');


end

对应数据集ex1data1.txt

猜你喜欢

转载自www.cnblogs.com/sunrise-to-set/p/11360513.html
ex1
今日推荐