使用tensorflow训练模型时制作自己的mnist集(附代码)

使用tensorflow训练模型时制作自己的mnist集(附代码)

探索过程

(ps:第一次写,写的不好多多见谅!)
mnist集合是一个被用烂的手写数字的图片训练集,但在实际中我们很多时候要用自己的数据集,那么就需要将自己的图片数据集转化成mnist形式的数据,或者用其他方法(之前用过keras,一个第三方库,虽然这个库很好用,简化了很多步骤。可能是因为我能力有限总是训练的模型accuracy和loss都不理想,如果有懂得小伙伴还希望教下我,刚入门的我真的好菜(^ △ ^))。

把数据喂给神经网络主要有两种方法:本地导入,load直接使用。

大概说下mnist集如下,由四个压缩包组成:
在这里插入图片描述
这种文件解压后是一种IDX3-UBYTE 文件,它以一种进制单位储存,所以作为非专业人员,我们是打不开的。
在这里插入图片描述
那么我们如何了解它里面的储存形式呢? 如下:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('mnist_data',one_hot=True)

test_x = mnist.test.images[:3000]
test_y = mnist.test.labels[:3000]
print(len(test_x)," ",len(test_x[0]))
print(test_x)

这样可以得到其train-images.idx3-ubyte的一些数据:

3000   784
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]

那么可以发现它由3000个list组成,每一个里面是784=28*28的图片向量。
同样的我们看下test_y

3000   10
[[0. 0. 0. ... 1. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 [0. 1. 0. ... 0. 0. 0.]
 ...
 [0. 1. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]]

3000的含义同上,10则是它的类标签属于哪一类哪个索引是1,其他为0。
了解到这里我们开始制作自己的mnist集了,存为csv或data都可以。

代码(python)

# coding:utf-8
import cv2
import os
import random
import pandas as pd

def progress_bar(i):
    a = int(i)*10
    b = '='*int(i)
    c = '->'
    d = '·'*(10-int(i))
    if(i==0):
        print("******执行开始******")
    print(a,'\t',"%","[",b+c+d,"]")
    if (i == 10):
        print("******执行结束******")

def mnist_change(path,img_width,img_height):
    images = []
    labels = []
    tags = os.listdir(path)

    n = 0
    for tag in tags:
        _tag_ = os.listdir(path+tag)
        n += len(_tag_)
    key = 0
    for tag in tags:
        _tag_ = os.listdir(path+tag)
        i = 0
        for image in _tag_:
            img_path = path+tag+'/'+image
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
            img = cv2.resize(img, (img_width,img_height),interpolation=cv2.INTER_CUBIC)
            img_data = []
            # images中的一个元素
            for data in img:
                img_data.extend(data)
            # labels中的一个元素
            zero_ = [0 for _ in range(len(tags))]
            zero_[i] = 1

            if(key == 0):
                images.append(img_data)
                labels.append(zero_)
            else:
                rand_i = random.randint(0,len(labels)-1)
                temp1,temp2 = images[rand_i],labels[rand_i]
                images[rand_i],labels[rand_i] = img_data,zero_
                images.append(temp1)
                labels.append(temp2)
            if(key%(n//10)==0 and key/(n//10)<=10):
                progress_bar(key/(n//10))
            key += 1
        i += 1
    return images,labels

def text_save(file, data):
    data = pd.DataFrame(data)
    data.to_csv(file,index=None)
    print("保存文件成功")

if __name__ == '__main__':
    # 自定义参数
    path = 'data/'
    width,height = 200,200
    images,labels = mnist_change(path,width,height)
    text_save('images.csv',images)
    text_save('labels.csv',labels)
    '''
    图片存放格式
    data
    	n0
    		4546asd.jpg
    		asdw4145.jpg
    	n1
    		asd4.jpg
    '''

结果:

******执行开始******
0 	 % [ ->·········· ]
10 	 % [ =->········· ]
20 	 % [ ==->········ ]
30 	 % [ ===->······· ]
40 	 % [ ====->······ ]
50 	 % [ =====->····· ]
60 	 % [ ======->···· ]
70 	 % [ =======->··· ]
80 	 % [ ========->·· ]
90 	 % [ =========->· ]
100  % [ ==========-> ]
******执行结束******
保存文件成功
保存文件成功

想法

这是我写的第一篇博客,无论内容还是代码的级别可能不是很高,大家见谅。
如果大家有什么想法,欢迎评论区交流呀。

我建的小群,欢迎大家交流。
在这里插入图片描述

发布了6 篇原创文章 · 获赞 0 · 访问量 73

猜你喜欢

转载自blog.csdn.net/qq_44851357/article/details/104054064