tensorflow实现最小二乘法(以boston房价为例)

import tensorflow as tf
import pandas as pd
import numpy as np

#导入数据
from sklearn.datasets import load_boston
boston = load_boston()
features = np.array(boston.data)
labels = np.array(boston.target)

n_training_samples = features.shape[0]
n_dim = features.shape[1]


def normalize(dataset):
    mu = np.mean(dataset, axis=0)
    sigma = np.std(dataset, axis=0)
    return (dataset - mu) / sigma


features_norm = normalize(features)

train_x = np.transpose(features_norm)
train_y = np.transpose(labels)
train_x.shape
train_y.shape
train_y = train_y.reshape(1, len(train_y))
train_y.shape
#y = wx + b

#构建ols模型
#Neuron and cost function for linear Regression
tf.reset_default_graph()
X = tf.placeholder(tf.float32, [n_dim, None])
Y = tf.placeholder(tf.float32, [1, None])

W = tf.Variable(tf.ones([1, n_dim]))
b = tf.Variable(tf.zeros(1))

learn_rate = tf.placeholder(tf.float32, shape=())
#learn_rate = 0.01
init = tf.global_variables_initializer()
y_ = tf.matmul(W, X) + b
cost = tf.reduce_mean(tf.square(Y - y_))

train_step = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost)


def run_linear_model(learn_r, training_epochs, train_obs, train_labels):
    sess = tf.Session()
    sess.run(init)

    cost_history = np.empty(shape=[0], dtype=float)
    for epoch in range(training_epochs + 1):
        sess.run(
            train_step,
            feed_dict={
                X: train_obs,
                Y: train_labels,
                learn_rate: learn_r
            })
        cost_ = sess.run(
            cost,
            feed_dict={
                X: train_obs,
                Y: train_labels,
                learn_rate: learn_r
            })
        cost_history = np.append(cost_history, cost_)

        if epoch % 200 == 0:
            print("Reached epoch", epoch, "cost J = ",
                  str.format('{0:.6f}', cost_))

    return sess, cost_history


sess, cost_history = run_linear_model(
    learn_r=0.01,
    training_epochs=2000,
    train_obs=train_x,
    train_labels=train_y)

#
import matplotlib.pyplot as plt 
plt.plot(range(len(cost_history)), cost_history)
plt.show()

猜你喜欢

转载自blog.csdn.net/yuanzhoulvpi/article/details/84257603