Handwritten digit recognition python implementation
# coding: utf-8
import sys, os
sys.path.append(os.pardir)
import numpy as np
import pickle
from dataset.mnist import load_mnist
import time
def softmax(x):
x = x - np.max(x) # 溢出对策
return np.exp(x) / np.sum(np.exp(x))
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def get_data():
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_test, t_test
def init_network():
with open("sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
def predict(network, x):
W1, W2, W3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, W1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, W2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, W3) + b3
y = softmax(a3)
return y
x, t = get_data()
network = init_network()
accuracy_cnt = 0
start = time.time()
for i in range(len(x)):
y = predict(network, x[i])
p= np.argmax(y) # 获取概率最高的元素的索引
if p == t[i]:
accuracy_cnt += 1
end = time.time()
print("time:",end-start)
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
time: 1.3383796215057373
Accuracy:0.9352
Network shape analysis
x,_ = get_data()
print(x.shape)
W1, W2, W3 = network['W1'], network['W2'], network['W3']
print(W1.shape,W2.shape,W3.shape)
The output is as follows:
(10000, 784)#测试数据为10000 784
(784, 50) (50, 100) (100, 10)
Processing process: After inputting a one-dimensional array composed of 784 elements, output a one-dimensional array with 10 elements.
What if you enter more than one?
The shape of the input data is 100 × 784, and the shape of the output data is
100 × 10. This means that the results of the input 100 images are output in batches at one time, which is of
great benefit to the computer's calculations and can shorten the processing time of each image. Because most libraries that deal with numerical calculations are optimized to efficiently handle large array operations. And in neural networks, when data transmission is called a bottleneck, more time can be spent on calculations.
# coding: utf-8
import sys, os
sys.path.append(os.pardir)
import numpy as np
import pickle
from dataset.mnist import load_mnist
def softmax(x):
x = x - np.max(x) # 溢出对策
return np.exp(x) / np.sum(np.exp(x))
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def get_data():
(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, flatten=True, one_hot_label=False)
return x_test, t_test
def init_network():
with open("sample_weight.pkl", 'rb') as f:
network = pickle.load(f)
return network
def predict(network, x):
w1, w2, w3 = network['W1'], network['W2'], network['W3']
b1, b2, b3 = network['b1'], network['b2'], network['b3']
a1 = np.dot(x, w1) + b1
z1 = sigmoid(a1)
a2 = np.dot(z1, w2) + b2
z2 = sigmoid(a2)
a3 = np.dot(z2, w3) + b3
y = softmax(a3)
return y
x, t = get_data()
network = init_network()
batch_size = 100 # 批数量
accuracy_cnt = 0
start = time.time()
for i in range(0, len(x), batch_size):
x_batch = x[i:i+batch_size]
y_batch = predict(network, x_batch)
p = np.argmax(y_batch, axis=1)
accuracy_cnt += np.sum(p == t[i:i+batch_size])
end = time.time()
print("Accuracy:" + str(float(accuracy_cnt) / len(x)))
time: 0.06778287887573242
Accuracy:0.9352
Batch processing time is about 20 times faster than no batch