Tensorflow 0.9.0开发的代码升级到Tensorflow 1.0.0版本的一些经验

以下是我在做代码升级时碰到的一些问题,这里和大家分享一下,避免做相同事情的人在同样的事情上浪费时间。

1. tensorfow.models模块在1.0.0版本被单独分离出来,调用seq2seq_model.py,seq2seq_model_utils.py和data_utils.py等文件接口的代码需做修改。

相关修改如下:

    1) 删除from tensorflow.models.rnn.translate import data_utils句,直接将tensorflow.models.rnn.translate的data_utils代码复制过来并import。tensorflow.models模块中对应的data_utils代码链接为https://github.com/tensorflow/models/blob/master/tutorials/rnn/translate/data_utils.py

    2)
def sampled_loss(inputs, labels):
        labels = tf.reshape(labels, [-1, 1])
        return tf.nn.sampled_softmax_loss(w_t, b, inputs, labels, num_samples,
                self.target_vocab_size)
      softmax_loss_function = sampled_loss

     改为

      def sampled_loss(labels, inputs):
        labels = tf.reshape(labels, [-1, 1])
        return tf.nn.sampled_softmax_loss(w_t, b, labels, inputs, num_samples,
                self.target_vocab_size)
      softmax_loss_function = sampled_loss
    原因在于sampled_softmax_loss方法的labels和inputs参数互换了位置
    3)
single_cell = tf.nn.rnn_cell.GRUCell(size)
    if use_lstm:
      single_cell = tf.nn.rnn_cell.BasicLSTMCell(size)
    cell = single_cell
    if num_layers > 1:
    cell = tf.nn.rnn_cell.MultiRNNCell([single_cell] * num_layers)

    改为

    single_cell = tf.contrib.rnn.GRUCell(size)
    if use_lstm:
      single_cell = tf.contrib.rnn.BasicLSTMCell(size)
    cell = single_cell
    if num_layers > 1:
      cell = tf.contrib.rnn.MultiRNNCell([single_cell] * num_layers)

   tf.nn.rnn_cell.*和tf.nn.rnn.*中的大多数函数(除了dynamic_rnn和raw_rnn)在1.0版本暂时移到tf.contrib.rnn中,在1.1版本被重新移回。

    4)

tf.nn.seq2seq.embedding_attention_seq2seq 更改为 tf.contrib.legacy_seq2seq.embedding_attention_seq2seq 

tf.nn.seq2seq.model_with_buckets 更改为 tf.contrib.legacy_seq2seq.model_with_buckets

tf.nn.seq2seq.sequence_loss_by_example 更改为 tf.contrib.legacy_seq2seq.sequence_loss_by_example

    Tensorflow1.0.0版本以后,开发了新的seq2seq接口,放在tf.contrib.seq2seq下,弃用了原来的接口,旧的接口移到tf.contrib.legacy_seq2seq

    新旧seq2seq接口的主要区别在于新接口是动态展开的,旧接口是静态展开的。
    5)
tf.initialize_all_variables() 方法替换为 tf.global_variables_initializer()

    6)在0.12版本tensorflow更新了checkpoint版本,默认情况下写入和读取的checkpoint都是新的V2版本,新版本能够在restore过程中显著降低峰值内存。

    两种版本模型保存方式如下:    

v1 v2
model.ckpt-66000 model.ckpt-66000.index
model.ckpt-66000.meta model.ckpt-66000.meta
  model.ckpt-66000.data-00000-of-00001
    为适应以后版本的更新,将模型重新训练并保存为V2格式。
    代码做了以下更改:
ckpt = tf.train.get_checkpoint_state(FLAGS.model_dir)
    if ckpt and gfile.Exists(ckpt.model_checkpoint_path):
        print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
        model.saver.restore(session, ckpt.model_checkpoint_path)
    else:
        print("Created model with fresh parameters.")
        session.run(tf.initialize_all_variables())
    改为
    checkpoint_file = tf.train.latest_checkpoint(FLAGS.nn_model_dir)
    if checkpoint_file is None:
      print("Created model with fresh parameters.")
      session.run(tf.global_variables_initializer())
    else:
      print("Reading model parameters from %s" % checkpoint_file)
      model.saver.restore(session, checkpoint_file)

     7)

tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(y, y_)) 更改为 tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_))

  1.0.0版本的Tensorflow要求使用命名参数的形式调用。

猜你喜欢

转载自blog.csdn.net/hfutdog/article/details/79333386
今日推荐