机器学习之特征缩放

一般用第三种:

来源

相关代码:波士顿房价预测:

代码来源:代码

#宗旨,先前项传播,然后进行梯度下降算法,然后反向传播,通过反向传播用来更新参数
import numpy as np
from sklearn.datasets import load_breast_cancer
import  matplotlib.pyplot as plt


def feature_scalling(X):
    mean = X.mean(axis=0)
    std = X.std(axis=0)
    return (X - mean) / std


def load_data(shuffled=False):
    data_cancer = load_breast_cancer()
    x = data_cancer.data
    y = data_cancer.target
    x = feature_scalling(x)
    y = np.reshape(y, (len(y), 1))
    if shuffled:
        shuffled_index = np.random.permutation(y.shape[0])
        x = x[shuffled_index]
        y = y[shuffled_index]
    return x, y


def sigmoid(z):
    gz = 1 / (1 + np.exp(-z))
    return gz


def gradDescent(X, y, W, b, alpha, maxIt):
    cost_history = []
    maxIteration = maxIt
    m, n = X.shape
    for i in range(maxIteration):#每迭代一次参数更新一次
        z = np.dot(X, W) + b
        error = sigmoid(z) - y
        print (error.shape)
        W = W - (1 / m) * alpha * np.dot(X.T, error)
        b = b - (1.0 / m) * alpha * np.sum(error)
        cost_history.append(cost_function(X, y, W, b))#每迭代一次计算当前代价并添加到cost_history中
    return W, b, cost_history


def accuracy(X, y, W, b):
    m, n = np.shape(X)
    z = np.dot(X, W) + b
    y_hat = sigmoid(z)
    predictioin = np.ones((m, 1), dtype=float)
    for i in range(m):
        if y_hat[i, 0] < 0.5:
            predictioin[i] = 0.0
    return 1 - np.sum(np.abs(y - predictioin)) / m


def cost_function(X, y, W, b):
    m, n = X.shape
    z = np.dot(X, W) + b
    y_hat = sigmoid(z)
    J = (-1 / m) * np.sum(y * np.log(y_hat) + (1 - y) * np.log(1 - y_hat))
    return J

if __name__ == '__main__':#程序从这里开始运行
    X, y = load_data()
    m, n = X.shape
    print(X.shape)#569*30
    alpha = 0.1
    W = np.random.randn(n, 1)
    b = 0.1
    maxIt = 200
    W, b, cost_history = gradDescent(X, y, W, b, alpha, maxIt)
    plt.plot(np.arange(len(cost_history)), cost_history)#画出代价函数和迭代次数的曲线
    plt.show()
    print("******************")
    print("W is :             ")
    print(W)
    print("accuracy is :         " + str(accuracy(X, y, W, b)))
    print("******************")

猜你喜欢

转载自blog.csdn.net/weixin_40849273/article/details/82964197