TensorFlow(一) 鸢尾花采用批量数据进行线性模拟

import matplotlib.pyplot as plt
import  numpy as np
from sklearn import  datasets
import tensorflow as tf
sess=tf.Session()
iris=datasets.load_iris()
#print(iris)
target=np.array([1. if x==0 else 0. for x in iris.target  ])
#print(target.shape)
iris_data=np.array([ [x[2],x[3]]  for x in iris.data] ) #shape[none,2]

#声明批量
batch_size=20
#宽度长度 均为【NOne,1】
x1_data=tf.placeholder(tf.float32,shape=[None,1])
x2_data=tf.placeholder(tf.float32,shape=[None,1])
y_target=tf.placeholder(tf.float32,shape=[None,1])
#初始化类型为1,1  可以和x1 相乘
A=tf.Variable(tf.random_normal(shape=[1,1]))
b=tf.Variable(tf.random_normal(shape=[1,1]))

#线性模型 x1=x2*A+b  --》 f=x1-x2*A-b
my_mult=tf.matmul(x2_data,A)
my_add=tf.add(my_mult,b)
my_output=tf.subtract(x1_data,my_add)

#损失函数 (交叉熵损失函数 非归一化 常用于两类验证)
sigmoid_logits=tf.nn.sigmoid_cross_entropy_with_logits(labels=y_target,logits=my_output)

#梯度下降取最小值 (选择学习率0.05)
my_opt=tf.train.GradientDescentOptimizer(0.05)
train_step=my_opt.minimize(sigmoid_logits)

#初始化所有声明的变量
init=tf.global_variables_initializer()
sess.run(init)

#迭代100次 训练模型 传入三种数据 长度 宽度 和目标
for i in range(1500):
    #随机获取批量数据 根据(iris_data)的长度已经确定
    rand_index=np.random.choice(len(iris_data) ,batch_size)
    #shape=[batchsize,1]
    x1_rand= np.array([[iris_data[x][0]]  for x in rand_index],dtype=np.float32)
    x2_rand = np.array([[iris_data[x][1]] for x in rand_index],dtype=np.float32)
    y_rand=np.array([[target[x]] for x in rand_index],dtype=np.float32)

    sess.run(train_step,feed_dict={x1_data:x1_rand,x2_data:x2_rand,y_target:y_rand})
    if (i+1)%200 ==0:
        print('Step %s :A= %s  ; b=%s   ' % ( i+1,str(sess.run(A)), str(sess.run(b))   ))

#保存A,b
[[slope]]=sess.run(A)
[[intercept]]=sess.run(b)
x=np.linspace(0,3,num=50)
abline=[]
for i in x:
    abline.append(slope*i+intercept)
#重新选取数据  从目标1中选取 长度 宽度
set1_x=[ a[1] for i,a in enumerate(iris_data) if target[i] == 1]
set1_y=[ a[0] for i,a in enumerate(iris_data) if target[i] == 1]
#重新选取数据 从目标0中 选取长度宽度
no_set1_x=[ a[1] for i,a in enumerate(iris_data) if target[i] == 0]
no_set1_y=[ a[0] for i,a in enumerate(iris_data) if target[i] == 0]

plt.plot(set1_x,set1_y,'rx',ms=10,mew=2,label='set1')
#plt.clabel('set1')

plt.plot(no_set1_x,no_set1_y,'ro',label='set0')
#plt.clabel('set0')

plt.plot(x,abline,'b-',label='my')
plt.xlim([0.0,2.7])
plt.ylim([0.0,7.1])
plt.xlabel('length')
plt.ylabel('width')
plt.legend(loc='lower right')
plt.show()

猜你喜欢

转载自www.cnblogs.com/x0216u/p/9167229.html