c++实现逻辑回归算法

样本从csv文件中读出

迭代10次,学习率为0.001

结果为:


样本集:

x1    x2      y
34.6 78 0
30.2 43.8 0
35.8 72.9 0
60.1 86.3 1
79 75.3 1
45 56.3 0
61.1 96.5 1
75 46.5 1
76 87.4 1
84.4 43.5 1
95.8 38.2 0
75 30.6 0

#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include <utility>

using namespace std;

vector<vector<double>> v_data;


double h_theta(double z)
{
    double h = 1 / (1 + (exp(-z)));
    return h;
}
double stringToDouble(string num)
{
    bool minus = false;      //标记是否是负数  
    string real = num;       //real表示num的绝对值
    if (num.at(0) == '-')
    {
        minus = true;
        real = num.substr(1, num.size() - 1);
    }

    char c;
    int i = 0;
    double result = 0.0, dec = 10.0;
    bool isDec = false;       //标记是否有小数
    unsigned long size = real.size();
    while (i < size)
    {
        c = real.at(i);
        if (c == '.')
        {//包含小数
            isDec = true;
            i++;
            continue;
        }
        if (!isDec)
        {
            result = result * 10 + c - '0';
        }
        else
        {//识别小数点之后都进入这个分支
            result = result + (c - '0') / dec;
            dec *= 10;
        }
        i++;
    }

    if (minus == true) {
        result = -result;
    }

    return result;
}
string Trim(string& str)
{
    //str.find_first_not_of(" \t\r\n"),在字符串str中从索引0开始,返回首次不匹配"\t\r\n"的位置  
    str.erase(0, str.find_first_not_of(" \t\r\n"));
    str.erase(str.find_last_not_of(" \t\r\n") + 1);
    return str;
}

void read_csv(string filename)
{

    ifstream fin(filename); //打开文件流操作  
    string line;
    while (getline(fin, line))   //整行读取,换行符“\n”区分,遇到文件尾标志eof终止读取  
    {
        cout << "原始字符串:" << line << endl; //整行输出  
        istringstream sin(line); //将整行字符串line读入到字符串流istringstream中  
        vector<string> fields; //声明一个字符串向量  
        string field;
        while (getline(sin, field, ',')) //将字符串流sin中的字符读入到field字符串中,以逗号为分隔符  
        {
            fields.push_back(field); //将刚刚读取的字符串添加到向量fields中  
        }
        pair<double, double> p;
        string one = Trim(fields[0]); //清除掉向量fields中第一个元素的无效字符,并赋值给变量name  
        string two = Trim(fields[1]); //清除掉向量fields中第二个元素的无效字符,并赋值给变量age  
        string three = Trim(fields[2]); //清除掉向量fields中第三个元素的无效字符,并赋值给变量birthday  
        cout << "处理之后的字符串:" << one << "\t" << two << "\t" << three << endl;

        double frist_data = stringToDouble(one);
        double second_data = stringToDouble(two);
        double three_data = stringToDouble(three);
        vector<double> v_rows;
        v_rows.push_back(1);
        v_rows.push_back(frist_data);
        v_rows.push_back(second_data);
        v_rows.push_back(three_data);

        v_data.push_back(v_rows);

    }
    //return EXIT_SUCCESS;
}

void Gradient()
{

    double study_alpha = 0.001;
    int v_data_size = v_data.size();

    double *theta = new double[3];
    for (int i = 0;i < 3;i++)
    {
        theta[i] = 0;
    }

    int step_time = 0;
    int flag = 0;
    while (step_time < 10)
    {
        double *theta_copy = new double[3];

        for (int j = 0;j < 3;j++)
        {

            double total = 0.0;

            for (int i = 1;i <=v_data_size;i++)
            {
                double z = 0.0;

                for (int h = 0;h < 3;h++)
                {
                    z = z + theta[h] * v_data[i-1][h];
                }

              //  cout << "z value" << z << endl;
                total = total + (h_theta(z) - v_data[i-1][3])*v_data[i-1][j];
            }


         //  cout << "这是total" << total << endl;

            theta_copy[j] = theta[j] - (1.0/v_data.size())*study_alpha*total;
            //cout << "j" <<j<<endl;
        }

        for (int p = 0;p < 3;p++)
        {
            theta[p] = theta_copy[p];
            //cout << "theta" << theta[p] << endl;
        }

        step_time++;

        delete[] theta_copy;


    }
    for (int t = 0;t < 3;t++)
    {
        cout << "theta" << t << " : " << theta[t] << endl;
    }



}
int main()
{
    string filename = "sample.csv";

    read_csv(filename);

    Gradient();


}



猜你喜欢

转载自blog.csdn.net/u014133104/article/details/80019261
今日推荐