Tensorflow-基本语法

1.常量/变量/占位符:

s=tf.constant('abc',dtype=tf.string)

ac=tf.constant(2.0,dtype=tf.float64)

bv=tf.Variable(1.0,dtype=tf.float32)          //只有经过session才能被真实的赋值,现在还是read:0状态

pl=tf.placeholder(tf.float32)          //在session中feed_dict={pl:[1,2,3,4,5]}赋予值,在计算图中只是一个边(空张量),只有run时才会填入数据

2.session:

sess=tf.Session()

op=tf.add(pl,ac)

sess.run(tf.global_variables_initializer())

sess.run(op,feed_dict={pl:bv})

3.一些OP:

tf.add(x,y);    tf.sub(x,y);     tf.mul(x,y);    tf.div(x,y);    tf.mod(x,y);    tf.maximum(x,y);     tf.minimum(x,y);

tf.abs(x);绝对值    tf.neg(x);取负    tf,sign(x);返回符号    tf.inv(x);取反    tf.square(x);平方    tf.sqrt(x);开根号    tf.exp(x);e的x次方

tf.log(x);    tf.cos(c);    tf.sin(x);    tf.tan(x);    tf.atan(x);

4.TFRecord读写:

写入:

writer= tf.python_io.TFRecordWriter('A.tfrecords')
img=np.random.random((1,3)).tobytes()
example=tf.train.Example(features=
            tf.train.Features(feature=
                 {'label':tf.train.Feature(int64_list=tf.train.Int64List(value=[0])),
                  'img_raw':tf.train.Feature(bytes_list=tf.train.BytesList(value=[img]))
                 }
                             )
                        )
writer.write(example.SerializeToString())
writer.close()

读入:

reader=tf.TFRecordReader()
file_queue= tf.train.string_input_producer(['A.tfrecords'],num_epochs=None)
_,example=reader.read(file_queue)
feature_map={'label':tf.FixedLenFeature([],tf.int64),'img_raw':tf.FixedLenFeature([],tf.string)}
features=tf.parse_single_example(example,features=feature_map)
label=features['label']
img=tf.reshape(tf.decode_raw(features['img_raw'],tf.uint8),[1,24])
l,im=tf.train.shuffle_batch([label,img],batch_size=1,capacity=100,min_after_dequeue=10,num_threads=2)
sess=tf.Session()
sess.run(tf.global_variables_initializer())
tf.train.start_queue_runners(sess=sess)
print(sess.run(im))

5.MNIST手写字符识别:

 train:

import cv2
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
import matplotlib.pyplot as plt
tf.reset_default_graph()
def get_weight(shape):
    return tf.Variable(tf.truncated_normal(shape,stddev=0.1))

def get_bias(shape):
    return tf.Variable(tf.constant(0.1,shape=shape))

def conv2d(x,w):
    return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

sess=tf.InteractiveSession()

x=tf.placeholder('float32',shape=[None,784])
y=tf.placeholder('float32',shape=[None,10])
x_img=tf.reshape(x,[-1,28,28,1])
## net struct
w_conv1=get_weight([5,5,1,32])
b_conv1=get_bias([32])
net=tf.nn.relu(conv2d(x_img,w_conv1)+b_conv1)
net=max_pool_2x2(net)#[-1,14,14,6]

w_conv2=get_weight([5,5,32,64])
b_conv2=get_bias([64])
net=tf.nn.relu(conv2d(net,w_conv2)+b_conv2)
net=max_pool_2x2(net)#[-1,7,7,16]

net=tf.reshape(net,[-1,7*7*64])
w_fc1=get_weight([7*7*64,1024])
b_fc1=get_bias([1024])
net=tf.nn.relu(tf.matmul(net,w_fc1)+b_fc1)#[-1,120]
net=tf.nn.dropout(net,0.98)

w_fc2=get_weight([1024,10])
b_fc2=get_bias([10])
net=tf.nn.softmax(tf.matmul(net,w_fc2)+b_fc2)#[-1,10]

#loss function
CE=-tf.reduce_sum(y*tf.log(net))
train=tf.train.RMSPropOptimizer(0.0001).minimize(CE)

#ACC
correct=tf.equal(tf.argmax(y,1),tf.argmax(net,1))
ACC=tf.reduce_mean(tf.cast(correct,'float32'))

sess.run(tf.global_variables_initializer())
data_set=input_data.read_data_sets('MNIST_data',one_hot=True)
#save sess
start_time=time.time()
save_model='.//model//model.ckpt'
saver=tf.train.Saver()
train_writer=tf.summary.FileWriter('.//log',sess.graph)
loss=[]
acc=[]
#coord threads
coord=tf.train.Coordinator()
threads=tf.train.start_queue_runners(coord=coord)
#train
for i in range(1000):
    batch_x,batch_y=data_set.train.next_batch(200)
    if i%10==0:
        #train
        sess.run(train,feed_dict={x:batch_x,y:batch_y})
        #acc append
        acc_value=sess.run(ACC,feed_dict={x:batch_x,y:batch_y})
        acc.append(acc_value)
        #loss append
        loss_value=sess.run(CE,feed_dict={x:batch_x,y:batch_y})
        loss.append(loss_value)
        #model save
        saver.save(sess,save_model)
        print('step %d,acc = %g,loss=%g' %(i,acc_value,loss_value))
#coord threads
coord.request_stop()
coord.join(threads)
sess.close()
plt.plot(acc)
plt.plot(loss)
plt.tight_layout()

restore and test:

import cv2
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time
import matplotlib.pyplot as plt
tf.reset_default_graph()
def get_weight(shape):
    return tf.Variable(tf.truncated_normal(shape,stddev=0.1))

def get_bias(shape):
    return tf.Variable(tf.constant(0.1,shape=shape))

def conv2d(x,w):
    return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')

sess=tf.InteractiveSession()

x=tf.placeholder('float32',shape=[None,784])
y=tf.placeholder('float32',shape=[None,10])
x_img=tf.reshape(x,[-1,28,28,1])
## net struct
w_conv1=get_weight([5,5,1,32])
b_conv1=get_bias([32])
net=tf.nn.relu(conv2d(x_img,w_conv1)+b_conv1)
net=max_pool_2x2(net)#[-1,14,14,6]

w_conv2=get_weight([5,5,32,64])
b_conv2=get_bias([64])
net=tf.nn.relu(conv2d(net,w_conv2)+b_conv2)
net=max_pool_2x2(net)#[-1,7,7,16]

net=tf.reshape(net,[-1,7*7*64])
w_fc1=get_weight([7*7*64,1024])
b_fc1=get_bias([1024])
net=tf.nn.relu(tf.matmul(net,w_fc1)+b_fc1)#[-1,120]
net=tf.nn.dropout(net,0.98)

w_fc2=get_weight([1024,10])
b_fc2=get_bias([10])
net=tf.nn.softmax(tf.matmul(net,w_fc2)+b_fc2)#[-1,10]

#loss function
CE=-tf.reduce_sum(y*tf.log(net))
train=tf.train.RMSPropOptimizer(0.0001).minimize(CE)

#ACC
correct=tf.equal(tf.argmax(y,1),tf.argmax(net,1))
ACC=tf.reduce_mean(tf.cast(correct,'float32'))

sess.run(tf.global_variables_initializer())

batch_x,batch_y=data_set.train.next_batch(1)
    #print (sess.run(batch_x))
model=tf.train.latest_checkpoint('.//model')
saver=tf.train.Saver()
saver.restore(sess,model)
prediction=sess.run(net,feed_dict={x:batch_x})
max_index=np.argmax(prediction)
cv2.imshow('test',batch_x.reshape([28,28]))
cv2.waitKey(0)
print(max_index)
sess.close()

猜你喜欢

转载自blog.csdn.net/s729193140/article/details/89108162