tensorflow学习笔记(3) Eager Execution

本文学习了tensorFlow官网上最新的Eager Execution,使用环境同上文使用pycharm

强烈建议在控制台中进行代码输入,否则会出错,因为Eager Execution需要在开始执行并且会持续激活。

官网范例代码

from __future__ import absolute_import, division, print_function

import os
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tf.enable_eager_execution()

print("TensorFlow version: {}".format(tf.VERSION))
print("Eager execution: {}".format(tf.executing_eagerly()))

如果执行不了,在settings中安装matplotlib包

得到如下结果

TensorFlow version: 1.8.0

Eager execution: True

鸢尾花分类问题

想象一下,您是一名植物学家,正在寻找一种能够对所发现的每株鸢尾花进行自动归类的方法。机器学习可提供多种从统计学上分类花卉的算法。例如,一个复杂的机器学习程序可以根据照片对花卉进行分类。我们的要求并不高 - 我们将根据鸢尾花花萼花瓣的长度和宽度对其进行分类。

鸢尾属约有 300 个种,但我们的程序将仅对下列三个种进行分类:

  • 山鸢尾
  • 维吉尼亚鸢尾
  • 变色鸢尾
三个鸢尾花品种的花瓣几何对比:山鸢尾、维吉尼亚鸢尾和变色鸢尾
图 1. 山鸢尾(提供者:Radomil,依据 CC BY-SA 3.0 使用)、变色鸢尾(提供者:Dlanglois,依据 CC BY-SA 3.0 使用)和维吉尼亚鸢尾(提供者:Frank Mayfield,依据 CC BY-SA 2.0 使用)。
 

幸运的是,有人已经创建了一个包含 120 株鸢尾花的数据集(其中有花萼和花瓣的测量值)。这是一个在入门级机器学习分类问题中经常使用的经典数据集。

导入和解析训练数据集

我们需要下载数据集文件,并将其转换为可供此 Python 程序使用的结构。

下载数据集

使用 tf.keras.utils.get_file 函数下载训练数据集文件。该函数会返回下载文件的文件路径。

官网没加头文件,我这里补上去

import tensorflow as tf
import os
train_dataset_url = "http://download.tensorflow.org/data/iris_training.csv"

train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),
                                           origin=train_dataset_url)

print("Local copy of the dataset file: {}".format(train_dataset_fp))

执行完以后下载好了数据

检查数据

数据集 iris_training.csv 是一个纯文本文件,其中存储了逗号分隔值 (CSV) 格式的表格式数据。使用 head -n5 命令查看前 5 个条目:

head -n5 {train_dataset_fp}

这里我就报错了,发现是没有装pandas,在settings里面安装pandas的包

并在开头声明

import pandas as pd

不过这个 head -n5命令我执行总是报错

我在命令行中改为

df = pd.read_csv('C:/Users/Masuzu/.keras/datasets/iris_training.csv') #改为文件下载的地址
df.head()

即可运行

   120    4  setosa  versicolor  virginica
0  6.4  2.8     5.6         2.2          2
1  5.0  2.3     3.3         1.0          1
2  4.9  2.5     4.5         1.7          2

3  4.9  3.1     1.5         0.1          0

得到了前5行数据

我们可以从该数据集视图中看到以下信息:

  1. 第一行是标题,其中包含数据集信息。
    • 共有 120 个样本。每个样本都有四个特征和一个标签名称,标签名称有三种可能。
  2. 后面的行是数据记录,每个样本各占一行,其中:
    • 前四个字段是特征:即样本的特点。在此数据集中,这些字段存储的是代表花卉测量值的浮点数。
    • 最后一列是标签:即我们想要预测的值。对于此数据集,该值为 0、1 或 2 中的某个整数值(每个值分别对应一个花卉名称)。

每个标签都分别与一个字符串名称(例如“setosa”)相关联,但机器学习通常依赖于数字值。标签编号会映射到一个指定的名称表示,例如:

  • 0:山鸢尾
  • 1:变色鸢尾
  • 2:维吉尼亚鸢尾

要详细了解特征和标签,请参阅《机器学习速成课程》的“机器学习术语”部分

解析数据集

由于我们的数据集是 CSV 格式的文本文件,因此我们会将特征和标签值解析为 Python 模型能够使用的格式。系统会将文件中的每一行传递给 parse_csv 函数,该函数会获取前四个特征字段,并将它们合并为一个张量。然后,系统会将最后一个字段解析为标签。该函数会返回 features 和 label 这两个张量

解析数据集

由于我们的数据集是 CSV 格式的文本文件,因此我们会将特征和标签值解析为 Python 模型能够使用的格式。系统会将文件中的每一行传递给 parse_csv 函数,该函数会获取前四个特征字段,并将它们合并为一个张量。然后,系统会将最后一个字段解析为标签。该函数会返回 features 和 label 这两个张量

 
 
def parse_csv(line):
  example_defaults
= [[0.], [0.], [0.], [0.], [0]]  # sets field types,前四位是浮点数,最后一个是整数
  parsed_line
= tf.decode_csv(line, example_defaults)
 
# First 4 fields are features, combine into single tensor
  features
= tf.reshape(parsed_line[:-1], shape=(4,))
 
# Last field is the label
  label
= tf.reshape(parsed_line[-1], shape=())
 
return features, label

猜你喜欢

转载自blog.csdn.net/u013003318/article/details/80865284
今日推荐