import matplotlib.pyplot as plt import numpy as np from sklearn import datasets import tensorflow as tf sess=tf.Session() iris=datasets.load_iris() #print(iris) target=np.array([1. if x==0 else 0. for x in iris.target ]) #print(target.shape) iris_data=np.array([ [x[2],x[3]] for x in iris.data] ) #shape[none,2] #声明批量 batch_size=20 #宽度长度 均为【NOne,1】 x1_data=tf.placeholder(tf.float32,shape=[None,1]) x2_data=tf.placeholder(tf.float32,shape=[None,1]) y_target=tf.placeholder(tf.float32,shape=[None,1]) #初始化类型为1,1 可以和x1 相乘 A=tf.Variable(tf.random_normal(shape=[1,1])) b=tf.Variable(tf.random_normal(shape=[1,1])) #线性模型 x1=x2*A+b --》 f=x1-x2*A-b my_mult=tf.matmul(x2_data,A) my_add=tf.add(my_mult,b) my_output=tf.subtract(x1_data,my_add) #损失函数 (交叉熵损失函数 非归一化 常用于两类验证) sigmoid_logits=tf.nn.sigmoid_cross_entropy_with_logits(labels=y_target,logits=my_output) #梯度下降取最小值 (选择学习率0.05) my_opt=tf.train.GradientDescentOptimizer(0.05) train_step=my_opt.minimize(sigmoid_logits) #初始化所有声明的变量 init=tf.global_variables_initializer() sess.run(init) #迭代100次 训练模型 传入三种数据 长度 宽度 和目标 for i in range(1500): #随机获取批量数据 根据(iris_data)的长度已经确定 rand_index=np.random.choice(len(iris_data) ,batch_size) #shape=[batchsize,1] x1_rand= np.array([[iris_data[x][0]] for x in rand_index],dtype=np.float32) x2_rand = np.array([[iris_data[x][1]] for x in rand_index],dtype=np.float32) y_rand=np.array([[target[x]] for x in rand_index],dtype=np.float32) sess.run(train_step,feed_dict={x1_data:x1_rand,x2_data:x2_rand,y_target:y_rand}) if (i+1)%200 ==0: print('Step %s :A= %s ; b=%s ' % ( i+1,str(sess.run(A)), str(sess.run(b)) )) #保存A,b [[slope]]=sess.run(A) [[intercept]]=sess.run(b) x=np.linspace(0,3,num=50) abline=[] for i in x: abline.append(slope*i+intercept) #重新选取数据 从目标1中选取 长度 宽度 set1_x=[ a[1] for i,a in enumerate(iris_data) if target[i] == 1] set1_y=[ a[0] for i,a in enumerate(iris_data) if target[i] == 1] #重新选取数据 从目标0中 选取长度宽度 no_set1_x=[ a[1] for i,a in enumerate(iris_data) if target[i] == 0] no_set1_y=[ a[0] for i,a in enumerate(iris_data) if target[i] == 0] plt.plot(set1_x,set1_y,'rx',ms=10,mew=2,label='set1') #plt.clabel('set1') plt.plot(no_set1_x,no_set1_y,'ro',label='set0') #plt.clabel('set0') plt.plot(x,abline,'b-',label='my') plt.xlim([0.0,2.7]) plt.ylim([0.0,7.1]) plt.xlabel('length') plt.ylabel('width') plt.legend(loc='lower right') plt.show()
TensorFlow(一) 鸢尾花采用批量数据进行线性模拟
猜你喜欢
转载自www.cnblogs.com/x0216u/p/9167229.html
今日推荐
周排行