参考:
http://doc.okbase.net/bentuwuying/archive/259870.html
关于svm_train函数分析
java使用svm_train时,有如下格式:
public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
String[] arg = { "trainfile\\train1.txt", // 存放SVM训练模型用的数据的路径
"trainfile\\model_r.txt" }; // 存放SVM通过训练数据训/ //练出来的模型的路径
String[] parg = { "trainfile\\train2.txt", // 这个是存放测试数据
"trainfile\\model_r.txt", // 调用的是训练以后的模型
"trainfile\\out_r.txt" }; // 生成的结果的文件的路径
System.out.println("........SVM运行开始..........");
// 创建一个训练对象
svm_train t = new svm_train();
// 创建一个预测或者分类的对象
svm_predict p = new svm_predict();
t.main(arg); // 训练对象调用
p.main(parg); // 预测对象调用
}
在libsvm
中的svm_train
中分别有回归SVR
(Support Vector Regression
)和分类(默认 SVC
)两部分,我只对其中分类做介绍。
svm_train结构
class svm_train
private svm_parameter param; // set by parse_command_line 用于设置svm模型的参数
private svm_problem prob; // set by read_problem用来存储样本序号、样本的目标变量Y、样本自变量X 详看class svm_problem
private svm_model model;
private String input_file_name; // set by parse_command_line 输入文件名
private String model_file_name; // set by parse_command_line 模型文件名
private String error_msg; //错误信息
private int cross_validation; //交叉验证
private int nr_fold;
private void do_cross_validation()...
private void run(String argv[]) ...
public static void main(String argv[])...
private void parse_command_line(String argv[])...
private void read_problem()...
svm_train
类中,含有svm_model
等属性,及main()
、run()
、parse_command_line()
等方法。
其中,svm_model
用于存储数据信息。
1 svm_model
package libsvm;
public class svm_model implements java.io.Serializable
{
public svm_parameter param; // parameter
public int nr_class; // number of classes, = 2 in regression/one class svm
public int l; // total #SV
public svm_node[][] SV; // SVs (SV[l])
public double[][] sv_coef; // coefficients for SVs in decision functions (sv_coef[k-1][l])
public double[] rho; // constants in decision functions (rho[k*(k-1)/2])
public double[] probA; // pariwise probability information
public double[] probB;
public int[] sv_indices; // sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set
// for classification only
public int[] label; // label of each class (label[k])
public int[] nSV; // number of SVs for each class (nSV[k])
// nSV[0] + nSV[1] + ... + nSV[k-1] = l
};
main()
方法
public static void main(String argv[]) throws IOException
{
svm_train t = new svm_train();
t.run(argv);
}
run()
方法:
private void run(String argv[]) throws IOException
{
parse_command_line(argv); // 1.进入到该函数中,获取SVM参数。
read_problem(); // 2.进入到该函数中,读取错误信息
error_msg = svm.svm_check_parameter(prob,param); //该函数是在class svm中,功能是检查svm模型的参数是否正确。
if(error_msg != null)
{
System.err.print("ERROR: "+error_msg+"\n");
System.exit(1);
}
if(cross_validation != 0)
{
do_cross_validation(); //由于SVR算法不需要交叉验证,故不执行此函数。而对于分类而言,执行交叉验证操作可增强算法的推广能力。
}
else
{
model = svm.svm_train(prob,param);
svm.svm_save_model(model_file_name,model);
}
}
parse_command_line()
方法
/**
* 初始化各种参数
* @param argv参数为训练样本 和训练好的模型
*/
private void parse_command_line(String argv[]) //参数为训练样本 和训练好的模型
{
int i;
svm_print_interface print_func = null; // default printing to stdout
param = new svm_parameter();
// default values 以下为参数的默认值
param.svm_type = svm_parameter.C_SVC;
//param.svm_type = svm_parameter.EPSILON_SVR; //此时运行的是SVR算法
param.kernel_type = svm_parameter.RBF; //核函数取径向基核函数
param.degree = 3;
param.gamma = 0; // 1/num_features
param.gamma = 0.08;
//gamma为RBF核函数的参数,默认时=1/num_features 此时设置为0.08 gamma=1/2*sig^2 sig=2.5
//RBF核函数:exp(-gamma*|Xi-Xj|^2)
param.coef0 = 0;
param.nu = 0.5;
param.cache_size = 100; //设置缓存的大小
param.C = 1; //惩罚参数.100试试
param.eps = 1e-3; //0.005
param.p = 0.1; //此值为EPSILON_SVR中EPSILON 0.01
param.shrinking = 1;
param.probability = 0; //概率估计
param.nr_weight = 0; //权重
param.weight_label = new int[0];
param.weight = new double[0];
cross_validation = 0; //交叉验证。0--不进行交叉验证。1--交叉验证
//获取输入参数
// parse options
for(i=0;i<argv.length;i++) //argv.length=2,即有两个字符串
{
if(argv[i].charAt(0) != '-') break;
//由于第一个字符(trainfile\\train1.txt)中的第一个字符不是‘-’,果断break!退出for循环。i=0
if(++i>=argv.length)
exit_with_help();
switch(argv[i-1].charAt(1))
{
case 's':
param.svm_type = atoi(argv[i]);
break;
case 't':
param.kernel_type = atoi(argv[i]);
break;
case 'd':
param.degree = atoi(argv[i]);
break;
case 'g':
param.gamma = atof(argv[i]);
break;
case 'r':
param.coef0 = atof(argv[i]);
break;
case 'n':
param.nu = atof(argv[i]);
break;
case 'm':
param.cache_size = atof(argv[i]);
break;
case 'c':
param.C = atof(argv[i]);
break;
case 'e':
param.eps = atof(argv[i]);
break;
case 'p':
param.p = atof(argv[i]);
break;
case 'h':
param.shrinking = atoi(argv[i]);
break;
case 'b':
param.probability = atoi(argv[i]);
break;
case 'q':
print_func = svm_print_null;
i--;
break;
case 'v':
cross_validation = 1;
nr_fold = atoi(argv[i]);
if(nr_fold < 2)
{
System.err.print("n-fold cross validation: n must >= 2\n");
exit_with_help();
}
break;
case 'w':
++param.nr_weight;
{
int[] old = param.weight_label;
param.weight_label = new int[param.nr_weight];
System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);
}
{
double[] old = param.weight;
param.weight = new double[param.nr_weight];
System.arraycopy(old,0,param.weight,0,param.nr_weight-1);
}
param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2));
param.weight[param.nr_weight-1] = atof(argv[i]);
break;
default:
System.err.print("Unknown option: " + argv[i-1] + "\n");
exit_with_help();
}
}
svm.svm_set_print_string_function(print_func); //1.1打印,详见下文说明2;
// determine filenames
if(i>=argv.length) //argv.length=2,而i=0,不执行此语句
exit_with_help();
input_file_name = argv[i]; //训练样本的文件名,即trainfile\\data_train_svr.txt
if(i<argv.length-1) //i=0,argv.length-1=1,符合条件
model_file_name = argv[i+1]; //模型文件名,即trainfile\\model_r.txt
else
{
int p = argv[i].lastIndexOf('/');
++p; // whew...
model_file_name = argv[i].substring(p)+".model";
}
}