TensorFlow Hub安装及使用

1.介绍

TensorFlow Hub是一个库,用于实现机器学习模型可重用部分的发行,发行和使用。一个模块是一个独立的tensorflow图,还有它的权重等,可以在迁移学习过程中的不同的任务里重复使用。
模块包含使用大型数据集对任务进行预训练的变量。通过在相关任务上重用模块,您可以:

  • 用较小的数据集训练模型
  • 提高泛化能力
  • 显著加快训练速度
    这有一个例子,使用英语嵌入模块,将一个字符串数组映射到此嵌入:
import tensorflow as tf
import tensorflow_hub as hub

with tf.Graph().as_default():
  embed = hub.Module("https://tfhub.dev/google/nnlm-en-dim128-with-normalization/1")
  embeddings = embed(["A long sentence.", "single-word", "http://example.com"])

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.tables_initializer())

    print(sess.run(embeddings))

2.安装

Tensorflow Hub取决于bug修复和增强,在tensorflow1.7之前的版本没有出现。必须安装或升级您的tensorflow包到至少1.7版本才可使用TensorFlow Hub:

pip install "tensorflow>=1.7.0"
pip install tensorflow-hub

当一个兼容的版本可用时,本节将会更新以包括一个特定的tensorflow版本要求。
参考文档:https://github.com/tensorflow/hub

猜你喜欢

转载自blog.csdn.net/daydayup_668819/article/details/80004188