keras分布式训练

先来个简单的分布式训练

keras分布式训练

#导入依赖 

#from __future__ import absolute_import, division, print_function, unicode_literals

# 导入 TensorFlow 和 TensorFlow 数据集
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()

import os

如果把第一行屏蔽会有什么效果?用没有影响? 也可以执行,应该是么有用到导入的特性

2020-05-28 21:34:05.468570: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA

2020-05-28 21:34:05.497031: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fe6ff957c10 initialized for platform Host (this does not guarantee that XLA will be used). Devices:

2020-05-28 21:34:05.497053: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version

Number of devices: 1

WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.

WARNING:tensorflow:There are non-GPU devices in `tf.distribute.Strategy`, not using nccl allreduce.

Epoch 1/12

      1/Unknown - 2s 2s/step - loss: 2.3115 - accuracy: 0.14062020-05-28 21:34:08.723148: I tensorflow/core/profiler/lib/profiler_session.cc:225] Profiler session started.

下载数据集

下载 MNIST 数据集并从 TensorFlow Datasets 加载。 这会返回 tf.data 格式的数据集。

将 with_info 设置为 True 会包含整个数据集的元数据,其中这些数据集将保存在 info 中。 除此之外,该元数据对象包括训练和测试示例的数量。 

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']

定义分配策略

创建一个 MirroredStrategy 对象。这将处理分配策略,并提供一个上下文管理器(tf.distribute.MirroredStrategy.scope)来构建你的模型。

strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

设置输入管道(pipeline)

在训练具有多个 GPU 的模型时,您可以通过增加批量大小(batch size)来有效地使用额外的计算能力。通常来说,使用适合 GPU 内存的最大批量大小(batch size),并相应地调整学习速率。

0-255 的像素值, 必须标准化到 0-1 范围。在函数中定义标准化。

生成模型

在 strategy.scope 的上下文中创建和编译 Keras 模型。

with strategy.scope():
  model = tf.keras.Sequential([
      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
      tf.keras.layers.MaxPooling2D(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(64, activation='relu'),
      tf.keras.layers.Dense(10, activation='softmax')
  ])

  model.compile(loss='sparse_categorical_crossentropy',
                optimizer=tf.keras.optimizers.Adam(),
                metrics=['accuracy'])

训练和评估

在该部分,以普通的方式训练模型,在模型上调用 fit 并传入在教程开始时创建的数据集。 无论您是否分布式训练,此步骤都是相同的。

model.fit(train_dataset, epochs=12, callbacks=callbacks)

要查看模型的执行方式,请加载最新的检查点(checkpoint)并在测试数据上调用 evaluate 。

使用适当的数据集调用 evaluate 。

model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))
eval_loss, eval_acc = model.evaluate(eval_dataset)
print('Eval loss: {}, Eval Accuracy: {}'.format(eval_loss, eval_acc))

导出到 SavedModel

将图形和变量导出为与平台无关的 SavedModel 格式。 保存模型后,可以在有或没有 scope 的情况下加载模型。

在无需 strategy.scope 加载模型

通过load_from_saved_model(载入上面保存的模型)

运行结果,精度为99

运行效果部分截取

Learning rate for epoch 1 is 0.0010000000474974513

938/938 [==============================] - 23s 25ms/step - loss: 0.2050 - accuracy: 0.9399

Epoch 2/12

937/938 [============================>.] - ETA: 0s - loss: 0.0661 - accuracy: 0.9804

Learning rate for epoch 2 is 0.0010000000474974513

938/938 [==============================] - 17s 18ms/step - loss: 0.0661 - accuracy: 0.9804

Epoch 3/12

936/938 [============================>.] - ETA: 0s - loss: 0.0464 - accuracy: 0.9862

Learning rate for epoch 3 is 0.0010000000474974513

938/938 [==============================] - 14s 15ms/step - loss: 0.0464 - accuracy: 0.9862

Epoch 4/12

936/938 [============================>.] - ETA: 0s - loss: 0.0266 - accuracy: 0.9927

Learning rate for epoch 4 is 9.999999747378752e-05

938/938 [==============================] - 18s 19ms/step - loss: 0.0266 - accuracy: 0.9927

Epoch 5/12

934/938 [============================>.] - ETA: 0s - loss: 0.0234 - accuracy: 0.9939

Learning rate for epoch 5 is 9.999999747378752e-05

......

938/938 [==============================] - 17s 18ms/step - loss: 0.0151 - accuracy: 0.9966

猜你喜欢

转载自blog.csdn.net/keny88888/article/details/106414031