tensorflow wide and deep 模型实践

tensorflow 环境搭建

官网教程已经非常简明易懂
https://www.tensorflow.org/install/install_mac

wide and deep demo

https://github.com/tensorflow/models/tree/master/official/wide_deep

代码准备并下载数据集合

clone demo 代码和数据

git clone [email protected]:tensorflow/models.git

or

git clone https://github.com/tensorflow/models.git

设置model目录到PYTHONPATH 和 PATH环境变量,否则下面的代码会报错https://github.com/tensorflow/models/blob/master/official/#running-the-models

进入对应样例文件夹

$ ls
README.md             census_test.csv       movielens_test.py
__init__.py           census_test.py        wide_deep_run_loop.py
census_dataset.py     movielens_dataset.py
census_main.py        movielens_main.py
(tensorflow)
python census_dataset.py
ll /tmp/census_data
total 10312
-rw-r--r--  1 leocai  wheel   3.4M  9  1 21:00 adult.data
-rw-r--r--  1 leocai  wheel   1.7M  9  1 21:01 adult.test
(tensorflow)

训练:

python census_main.py

Demo 主要源码分析

def build_estimator(model_dir, model_type, model_column_fn, inter_op, intra_op):
  """Build an estimator appropriate for the given model type."""
  wide_columns, deep_columns = model_column_fn()
  hidden_units = [100, 75, 50, 25]

  # Create a tf.estimator.RunConfig to ensure the model is run on CPU, which
  # trains faster than GPU for this model.
  run_config = tf.estimator.RunConfig().replace(
      session_config=tf.ConfigProto(device_count={'GPU': 0},
                                    inter_op_parallelism_threads=inter_op,
                                    intra_op_parallelism_threads=intra_op))

  if model_type == 'wide':
    return tf.estimator.LinearClassifier(
        model_dir=model_dir,
        feature_columns=wide_columns,
        config=run_config)
  elif model_type == 'deep':
    return tf.estimator.DNNClassifier(
        model_dir=model_dir,
        feature_columns=deep_columns,
        hidden_units=hidden_units,
        config=run_config)
  else:
    return tf.estimator.DNNLinearCombinedClassifier(
        model_dir=model_dir,
        linear_feature_columns=wide_columns,
        dnn_feature_columns=deep_columns,
        dnn_hidden_units=hidden_units,
        config=run_config)

也就是调用的 tf.estimator.DNNLinearCombinedClassifier代码
参数是保存模型位置,宽部分的列名,深部门的列名,隐层配置,运行配置。

然后我们继续看tf.estimator.DNNLinearCombinedClassifier
https://www.tensorflow.org/api_docs/python/tf/estimator/DNNLinearCombinedClassifier

wide and deep 介绍

https://github.com/tensorflow/models/tree/master/official/wide_deep
https://ai.googleblog.com/2016/06/wide-deep-learning-better-together-with.html

发布了35 篇原创文章 · 获赞 61 · 访问量 16万+

猜你喜欢

转载自blog.csdn.net/cql342624757/article/details/82288072