C++ implements logistic regression algorithm

The sample is read from the csv file

Iterate 10 times with a learning rate of 0.001

The result is:


Sample set:

x1    x2      y
34.6 78 0
30.2 43.8 0
35.8 72.9 0
60.1 86.3 1
79 75.3 1
45 56.3 0
61.1 96.5 1
75 46.5 1
76 87.4 1
84.4 43.5 1
95.8 38.2 0
75 30.6 0

#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include <utility>

using namespace std;

vector<vector<double>> v_data;


double h_theta(double z)
{
    double h = 1 / (1 + (exp(-z)));
    return h;
}
double stringToDouble(string num)
{
    bool minus = false; //whether the mark is negative  
    string real = num; //real represents the absolute value of num
    if (num.at(0) == '-')
    {
        minus = true;
        real = num.substr(1, num.size() - 1);
    }

    char c;
    int i = 0;
    double result = 0.0, dec = 10.0;
    bool isDec = false; // mark whether there are decimals
    unsigned long size = real.size();
    while (i < size)
    {
        c = real.at(i);
        if (c == '.')
        {//Include decimals
            isDec = true;
            i++;
            continue;
        }
        if (!isDec)
        {
            result = result * 10 + c - '0';
        }
        else
        {//Enter this branch after identifying the decimal point
            result = result + (c - '0') / dec;
            dec *= 10;
        }
        i++;
    }

    if (minus == true) {
        result = -result;
    }

    return result;
}
string Trim(string& str)
{
    //str.find_first_not_of(" \t\r\n"), starting from index 0 in the string str, returns the position that does not match "\t\r\n" for the first time  
    str.erase(0, str.find_first_not_of(" \t\r\n"));
    str.erase(str.find_last_not_of(" \t\r\n") + 1);
    return str;
}

void read_csv(string filename)
{

    ifstream fin(filename); //Open file stream operation  
    string line;
    while (getline(fin, line)) //The entire line is read, the newline character "\n" is distinguished, and the reading is terminated when the end-of-file flag eof is encountered  
    {
        cout << "Original string:" << line << endl; //The entire line is output  
        istringstream sin(line); //Read the entire line of string line into the string stream istringstream  
        vector<string> fields; //declare a string vector  
        string field;
        while (getline(sin, field, ',')) //Read the characters in the string stream sin into the field string, with commas as separators  
        {
            fields.push_back(field); //Add the string just read to the vector fields  
        }
        pair<double, double> p;
        string one = Trim(fields[0]); //Clear the invalid characters in the first element of the vector fields and assign it to the variable name  
        string two = Trim(fields[1]); //Clear the invalid characters of the second element in the vector fields and assign it to the variable age  
        string three = Trim(fields[2]); //Clear the invalid characters in the third element of the vector fields and assign it to the variable birthday  
        cout << "处理之后的字符串:" << one << "\t" << two << "\t" << three << endl;

        double frist_data = stringToDouble(one);
        double second_data = stringToDouble(two);
        double three_data = stringToDouble(three);
        vector<double> v_rows;
        v_rows.push_back(1);
        v_rows.push_back(frist_data);
        v_rows.push_back(second_data);
        v_rows.push_back(three_data);

        v_data.push_back(v_rows);

    }
    //return EXIT_SUCCESS;
}

void Gradient()
{

    double study_alpha = 0.001;
    int v_data_size = v_data.size();

    double *theta = new double[3];
    for (int i = 0;i < 3;i++)
    {
        theta[i] = 0;
    }

    int step_time = 0;
    int flag = 0;
    while (step_time < 10)
    {
        double *theta_copy = new double[3];

        for (int j = 0;j < 3;j++)
        {

            double total = 0.0;

            for (int i = 1;i <=v_data_size;i++)
            {
                double z = 0.0;

                for (int h = 0;h < 3;h++)
                {
                    z = z + theta[h] * v_data[i-1][h];
                }

              //  cout << "z value" << z << endl;
                total = total + (h_theta (z) - v_data [i-1] [3]) * v_data [i-1] [j];
            }


         //  cout << "这是total" << total << endl;

            theta_copy[j] = theta[j] - (1.0/v_data.size())*study_alpha*total;
            //cout << "j" <<j<<endl;
        }

        for (int p = 0;p < 3;p++)
        {
            theta[p] = theta_copy[p];
            //cout << "theta" << theta[p] << endl;
        }

        step_time++;

        delete[] theta_copy;


    }
    for (int t = 0;t < 3;t++)
    {
        cout << "theta" << t << " : " << theta[t] << endl;
    }



}
intmain()
{
    string filename = "sample.csv";

    read_csv(filename);

    Gradient();


}



Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=324637889&siteId=291194637