kaggle使用tpu

代码前面加这些

import re
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
print("Tensorflow version " + tf.__version__)
AUTO = tf.data.experimental.AUTOTUNE
from kaggle_datasets import KaggleDatasets
# Detect hardware, return appropriate distribution strategy
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    print('Running on TPU ', tpu.master())
except ValueError:
    tpu = None

if tpu:
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
    strategy = tf.distribute.get_strategy() # default distribution strategy in Tensorflow. Works on CPU and single GPU.

print("REPLICAS: ", strategy.num_replicas_in_sync)

这个使用的数据集需要是公开的,不然没法使用,可以公开自己的数据,我是公开了。


最后一点区别就是下面这一句

with strategy.scope():
model = create_model()
model.summary()

其他的和tensorflow和keras都一样,不过只是模型创建在scope下不同而已。
接下来开始享受tpu飞一般的速度吧


发布了60 篇原创文章 · 获赞 15 · 访问量 4060

猜你喜欢

转载自blog.csdn.net/qq_15557299/article/details/104272560