TensorFlow学习笔记-feed与fetch

 今天学习Feed 与Feach。他要配合着占位符来使用!!!

会话运行完成之后,如果我们想查看会话运行的结果,就需要使用fetch来实现,feed,fetch同样可以fetch单个或者多个值。 

TensorFlow中数据的feed与fetch

一:占位符(placeholder)与feed

当我们构建一个模型的时候,有时候我们需要在运行时候输入一些初始数据,这个时候定义模型数据输入在tensorflow中就是用placeholder(占位符)来完成。它的定义如下:

def placeholder(dtype, shape=None, name=None):

其中dtype表示数据类型,shape表示维度,name表示名称。它支持单个数值与任意维度的数组输入。

1. 单个数值占位符定义

a = tf.placeholder(tf.float32)
b = tf.placeholder(tf.float32)
c = tf.add(a, b)

当我们需要执行得到c的运行结果时候我们就需要在会话运行时候,通过feed来插入a与b对应的值,代码演示如下:

with tf.Session() as sess:
    result = sess.run(c, feed_dict={a:3, b:4})
    print(result)

其中 feed_dict就是完成了feed数据功能,feed中文有喂饭的意思,这里还是很形象的,对定义的模型来说,数据就是最好的食物,所以就通过feed_dict来实现。

2. 多维数据

同样对于模型需要多维数据的情况下通过feed一样可以完成,定义二维数据的占位符,然后相加,代码如下:

_x = tf.placeholder(shape=[None, 2], dtype=tf.float32, name="x")
_y = tf.placeholder(shape=[None, 2], dtype=tf.float32, name="y")
z = tf.add(_x, _y);

运行时候需要feed二维数组,实现如下:

with tf.Session() as sess:
    result = sess.run(z, feed_dict={_x:[[3, 4], [1, 2]], _y:[[8, 8],[9, 9]]})
    print(result)

二:fetch用法

会话运行完成之后,如果我们想查看会话运行的结果,就需要使用fetch来实现,feed,fetch同样可以fetch单个或者多个值。
1. fetch单个值
矩阵a与b相乘之后输出结果,通过会话运行接受到值c_res这个就是fetch单个值,fetch这个单词在数据库编程中比较常见,这里称为fetch也比较形象。代码演示如下:

import tensorflow as tf

a = tf.Variable(tf.random_normal([3, 3], stddev=3.0), dtype=tf.float32)
b = tf.Variable(tf.random_normal([3, 3], stddev=3.0), dtype=tf.float32)
c = tf.matmul(a, b);
init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    c_res = sess.run(c)
    print(c_res)
  1. fetch多个值
    还是以feed中代码为例,我们把feed与fetch整合在一起,实现feed与fetch多个值,代码演示如下:
import tensorflow as tf

_x = tf.placeholder(shape=[None, 2], dtype=tf.float32, name="x")
_y = tf.placeholder(shape=[None, 2], dtype=tf.float32, name="y")
z = tf.add(_x, _y);
data = tf.random_normal([2, 2], stddev=5.0)
Y = tf.add(data, z)

with tf.Session() as sess:

    z_res, Y_res = sess.run((z, Y), feed_dict={_x:[[3, 4], [1, 2]], _y:[[8, 8],[9, 9]]})
    print(z_res)
    print(Y_res)

上述代码我们就fetch了两个值,这个就是feed与fetch的基本用法。下面我们就集合图像来通过feed与fetch实现一些图像ROI截取操作。代码演示如下:

扫描二维码关注公众号,回复: 2838155 查看本文章
import tensorflow as tf
import cv2 as cv

src = cv.imread("D:/javaopencv/test.png")
cv.imshow("input", src)
_image = tf.placeholder(shape=[None, None, 3], dtype=tf.uint8, name="image")
roi_image = tf.slice(_image, [40, 130, 0], [180, 180, -1])

with tf.Session() as sess:
    slice = sess.run(roi_image, feed_dict={_image:src})
    print(slice.shape)
    cv.imshow("roi", slice)
    cv.waitKey(0)
    cv.destroyAllWindows()

原图
这里写图片描述
ROI图像

这里写图片描述

    在神经网络里面,主要用到 Feed和Feach 还是为了看出来的结果数据!

    下面学习的目的:找到输入728维数据得到是4的函数关系,为了得到模型wx+b中的模型系数。标签4,对应一堆728维的向量数据。可以输入任何满足正态分布的728维数据和标签4,通过数学模型wx+b得到数值,使得它与标签4越来越近,进行了100次迭代,训练出系数,得到线性回归系数,w(728*1)和b(1*1)。

      一个原值为4的图片,28*28, np.reshape(image, [1, 784]) 成一维向量,共784个数,然后通过按照一定正态分布的随机数,生成x,w,b,然后让他们进行一定的运算生成y_(期望得到4),通过梯度下降法,学习率为0.01,然后MSE(均方根误差)去作为优化函数,然后按照这种学习方法去找到w,b,使得与4很相近。

mport tensorflow as tf
import numpy as np
import cv2 as cv


def get_x(a):
    image = np.zeros([28, 28], dtype=np.uint8)
    cv.putText(image, str(a), (7, 21), cv.FONT_HERSHEY_PLAIN, 1.3, (255), 2, 8)
    cv.imshow("image", image)
    data = np.reshape(image, [1, 784])
    return data / 255


def feed_fetch():
    x = tf.placeholder(shape=[1, 784], dtype=tf.float32)
    y = tf.placeholder(dtype=tf.float32)
    w = tf.Variable(tf.random_normal([784, 1]))
    b = tf.Variable(tf.random_normal([1, 1]))

    y_ = tf.add(tf.matmul(x, w), b)
    loss = tf.reduce_sum(tf.square(tf.subtract(y, y_)))
    train = tf.train.GradientDescentOptimizer(0.01)
    step = train.minimize(loss)
    init = tf.global_variables_initializer()

    x_input = get_x(4)
    with tf.Session() as sess:
        sess.run(init)
        for i in range(100):
            y_result, curr_loss, curr_step = sess.run([y_, loss, step], feed_dict={x: x_input, y: 4})
            if i % 10 == 0:
                print("y_ :%f ,loss :%f" % (y_result, curr_loss))
        curr_w, curr_b = sess.run([w, b], feed_dict={x: x_input, y: 4})
        print("curr_w :", curr_w)
        print("curr_b : ", curr_b)


if __name__ == "__main__":
    feed_fetch()

结果: 

.......省略很多个curr_w.....

当前的w很复杂,b就是一个数,

猜你喜欢

转载自blog.csdn.net/qq_37791134/article/details/81712772