【徒手写机器学习算法】感知机算法

今天开始又开一个新坑叫“徒手写机器学习算法”,先说明一下实际上你徒手写出的这些算法没什么毛用,但是如今天天在写paper,感觉远离了之前的统计机器学习算法和编程,所以弄个这个专栏督促自己复习一下不要忘了之前学的东西。


关于这个系列“徒手写机器学习算法”

首先既然是徒手肯定就别想什么”import tensorflow as tf”之类的了,准备还是用c++和几个基础库来写,就eigen吧,数据I/O就用csv读取。今天先展示写个感知机。


关于感知机算法

感知机算法是误差驱动的在线学习算法。首先初始化一个权重 W 0 、截距 b ,然后每次分错样本时,就用这个样本改变权重以及截距 b 。其实就是构造了一个分类超平面。感知机算法的学习策略是极小化误分类点到超平面的距离。

这里写图片描述


关于数据I/O

我们单独写一个c++文件csv.hpp来读取数据文件,数据文件data.csv看起来像这样:

6.62641e+02,3.72825e+02,1.56250e+02,...,1.0e+00
6.61418e+02,3.73390e+02,1.84375e+02,...,1.0e+00
6.61835e+02,3.72908e+02,1.31250e+02,...,1.0e+00
...

其中最后一维的数据是标签 y ,而除最后一维外为数据 X


核心程序

#include "csv.hpp"
#include <Eigen/Dense>
#include <iostream>
#include <vector>

//the max training steps
#define MAX_STEPS 200

using namespace std;  
using namespace Eigen;

static VectorXd* W;

template<typename DType>
void train(std::vector<VectorXd*> X, std::vector<DType> y)
{
    int step = 0;
    while(step < MAX_STEPS)
    {
        int flag = 0; //judge if all samples from <X,y> meet y<W,X> >= 0;
        for (int t = 0; t < X.size(); ++t)
        {
            if ( y[t]*( W->dot(*X) ) < 0 )
            {
                W->array() += y[t]*(X->array());
                flag++;
            }
        }
        if (flag == 0)
        {
            return;
        }
        step++;
    }
}

现在给出数据I/O程序csv.hpp

调用方式是:

X1=csv::get_X("d1.csv");
y1=csv::get_y("d1.csv");

整个程序:

#include <iostream>  
#include <string>  
#include <vector>  
#include <fstream>  
#include <sstream>  
#include <Eigen/Dense>
using namespace std;  
using namespace Eigen;
//g++ csv.cpp -o csv -I/download/eigen
namespace csv
{   
    template <class Type>  
    Type stringToNum(const string& str)
    {  
        istringstream iss(str);  
        Type num;  
        iss >> num;  
        return num;      
    }  

    vector<VectorXd*> get_X(string path)  
    { 
        // 读文件  
        const char *char_path = path.data();
        ifstream inFile(char_path, ios::in);  
        string lineStr;  
        static vector<VectorXd*> X;  
        if (X.size()>0)
        {
            X.clear();
        }
        int line = 0;
        while (getline(inFile, lineStr))  
        {  
            // 打印整行字符串  
            //cout << lineStr << endl;  
            // 存成二维表结构  
            stringstream ss(lineStr);  
            string str;  
            VectorXd x(10);
            // 按照逗号分隔
            int dim = 0;  
            while (getline(ss, str, ',') && dim<9)  
            {
                //lineArray.push_back(str);  
                //cout<<stringToNum<double>(str)<<endl;
                x[dim] = stringToNum<double>(str);
                dim++;
            }
            //VectorXd* p_x = new VectorXd(10);
            //(*p_x).array() = x.array();    
            //cout<<x[0]<<endl;
            //X.push_back(p_x);
            X.push_back(new VectorXd(10));
            (*(X[line])).array() = x.array();
            line++;  
        }  
        return X;
    } 

    vector<int> get_y(string path)  
    {  
        // 读文件  
        const char *char_path = path.data();
        ifstream inFile(char_path, ios::in);  
        string lineStr;  
        static vector<int> y;
        if (y.size()>0)
        {
            y.clear();
        }  
        int line = 0;
        while (getline(inFile, lineStr))  
        {  
            // 打印整行字符串  
            //cout << lineStr << endl;  
            // 存成二维表结构  
            stringstream ss(lineStr);  
            string str;  
            int y_i;
            // 按照逗号分隔
            int dim = 0;  
            while (getline(ss, str, ',') && dim<11)  
            {
                //lineArray.push_back(str);  
                //cout<<stringToNum<double>(str)<<endl;
                y_i = int(stringToNum<float>(str));
                dim++;
            }
            y.push_back(y_i);
            line++;  
        }  
        return y;
    }
}


//test
/*
int main(int argc, char const *argv[])
{
    vector<VectorXd*> X = csv::get_X("./new.csv");
    for (std::vector<VectorXd*>::iterator iter = X.begin(); iter != X.end(); ++iter )
    {
        cout<<(*(*iter))[0]<<endl;
    }
    vector<int> y = csv::get_y("./d.csv");
    for (std::vector<int>::iterator iter = y.begin(); iter != y.end(); ++iter )
    {
        cout<<*iter<<endl;
    }
    //string str =  "2.620444107055664062e+01";
    //cout<<csv::stringToNum<double>(str)<<endl;
    return 0;
}*/

猜你喜欢

转载自blog.csdn.net/hanss2/article/details/80350688