tensorflow实战一

概述

本来笔者想从tensorflow的基础概念开始写起,后来感觉还是边写代码边学概念来的靠谱,下面开始首次学习。

构建并运行一个基本图

tensorflow中的运算被组织成一个有向无环图。图的每个节点是一次操作,这样一个图在构成的时候是一个静态图,并没有数据流动。图要在session中被运行,一个session可以运行多个图。运行起来的图才是有数据流动的图。话不多说,开始代码:

#定义常量
m1=tf.constant([[3,3]])
m2=tf.constant([[2],[3]])
mul=tf.matmul(m1,m2)
#上面代码构建了一个向量乘法图,有三个节点m1,m2,mul。

#定义一个session,启动上述图
sess=tf.Session()
result=sess.run(product)
print m1,mul,result
sess.close()
#可以发现只有运行之后图中节点才有数据,m1,mul的值均为0;result的值不是零。

变量和变量赋值

tensorflow中的变量需要先进行赋值操作才能有数据。

#定义部分
state=tf.Variable(0,name='counter') #创建一个变量初始化为0
new_value=tf.add(state,1)
update=tf.assign(state,new_value)#tensor中的赋值操作,把new_value的值赋给state
init=tf.global_variables_initializer() #一个全局变量初始化,只要使用全局变量就要初始化
#运行部分
sess=tf.Session()
sess.run(init)
print sess.run(state)
for _ in range(5):
    sess.run(update)
    print (sess.run(state))
sess.close()

看了两个小例子,大家应该有所发现了,tensorflow中代码分为两个过程,首先定义一个静态图,然后在session中运行。

feed和fetch

feed用于给静态图喂数据,tensorflow中提供占位符这种节点,feed就是给占位符赋值的操作。
fetch用于查看会话运行结果

#fetch
input1=tf.constant(3.0)
input2=tf.constant(2.0)
input3=tf.constant(3.0)
add=tf.add(input2,input3)
mul=tf.multiply(input1,add)
sess=tf.Session()
result=sess.run([mul,add])
print result #发现打印出来的result包含mul和add的值

#feed
input1 = tf.placeholder(tf.float32) #创建占位符,占位符就是空白的节点
input2 = tf.placeholder(tf.float32)
mul = tf.multiply(input1,input2)
result = sess.run(mul,feed_dict={input1:9,input2:1.1}) #喂数据
print result
sess.close()

占位符还可以定义类型,维度,名称属性,上面仅展示了类型属性,有需要的童鞋可以自己再深入学习。

一个简单回归模型训练

下面构建一个梯度下降的回归模型来拟合y=k*x+b

#定义100组样本点
x_data = np.random.rand(100)
y_data = x_data*0.1 + 0.2
#初始化参数k,b都为0
b = tf.Variable(0.)
k = tf.Variable(0.)
y = k*x_data + b
#定义代价函数,误差的平方和的均值(不明白代价函数为什么这样的童鞋可以去学一下原理)
loss = tf.reduce_mean(tf.square(y_data-y))
#定义梯度下降法来训练优化器,学习率设置为0.01
optimizer = tf.train.GradientDescentOptimizer(0.01)
#定义优化器的目标为最小化loss
train = optimizer.minimize(loss)
init = tf.global_variables_initializer()
#运行图,迭代20次
sess = tf.Session()
sess.run(init)
for step in range(201):
    sess.run(train)
    if step%20==0:
        print step,sess.run([k,b])
sess.close()

下面为迭代20次的结果,大家也可以随意改动一下k、b、学习率来看看最后结果是否仍然拟合的很好。

小结

每一个tensorflow模型都是首先构建静态图,然后再运行。想必大家对模型构建有了一定的了解,大家可以去实现更为复杂的回归或分类例子。今天暂时先讲到这里,楼主会继续更文的!

猜你喜欢

转载自blog.csdn.net/qq_40504899/article/details/84970635