Getting Started with TensorFlow - Classifying Iris with TensorFlow (Linear Model)

Getting Started with TensorFlow - Classifying Iris with TensorFlow (Linear Model)

This example refers to Plain and Simple Estimators - YouTube , Chinese subtitles and detailed explanation refer to Machine Learning | Going further, using the estimator to classify flowers , this article focuses on its specific implementation and adds more detailed comments to the code.

This example is a part of the author's graduation project, so it is guaranteed to be absolutely correct and effective. Reproduction in any form is prohibited, and please do not copy and paste it at will.

Environment construction


This example uses Jupyter Notebook for specific implementation, which requires Anaconda to be installed. This part refers to the blogger's previous blog post.
Ubuntu16.04 uses Anaconda5 to build TensorFlow using environment graphic and text detailed tutorial Jupyter Notebook, the previous Ipython Notebook, is a web application that can save all input and output in the form of documents.

Iris dataset


The iris dataset is a classic machine learning dataset that is great for getting started. It includes 5 columns of data: the first 4 columns represent 4 eigenvalues, namely sepal length (sepal length), sepal width (sepal width), petal length (petal length), and petal width (petal width); the last column is Species, which is iris The type of flower is our training target, which is called label in machine learning. Such data is also referred to as labeled data.
Iris dataset

Implementation

在tensorflow虚拟环境中启动jupyter notebook
steve@steve-Lenovo-V2000:~$ source activate tensorflow
(tensorflow) steve@steve-Lenovo-V2000:~$ jupyter notebook
In[1]       
import tensorflow as tf
import numpy as np

print(tf.__version__)

1.3.0
In[2]       
from tensorflow.contrib.learn.python.learn.datasets import base

#所用的数据集文件
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"
#加载数据集
training_set = base.load_csv_with_header(filename = IRIS_TRAINING, 
                                         features_dtype = np.float32,
                                         target_dtype = np.int)
test_set = base.load_csv_with_header(filename = IRIS_TEST, 
                                         features_dtype = np.float32,
                                         target_dtype = np.int)

print(training_set.data)
print(training_set.target)

[[ 6.4000001   2.79999995  5.5999999   2.20000005]
 [ 5.          2.29999995  3.29999995  1.        ]
 [ 4.9000001   2.5         4.5         1.70000005]
 [ 4.9000001   3.0999999   1.5         0.1       ]
 [ 5.69999981  3.79999995  1.70000005  0.30000001]
 [ 4.4000001   3.20000005  1.29999995  0.2       ]
 [ 5.4000001   3.4000001   1.5         0.40000001]
 [ 6.9000001   3.0999999   5.0999999   2.29999995]
 [ 6.69999981  3.0999999   4.4000001   1.39999998]
 [ 5.0999999   3.70000005  1.5         0.40000001]
 [ 5.19999981  2.70000005  3.9000001   1.39999998]
 [ 6.9000001   3.0999999   4.9000001   1.5       ]
 [ 5.80000019  4.          1.20000005  0.2       ]
 [ 5.4000001   3.9000001   1.70000005  0.40000001]
 [ 7.69999981  3.79999995  6.69999981  2.20000005]
 [ 6.30000019  3.29999995  4.69999981  1.60000002]
 [ 6.80000019  3.20000005  5.9000001   2.29999995]
 [ 7.5999999   3.          6.5999999   2.0999999 ]
 [ 6.4000001   3.20000005  5.30000019  2.29999995]
 [ 5.69999981  4.4000001   1.5         0.40000001]
 [ 6.69999981  3.29999995  5.69999981  2.0999999 ]
 [ 6.4000001   2.79999995  5.5999999   2.0999999 ]
 [ 5.4000001   3.9000001   1.29999995  0.40000001]
 [ 6.0999999   2.5999999   5.5999999   1.39999998]
 [ 7.19999981  3.          5.80000019  1.60000002]
 [ 5.19999981  3.5         1.5         0.2       ]
 [ 5.80000019  2.5999999   4.          1.20000005]
 [ 5.9000001   3.          5.0999999   1.79999995]
 [ 5.4000001   3.          4.5         1.5       ]
 [ 6.69999981  3.          5.          1.70000005]
 [ 6.30000019  2.29999995  4.4000001   1.29999995]
 [ 5.0999999   2.5         3.          1.10000002]
 [ 6.4000001   3.20000005  4.5         1.5       ]
 [ 6.80000019  3.          5.5         2.0999999 ]
 [ 6.19999981  2.79999995  4.80000019  1.79999995]
 [ 6.9000001   3.20000005  5.69999981  2.29999995]
 [ 6.5         3.20000005  5.0999999   2.        ]
 [ 5.80000019  2.79999995  5.0999999   2.4000001 ]
 [ 5.0999999   3.79999995  1.5         0.30000001]
 [ 4.80000019  3.          1.39999998  0.30000001]
 [ 7.9000001   3.79999995  6.4000001   2.        ]
 [ 5.80000019  2.70000005  5.0999999   1.89999998]
 [ 6.69999981  3.          5.19999981  2.29999995]
 [ 5.0999999   3.79999995  1.89999998  0.40000001]
 [ 4.69999981  3.20000005  1.60000002  0.2       ]
 [ 6.          2.20000005  5.          1.5       ]
 [ 4.80000019  3.4000001   1.60000002  0.2       ]
 [ 7.69999981  2.5999999   6.9000001   2.29999995]
 [ 4.5999999   3.5999999   1.          0.2       ]
 [ 7.19999981  3.20000005  6.          1.79999995]
 [ 5.          3.29999995  1.39999998  0.2       ]
 [ 6.5999999   3.          4.4000001   1.39999998]
 [ 6.0999999   2.79999995  4.          1.29999995]
 [ 5.          3.20000005  1.20000005  0.2       ]
 [ 7.          3.20000005  4.69999981  1.39999998]
 [ 6.          3.          4.80000019  1.79999995]
 [ 7.4000001   2.79999995  6.0999999   1.89999998]
 [ 5.80000019  2.70000005  5.0999999   1.89999998]
 [ 6.19999981  3.4000001   5.4000001   2.29999995]
 [ 5.          2.          3.5         1.        ]
 [ 5.5999999   2.5         3.9000001   1.10000002]
 [ 6.69999981  3.0999999   5.5999999   2.4000001 ]
 [ 6.30000019  2.5         5.          1.89999998]
 [ 6.4000001   3.0999999   5.5         1.79999995]
 [ 6.19999981  2.20000005  4.5         1.5       ]
 [ 7.30000019  2.9000001   6.30000019  1.79999995]
 [ 4.4000001   3.          1.29999995  0.2       ]
 [ 7.19999981  3.5999999   6.0999999   2.5       ]
 [ 6.5         3.          5.5         1.79999995]
 [ 5.          3.4000001   1.5         0.2       ]
 [ 4.69999981  3.20000005  1.29999995  0.2       ]
 [ 6.5999999   2.9000001   4.5999999   1.29999995]
 [ 5.5         3.5         1.29999995  0.2       ]
 [ 7.69999981  3.          6.0999999   2.29999995]
 [ 6.0999999   3.          4.9000001   1.79999995]
 [ 4.9000001   3.0999999   1.5         0.1       ]
 [ 5.5         2.4000001   3.79999995  1.10000002]
 [ 5.69999981  2.9000001   4.19999981  1.29999995]
 [ 6.          2.9000001   4.5         1.5       ]
 [ 6.4000001   2.70000005  5.30000019  1.89999998]
 [ 5.4000001   3.70000005  1.5         0.2       ]
 [ 6.0999999   2.9000001   4.69999981  1.39999998]
 [ 6.5         2.79999995  4.5999999   1.5       ]
 [ 5.5999999   2.70000005  4.19999981  1.29999995]
 [ 6.30000019  3.4000001   5.5999999   2.4000001 ]
 [ 4.9000001   3.0999999   1.5         0.1       ]
 [ 6.80000019  2.79999995  4.80000019  1.39999998]
 [ 5.69999981  2.79999995  4.5         1.29999995]
 [ 6.          2.70000005  5.0999999   1.60000002]
 [ 5.          3.5         1.29999995  0.30000001]
 [ 6.5         3.          5.19999981  2.        ]
 [ 6.0999999   2.79999995  4.69999981  1.20000005]
 [ 5.0999999   3.5         1.39999998  0.30000001]
 [ 4.5999999   3.0999999   1.5         0.2       ]
 [ 6.5         3.          5.80000019  2.20000005]
 [ 4.5999999   3.4000001   1.39999998  0.30000001]
 [ 4.5999999   3.20000005  1.39999998  0.2       ]
 [ 7.69999981  2.79999995  6.69999981  2.        ]
 [ 5.9000001   3.20000005  4.80000019  1.79999995]
 [ 5.0999999   3.79999995  1.60000002  0.2       ]
 [ 4.9000001   3.          1.39999998  0.2       ]
 [ 4.9000001   2.4000001   3.29999995  1.        ]
 [ 4.5         2.29999995  1.29999995  0.30000001]
 [ 5.80000019  2.70000005  4.0999999   1.        ]
 [ 5.          3.4000001   1.60000002  0.40000001]
 [ 5.19999981  3.4000001   1.39999998  0.2       ]
 [ 5.30000019  3.70000005  1.5         0.2       ]
 [ 5.          3.5999999   1.39999998  0.2       ]
 [ 5.5999999   2.9000001   3.5999999   1.29999995]
 [ 4.80000019  3.0999999   1.60000002  0.2       ]
 [ 6.30000019  2.70000005  4.9000001   1.79999995]
 [ 5.69999981  2.79999995  4.0999999   1.29999995]
 [ 5.          3.          1.60000002  0.2       ]
 [ 6.30000019  3.29999995  6.          2.5       ]
 [ 5.          3.5         1.60000002  0.60000002]
 [ 5.5         2.5999999   4.4000001   1.20000005]
 [ 5.69999981  3.          4.19999981  1.20000005]
 [ 4.4000001   2.9000001   1.39999998  0.2       ]
 [ 4.80000019  3.          1.39999998  0.1       ]
 [ 5.5         2.4000001   3.70000005  1.        ]]
[2 1 2 0 0 0 0 2 1 0 1 1 0 0 2 1 2 2 2 0 2 2 0 2 2 0 1 2 1 1 1 1 1 2 2 2 2
2 0 0 2 2 2 0 0 2 0 2 0 2 0 1 1 0 1 2 2 2 2 1 1 2 2 2 1 2 0 2 2 0 0 1 0 2 2 0 1 1 1 2 0 1 1 1 2 0 1 1 1 0 2 1 0 0 2 0 0 2 1 0 0 1 0 1 0 0 0 0 1 0 2 1 0 2 0 1 1 0 0 1]
(第一个list是4个特征值,第二个list是目标结果,即鸢尾的种类,用int的012表示Iris Setosa(山鸢尾)、Iris Versicolour(变色鸢尾)和Iris Virginica(维吉尼亚鸢尾)。)
In[3]   
    #构建模型
    #假定所有的特征都有一个实数值作为数据
    feature_name = "flower_features"
feature_columns = [tf.feature_column.numeric_column(feature_name, shape = [4])]
    classifier = tf.estimator.LinearClassifier(
                 feature_columns = feature_columns,
                     n_classes = 3,
                     model_dir = "/tmp/iris_model")

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': '/tmp/iris_model', '_tf_random_seed': 1, '_save_summary_steps': 100, '_save_checkpoints_secs': 600, '_save_checkpoints_steps': None, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100}
In[4]   
# define input function 定义一个输入函数,用于为模型产生数据
def input_fn(dataset):
            def _fn():
                features = {feature_name: tf.constant(dataset.data)}
                label = tf.constant(dataset.target)
                return features, label
            return _fn
print(input_fn(training_set)())

({'flower_features': <tf.Tensor 'Const:0' shape=(120, 4) dtype=float32>}, <tf.Tensor 'Const_1:0' shape=(120,) dtype=int64>)
In[5]   
# 数据流向
# raw data -> input_fn -> feature columns -> model
# fit model 训练模型
classifier.train(input_fn = input_fn(training_set), steps = 1000)
print('fit already done.')

INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/iris_model/model.ckpt.
INFO:tensorflow:loss = 131.833, step = 1
INFO:tensorflow:global_step/sec: 1396.3
INFO:tensorflow:loss = 37.1391, step = 101 (0.072 sec)
INFO:tensorflow:global_step/sec: 1279.85
INFO:tensorflow:loss = 27.8594, step = 201 (0.078 sec)
INFO:tensorflow:global_step/sec: 1400.15
INFO:tensorflow:loss = 23.0449, step = 301 (0.071 sec)
INFO:tensorflow:global_step/sec: 1293.92
INFO:tensorflow:loss = 20.058, step = 401 (0.077 sec)
INFO:tensorflow:global_step/sec: 1610.43
INFO:tensorflow:loss = 18.0083, step = 501 (0.062 sec)
INFO:tensorflow:global_step/sec: 1617.19
INFO:tensorflow:loss = 16.505, step = 601 (0.062 sec)
INFO:tensorflow:global_step/sec: 1602.84
INFO:tensorflow:loss = 15.3496, step = 701 (0.062 sec)
INFO:tensorflow:global_step/sec: 1799.5
INFO:tensorflow:loss = 14.43, step = 801 (0.056 sec)
INFO:tensorflow:global_step/sec: 1577.18
INFO:tensorflow:loss = 13.6782, step = 901 (0.063 sec)
INFO:tensorflow:Saving checkpoints for 1000 into /tmp/iris_model/model.ckpt.
INFO:tensorflow:Loss for final step: 13.0562.
fit already done.
In[6]   
# Evaluate accuracy 评估模型的准确度
accuracy_score = classifier.evaluate(input_fn = input_fn(test_set),
                                     steps = 100)["accuracy"]
print('\nAccuracy: {0:f}'.format(accuracy_score))

INFO:tensorflow:Starting evaluation at 2018-03-03-12:07:04
INFO:tensorflow:Restoring parameters from /tmp/iris_model/model.ckpt-1000
INFO:tensorflow:Evaluation [1/100]
INFO:tensorflow:Evaluation [2/100]
INFO:tensorflow:Evaluation [3/100]
INFO:tensorflow:Evaluation [4/100]
INFO:tensorflow:Evaluation [5/100]
INFO:tensorflow:Evaluation [6/100]
INFO:tensorflow:Evaluation [7/100]
INFO:tensorflow:Evaluation [8/100]
……
INFO:tensorflow:Evaluation [98/100]
INFO:tensorflow:Evaluation [99/100]
INFO:tensorflow:Evaluation [100/100]
INFO:tensorflow:Finished evaluation at 2018-03-03-12:07:05
INFO:tensorflow:Saving dict for global step 1000: accuracy = 0.966667, average_loss = 0.120964, global_step = 1000, loss = 3.62893
Accuracy: 0.966667

Summary and Explanation


This example mainly uses the high-level API encapsulated by TensorFlow, namely Estimator. Estimator has already encapsulated the training process, so we only need to configure it to use it.

    classifier = tf.estimator.LinearClassifier(
                 feature_columns = feature_columns,
                     n_classes = 3,
                     model_dir = "/tmp/iris_model")

This is the code used to build the model, which defines a simple linear model and configures three parameters: feature_columnsthe eigenvalues, which have been defined earlier; n_classthe total number of categories, 3 in this case; model_dirand the storage path for the model. The final accuracy of the linear model built in this example reaches 96.66667%. This is a good number because it means that statistically the model correctly distinguishes 96 iris species from 100 irises. In fact, if a real person were to distinguish between 100 irises, he might be wrong about 4 or more of them. Of course this does not mean that we are satisfied with this, since this is a simple model for an example, we should aim for a real-world model with an accuracy of over 99%!

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325893647&siteId=291194637