基于tensorflow的线性回归

import tensorflow as tf

# 初始化变量和模型参数,定义训练闭环中的运算
W = tf.Variable(tf.zeros([2, 1]), name="weights")
b = tf.Variable(0., name="bias")


def inference(X):  # 计算推断模型在数据X上的输出,并将结果保存
    return tf.matmul(X, W) + b


def loss(X, Y):  # 依据训练数据X和期望输出Y计算损失
    Y_predicted = inference(X)
    return tf.reduce_sum(tf.squared_difference(Y, Y_predicted))


def inputs():  # 读取或生成训练数据X及其期望输出Y
    weight_age = [[84, 46], [73, 20], [65, 52], [70, 30], [76, 57], [69, 25], [63, 28], [72, 36], [79, 57],
                  [75, 44], [27, 24], [89, 31], [65, 52], [57, 23], [59, 60], [69, 48], [60, 34], [79, 51],
                  [75, 50], [82, 34], [59, 46], [67, 23], [85, 37], [55, 40], [63, 30]]
    blood_fat_content = [354, 190, 405, 263, 451, 302, 288, 385, 402, 365, 209, 346, 254, 395, 434, 220, 374, 308, 220,
                         311, 181, 274, 303, 244]
    return tf.to_float(weight_age), tf.to_float(blood_fat_content)


def train(total_loss):  # 依据计算的总损失训练或调整模型参数
    learning_rate = 0.0000001
    return tf.train.GradientDescentOptimizer(learning_rate).minimize(total_loss)


def evaluate(sess, X, Y):  # 对训练得到的模型进行评估
    print(sess.run(inference([[80., 25.]])))
    print(sess.run(inference([[65., 25.]])))


# 在一个会话对象中启动数据流图,搭建流程
with tf.Session() as sess:
    tf.initialize_all_variables().run()
    X, Y = inputs()

    total_loss = loss(X, Y)
    train_op = train(total_loss)

    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    # 实际的训练迭代次数
    training_steps = 1000
    for step in range(training_steps):
        sess.run([train_op])
        # 处于调试和学习的目的,查看损失在训练过程中的递减情况
        if step % 10 == 0:
            print("loss:", sess.run([total_loss]))
    evaluate(sess, X, Y)
    coord.request_stop()
    coord.join(threads)
    sess.close()

对于这种简单的模型,将采用总平方误差,即模型对每个训练样本的预测值与期望输出之差的平方的总和。从代数角度看,这个损失函数实际上是预测的输出向量与期望向量之间欧氏距离的平方。对于2D数据集,总平方误差对应于每个数据点在垂直方向上到所预测的回归直线的距离的平方总和。这种损失函数也称为L2范数或L2损失函数。这里之所以采用平方,是为了避免计算平方根,因为对于最小化损失这个目标,有无平方并无本质区别,但有平方可以节省一定的计算量。

数据来源:http://people.sc.fsu.edu/~jburkardt/datasets/regression/x09.txt

猜你喜欢

转载自blog.csdn.net/qq_25366173/article/details/80223452
今日推荐