mnist handwritten digit recognition using pycaffe

Official guide tutorial

Operating environment

win10+python3.5+gpu version of caffe

step

  1. Download dataset
  2. Convert the dataset to lmdb
  3. Training
  4. Test the trained model

Download dataset

Download the following 4 files from mnist official website

t10k-images.idx3-ubyte
t10k-labels.idx1-ubyte
train-images.idx3-ubyte
train-labels.idx1-ubyte

Their structure is explained on the mnist website

The labeling value file of the training image set (train-labels-idx1-ubyte):

[offset]  [type]               [value]                     [description] 
0000     32 bit integer   0x00000801(2049) magic number (MSB first) 
0004     32 bit integer   10000                     标签值总数 
0008     unsigned byte  ??                          标签值
0009     unsigned byte  ??                          标签值
........ 
xxxx     unsigned byte   ??                          标签值

Training image set file (train-images-idx3-ubyte):

[offset]   [type]               [value]                     [description] 
0000     32 bit integer   0x00000803(2051)  magic number 
0004     32 bit integer   10000                      图片总数 
0008     32 bit integer    28                           单张图片的长度像素值数量
0012     32 bit integer    28                           单张图片的高度像素值数量
0016     unsigned byte   ??                          单像素值
0017     unsigned byte   ??                          单像素值 
........ 
xxxx     unsigned byte    ??                          单像素值

The labeling value file of the test image set (t10k-labels-idx1-ubyte):

[offset]  [type]               [value]                     [description] 
0000     32 bit integer   0x00000801(2049) magic number (MSB first) 
0004     32 bit integer   10000                     标签值总数 
0008     unsigned byte  ??                          标签值
0009     unsigned byte  ??                          标签值
........ 
xxxx     unsigned byte   ??                          标签值

The range of tag values ​​is 0-9

Test image set file (t10k-images-idx3-ubyte):

[offset]   [type]               [value]                     [description] 
0000     32 bit integer   0x00000803(2051)  magic number 
0004     32 bit integer   10000                      图片总数 
0008     32 bit integer    28                           单张图片的长度像素值数量
0012     32 bit integer    28                           单张图片的高度像素值数量
0016     unsigned byte   ??                          单像素值
0017     unsigned byte   ??                          单像素值 
........ 
xxxx     unsigned byte    ??                          单像素值

Convert dataset to lmdb

There are two pairs of downloaded data sets, one is the training data image set and its corresponding label value, and the other is the test image set and its corresponding label value. The structure of these two pairs of files is the same, so they are converted to lmdb file, you can use the same method

def orgin_to_lmdb(image_file, label_file, lmdb_save_path, force_update=False):
    mean_file = '{}.binaryproto'.format(lmdb_save_path)

    if os.path.exists(mean_file) and os.path.exists(lmdb_save_path) and force_update == False:
        return

    try:
        shutil.rmtree(lmdb_save_path)
    except:
        pass
    try:
        shutil.rmtree(mean_file)
    except:
        pass

    with open(image_file, 'rb') as image_f:
        with open(label_file, 'rb') as label_f:
            # 读取标签文件头的4个整型
            size = struct.calcsize('>2I')
            magic, num_items = struct.unpack_from('>2I', label_f.read(size))
            print(magic, num_items)

            # 读取图片文件头的4个整型
            size = struct.calcsize('>4I')
            magic, num_images, num_rows, num_columns = struct.unpack_from('>4I', image_f.read(size))
            print(magic, num_images, num_rows, num_columns)

            map_size = num_images*num_rows*num_columns * 1.5

            # 遍历所有图片,将文件列表写入到lmdb中
            with lmdb.open(lmdb_save_path,map_size=map_size) as in_db:
                with in_db.begin(write=True) as in_txn:
                    im_size = num_rows * num_columns
                    label_size = struct.calcsize('>B')
                    im_idx = 0
                    while im_idx < num_images:
                        img_item = struct.unpack_from('>B', label_f.read(label_size))[0]
                        img_buf = image_f.read(im_size)

                        datum = caffe_pb2.Datum(
                            channels=1,  # 数据集里面的图片是灰度图,因此通道数设置为1
                            width=num_columns,
                            height=num_rows,
                            label=int(img_item),
                            data=img_buf
                        )
                        in_txn.put('{:0>8d}'.format(im_idx).encode('utf8'), datum.SerializeToString())
                        im_idx += 1

    # 生成mean文件
    cmd = '{0} {1} {2}'.format(compute_image_mean, lmdb_save_path, mean_file)
    print(cmd)
    os.system(cmd)

The following code can open lmdb to view the first picture

# 查看lmdb的第一张图片
def show_lmdb_first_image(lmdb_save_path):
    with lmdb.open(lmdb_save_path, readonly=True) as lmdb_env:
        lmdb_txn = lmdb_env.begin()
        lmdb_cursor = lmdb_txn.cursor()
        datum = caffe_pb2.Datum()

        lmdb_cursor.first()
        key, value = lmdb_cursor.item()
        datum.ParseFromString(value)

        label = datum.label
        data = caffe.io.datum_to_array(datum)
        print(label, datum.channels, data.shape)
        image = data.transpose(1, 2, 0)
        cv2.imshow('cv2.png', image)
        cv2.waitKey(0)

        cv2.destroyAllWindows()

Use the dataset for training

examples\mnist\lenet_solver.prototxtUsing the sum in the caffe code directory, the network file address that needs to be modified is the new one, examples\mnist\lenet_train_test.prototxtand the data layer that needs to be modified is the lmdb address just generatedlenet_solver.prototxtlenet_train_test.prototxtlenet_train_test.prototxt

        solver = caffe.SGDSolver('lenet_solver.prototxt')
        solver.solve()

After completion, two model files lenet_iter_5000.caffemodelandlenet_iter_10000.caffemodel

Test the trained model

You need to generate a network configuration file first, usually to change the network configuration file used for training, which is used directly hereexamples\mnist\lenet.prototxt

        net = caffe.Net(
            'lenet.prototxt', # 网络配置文件
            caffe.TEST,
            weights='lenet_iter_10000.caffemodel'  # 训练产生的模型
        )

        transformer = caffe.io.Transformer({'data': net.blobs['data'].data.shape})
        transformer.set_transpose('data', (2,0,1))
        transformer.set_raw_scale('data', 255)
        # transformer.set_channel_swap('data', (2, 1, 0))  # minist用的是灰度图 channel只有1,因此无需转换

        # 因为minist的channel是1, 所以需要转为灰度图color=False
        im = caffe.io.load_image('3.jpg', color=False)  # 打开测试图片
        net.blobs['data'].data[0] = transformer.preprocess('data', im)
        res = net.forward()
        print(res['prob'].argmax())

The test picture is a number of numbers written with the windows test tool, which needs white characters on a black background, and the picture size should be changed to 28*28

  • Enter image description
  • Enter image description
  • Enter image description

Some recognition will be wrong. . .

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325027681&siteId=291194637