使用C++解析MNIST数据库

遇到的两个个大坑
1.官方主页给出了每个文件的字节数是个玄幻数据,training set images (9912422 bytes) ,这个字节数是解压前的,解压后字节数应该为47,040,016,这个数等于4 + 4 + 4 + 4 + 60000 * 28 * 28。
2.windows下的fgetc是个玄幻函数,以文本方式"r"读取时会错误判断EOF标志,改成"rb",以字节流方式读取即可。

#include <cstdio>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <opencv2/opencv.hpp>
using namespace std;
using namespace cv;

const int MnistTrainNumber = 6000;
const int MnistTestNumber = 1000;

//存储像素信息
struct Image
{
    cv::Mat pixs;
    Image()
    {
        pixs.create(Size(28, 28), CV_8U);
    }
};
struct MnistImage
{
    //检验值
    int magicNumber;
    //图片数量
    int number;
    //图片行数
    int rows;
    //图片列数
    int cols;
    //图片数组
    vector<Image> images;
};
struct MnistLabel
{
    //检验值
    int magicNumber;
    //标签数量
    int number;
    //标签数组
    vector<int> labels;
};
//训练集
struct MnistTrainSet
{
    MnistImage trainImages;
    MnistLabel trainLabels;
};
//测试集
struct MnistTestSet
{
    MnistImage trainImages;
    MnistLabel trainLabels;
};
//从file当前指针开始,连续读取length个字节,返回读取到的整数
int readData(FILE *file, int length)
{
    int ans = 0;
    for (int i = 0; i < length; i++)
    {
        ans = ans * 256 + fgetc(file);
    }
    return ans;
}
//解析图片字节流文件
int parseMnistImage(const char *fileName, MnistImage &mnistImage, int imagesNumber)
{
    FILE *out = fopen(fileName, "rb");
    if (out == NULL) return -1;
    mnistImage.magicNumber = readData(out, 4);
    mnistImage.number = readData(out, 4);
    mnistImage.rows = readData(out, 4);
    mnistImage.cols = readData(out, 4);
    for (int k = 0; k < imagesNumber; k++)
    {
        Image image;
        for (int i = 0; i < 28; i++)
        {
            for (int j = 0; j < 28; j++)
            {
                int x = fgetc(out);
                image.pixs.at<uchar>(i, j) = x;
            }
        }
        mnistImage.images.push_back(image);
    }
    fclose(out);
    return mnistImage.magicNumber;
}
//解析标签字节流文件
int parseMnistLabel(const char *fileName, MnistLabel &mnistLabel, int labelNumber)
{
    FILE *out = fopen(fileName, "rb");
    if (out == NULL) return -1;
    mnistLabel.magicNumber = readData(out, 4);
    mnistLabel.number = readData(out, 4);
    for (int i = 0; i < labelNumber; i++)
    {
        int x = fgetc(out);
        mnistLabel.labels.push_back(x);
    }
    fclose(out);
    return mnistLabel.magicNumber;
}
void virtualizeData(Mat &mat)
{
    imshow("virtualizeData", mat);
    waitKey();
}
int main()
{
    MnistTrainSet mnistTrainSet;
    MnistTestSet mnistTestSet;
    //magic number分别为2051,2049,2051,2049,与官方提供的检验值比对以确定解析程序是否有误
    cout << parseMnistImage("train-images.idx3-ubyte", mnistTrainSet.trainImages, MnistTrainNumber) << endl;
    cout << parseMnistLabel("train-labels.idx1-ubyte", mnistTrainSet.trainLabels, MnistTrainNumber) << endl;
    cout << parseMnistImage("t10k-images.idx3-ubyte", mnistTestSet.trainImages, MnistTestNumber) << endl;
    cout << parseMnistLabel("t10k-labels.idx1-ubyte", mnistTestSet.trainLabels, MnistTestNumber) << endl;
    //可视化训练集中第k张图片
    int k = 5;
    virtualizeData(mnistTrainSet.trainImages.images[k].pixs);
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/liuzhan709/p/9320724.html