Use mllib to complete the mnist handwriting recognition task

Use mllib to complete the mnist handwriting recognition task

  1. Tips, restart the exited container through the restart command

    sudo docker restart <contain id>

    Please add a picture description

  2. Complete the identification task preparation

    1. Download the dataset from the following websites:

      MNIST Database of Handwritten Digits, Yann LeCun, Corinna Cortes, and Chris Burges

      The data set contains the following four compressed packages, which can be downloaded and decompressed to obtain the data set file:

      • t10k-images-idx3-ubyte.gz
      • t10k-labels-idx1-ubyte.gz
      • train-images-idx3-ubyte.gz
      • train-labels-idx1-ubyte.gz
    2. Convert the dataset file to a csv file through the following python program

      def convert(imgf, labelf, outf, n):
          f = open(imgf, "rb")
          o = open(outf, "w")
          l = open(labelf, "rb")
      
          f.read(16)
          l.read(8)
          images = []
      
          for i in range(n):
              image = [ord(l.read(1))]
              for j in range(28 * 28):
                  image.append(ord(f.read(1)))
              images.append(image)
      
          for image in images:
              o.write(",".join(str(pix) for pix in image) + "\n")
          f.close()
          o.close()
          l.close()
      
      
      # 数据集在 http://yann.lecun.com/exdb/mnist/ 下载
      convert("train-images.idx3-ubyte", "train-labels.idx1-ubyte",
              "mnist_train.csv", 60000)
      convert("t10k-images.idx3-ubyte", "t10k-labels.idx1-ubyte",
              "mnist_test.csv", 10000)
      

      Through this program, the following two files will be generated in the root directory:

      • mnist_train.csv
      • mnist_test.csv
    3. Convert csv files to libsvm files through the following python program

      import csv
      
      
      def execute(data, savepath):
      
          csv_reader = csv.reader(open(data))
          f = open(savepath, 'wb')
          for line in csv_reader:
              label = line[0]
              features = line[1:]
              libsvm_line = label + ' '
      
              for index, feature in enumerate(features):
                  libsvm_line += str(index + 1) + ':' + feature + ' '
              f.write(bytes(libsvm_line.strip() + '\n', 'UTF-8'))
      
          f.close()
      
      
      execute('mnist_train.csv', 'mnist_train.libsvm')
      execute('mnist_test.csv', 'mnist_test.libsvm')
      

      The program will generate the following two .libsvm files:

      • mnist_test.libsvm
      • mnist_train.libsvm
    4. Pass the dataset to the spark-master container through the shared directory.

    5. Enter spark-master

      sudo docker exec -it spark-master /bin/bash

      Please add a picture description

    6. open spark-shell

      spark-shell is located in the /spark/bin directory

      Use ./spark-shellthe command to enter spark-shell.

      Please add a picture description

  3. Complete the recognition task

    1. read training set

      val train = spark.read.format("libsvm").load("/data/mnist_train.libsvm")
      

      Please add a picture description

    2. read test set

      val test = 		spark.read.format("libsvm").load("/data/mnist_test.libsvm")
      

      Please add a picture description

    3. Define the network structure. If the computer performance is not good, you can reduce the parameters of the hidden layer.

      val layers = Array[Int](784, 784, 784, 10)
      

      Please add a picture description

    4. Import multi-layer perceptron and multi-class evaluator.

      import org.apache.spark.ml.classification.MultilayerPerceptronClassifier
      import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
      

      Please add a picture description

    5. Initialize the trainer with a multilayer perceptron.

      val trainer = new MultilayerPerceptronClassifier().setLayers(layers).setBlockSize(128).setSeed(1234L).setMaxIter(100)
      

      Please add a picture description

    6. training model

      var model = trainer.fit(train)
      

      Please add a picture description

      Please add a picture description

    7. Input test set for recognition

      val result = model.transform(test)
      

      Please add a picture description

    8. Get the predicted and actual results in the test results

      val predictionAndLabels = result.select("prediction", "label")
      

      Please add a picture description

    9. Initialize the evaluator

      val evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy")
      

      Please add a picture description

    10. Calculate recognition accuracy

      println(s"Test set accuracy = ${
                
                evaluator.evaluate(predictionAndLabels)}")
      

      Please add a picture description

    11. Create a temporary view on result

      result.toDF.createOrReplaceTempView("deep_learning")
      

      Please add a picture description

    12. Use Spark SQL to calculate recognition accuracy

      spark.sql("select (select count(*) from deep_learning where label=prediction)/count(*) as accuracy from deep_learning").show()
      

      Please add a picture description

Guess you like

Origin blog.csdn.net/weixin_45795947/article/details/124558069