Handwritten watermelon book bp neural network mnist10 c# version

This article is written according to the formula given in the fifth chapter of the watermelon book. The book gives the implementation logic of the fully connected neural network. On this basis, this article writes a case of Mnist10 handwriting 10 numbers. There are also some other handwritten examples on the Internet. . The demo is not written in UE but in unity, which is convenient and easy to check for errors.
This case is only for learning, and the blogger is just learning some machine learning knowledge by himself in his spare time. Welcome to leave suggestions~

Test results:
insert image description here

Source code download address:
https://download.csdn.net/download/grayrail/87802798

1. Meaning of symbols

First straighten out the meanings of the symbols in Chapter 5 of the Watermelon Book:

  • x x raw value of x input
  • yy raw value of y output
  • y ^ \hat{y}y^ The predicted value of the output
  • η eta, learning rate, the value is between 0-1
  • d d d The number of neuronsin the input layer
  • q q q The number of hidden layerneurons
  • hel The number of neuronsin the output layer
  • i i i The subscriptindexassociated withthe input layer
  • h h h Hidden layerrelatedsubscriptindex
  • not a wordj The subscriptindexassociated withthe output layer
  • v i , h v_{i,h} vi,h Connection weights from input layer to hidden layer neurons
  • w i , h w_{i,h} wi,h Connection weights from hidden layer neurons to output layer neurons
  • γ h γ_{h}ch Gamma, the threshold of hidden layer neurons (threshold is b in y=ax+b)
  • θ j θ_{j} ij theta, the threshold of the output layer neurons (the threshold is b in y=ax+b)
  • α h α_{h}ah alpha, the input received by the hidden layer , the formula in the book α h = ∑ i = 1 dvi , hxi α_{h}=\sum_{i=1}^{d} v_{i,h}x_{i}ah=i=1dvi,hxi
  • β j β_{j}bj beta, the input received by the output layer , the formula in the book β j = ∑ h = 1 qwh , jbh β_{j}=\sum_{h=1}^{q} w_{h,j}b_{h}bj=h=1qwh,jbh
  • b h b_{h} bh Store the value of the hidden layer after the activation function, and the length of the array is the length of the hidden layer
  • y h a t s yhats y ha t s storesthe output layerafter the activation function, the length of the array isthe output layer, not in the book, but such a collection is needed
  • g j g_{j} gj Backpropagation of g items, the length of the array is the length of the output layer
  • ej e_{j}ej Backpropagation e item, the length of the array is the length of the hidden layer

2. Forward Propagation

Chapter 5 of the watermelon book directly talks about backpropagation, so before that, briefly talk about forward propagation.

输入层(d)
x1
x2
x3
隐层(q)
1
x1*vi1
x2*vi1
x3*vi1
2
x1*vi2
x2*vi2
x3*vi2
...
输出层(l)
(隐层1*wh1,隐层2*wh1...)yhat1
同上yhat2
同上yhat3
同上yhat4

Take the above figure as an example, the dimension of the input layer is [3], the dimension of the hidden layer is [n,3], and the dimension of the output layer is [4,n], so the final output dimension is [4]. The input layer is usually the original input information, and the hidden layer is used for the calculation of hyperparameters and intermediate links. The dimension of the hidden layer is n * m, m is the dimension of the data in the input layer itself, and n can be understood as n kinds of possibilities (the blogger himself understanding) , for example, the second dimension of the hidden layer is 50, then it is assumed that 50 possibilities are trained.

Based on this, the forward propagation process is as follows:

  1. Initialize the connection weight v of the hidden layer. Dimension 1 is the length of the input data. Dimension 2 is how many possibilities there are. Dimension 2 can fill in a suitable value by itself.
  2. Loop with the length q of the hidden layer, and perform point multiplication and summation on the sets under each h dimension, and each element xi x_{i}ximultiply by vi , h v_{i,h}vi,h(Immediate watermelon book: α h = ∑ i = 1 dvi , hxi α_{h}=\sum_{i=1}^{d} v_{i,h}x_{i}ah=i=1dvi,hxi), then subtract the threshold γ and pass it into the activation function, and write it into the b set.
  3. Circulate with the length l of the output layer, and store all possible dot product results under each l dimension, that is, in the watermelon book: β j = ∑ h = 1 qwh , jbh β_{j}=\sum_{h =1}^{q} w_{h,j}b_{h}bj=h=1qwh,jbh. Then subtract the threshold θ and pass it into the activation function, and write it into the yhats collection.
  4. You can add a softmax operation to yhats to filter the maximum value in the collection and return the subscript index, that is, the output result.

3. Backpropagation (BackPropagation)

One of the difficulties of backpropagation is chain derivation. The watermelon book has already written the derivation process for us. Here I will talk about tips first, and then sort out the backpropagation process.

3.1 About the loss function E

insert image description here
The loss function E is only mentioned once in the book. In the follow-up operation, the value directly subtracted by yhat-y is brought into the partial derivative formula of sigmoid, which is a bit confusing.
Later, after consulting ChatGTP and some other articles, I learned that the direct difference is the result of 1/2 MSE partial derivative. If 1/2 MSE is not used as the loss function, replace yhat-y with other formulas.

3.2 Formula sorting

The formulas in the book are a bit messy, and the following is a sorting diagram in order:

3.2.1 Partial derivative of W with respect to E

insert image description here

3.2.2 Partial derivative of V with respect to E

insert image description here
This part should be the essence, I don't know it very well, so I don't make any comments.

3.2.3 Process combing

Based on this, the process of backpropagation is as follows:

  1. alone for ggThe g item is evaluated and stored in the array, and the real valueyyy and predicted valuey ^ \hat{y}y^Bring into the derivative formula of Sigmoid
  2. alone to eeThe e item is evaluated and stored in the array
  3. Calculate the threshold (bias bias) θ θθ 's delta value and assignment,Δ θ j = − η gj Δθ_{j}=-ηg_{j}D ij=the gj
  4. calculate wwThe delta value of w and assignment, Δ wh , j = η gjbh Δw_{h,j}=ηg_{j}b_{h}Δwh,j=the gjbh(Note that b here is the final output of the previous layer, not bias)
  5. Calculate the threshold (bias bias) γ γγ的delta量例见值,Δ γ j = − η eh Δγ_{j}=-ηe_{h}D cj=ηeh
  6. calculate vvThe delta value of v and assignment, Δ vi , h = η ehxi Δv_{i,h}=ηe_{h}x_{i}v _i,h=ηehxi

3.3 Optimizer

When I read the book later, I found that in addition to forward propagation and back propagation, the neural network also has two more important things, the loss function and the optimizer . The loss function calculates the gradient of each hyperparameter, and the optimizer decides how to apply these gradients.

For example, the c# code given at the end of this article uses momentum, which is also an optimizer. For example, the commonly used Adam optimizer in some demos of Tensorflow is the practice of RMSProp+momentum, which roughly means that the magnitude of momentum changes dynamically, and even the learning rate is added. A hyperparameter to control.

To understand the RMSProp optimizer, you can read chapter 7.6.1 of "Hands-on Deep Learning", which has related implementations.

4. Code implementation

Take the Mnist case as an example. In this case, a neural network is used to recognize 0-9 handwritten digits in a picture within 28x28 pixels. Next, the C# version of the Mnist code implementation is given. There are three modes after the script is loaded:
insert image description here

  • Draw Image Mode is used to draw 0-9 numbers. After drawing, click the button in OnGUI to save. It is stored in the Assets directory by default. You need to manually right-click to refresh the Project panel, and then rename it according to the rules
  • User Mode uses the trained neural network for digital recognition (there is no caching function, it needs to be trained several times manually)
  • Train Mode training mode, fill in the image path in DataPath, the image format first takes the prefix, for example: 3_04, indicating that the real value of this image is 3, which is the fourth candidate image

In this case, momentum, softmax, dropout, and initial random value range modification (-1,1) are added to the watermelon book. Softmax uses the formula provided in the book "Introduction to Deep Learning Based on PYTHON Theory and Implementation". Running results after some rounds of training:
insert image description here

The c# code is as follows:

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using UnityEngine;

public class TestMnist10 : MonoBehaviour
{
    
    
    public enum EMode {
    
     Train, DrawImage, User }

    const float kDropoutProb = 0.4f;

    /// <summary>
    /// d个输入神经元
    /// </summary>
    int d;

    /// <summary>
    /// q个隐层神经元
    /// </summary>
    int q;

    /// <summary>
    /// l个输出神经元
    /// </summary>
    int l;

    /// <summary>
    /// 输入层原始值
    /// </summary>
    float[] x;

    /// <summary>
    /// 输入层到隐层神经元的连接权
    /// </summary>
    float[][] v;

    /// <summary>
    /// 缓存上一次的v权值
    /// </summary>
    float[][] lastVMomentum;

    /// <summary>
    /// 隐层神经元到输出层神经元的连接权
    /// </summary>
    float[][] w;

    /// <summary>
    /// 缓存上一次的w权值
    /// </summary>
    float[][] lastWMomentum;

    float[] wDropout;

    /// <summary>
    /// 反向传播g项
    /// </summary>
    float[] g;

    /// <summary>
    /// 反向传播e项
    /// </summary>
    float[] e;

    /// <summary>
    /// 隐层接收到的输入(通常List长度是隐层长度)
    /// </summary>
    List<float> b;

    /// <summary>
    /// 输出层接收到的输入(通常List长度是输出层长度)
    /// </summary>
    List<float> yhats;

    /// <summary>
    /// 输出层神经元的阈值
    /// </summary>
    float[] theta;

    /// <summary>
    /// 隐层神经元的阈值
    /// </summary>
    float[] gamma;


    public void Init(int inputLayerCount, int hiddenLayerCount, int outputLayerCount)
    {
    
    
        d = inputLayerCount;
        q = hiddenLayerCount;
        l = outputLayerCount;

        x = new float[inputLayerCount];
        b = new List<float>(1024);
        yhats = new List<float>(1024);

        e = new float[hiddenLayerCount];
        g = new float[outputLayerCount];

        v = GenDimsArray(typeof(float), new int[] {
    
     q, d }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[][];
        w = GenDimsArray(typeof(float), new int[] {
    
     l, q }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[][];
        wDropout = GenDimsArray(typeof(float), new int[] {
    
     l }, 0, null) as float[];

        lastVMomentum = GenDimsArray(typeof(float), new int[] {
    
     q, d }, 0, null) as float[][];
        lastWMomentum = GenDimsArray(typeof(float), new int[] {
    
     l, q }, 0, null) as float[][];

        theta = GenDimsArray(typeof(float), new int[] {
    
     l }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[];
        gamma = GenDimsArray(typeof(float), new int[] {
    
     q }, 0, () => UnityEngine.Random.Range(-1f, 1f)) as float[];
    }
    public void ForwardPropagation(float[] input, out int output)
    {
    
    
        x = input;

        for (int jIndex = 0; jIndex < l; ++jIndex)
        {
    
    
            var r = UnityEngine.Random.value < kDropoutProb ? 1f : 0f;
            wDropout[jIndex] = r;
        }

        b.Clear();
        for (int hIndex = 0; hIndex < q; ++hIndex)
        {
    
    
            var sum = 0f;
            for (int iIndex = 0; iIndex < d; ++iIndex)
            {
    
    
                var u = input[iIndex] * v[hIndex][iIndex];
                sum += u;
            }
            var alpha = sum - gamma[hIndex];

            var r = Sigmoid(alpha);

            b.Add(r);
        }

        yhats.Clear();
        for (int jIndex = 0; jIndex < l; ++jIndex)
        {
    
    
            var sum = 0f;
            for (int hIndex = 0; hIndex < q; ++hIndex)
            {
    
    
                var u = b[hIndex] * w[jIndex][hIndex];
                sum += u;
            }
            var beta = sum - theta[jIndex];

            var r = Sigmoid(beta);

            //实际使用时关闭Dropout,训练时打开
            if (_EnableDropout)
            {
    
    
                r *= wDropout[jIndex];
                r /= kDropoutProb;
            }

            yhats.Add(r);
        }

        var softmaxResult = Softmax(yhats.ToArray());
        for (int i = 0; i < yhats.Count; i++)
        {
    
    
            yhats[i] = softmaxResult[i];
        }

        int index = 0;
        float maxValue = yhats[0];
        for (int jIndex = 0; jIndex < l; ++jIndex)
        {
    
    
            if (yhats[jIndex] > maxValue)
            {
    
    
                maxValue = yhats[jIndex];
                index = jIndex;
            }
        }
        output = index;
    }
    public void BackPropagation(float[] correct)
    {
    
    
        const float kEta1 = 0.03f;
        const float kEta2 = 0.01f;

        const float kMomentum = 0.3f;

        for (int jIndex = 0; jIndex < l; ++jIndex)
        {
    
    
            var yhat = this.yhats[jIndex];
            var y = correct[jIndex];
            g[jIndex] = yhat * (1f - yhat) * (y - yhat);
        }

        for (int hIndex = 0; hIndex < q; ++hIndex)
        {
    
    
            var bh = b[hIndex];
            var sum = 0f;
            //这个for循环的内容,个人感觉是精妙之处,可以拿到别的神经元的梯度。
            for (int jIndex = 0; jIndex < l; ++jIndex)
                sum += w[jIndex][hIndex] * g[jIndex];
            e[hIndex] = bh * (1f - bh) * sum;
        }

        for (int jIndex = 0; jIndex < l; ++jIndex)
        {
    
    
            theta[jIndex] += -kEta1 * g[jIndex];
        }

        for (int hIndex = 0; hIndex < q; ++hIndex)
        {
    
    
            for (int jIndex = 0; jIndex < l; ++jIndex)
            {
    
    
                var bh = b[hIndex];
                var delta = kMomentum * lastWMomentum[jIndex][hIndex] + kEta1 * g[jIndex] * bh;

                //实际使用时关闭Dropout,训练时打开
                if (_EnableDropout)
                {
    
    
                    var dropout = wDropout[jIndex];
                    delta *= dropout;
                    delta /= kDropoutProb;
                }

                w[jIndex][hIndex] += delta;
                lastWMomentum[jIndex][hIndex] = delta;
            }
        }

        for (int hIndex = 0; hIndex < q; ++hIndex)
        {
    
    
            gamma[hIndex] += -kEta2 * e[hIndex];
        }

        for (int hIndex = 0; hIndex < q; ++hIndex)
        {
    
    
            for (int iIndex = 0; iIndex < d; ++iIndex)
            {
    
    
                var delta = kMomentum * lastVMomentum[hIndex][iIndex] + kEta2 * e[hIndex] * x[iIndex];

                v[hIndex][iIndex] += delta;
                lastVMomentum[hIndex][iIndex] = delta;
            }
        }
    }
    void Start()
    {
    
    
        Init(784, 64, 10);
    }

    EMode _Mode;
    int[] _DrawNumberImage;
    bool _EnableDropout;
    string _DataPath;

    float Sigmoid(float val)
    {
    
    
        return 1f / (1f + Mathf.Exp(-val));
    }
    float[] Softmax(float[] inputs)
    {
    
    
        float[] outputs = new float[inputs.Length];
        float maxInput = inputs.Max();

        for (int i = 0; i < inputs.Length; i++)
        {
    
    
            outputs[i] = Mathf.Exp(inputs[i] - maxInput);
        }

        float expSum = outputs.Sum();
        for (int i = 0; i < outputs.Length; i++)
        {
    
    
            outputs[i] /= expSum;
        }

        return outputs;
    }
    float[] GetOneHot(string input)
    {
    
    
        if (input.StartsWith("0"))
            return new float[] {
    
     1, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
        if (input.StartsWith("1"))
            return new float[] {
    
     0, 1, 0, 0, 0, 0, 0, 0, 0, 0 };
        if (input.StartsWith("2"))
            return new float[] {
    
     0, 0, 1, 0, 0, 0, 0, 0, 0, 0 };
        if (input.StartsWith("3"))
            return new float[] {
    
     0, 0, 0, 1, 0, 0, 0, 0, 0, 0 };
        if (input.StartsWith("4"))
            return new float[] {
    
     0, 0, 0, 0, 1, 0, 0, 0, 0, 0 };
        if (input.StartsWith("5"))
            return new float[] {
    
     0, 0, 0, 0, 0, 1, 0, 0, 0, 0 };
        if (input.StartsWith("6"))
            return new float[] {
    
     0, 0, 0, 0, 0, 0, 1, 0, 0, 0 };
        if (input.StartsWith("7"))
            return new float[] {
    
     0, 0, 0, 0, 0, 0, 0, 1, 0, 0 };
        if (input.StartsWith("8"))
            return new float[] {
    
     0, 0, 0, 0, 0, 0, 0, 0, 1, 0 };
        else
            return new float[] {
    
     0, 0, 0, 0, 0, 0, 0, 0, 0, 1 };
    }
    void Shuffle<T>(List<T> cardList)
    {
    
    
        int tempIndex = 0;
        T temp = default;
        for (int i = 0; i < cardList.Count; ++i)
        {
    
    
            tempIndex = UnityEngine.Random.Range(0, cardList.Count);
            temp = cardList[tempIndex];
            cardList[tempIndex] = cardList[i];
            cardList[i] = temp;
        }
    }
    /// <summary>
    /// 快速得到多维数组
    /// </summary>
    Array GenDimsArray(Type type, int[] dims, int deepIndex, Func<object> initFunc = null)
    {
    
    
        if (deepIndex < dims.Length - 1)
        {
    
    
            var sub_template = GenDimsArray(type, dims, deepIndex + 1, null);
            var current = Array.CreateInstance(sub_template.GetType(), dims[deepIndex]);

            for (int i = 0; i < dims[deepIndex]; ++i)
            {
    
    
                var sub = GenDimsArray(type, dims, deepIndex + 1, initFunc);
                current.SetValue(sub, i);
            }

            return current;
        }
        else
        {
    
    
            var arr = Array.CreateInstance(type, dims[deepIndex]);
            if (initFunc != null)
            {
    
    
                for (int i = 0; i < arr.Length; ++i)
                    arr.SetValue(initFunc(), i);
            }
            return arr;
        }
    }
    void OnGUI()
    {
    
    
        if (_DrawNumberImage == null)
            _DrawNumberImage = new int[784];

        GUILayout.BeginHorizontal();
        if (GUILayout.Button("Draw Image Mode"))
        {
    
    
            _Mode = EMode.DrawImage;

            Array.Clear(_DrawNumberImage, 0, _DrawNumberImage.Length);
        }
        if (GUILayout.Button("User Mode"))
        {
    
    
            _Mode = EMode.User;

            Array.Clear(_DrawNumberImage, 0, _DrawNumberImage.Length);
        }
        if (GUILayout.Button("Train Mode"))
        {
    
    
            _Mode = EMode.Train;
            _DataPath = Directory.GetCurrentDirectory() + "/TrainData";
        }
        GUILayout.EndHorizontal();

        var lastRect = GUILayoutUtility.GetLastRect();

        switch (_Mode)
        {
    
    
            case EMode.Train:
                {
    
    
                    GUILayout.BeginHorizontal();
                    GUILayout.Label("Data Path: ");
                    _DataPath = GUILayout.TextField(_DataPath);
                    GUILayout.EndHorizontal();

                    _EnableDropout = GUILayout.Button("dropout(" + (_EnableDropout ? "True" : "False") + ")")
                        ? !_EnableDropout : _EnableDropout;

                    if (GUILayout.Button("Train 10"))
                    {
    
    
                        var files = Directory.GetFiles(_DataPath);
                        List<(string, float[])> datas = new(512);
                        for (int i = 0; i < files.Length; ++i)
                        {
    
    
                            var strArr = File.ReadAllText(files[i]).Split(',');
                            datas.Add((Path.GetFileNameWithoutExtension(files[i]), Array.ConvertAll(strArr, m => float.Parse(m))));
                        }

                        for (int s = 0; s < 10; ++s)
                        {
    
    
                            Shuffle(datas);

                            for (int i = 0; i < datas.Count; ++i)
                            {
    
    
                                ForwardPropagation(datas[i].Item2, out int output);
                                UnityEngine.Debug.Log("<color=#00ff00> Input Number: " + datas[i].Item1 + " output: " + output + "</color>");
                                BackPropagation(GetOneHot(datas[i].Item1));
                                //break;
                            }
                        }
                    }
                }
                break;
            case EMode.DrawImage:
                {
    
    
                    lastRect.y += 50f;
                    var size = 20f;
                    var spacing = 2f;
                    var mousePosition = Event.current.mousePosition;
                    var mouseLeftIsPress = Input.GetMouseButton(0);
                    var mouseRightIsPress = Input.GetMouseButton(1);
                    var containSpacingSize = size + spacing;

                    for (int y = 0, i = 0; y < 28; ++y)
                    {
    
    
                        for (int x = 0; x < 28; ++x)
                        {
    
    
                            var rect = new Rect(lastRect.x + x * containSpacingSize, lastRect.y + y * containSpacingSize, size, size);
                            GUI.DrawTexture(rect, _DrawNumberImage[i] == 1 ? Texture2D.blackTexture : Texture2D.whiteTexture);

                            if (rect.Contains(mousePosition))
                            {
    
    
                                if (mouseLeftIsPress)
                                    _DrawNumberImage[i] = 1;
                                else if (mouseRightIsPress)
                                    _DrawNumberImage[i] = 0;
                            }

                            ++i;
                        }
                    }
                    if (GUILayout.Button("Save"))
                    {
    
    
                        File.WriteAllText(Directory.GetCurrentDirectory() + "/Assets/tmp.txt", string.Join(",", _DrawNumberImage));
                    }
                }
                break;
            case EMode.User:
                {
    
    
                    lastRect.y += 150f;
                    var size = 20f;
                    var spacing = 2f;
                    var mousePosition = Event.current.mousePosition;
                    var mouseLeftIsPress = Input.GetMouseButton(0);
                    var mouseRightIsPress = Input.GetMouseButton(1);
                    var containSpacingSize = size + spacing;

                    for (int y = 0, i = 0; y < 28; ++y)
                    {
    
    
                        for (int x = 0; x < 28; ++x)
                        {
    
    
                            var rect = new Rect(lastRect.x + x * containSpacingSize, lastRect.y + y * containSpacingSize, size, size);
                            GUI.DrawTexture(rect, _DrawNumberImage[i] == 1 ? Texture2D.blackTexture : Texture2D.whiteTexture);

                            if (rect.Contains(mousePosition))
                            {
    
    
                                if (mouseLeftIsPress)
                                    _DrawNumberImage[i] = 1;
                                else if (mouseRightIsPress)
                                    _DrawNumberImage[i] = 0;
                            }

                            ++i;
                        }
                    }
                    if (GUILayout.Button("Recognize"))
                    {
    
    
                        ForwardPropagation(Array.ConvertAll(_DrawNumberImage, m => (float)m), out int output);
                        Debug.Log("output: " + output);
                    }
                    break;
                }
        }
    }
}

reference article

Guess you like

Origin blog.csdn.net/grayrail/article/details/130769374