使用tensorflow对Mnist数据集进行字体识别

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/a_step_further/article/details/54917576



上代码:

#!/usr/bin/env python
#coding:utf-8

import tensorflow
import pandas as pd
import skflow

train = pd.read_csv('~/Mnist/train.csv')
X_train = train.drop('label',1)
y_train = train['label']

classifier = skflow.TensorFlowLinearClassifier(n_classes=10,batch_size=100,steps=1000,learning_rate=0.01)
classifier.fit(X_train,y_train)
linear_y_predict = classifier.predict(X_test)

#save results for submission onto Kaggle
linear_submission = pd.DataFrame({'ImageId':range(1,28001),'label':linear_y_predict})
linear_submission.to_csv('~/Mnist/linear_submission.csv',index=False)


猜你喜欢

转载自blog.csdn.net/a_step_further/article/details/54917576
今日推荐