用C++搭建三层神经网络

版权声明:我是南七小僧,微信: to_my_love ,2020年硕士毕业,寻找 自然语言处理,图像处理,软件开发等相关工作,欢迎交流思想碰撞。 https://blog.csdn.net/qq_25439417/article/details/85611483

这里写图片描述

C++写三层神经网络:

整体网络框架搭建:

bp.h

#include <vector>

// #define LAYER 3
// #define NUM 10

const int NUM = 10;
const int LAYER = 3;

using namespace std;
struct Data
{
    vector<double> x;
    vector<double> y;    
};

class BP{
    private:
        int in_num;
        int out_num;
        int hd_num;

        vector<Data> data;
        vector<vector<double>> testdata;
        vector<vector<double>> result;
        int rowLen;
        int restrowLen;

        double w[LAYER][NUM][NUM];
        double b[LAYER][NUM];

        double x[LAYER][NUM];
        double d[LAYER][NUM];

    private:
        void InitNetWork();
        void GetNums();
        void ForwordTransfer();
        void BackwordTransfer();
        void CalcDelta(int);
        void UpdateNetwork();
        double GetError(int);
        double GetAcc();
        double Sigmoid(double);
        // void split(char)

    public:
        void GetData();
        void Train();
        vector<double> Predict(vector<double>);


};

框架具体实现代码

bp.cpp

#include <string.h>  
#include <stdio.h>  
#include <math.h>  
#include <assert.h>  
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <bp.h>

using namespace std;

const int ITERS = 100;

void BP::GetData(vector<Data> datalist){
    data = datalist;
}

void BP::Train(){
    printf("======================");
    printf("==神经网络要跑起来了==");
    printf("======================");
    BP::InitNetWork();

    int num = data.size();
    for(int iter = 0;iter <= ITERS;iter++)
    {
        for(int cnt = 0;cnt < num; cnt++)
        {
            for(int i = 0;i<BP::in_num;i++)
            {
                BP::x[0][i] = data.at(cnt).x[i];
            }

            while(1)
            {
                BP::ForwordTransfer();
                if(GetError(cnt)>=ERROR)
                {
                    BP::BackwordTransfer(cnt);
                }
            }
        }
        printf("第%d轮循环结束",iter);
        double acc = GetAcc();
        printf("ACC%d",acc);
        if(acc<ACCU)
        {
            break;
        }
    }
    printf("BP结束");
}

void BP::ForwordTransfer(){
    for(int j =0;j<hd_num;j++)
    {
        double t = 0;
        for(int i=0;i<in_num;i++){
            t += w[1][i][j] * x[0][i];
        }
        t += b[1][j];
        x[1][j] = BP::Sigmoid(t);
    }

    for(int j = 0;j<out_num;j++){
        double t = 0;
        for(int i = 0 ;i<hd_num;i++){
            t += w[2][i][j] * x[1][i];
        }
        t += b[2][j];
        x[2][j] = Sigmoid(t);
    }
}

double BP::GetError(int cnt){
    double ans = 0;
    for(int i = 0 ; i < out_num;i++){
        ans += 0.5*(x[2][i] - data.at(cnt).y[i])*(x[2][i] - data.at(cnt).y[i]);
    }
    return ans;
}

void BP::BackwordTransfer(int cnt){
    CalcDelta(cnt);
    UpdateNetwork();
}

double BP::GetAcc(){
    double ans = 0 ;
    int num = data.size();
    for(int i =0;i<num;i++){
        int m = data.at(i).x.size();
        for(int j =0;j<m;j++){
            x[0][j] = data.at(i).x[j];

        }
        ForwordTransfer();
        int n = data.at(i).y.size();
        for (int j=0;j<n;j++){
              ans += 0.5*(x[2][i] - data.at(i).y[i])*(x[2][i] - data.at(i).y[i]);
        }
    }
    return ans/num;
}

void BP::CalcDelta(int cnt){
    for(int i =0 ;i<out_num;i++){
        d[2][i] = (x[2][i] - data.at(cnt).y[i] )*x[2][i]*(A - x[2][i]) / (A*B);
    }
    for (int i =0;i<hd_num;i++){
        double t= 0;
        for(int j =0;j<out_num;j++){
            t += w[2][i][j] * d[2][j];

        }
        d[1][i] = t*x[1][i] * (A-x[1][i])/(A*B);
    }
}

void BP::UpdateNetwork(){
    for (int i=0;i<hd_num;i++){
        for(int j =0;j<out_num;j++){
            w[2][i][j] -= ETA_W*d[2][j]*x[1][i];
        }
    }

    for(int i=0;i<out_num;i++){
        b[2][i] -= ETA_B*d[2][i];
    }

    for (int i =0;i<in_num;i++){

        for(int j =0;j<hd_num;j++){
            w[1][i][j] = ETA_W*x[0][i]*d[1][j];
        }
    }
    for(int i=0;i<hd_num;i++){
        b[1][i] -= ETA_B*d[1][i];
    }
}

double BP::Sigmoid(double x){
    return A/(1+exp(-x/B));
}

------------------------------

运行版本:

#include <string.h>  
#include <stdio.h>  
#include <math.h>  
#include <assert.h>  
#include <cstdlib>
#include <fstream>
#include <iostream>
#include "bp.h"

using namespace std;

const int ITERS = 100;

void BP::GetData(vector<Data> datalist){
    data = datalist;
}

void BP::Train(){
    printf("======================\n");
    printf("=====net will run=====\n");
    printf("======================\n");
    BP::InitNetWork();

    int num = data.size();
    printf("size%d\n ",num);
    for(int iter = 0;iter <= ITERS;iter++)
    {
        for(int cnt = 0;cnt < num; cnt++)
        {
            printf("No%d cnt\n",cnt);
            for(int i = 0;i<BP::in_num;i++)
            {
                // printf("")
                BP::x[0][i] = data.at(cnt).x[i];
            }

            while(1)
            {
                BP::ForwordTransfer();
                if(GetError(cnt)>=ERROR)
                {
                    BP::BackwordTransfer(cnt);
                }else{
                    break;
                }
            }
        }
        printf("No%d Epochs\n",iter);
        double acc = GetAcc();
        printf("ACC%d\n",acc);
        if(acc<ACCU)
        {
            break;
        }
    }
    printf("\n\nBP End\n");
}

void BP::ForwordTransfer(){
    for(int j =0;j<hd_num;j++)
    {
        double t = 0;
        for(int i=0;i<in_num;i++){
            t += w[1][i][j] * x[0][i];
        }
        t += b[1][j];
        x[1][j] = BP::Sigmoid(t);
    }

    for(int j = 0;j<out_num;j++){
        double t = 0;
        for(int i = 0 ;i<hd_num;i++){
            t += w[2][i][j] * x[1][i];
        }
        t += b[2][j];
        x[2][j] = Sigmoid(t);
    }
}

double BP::GetError(int cnt){
    double ans = 0;
    for(int i = 0 ; i < out_num;i++){
        ans += 0.5*(x[2][i] - data.at(cnt).y[i])*(x[2][i] - data.at(cnt).y[i]);
    }
    return ans;
}

void BP::BackwordTransfer(int cnt){
    CalcDelta(cnt);
    UpdateNetwork();
}

double BP::GetAcc(){
    double ans = 0 ;
    int num = data.size();
    for(int i =0;i<num;i++){
        int m = data.at(i).x.size();
        for(int j =0;j<m;j++){
            x[0][j] = data.at(i).x[j];

        }
        ForwordTransfer();
        int n = data.at(i).y.size();
        for (int j=0;j<n;j++){
              ans += 0.5*(x[2][i] - data.at(i).y[i])*(x[2][i] - data.at(i).y[i]);
        }
    }
    return ans/num;
}

void BP::CalcDelta(int cnt){
    for(int i =0 ;i<out_num;i++){
        d[2][i] = (x[2][i] - data.at(cnt).y[i] )*x[2][i]*(A - x[2][i]) / (A*B);
    }
    for (int i =0;i<hd_num;i++){
        double t= 0;
        for(int j =0;j<out_num;j++){
            t += w[2][i][j] * d[2][j];

        }
        d[1][i] = t*x[1][i] * (A-x[1][i])/(A*B);
    }
}

void BP::UpdateNetwork(){
    for (int i=0;i<hd_num;i++){
        for(int j =0;j<out_num;j++){
            w[2][i][j] -= ETA_W*d[2][j]*x[1][i];
        }
    }

    for(int i=0;i<out_num;i++){
        b[2][i] -= ETA_B*d[2][i];
    }

    for (int i =0;i<in_num;i++){

        for(int j =0;j<hd_num;j++){
            w[1][i][j] = ETA_W*x[0][i]*d[1][j];
        }
    }
    for(int i=0;i<hd_num;i++){
        b[1][i] -= ETA_B*d[1][i];
    }
}

double BP::Sigmoid(double x){
    return A/(1+exp(-x/B));
}

//初始化网络  
void BP::InitNetWork()
{
    memset(w, 0, sizeof(w));      //初始化权值和阀值为0,也可以初始化随机值  
    memset(b, 0, sizeof(b));
}

int main(int argc,char* argsc[])
{
    vector<Data> datalist;
    Data* data = new Data;
    data->x.push_back(1);
    data->x.push_back(2);
    data->x.push_back(3);
    data->x.push_back(4);
    data->x.push_back(5);
    data->y.push_back(1);

    datalist.push_back(*data);
    data = new Data;
    data->x.push_back(6);
    data->x.push_back(7);
    data->x.push_back(8);
    data->x.push_back(9);
    data->x.push_back(10);    
    data->y.push_back(1);
    datalist.push_back(*data);



    BP* net = new BP;
    net->GetData(datalist);
    net->Train();

    system("pause");
    return 0;


}

猜你喜欢

转载自blog.csdn.net/qq_25439417/article/details/85611483