1 训练集ex6data3.mat可视化
绘制训练集函数plotData.m
同机器学习(20)中的plotData.m
训练集可视化脚本GaussianSVM1.m 部分代码
%% Initialization
clear ; close all; clc
%% =============== Part 1: 加载并可视化数据 ================
fprintf('Loading and Visualizing Data ...\n')
% 从文件 ex6data3.mat中加载,发现环境中有X,y,Xval,yval变量值:
load('ex6data3.mat');
% 绘制训练集数据
plotData(X, y);
% 绘制交叉验证集数据
plotData(Xval, yval);
fprintf('Program paused. Press enter to continue.\n');
pause;
执行结果
左图为训练集样本共211个,右图为测试集样本共200个,每个样本的输入特征为2个。
可以看到仍属于二分类问题,只是对应的边界不明显,因此需要根据交叉验证集训练情况选用适合的参数C、σ(这也是4.2中的内容,根据交叉验证集来选参数)。
2 SVM的高斯核函数gaussianKernel.m
同机器学习(21)中的gaussianKernel.m
3 寻找最佳C、σ的函数dataset3Params.m
这里C和σ都采用公比为3的等比数列进行逐次验证,每次根据这两个对应的值去训练得到参数theta,然后得到交叉验证集的准确率,通过比较这次与上次的准确率,选用准确率更高的参数值。容易得到循环次数为length(vec)的平方,即内外各套一层循环。
function [C, sigma] = dataset3Params(X, y, Xval, yval)
% dataset3Params返回练习第3部分中选择的C和sigma
% 根据交叉验证集返回最佳的C和sigma。
% 输入: X 训练矩阵,行数为样本数,列数为输入特征数
% y 训练集输出特征向量,是一个包含1和0的列向量,行数为样本数,列数为1
% Xval 验证样本矩阵,行数为样本数,列数为输入特征数
% yval 验证样本输出特征向量,是一个包含1和0的列向量,行数为样本数,列数为1
% 输出: 用于选择最优的(C, sigma)学习参数用于支持向量机与RBF核
C = 1; %初始化C、sigma
sigma = 0.3;
% 说明:以下代码以返回使用交叉验证集找到的最佳C和西格玛学习参数。
% 可以使用svmPredict来预测交叉验证集中的标签。例如,
% predictions = svmPredict(model, Xval);
% 将返回交叉验证集上的预测
% 注意:您可以使用mean(double(predictions ~= yval))计算预测误差
vec = [0.01 0.03 0.1 0.3 1 3 10 30]';
C = 0.01;
sigma = 0.01;
model= svmTrain(X, y, C, @(x1, x2) gaussianKernel(x1, x2, sigma));
predictions = svmPredict(model,Xval);
meanMin = mean(double(predictions ~= yval));
C_optimal = C;
sigma_optimal = sigma;
for i = 1:length(vec)
for j = 1:length(vec)
C = vec(i);
sigma = vec(j);
model= svmTrain(X, y, C, @(x1, x2) gaussianKernel(x1, x2, sigma));
predictions = svmPredict(model,Xval);
if(meanMin >= mean(double(predictions ~= yval)))
meanMin = mean(double(predictions ~= yval));
C_optimal = C;
sigma_optimal = sigma;
endif
endfor
endfor
C = C_optimal;
sigma = sigma_optimal;
endfunction
4 训练函数svmTrain.m
同机器学习(20)中的svmTrain.m
5 非线性SVM脚本GaussianSVM1.m
%% Initialization
clear ; close all; clc
%% =============== Part 1: 加载并可视化数据 ================
fprintf('Loading and Visualizing Data ...\n')
% 从文件 ex6data3.mat中加载,会发现环境中有X,y变量值:
load('ex6data3.mat');
% 绘制训练集数据
plotData(X, y);
fprintf('Program paused. Press enter to continue.\n');
pause;
%% ==================== Part 2: 非线性SVM训练 ====================
% 在实现了内核之后,我们现在可以使用它来训练SVM分类器。
% 从文件 ex6data3.mat中加载,会发现环境中有X,y变量值:
fprintf('\nTraining SVM with RBF Kernel (this may take 1 to 2 minutes) ...\n');
load('ex6data3.mat');
% 在这里尝试不同的SVM参数
[C, sigma] = dataset3Params(X, y, Xval, yval);
% 训练SVM
model= svmTrain(X, y, C, @(x1, x2) gaussianKernel(x1, x2, sigma));
visualizeBoundary(X, y, model);
fprintf('Program paused. Press enter to continue.\n');
pause;