【深度学习框架】Tensorflow Session.run()函数的进一步理解

在tensorflow中session.run()用来将数据传入计算图,计算并返回出给定变量/placeholder的结果。

在看论文代码的时候遇到一段复杂的feed_dict, 本文记录了对sess.run()的复习。

1.tensorflow Session.run()

session.run()的函数定义如下,可以在交互式python中sess = tf.Session; ?sess.run,也可以在源码 line846中查看到。首先来看函数的参数定义:

run(self, fetches, feed_dict=None, options=None, run_metadata=None)

其中常用的fetchesfeed_dict就是常用的传入参数。fetches主要指从计算图中取回计算结果进行放回的那些placeholder和变量,而feed_dict则是将对应的数据传入计算图中占位符,它是字典数据结构只在调用方法内有效。
参考这个例子额的解释,最下面的fetch和feed原始定义 在make_callable
下面让我们来看看官方代码中对run()函数的解释:

def run(self, fetches, feed_dict=None, options=None, run_metadata=None):
    """Runs operations and evaluates tensors in `fetches`.
    运行操作和对fetches中的张量进行计算
    
    This method runs one "step" of TensorFlow computation, by
    running the necessary graph fragment to execute every `Operation`
    and evaluate every `Tensor` in `fetches`, substituting the values in
    `feed_dict` for the corresponding input values.
    这一方法将在tensorflow中运行一次计算,通过将feed_dict中的数据馈入计算图中,
    运行计算图定义的操作并最终得到fectch中tensor的评测结果
    
    The `fetches` argument may be a single graph element, or an arbitrarily
    nested list, tuple, namedtuple, dict, or OrderedDict containing graph
    elements at its leaves.  A graph element can be one of the following types:
    fecches是从计算图中取出对应变量的参数,可以是单个图元素、任意的列表、元组、字典等等形式的图元素。
    图元素包括操作、张量、稀疏张量、句柄、字符串等等。
    * A `tf.Operation`.
      The corresponding fetched value will be `None`.
    * A `tf.Tensor`.
      The corresponding fetched value will be a numpy ndarray containing the
      value of that tensor.
    * A `tf.SparseTensor`.
      The corresponding fetched value will be a
      `tf.compat.v1.SparseTensorValue`
      containing the value of that sparse tensor.
    * A `get_tensor_handle` op.  The corresponding fetched value will be a
      numpy ndarray containing the handle of that tensor.
    * A `string` which is the name of a tensor or operation in the graph.
    The value returned by `run()` has the same shape as the `fetches` argument,
    where the leaves are replaced by the corresponding values returned by
    TensorFlow.
    run的返回值与fetches的形状一致
    
    Example:
    ```python
       a = tf.constant([10, 20])
       b = tf.constant([1.0, 2.0])
       # 'fetches' can be a singleton
       v = session.run(a)
       # v is the numpy array [10, 20]  # 这里就是单个元素作为fetch数值
       # 'fetches' can be a list.
       v = session.run([a, b])          # 这里作为list取回值
       # v is a Python list with 2 numpy arrays: the 1-D array [10, 20] and the
       # 1-D array [1.0, 2.0]
       # 'fetches' can be arbitrary lists, tuples, namedtuple, dicts:
       MyData = collections.namedtuple('MyData', ['a', 'b'])
       v = session.run({'k1': MyData(a, b), 'k2': [b, a]})
       # v is a dict with
       # v['k1'] is a MyData namedtuple with 'a' (the numpy array [10, 20]) and
       # 'b' (the numpy array [1.0, 2.0])
       # v['k2'] is a list with the numpy array [1.0, 2.0] and the numpy array
       # [10, 20].
    ```
    feed_dict可以使得输入的值覆盖图中定义的tensorflow,并在保持dtype一致的情况下在调用函数内起作用
    The optional `feed_dict` argument allows the caller to override
    the value of tensors in the graph. Each key in `feed_dict` can be
    one of the following types:
    * If the key is a `tf.Tensor`, the
      value may be a Python scalar, string, list, or numpy ndarray
      that can be converted to the same `dtype` as that
      tensor. Additionally, if the key is a
      `tf.compat.v1.placeholder`, the shape of
      the value will be checked for compatibility with the placeholder.
    * If the key is a
      `tf.SparseTensor`,
      the value should be a
      `tf.compat.v1.SparseTensorValue`.
    * If the key is a nested tuple of `Tensor`s or `SparseTensor`s, the value
      should be a nested tuple with the same structure that maps to their
      corresponding values as above.
    Each value in `feed_dict` must be convertible to a numpy array of the dtype
    of the corresponding key.
    The optional `options` argument expects a [`RunOptions`] proto. The options
    allow controlling the behavior of this particular step (e.g. turning tracing
    on).
    The optional `run_metadata` argument expects a [`RunMetadata`] proto. When
    appropriate, the non-Tensor output of this step will be collected there. For
    example, when users turn on tracing in `options`, the profiled info will be
    collected into this argument and passed back.
    
    #----------------------常用输入变量--------------------#
    fetches:图元素,需要从中取出对应运行结果
    feed_dict:字典映射图元素对应的值
    Args:
      fetches: A single graph element, a list of graph elements, or a dictionary
        whose values are graph elements or lists of graph elements (described
        above).
      feed_dict: A dictionary that maps graph elements to values (described
        above).
      options: A [`RunOptions`] protocol buffer
      run_metadata: A [`RunMetadata`] protocol buffer
    Returns:
      Either a single value if `fetches` is a single graph element, or
      a list of values if `fetches` is a list, or a dictionary with the
      same keys as `fetches` if that is a dictionary (described above).
      Order in which `fetches` operations are evaluated inside the call
      is undefined.
    Raises:
      RuntimeError: If this `Session` is in an invalid state (e.g. has been
        closed).
      TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type.
      ValueError: If `fetches` or `feed_dict` keys are invalid or refer to a
        `Tensor` that doesn't exist.
    """

2.代码实例

下面这段代码来源于PU-Net的主函数源码,作者自定义了很多ops。我们直接聚焦在代码的第九行feed_dict来看:

# copy from:https://github.com/yulequan/PU-Net/blob/master/code/main.py
def train_one_epoch(sess, ops, fetchworker, train_writer):
    loss_sum = []
    fetch_time = 0
    for batch_idx in range(fetchworker.num_batches):
        start = time.time()
        batch_input_data, batch_data_gt, radius =fetchworker.fetch()
        end = time.time()
        fetch_time+= end-start
        feed_dict = {ops['pointclouds_pl']: batch_input_data,    #<<<<<<<<<看这里--------------------------
                     ops['pointclouds_gt']: batch_data_gt[:,:,0:3],
                     ops['pointclouds_gt_normal']:batch_data_gt[:,:,0:3],
                     ops['pointclouds_radius']: radius}
        summary,step, _, pred_val,gen_loss_emd = sess.run( [ops['pretrain_merged'],ops['step'],ops['pre_gen_train'],
                                                            ops['pred'], ops['gen_loss_emd']], feed_dict=feed_dict)
        # 这里定义了一个feed_dict
        # 其中键为ops['xxx']  值为各种对应的输入数据
        # 在运行的时候 fetches构成一整个list放到run中计算出summary  step  预训出的结果pred  计算出的损失
        # 为了更直观的看到这里馈入和抓取的张量/变量,我们来看看他们的定义:
        """
        ops = {'pointclouds_pl': pointclouds_pl,
               'pointclouds_gt': pointclouds_gt,
               'pointclouds_gt_normal':pointclouds_gt_normal,
               'pointclouds_radius': pointclouds_radius,
               'pointclouds_image_input':pointclouds_image_input,
               'pointclouds_image_pred': pointclouds_image_pred,
               'pointclouds_image_gt': pointclouds_image_gt,
               'pretrain_merged':pretrain_merged,
               'image_merged': image_merged,
               'gen_loss_emd': gen_loss_emd,
               'pre_gen_train':pre_gen_train,
               'pred': pred,
               'step': step,
               }
       """
       # 其中包含了输入输出的占位符以及对应的计算图元素,还有记录网络运行过程的变量如step.

        
        
        train_writer.add_summary(summary, step)
        loss_sum.append(gen_loss_emd)

        if step%30 == 0:
            pointclouds_image_input = pc_util.point_cloud_three_views(batch_input_data[0,:,0:3])
            pointclouds_image_input = np.expand_dims(np.expand_dims(pointclouds_image_input,axis=-1),axis=0)
            pointclouds_image_pred = pc_util.point_cloud_three_views(pred_val[0, :, :])
            pointclouds_image_pred = np.expand_dims(np.expand_dims(pointclouds_image_pred, axis=-1), axis=0)
            pointclouds_image_gt = pc_util.point_cloud_three_views(batch_data_gt[0, :, 0:3])
            pointclouds_image_gt = np.expand_dims(np.expand_dims(pointclouds_image_gt, axis=-1), axis=0)
            
            # 下面两句则定义了需要从计算图中拿到的一个merged可视化结果,并馈入三个对应数据来获取
            feed_dict ={ops['pointclouds_image_input']:pointclouds_image_input,
                        ops['pointclouds_image_pred']: pointclouds_image_pred,
                        ops['pointclouds_image_gt']: pointclouds_image_gt,
                        }
            summary = sess.run(ops['image_merged'],feed_dict)
            train_writer.add_summary(summary,step)

    loss_sum = np.asarray(loss_sum)
    log_string('step: %d mean gen_loss_emd: %f\n' % (step, round(loss_sum.mean(),4)))
    print 'read data time: %s mean gen_loss_emd: %f' % (round(fetch_time,4), round(loss_sum.mean(),4))

在上面的例子中可以看到run()函数可以根据馈入的feed_dict字典来依据计算图进行计算,而后更具fetches的元素返回出对应元素的计算结果,完成一次运行过程。

https://images.pexels.com/photos/1707825/pexels-photo-1707825.jpeg

ref:
https://www.w3cschool.cn/tensorflow_python/tensorflow_python-slp52jz8.html
https://www.w3cschool.cn/tensorflow_python/tensorflow_python-fibz28ss.html
博客简介:https://www.cnblogs.com/gengyi/p/9865915.html
博客learner_ctr讲解:https://blog.csdn.net/a1066196847/article/details/84104655
一个视频讲解:https://www.aiworkbox.com/lessons/use-feed_dict-to-feed-values-to-tensorflow-placeholders
教程:https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/Session.Run.html#Session.Run()
https://www.tensorflow.org/api_docs/java/reference/org/tensorflow/Session.Runner.html#runAndFetchMetadata()
https://www.cnblogs.com/yao62995/p/5773043.html

发布了357 篇原创文章 · 获赞 307 · 访问量 56万+

猜你喜欢

转载自blog.csdn.net/u014636245/article/details/101698300