【Matlab算法】随机梯度下降法 (Stochastic Gradient Descent,SGD) (附MATLAB完整代码)

前言

随机梯度下降法 (Stochastic Gradient Descent,SGD) 是一种梯度下降法的变种,用于优化损失函数并更新模型参数。与传统的梯度下降法不同,SGD每次只使用一个样本来计算梯度和更新参数,而不是使用整个数据集。这种随机性使得SGD在大型数据集上更加高效,因为它在每次迭代中只需要处理一个样本。

以下是关于随机梯度下降法的详细描述:

  1. 初姶化参数:与梯度下降法类似,首先需要初始化模型的参数,通常使用随机的初始值。
  2. 选代过程:
  • 对于每个训练样本 i i i :
  • 计算损失函数关于当前参数的梯度,即 ∇ f i ( θ ) \nabla f_i(\theta) fi(θ) ,其中 f i ( θ ) f_i(\theta) fi(θ) 是针对第 i i i 个样本的损失。
  • 使用计算得到的梯度来更新模型参数: θ = θ − η ⋅ ∇ f i ( θ ) \theta=\theta-\eta \cdot \nabla f_i(\theta) θ=θηfi(θ) ,其中 η \eta η 是学习率。
  1. 重复迭代: 重复以上过程,直到达到预定的迭代次数或满足停止条件(例如梯度的范数足够小)。
    相比于传统的梯度下降法,SGD的优点包括:
  • 高效:特别适用于大型数据集,因为每次迭代只使用一个样本。
  • 在线学习: 可以用于在线学习,即在接收到新数据时立即更新模型。

然而,由于随机性的引入,SGD的参数更新可能会更加不稳定,因此学习率的选择变得尤为重要。为了解决这个问题,有一些SGD的变种,如Mini-batch SGD,它在每次迭代中使用小批量的样本来计算梯度。这样可以在保持高效性的同时减小参数更新的方差。

正文

对于给出的函数 f ( x ) f(x) f(x) :
f ( x ) = x ( 1 ) 2 + x ( 2 ) 2 − 2 ⋅ x ( 1 ) ⋅ x ( 2 ) + sin ⁡ ( x ( 1 ) ) + cos ⁡ ( x ( 2 ) ) f(x)=x(1)^2+x(2)^2-2 \cdot x(1) \cdot x(2)+\sin (x(1))+\cos (x(2)) f(x)=x(1)2+x(2)22x(1)x(2)+sin(x(1))+cos(x(2))

  1. 初始化参数: 随机选择初始参数 x x x ,通常使用某种随机的初始值。
  2. 选择学习率: 选择一个适当的学习率 η \eta η ,这是一个重要的超参数,影响着参数更新的步长。
  3. 设置迭代次数和停止条件: 确定迭代次数的上限或设置停止条件,例如当梯度的范数小于某个容许误差时停止迭代。
  4. 随机梯度下降选代:
  • 对于每次迭代 t t t ,从训练集中随机选择一个样本 i i i
  • 计算该样本的梯度: ∇ f i ( x ( t ) ) \nabla f_i\left(x^{(t)}\right) fi(x(t))
  • 使用梯度更新参数: x ( t + 1 ) = x ( t ) − η ⋅ ∇ f i ( x ( t ) ) x^{(t+1)}=x^{(t)}-\eta \cdot \nabla f_i\left(x^{(t)}\right) x(t+1)=x(t)ηfi(x(t))
  • 检查是否满足停止条件。如果满足,停止迭代;否则,继续下一次迭代。
  1. 输出结果: 输出最终的参数 x x x ,以及在最优点的目标函数值 f ( x ) f(x) f(x)

代码实现

可运行代码

% 定义目标函数
f = @(x) x(1)^2 + x(2)^2 - 2*x(1)*x(2) + sin(x(1)) + cos(x(2));

% 定义目标函数的梯度
grad_f = @(x) [2*x(1) - 2*x(2) + cos(x(1)); 2*x(2) - 2*x(1) - sin(x(2))];

% 设置参数
learning_rate = 0.01;
max_iterations = 1000;
tolerance = 1e-6;

% 初始化起始点
x = [0; 0];

% 随机梯度下降
for iteration = 1:max_iterations
    % 随机选择一个样本
    i = randi(2);
    % 计算梯度
    gradient = grad_f(x);
    % 更新参数
    x = x - learning_rate * gradient;
    % 检查收敛性
    if norm(gradient) < tolerance
        break;
    end
end

% 显示结果
fprintf('Optimal solution: x = [%f, %f]\n', x(1), x(2));
fprintf('Optimal value of f(x): %f\n', f(x));
fprintf('Number of iterations: %d\n', iteration);

结果

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/AlbertDS/article/details/135094638