代码前面加这些
import re
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
print("Tensorflow version " + tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE
from kaggle_datasets import KaggleDatasets
# Detect hardware, return appropriate distribution strategy
try:
tpu = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
print('Running on TPU ', tpu.master())
except ValueError:
tpu = None
if tpu:
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.
print("REPLICAS: ", strategy.num_replicas_in_sync)
这个使用的数据集需要是公开的,不然没法使用,可以公开自己的数据,我是公开了。
最后一点区别就是下面这一句
with strategy.scope():
model = create_model()
model.summary()
其他的和tensorflow和keras都一样,不过只是模型创建在scope下不同而已。
接下来开始享受tpu飞一般的速度吧