深度学习torch之二(数据集的加载,mnist数据集为例)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/littlle_yan/article/details/79180644

加载数据集(在加载数据集之前需要先输入 require'torch'):

本实验以mnist数据集为例。现在网络上下载mnist数据集,要注意的是在torch框架下,需要使用后缀名为.t7格式的文本类型,可以到该链接下载http://download.csdn.net/download/littlle_yan/10220629。下载下来mnist.t7的文件之后,编辑如下代码,进行加载数据集:

trainData=torch.load('train_mnist.t7','ascii')
testData=torch.load('test_mnist.t7',ascii)
print('trainData=',trainData)
print('testData=',testData)
输出结果如图:

如图可以看出加载后的trainData,testData对象包含两部分内容:data和labels,且data和label的数据类型是Byte类型。

有些程序这样加载进入是数据在运行过程中会出现错误,其中一个原因是data数据类型为Byte如法进行数学运算,所以将data数据类型变成Float就可以正常运行,如何将data数据类型变为Float呢,其实很简单,代码如下:

torch.setdefaulttensortype('torch.FloatTensor')
Data=torch.load('train_mnist.t7','ascii')
trainData={
data=Data.data:type(torch.setdefaulttensortype())
labels=Data.labels
}
Data1=torch.load('test_mnist.t7','ascii')
testData={
data=Data1.data:type(torch.setdefaulttensortype())
labels=Data1.labels
}
print('trainData=',trainData)
print('testData=',testData)
输出结果如图所示,就会发现data的数据类型已经变成Float:









猜你喜欢

转载自blog.csdn.net/littlle_yan/article/details/79180644
今日推荐