libsvm源码解析2

参考:
http://doc.okbase.net/bentuwuying/archive/259870.html

关于svm_train函数分析

java使用svm_train时,有如下格式:

public static void main(String[] args) throws IOException {
    // TODO Auto-generated method stub
    String[] arg = { "trainfile\\train1.txt", // 存放SVM训练模型用的数据的路径
            "trainfile\\model_r.txt" }; // 存放SVM通过训练数据训/ //练出来的模型的路径

    String[] parg = { "trainfile\\train2.txt", // 这个是存放测试数据
            "trainfile\\model_r.txt", // 调用的是训练以后的模型
            "trainfile\\out_r.txt" }; // 生成的结果的文件的路径
    System.out.println("........SVM运行开始..........");
    // 创建一个训练对象
    svm_train t = new svm_train();
    // 创建一个预测或者分类的对象
    svm_predict p = new svm_predict();
    t.main(arg); // 训练对象调用
    p.main(parg); // 预测对象调用
}

libsvm中的svm_train中分别有回归SVR(Support Vector Regression)和分类(默认 SVC)两部分,我只对其中分类做介绍。

svm_train结构

class svm_train
    private svm_parameter param;        // set by parse_command_line 用于设置svm模型的参数
    private svm_problem prob;       // set by read_problem用来存储样本序号、样本的目标变量Y、样本自变量X 详看class svm_problem
    private svm_model model;   
    private String input_file_name;     // set by parse_command_line 输入文件名
    private String model_file_name;     // set by parse_command_line 模型文件名
    private String error_msg; //错误信息
    private int cross_validation; //交叉验证
    private int nr_fold;

    private void do_cross_validation()...
    private void run(String argv[]) ...
    public static void main(String argv[])...
    private void parse_command_line(String argv[])...
    private void read_problem()...

svm_train类中,含有svm_model等属性,及main()run()parse_command_line()等方法。
其中,svm_model用于存储数据信息。

1 svm_model

package libsvm;
public class svm_model implements java.io.Serializable
{
    public svm_parameter param; // parameter
    public int nr_class;        // number of classes, = 2 in regression/one class svm
    public int l;           // total #SV
    public svm_node[][] SV; // SVs (SV[l])
    public double[][] sv_coef;  // coefficients for SVs in decision functions (sv_coef[k-1][l])
    public double[] rho;        // constants in decision functions (rho[k*(k-1)/2])
    public double[] probA;         // pariwise probability information
    public double[] probB;
    public int[] sv_indices;       // sv_indices[0,...,nSV-1] are values in [1,...,num_traning_data] to indicate SVs in the training set

    // for classification only

    public int[] label;     // label of each class (label[k])
    public int[] nSV;       // number of SVs for each class (nSV[k])
                // nSV[0] + nSV[1] + ... + nSV[k-1] = l
};

main()方法

    public static void main(String argv[]) throws IOException
    {
        svm_train t = new svm_train();
        t.run(argv);
    }

run()方法:

private void run(String argv[]) throws IOException
    {
        parse_command_line(argv);  // 1.进入到该函数中,获取SVM参数。
        read_problem();     // 2.进入到该函数中,读取错误信息
        error_msg = svm.svm_check_parameter(prob,param); //该函数是在class svm中,功能是检查svm模型的参数是否正确。

        if(error_msg != null)
        {
            System.err.print("ERROR: "+error_msg+"\n");
            System.exit(1);
        }

        if(cross_validation != 0)
        {
            do_cross_validation(); //由于SVR算法不需要交叉验证,故不执行此函数。而对于分类而言,执行交叉验证操作可增强算法的推广能力。
        }
        else
        {
            model = svm.svm_train(prob,param);
            svm.svm_save_model(model_file_name,model);
        }
    }

parse_command_line()方法

/**
     * 初始化各种参数
     * @param argv参数为训练样本 和训练好的模型
     */
    private void parse_command_line(String argv[]) //参数为训练样本 和训练好的模型
    {
        int i;
        svm_print_interface print_func = null;  // default printing to stdout

        param = new svm_parameter();
        // default values  以下为参数的默认值

        param.svm_type = svm_parameter.C_SVC;       
        //param.svm_type = svm_parameter.EPSILON_SVR; //此时运行的是SVR算法

        param.kernel_type = svm_parameter.RBF;  //核函数取径向基核函数
        param.degree = 3;

        param.gamma = 0;    // 1/num_features           
        param.gamma = 0.08; 
        //gamma为RBF核函数的参数,默认时=1/num_features 此时设置为0.08 gamma=1/2*sig^2 sig=2.5
        //RBF核函数:exp(-gamma*|Xi-Xj|^2)

        param.coef0 = 0;
        param.nu = 0.5;
        param.cache_size = 100;  //设置缓存的大小
        param.C = 1;             //惩罚参数.100试试
        param.eps = 1e-3;        //0.005
        param.p = 0.1;           //此值为EPSILON_SVR中EPSILON 0.01
        param.shrinking = 1;
        param.probability = 0;   //概率估计
        param.nr_weight = 0;     //权重
        param.weight_label = new int[0];
        param.weight = new double[0];
        cross_validation = 0;    //交叉验证。0--不进行交叉验证。1--交叉验证

        //获取输入参数
        // parse options
        for(i=0;i<argv.length;i++)     //argv.length=2,即有两个字符串
        {
            if(argv[i].charAt(0) != '-') break;
            //由于第一个字符(trainfile\\train1.txt)中的第一个字符不是‘-’,果断break!退出for循环。i=0

            if(++i>=argv.length)
                exit_with_help();
            switch(argv[i-1].charAt(1))
            {
                case 's':
                    param.svm_type = atoi(argv[i]);
                    break;
                case 't':
                    param.kernel_type = atoi(argv[i]);
                    break;
                case 'd':
                    param.degree = atoi(argv[i]);
                    break;
                case 'g':
                    param.gamma = atof(argv[i]);
                    break;
                case 'r':
                    param.coef0 = atof(argv[i]);
                    break;
                case 'n':
                    param.nu = atof(argv[i]);
                    break;
                case 'm':
                    param.cache_size = atof(argv[i]);
                    break;
                case 'c':
                    param.C = atof(argv[i]);
                    break;
                case 'e':
                    param.eps = atof(argv[i]);
                    break;
                case 'p':
                    param.p = atof(argv[i]);
                    break;
                case 'h':
                    param.shrinking = atoi(argv[i]);
                    break;
                case 'b':
                    param.probability = atoi(argv[i]);
                    break;
                case 'q':
                    print_func = svm_print_null;
                    i--;
                    break;
                case 'v':
                    cross_validation = 1;
                    nr_fold = atoi(argv[i]);
                    if(nr_fold < 2)
                    {
                        System.err.print("n-fold cross validation: n must >= 2\n");
                        exit_with_help();
                    }
                    break;
                case 'w':
                    ++param.nr_weight;
                    {
                        int[] old = param.weight_label;
                        param.weight_label = new int[param.nr_weight];
                        System.arraycopy(old,0,param.weight_label,0,param.nr_weight-1);
                    }

                    {
                        double[] old = param.weight;
                        param.weight = new double[param.nr_weight];
                        System.arraycopy(old,0,param.weight,0,param.nr_weight-1);
                    }

                    param.weight_label[param.nr_weight-1] = atoi(argv[i-1].substring(2));
                    param.weight[param.nr_weight-1] = atof(argv[i]);
                    break;
                default:
                    System.err.print("Unknown option: " + argv[i-1] + "\n");
                    exit_with_help();
            }
        }

        svm.svm_set_print_string_function(print_func);   //1.1打印,详见下文说明2;

        // determine filenames

        if(i>=argv.length)                //argv.length=2,而i=0,不执行此语句
            exit_with_help();

        input_file_name = argv[i];        //训练样本的文件名,即trainfile\\data_train_svr.txt

        if(i<argv.length-1)                //i=0,argv.length-1=1,符合条件
            model_file_name = argv[i+1];   //模型文件名,即trainfile\\model_r.txt
        else
        {
            int p = argv[i].lastIndexOf('/');
            ++p;    // whew...
            model_file_name = argv[i].substring(p)+".model";
        }
    }

猜你喜欢

转载自blog.csdn.net/answer100answer/article/details/80114325