Kaggle competition entry questions·Use CNN for handwritten digit recognition

Kaggle competition entry question·Use CNN for handwritten number recognition, the accuracy rate is 0.98260

Digit Recognizer entry link: https://www.kaggle.com/c/digit-recognizer
Data set:
Link: https://pan.baidu.com/s/13f3rM_lhNGyu2Rsqbc1AUw
Extraction code: otoy
copy this content and open Baidu.com It’s more convenient to operate on the mobile phone app.
Insert picture description here
train.csv and test.csv are training set and test set data respectively.
sample_submission is the official submission format
predict.csv is the result of this blog post’s prediction on the test set, with an accuracy rate of 0.98260

1. Data loading and preprocessing

Read the csv files in the training set and test set

train_file = pd.read_csv(os.path.join(main_path, "train.csv"))
test_file = pd.read_csv(os.path.join(main_path, "test.csv"))

Since the range of the text matrix in the MNIST data set is (0-255), we will do a data standardization here, and normalize (0-255) to the range (0-1).

ps: Why standardize? You can refer to this blog

# Normalization
train_file_norm = train_file.iloc[:, 1:] / 255.0
test_file_norm = test_file / 255.0

View the shape of the data set at this time

train_file_norm.shape

Insert picture description here
We can also use matplotlib.pyplot to print the data set to visualize the sample

rand_indices = np.random.choice(train_file_norm.shape[0], 64, replace=False)
examples = train_file_norm.iloc[rand_indices, :]

fig, ax_arr = plt.subplots(8, 8, figsize=(6, 5))
fig.subplots_adjust(wspace=.025, hspace=.025)

ax_arr = ax_arr.ravel()
for i, ax in enumerate(ax_arr):
    ax.imshow(examples.iloc[i, :].values.reshape(28, 28), cmap="gray")
    ax.axis("off")
plt.show() 

Insert picture description here
We need to process the data into a shape of (42000, 32, 32, 3) to facilitate training
Define sample shape parameters

num_examples_train = train_file.shape[0]
num_examples_test = test_file.shape[0]
n_h = 32
n_w = 32
n_c = 3

Initialize the sample space

Train_input_images = np.zeros((num_examples_train, n_h, n_w, n_c))
Test_input_images = np.zeros((num_examples_test, n_h, n_w, n_c))

Load data into sample space

for example in range(num_examples_train):
    Train_input_images[example,:28,:28,0] = train_file.iloc[example, 1:].values.reshape(28,28)
    Train_input_images[example,:28,:28,1] = train_file.iloc[example, 1:].values.reshape(28,28)
    Train_input_images[example,:28,:28,2] = train_file.iloc[example, 1:].values.reshape(28,28)
    
for example in range(num_examples_test):
    Test_input_images[example,:28,:28,0] = test_file.iloc[example, :].values.reshape(28,28)
    Test_input_images[example,:28,:28,1] = test_file.iloc[example, :].values.reshape(28,28)
    Test_input_images[example,:28,:28,2] = test_file.iloc[example, :].values.reshape(28,28)

Use cv2.resize to zoom

for example in range(num_examples_train):
    Train_input_images[example] = cv2.resize(Train_input_images[example], (n_h, n_w))
    
for example in range(num_examples_test):
    Test_input_images[example] = cv2.resize(Test_input_images[example], (n_h, n_w))

Extract the label value of the training set

Train_labels = np.array(train_file.iloc[:, 0])

Print out the preprocessed sample data shape

print("Shape of train input images : ", Train_input_images.shape)
print("Shape of test input images : ", Test_input_images.shape)
print("Shape of train labels : ", Train_labels.shape)

Insert picture description here
At this point, our data processing and preprocessing process is over!

2. Training model and predicting results

Before introducing the model, let me introduce one-hot encoding.
One-hot encoding can convert classified data into binary format for machine learning. The implementation function is as follows

def one_hot(labels):
    onehot_labels = np.zeros(shape=[len(labels), 10])
    for i in range(len(labels)):
        index = labels[i]
        onehot_labels[i][index] = 1
    return onehot_labels

Build a CNN network model

def mnist_cnn(input_shape):
    '''
    构建一个CNN网络模型
    :param input_shape: 指定输入维度
    :return:
    '''
    model = keras.Sequential()
    model.add(keras.layers.Conv2D(filters=32, kernel_size=5, strides=(1, 1),
                                  padding='same', activation=tf.nn.relu, input_shape=input_shape))
    model.add(keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))
    model.add(keras.layers.Conv2D(filters=64, kernel_size=3, strides=(1, 1), padding='same', activation=tf.nn.relu))
    model.add(keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding='valid'))
    model.add(keras.layers.Dropout(0.25))
    model.add(keras.layers.Flatten())
    model.add(keras.layers.Dense(units=128, activation=tf.nn.relu))
    model.add(keras.layers.Dropout(0.5))
    model.add(keras.layers.Dense(units=10, activation=tf.nn.softmax))
    return model

Train and save the model

def trian_model(train_images, train_labels):
    #  re-scale to 0~1.0之间
    print("train_images :{}".format(train_images.shape))
    print(train_labels)
    train_labels = one_hot(train_labels)

    # 建立模型
    model = mnist_cnn(input_shape=(32, 32, 3))
    model.compile(optimizer=tf.optimizers.Adam(), loss="categorical_crossentropy", metrics=['accuracy'])
    model.fit(x=train_images, y=train_labels, epochs=5, batch_size = 256)
    model.save('MYCNN2MNIST.h5')

Use the trained model to predict the label of the test set, save_path is the path where we save the model

def pred(save_path,test_images):#载入模型并生成图片
    model=keras.models.load_model(save_path)
    #  开始预测
    predictions = model.predict(test_images)
#     print(predictions)
#     print(type(predictions))
    targetlist = []
    targetlist.append(0)
    for i in range(len(test_images)):
        target = np.argmax(predictions[i])
        targetlist.append(target)
    print(targetlist)
    predictions = pd.DataFrame(targetlist)
    predictions.to_csv("predict.csv")

Write our prediction results into a file

submission = pd.read_csv('DataSet/sample_submission.csv')

Then submit the file to the server, you can see your accuracy and ranking.
ps: Since kaggle's server is abroad, you need to ** (know all)

Source link: https://pan.baidu.com/s/1HXt24GiUrRZliUXN-0N06g
Extraction code: pza9

If you have any questions, please leave a message in the comment area~

Guess you like

Origin blog.csdn.net/zjlwdqca/article/details/110702538