The sample is read from the csv file
Iterate 10 times with a learning rate of 0.001
The result is:
Sample set:
x1 x2 y
34.6 78 0
30.2 43.8 0
35.8 72.9 0
60.1 86.3 1
79 75.3 1
45 56.3 0
61.1 96.5 1
75 46.5 1
76 87.4 1
84.4 43.5 1
95.8 38.2 0
75 30.6 0
#include <iostream> #include <fstream> #include <sstream> #include <string> #include <vector> #include <utility> using namespace std; vector<vector<double>> v_data; double h_theta(double z) { double h = 1 / (1 + (exp(-z))); return h; } double stringToDouble(string num) { bool minus = false; //whether the mark is negative string real = num; //real represents the absolute value of num if (num.at(0) == '-') { minus = true; real = num.substr(1, num.size() - 1); } char c; int i = 0; double result = 0.0, dec = 10.0; bool isDec = false; // mark whether there are decimals unsigned long size = real.size(); while (i < size) { c = real.at(i); if (c == '.') {//Include decimals isDec = true; i++; continue; } if (!isDec) { result = result * 10 + c - '0'; } else {//Enter this branch after identifying the decimal point result = result + (c - '0') / dec; dec *= 10; } i++; } if (minus == true) { result = -result; } return result; } string Trim(string& str) { //str.find_first_not_of(" \t\r\n"), starting from index 0 in the string str, returns the position that does not match "\t\r\n" for the first time str.erase(0, str.find_first_not_of(" \t\r\n")); str.erase(str.find_last_not_of(" \t\r\n") + 1); return str; } void read_csv(string filename) { ifstream fin(filename); //Open file stream operation string line; while (getline(fin, line)) //The entire line is read, the newline character "\n" is distinguished, and the reading is terminated when the end-of-file flag eof is encountered { cout << "Original string:" << line << endl; //The entire line is output istringstream sin(line); //Read the entire line of string line into the string stream istringstream vector<string> fields; //declare a string vector string field; while (getline(sin, field, ',')) //Read the characters in the string stream sin into the field string, with commas as separators { fields.push_back(field); //Add the string just read to the vector fields } pair<double, double> p; string one = Trim(fields[0]); //Clear the invalid characters in the first element of the vector fields and assign it to the variable name string two = Trim(fields[1]); //Clear the invalid characters of the second element in the vector fields and assign it to the variable age string three = Trim(fields[2]); //Clear the invalid characters in the third element of the vector fields and assign it to the variable birthday cout << "处理之后的字符串:" << one << "\t" << two << "\t" << three << endl; double frist_data = stringToDouble(one); double second_data = stringToDouble(two); double three_data = stringToDouble(three); vector<double> v_rows; v_rows.push_back(1); v_rows.push_back(frist_data); v_rows.push_back(second_data); v_rows.push_back(three_data); v_data.push_back(v_rows); } //return EXIT_SUCCESS; } void Gradient() { double study_alpha = 0.001; int v_data_size = v_data.size(); double *theta = new double[3]; for (int i = 0;i < 3;i++) { theta[i] = 0; } int step_time = 0; int flag = 0; while (step_time < 10) { double *theta_copy = new double[3]; for (int j = 0;j < 3;j++) { double total = 0.0; for (int i = 1;i <=v_data_size;i++) { double z = 0.0; for (int h = 0;h < 3;h++) { z = z + theta[h] * v_data[i-1][h]; } // cout << "z value" << z << endl; total = total + (h_theta (z) - v_data [i-1] [3]) * v_data [i-1] [j]; } // cout << "这是total" << total << endl; theta_copy[j] = theta[j] - (1.0/v_data.size())*study_alpha*total; //cout << "j" <<j<<endl; } for (int p = 0;p < 3;p++) { theta[p] = theta_copy[p]; //cout << "theta" << theta[p] << endl; } step_time++; delete[] theta_copy; } for (int t = 0;t < 3;t++) { cout << "theta" << t << " : " << theta[t] << endl; } } intmain() { string filename = "sample.csv"; read_csv(filename); Gradient(); }