神经网络解决二分类问题

 
 
import tensorflow as tf
from numpy.random import RandomState


batch_size = 8
w1 = tf.Variable(tf.random_normal([2, 3], stddev=1, seed=1))
w2 = tf.Variable(tf.random_normal([3, 1], stddev=1, seed=1))
x = tf.placeholder(tf.float32, shape=(None, 2), name="x-input")
y_ = tf.placeholder(tf.float32, shape=(None, 1), name="y-input")

a = tf.matmul(x, w1)
y = tf.matmul(a, w2)


#定义损失函数和反向传播的算法
cross_entropy = -tf.reduce_mean(
        y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_step = tf.train.AdadeltaOptimizer(0.001).minimize(cross_entropy)


#通过随机数生成一个模拟数据集
rdm = RandomState(1)
dataset_size = 128
X = rdm.rand(dataset_size, 2)
Y = [[int(x1+x2 < 1)] for (x1, x2) in X]


with tf.Session() as sess:  
    init_op = tf.global_variables_initializer()  
    sess.run(init_op)  
    print(sess.run(w1))   
    print(sess.run(w2)) 


#设定训练的轮数
    STEPS = 5000
    for i in range(STEPS):
    #每次选取batch_size个样本进行训练
        start = (i * batch_size) % dataset_size
        end = min(start+batch_size, dataset_size)
    
    #通过选取的样本训练神经网络并更新参数
        sess.run(train_step, feed_dict={x: X[start:end], y_: Y[start:end]})
        if i % 1000 == 0:
            total_cross_entropy = sess.run(cross_entropy, feed_dict={x: X, y_: Y})
            print("After %d training step(s), cross entropy on all data is %g" %(i, total_cross_entropy))
    print(sess.run(w1))   
    print(sess.run(w2)) 
[[-0.8113182   1.4845988   0.06532937]
 [-2.4427042   0.0992484   0.5912243 ]]
[[-0.8113182 ]
 [ 1.4845988 ]
 [ 0.06532937]]
After 0 training step(s), cross entropy on all data is 0.0677411
After 1000 training step(s), cross entropy on all data is 0.0676762
After 2000 training step(s), cross entropy on all data is 0.0676077
After 3000 training step(s), cross entropy on all data is 0.0675391
After 4000 training step(s), cross entropy on all data is 0.0674705
[[-0.812737    1.48603     0.06671807]
 [-2.444031    0.10049716  0.5924467 ]]
[[-0.812682  ]
 [ 1.4861072 ]
 [ 0.06662364]]



猜你喜欢

转载自blog.csdn.net/qq_34000894/article/details/80104643