感谢 以下四篇文章,让我比较深刻了解了卷积神经网络
CNN 卷积神经网络推导和实现 http://blog.csdn.net/zouxy09/article/details/9993371
c++ 实现卷积神经网络 http://www.codeproject.com/Articles/16650/Neural-Network-for-Recognition-of-Handwritten-Digi
python 实现卷积神经网络 http://deeplearning.net/tutorial/lenet.html
手写识别的例子:http://www.csdn.net/article/1970-01-01/2825549
想自己实现一个简单版本的CNN, 中间碰到了些问题,留着以后有时间再实现(给自己找了个不重复制造轮子的理由)。
今天主要是想使用版本Lasagne来实现手写识别,把准确率从97%提升到99%。
实现代码:
from datetime import datetime from time import clock import lasagne import numpy as np from lasagne import layers from lasagne.updates import nesterov_momentum from nolearn.lasagne import NeuralNet from sklearn.metrics import classification_report net2 = NeuralNet( layers=[('input', layers.InputLayer), ('conv2d1', layers.Conv2DLayer), ('maxpool1', layers.MaxPool2DLayer), ('conv2d2', layers.Conv2DLayer), ('maxpool2', layers.MaxPool2DLayer), ('dropout1', layers.DropoutLayer), ('dense', layers.DenseLayer), ('dropout2', layers.DropoutLayer), ('output', layers.DenseLayer), ], # input layer input_shape=(None, 1, 28, 28), # layer conv2d1 conv2d1_num_filters=32, conv2d1_filter_size=(5, 5), conv2d1_nonlinearity=lasagne.nonlinearities.rectify, conv2d1_W=lasagne.init.GlorotUniform(), # layer maxpool1 maxpool1_pool_size=(2, 2), # layer conv2d2 conv2d2_num_filters=60, conv2d2_filter_size=(5, 5), conv2d2_nonlinearity=lasagne.nonlinearities.rectify, # layer maxpool2 maxpool2_pool_size=(2, 2), # dropout1 dropout1_p=0.5, # dense dense_num_units=500, dense_nonlinearity=lasagne.nonlinearities.rectify, # dropout2 dropout2_p=0.5, # output output_nonlinearity=lasagne.nonlinearities.softmax, output_num_units=10, # optimization method params update=nesterov_momentum, update_learning_rate=0.01, update_momentum=0.9, max_epochs=10, verbose=1, ) def load_source(filename): with open(filename, "r") as file: lines = file.readlines() return lines[1:] data_lines = load_source("./data/train.csv") for i in range(len(data_lines)): data_lines[i] = data_lines[i].split(',') data_lines = np.array(data_lines).astype(np.float32) x_data = data_lines[:, 1:].reshape((len(data_lines), 1, 28, 28)) y_data = data_lines[:, 0].astype(np.int32) x_data /= np.float32(256) X_train = x_data[:-3000] y_train = y_data[:-3000] X_test = x_data[-3000:] y_test = y_data[-3000:] np.set_printoptions(suppress=True, linewidth=175, precision=3) # Train the network print(datetime.now().strftime('%b-%d-%y %H:%M:%S'), "start trans net0") start = clock() net2.fit(X_train, y_train) preds = net2.predict(X_test) print(classification_report(y_test, preds)) print(datetime.now().strftime('%b-%d-%y %H:%M:%S'), "end net0") end = clock() print("net0 : %.3f s" % (end-start))
输出结果:
# Neural Network with 534402 learnable parameters ## Layer information # name size --- -------- -------- 0 input 1x28x28 1 conv2d1 32x24x24 2 maxpool1 32x12x12 3 conv2d2 60x8x8 4 maxpool2 60x4x4 5 dropout1 60x4x4 6 dense 500 7 dropout2 500 8 output 10 epoch trn loss val loss trn/val valid acc dur ------- ---------- ---------- --------- ----------- ------ 1 0.74042 0.15038 4.92376 0.95527 61.06s 2 0.20932 0.10588 1.97692 0.96732 62.03s 3 0.15610 0.08767 1.78063 0.97270 60.36s 4 0.12981 0.07162 1.81254 0.97783 59.70s 5 0.11279 0.06259 1.80196 0.97975 61.17s 6 0.09886 0.05699 1.73463 0.98180 59.02s 7 0.09049 0.05362 1.68761 0.98231 60.36s 8 0.08104 0.04864 1.66604 0.98488 61.23s 9 0.07727 0.04682 1.65020 0.98398 60.72s 10 0.06763 0.04635 1.45896 0.98552 60.29s precision recall f1-score support 0 0.98 0.99 0.99 327 1 0.99 0.98 0.99 330 2 0.98 0.96 0.97 284 3 0.99 0.99 0.99 300 4 0.99 0.99 0.99 315 5 1.00 0.98 0.99 242 6 0.98 1.00 0.99 317 7 0.97 0.99 0.98 304 8 0.99 0.98 0.98 283 9 0.98 0.98 0.98 298 avg / total 0.99 0.99 0.99 3000
可以调整配置来实现新的网络,很是方便。
后来把识别错误的数据打印出来,发现识别率还可以提有提升空间, 以后有时间再来细研究.
print("net0 : %.3f s" % (end-start)) print("error cnt:", len([item for i, item in enumerate(y_test) if y_test[i] != preds[i]])) for i in range(len(y_test)): if y_test[i] == preds[i]: continue print(i, "error", y_test[i], preds[i]) plt.imshow(X_test[i][0], cmap=cm.binary) plt.savefig("/XXX/error/%d-%d-%d" % (i, y_test[i], preds[i]))