最终生成的效果
我把我制作的训练样本贴出来,供大家参考。
import random
import numpy
import time
nums=[] #训练样本的list
start = time.time()
with open('11.txt') as f: #11.txt为图像数据,每个数位0-255,数与数之间为空格,一个图像数据为一行
for line in f:
a=line.split()
nums+=a
sss=numpy.asarray(nums,dtype=float)
sss=sss/255
sss=sss.reshape(2800,784)
aa=time.time()-start
print 'ok' , aa#print的输出为:ok 2.89900016785
nums1=[] #标签的list
start = time.time()
with open('22.txt') as f: #22.txt为图像标签
for line in f:
a=line.split()
nums1+=a
sss1=numpy.asarray(nums1,dtype=int)
aa1=time.time()-start
print 'ok' , aa1#print的输出为:ok 2.89900016785
ee1=zip(sss,sss1) #zip形成字典
random.shuffle(ee1)#这个是将数据进行随机化排列
dd1,dd2=map(list,zip(*ee1))#将ee1拆分成两个list
ddd1=numpy.asarray(dd1,dtype=float) #将list转变成numpy数组
ddd2=numpy.asarray(dd2,dtype=int)
cc=ddd1,ddd2 #最终得到cc这个元祖
下面为完整代码:
import random
import numpy
import time
import gzip
import cPickle
class shuju(object):
def __init__(self,in1,in2,in3,in4):
self.in1=in1
self.in2=in2
self.in3=in3
self.in4=in4
def duqushuju(self):
nums=[] #训练样本的list
start = time.time()
with open(self.in1) as f: #11.txt为图像数据
for line in f:
a=line.split()
nums+=a
sss=numpy.asarray(nums,dtype=float)
sss=sss/255
sss=sss.reshape(self.in3,self.in4)
aa=time.time()-start
print 'ok' , aa#print的输出为:ok 2.89900016785
nums1=[] #标签的list
start = time.time()
with open(self.in2) as f: #22.txt为图像标签
for line in f:
a=line.split()
nums1+=a
sss1=numpy.asarray(nums1,dtype=int)
aa1=time.time()-start
print 'ok' , aa1#print的输出为:ok 2.89900016785
ee1=zip(sss,sss1) #zip形成字典
random.shuffle(ee1)#这个是将数据进行随机化排列
dd1,dd2=map(list,zip(*ee1))#将ee1拆分成两个list
ddd1=numpy.asarray(dd1,dtype=float) #将list转变成numpy数组
ddd2=numpy.asarray(dd2,dtype=int)
return ddd1,ddd2 #最终得到cc这个元祖
def shu():
a1='11.txt'#训练数据(train)
a2='111.txt'#测试数据(test)
a3='1111.txt'#检验数据(valid)
b1='22.txt'
b2='222.txt'
b3='2222.txt'
o1=shuju(a1,b1,2800,784)
o2=shuju(a2,b2,800,784)
o3=shuju(a3,b3,200,784)
oo1=o1.duqushuju()
oo2=o2.duqushuju()
oo3=o3.duqushuju()
return oo1,oo2,oo3
if __name__ == '__main__':
oo1,oo2,oo3=shu()
d=oo1,oo3,oo2
p1=cPickle.dumps(d,2) #生成pkl.gz文件就和theano中的一样
s=gzip.open('cnn.pkl.gz','wb')#要保存的文件路径,这里用了gzip,压缩文件
s.write(p1)
s.close()
print 'ok'
如果想直接读取文件名列表的话可以参见下面这个博文,包括文件太大上述方法显示错误的解决方案。