人工智能实战2019 第5次作业 王铈弘

项目 内容
课程 人工智能实战2019
作业要求 第5次作业
课程目标 学习人工智能基础知识
本次作业对我的帮助 学习并练习应用线性二分类的知识
理论课程 二分类原理

一、作业要求


训练一个逻辑或门和逻辑与门

二、样本数据


1. 逻辑或门的样本数据

样本序号 1 2 3 4
x1 0 0 1 1
x2 0 1 0 1
Y 0 1 1 1

2. 逻辑与门的样本数据

样本序号 1 2 3 4
x1 0 0 1 1
x2 0 1 0 1
Y 0 0 0 1

三、关键公式


1. 分类函数

\[ A(z)=Sigmoid(z)=\frac{1}{1+e^{-z}} \]

2. 二分类交叉熵损失函数

\[ J=-[Y \ln A+(1-Y) \ln (1-A)] \]

3. 使用方式

训练时,一个样本x经过神经网络的最后一层的矩阵运算结果作为输入z,经过Sigmoid函数后,输出一个[0,1]之间的预测值。对于标签值为1的样本数据,预测值越接近0,惩罚越大,反向传播的力度越大;反之同理

四、程序实现


import numpy as np
import matplotlib.pyplot as plt
import math


# 读取样本数据
def ReadData(logic):
    if logic == "OR":
        X = np.array([0, 0, 1, 1, 0, 1, 0, 1]).reshape(2, 4)
        Y = np.array([0, 1, 1, 1]).reshape(1, 4)
    elif logic == "AND":
        X = np.array([0, 0, 1, 1, 0, 1, 0, 1]).reshape(2, 4)
        Y = np.array([0, 0, 0, 1]).reshape(1, 4)
    return X, Y


# 分类函数
def Sigmoid(x):
    s = 1 / (1 + np.exp(-x))
    return s


# 前向计算
def ForwardCalculationBatch(W, B, batch_X):
    Z = np.dot(W, batch_X) + B
    A = Sigmoid(Z)
    return A


# 计算损失函数值,二分类交叉熵损失函数
def CheckLoss(W, B, X, Y):
    m = X.shape[1]
    A = ForwardCalculationBatch(W, B, X)

    p1 = 1 - Y
    p2 = np.log(1 - A)
    p3 = np.log(A)

    p4 = np.multiply(p1, p2)
    p5 = np.multiply(Y, p3)

    LOSS = np.sum(-(p4 + p5))
    loss = LOSS / m
    return loss


# 反向计算
def BackPropagationBatch(X, Y, A):
    m = X.shape[1]
    dZ = A - Y
    dB = dZ.sum(axis=1, keepdims=True) / m
    dW = np.dot(dZ, X.T) / m
    return dW, dB


# 更新权重参数
def UpdateWeights(W, B, dW, dB, eta):
    W = W - eta * dW
    B = B - eta * dB
    return W, B


# 训练
def train(X, Y):

    num_example = X.shape[1]
    num_feature = X.shape[0]
    num_category = Y.shape[0]
    eta = 0.5
    max_epoch = 10000
    loss = 5  # initialize loss (larger than 0)
    error = 1e-3  # stop condition
    w = np.zeros((num_category, num_feature))
    b = np.zeros((num_category, 1))

    for epoch in range(max_epoch):
        print("epoch=%d" % epoch)
        for i in range(num_example):
            x = X[:, i].reshape(2, 1)
            y = Y[:, i].reshape(1, 1)
            z = ForwardCalculationBatch(w, b, x)
            dW, dB = BackPropagationBatch(x, y, z)
            w, b = UpdateWeights(w, b, dW, dB, eta)
        # end for
        loss = CheckLoss(w, b, X, Y)
        #print(epoch, i, loss, w, b)
        if math.isnan(loss):
            break
        if loss < error:
            break
    return w, b
    # end for


# 显示结果
def ShowResult(W, B, X, Y, title):
    # 根据w,b值画出分割线
    w = -W[0, 0] / W[0, 1]
    b = -B[0, 0] / W[0, 1]
    x = np.array([0, 1])
    y = w * x + b
    plt.plot(x, y)

    # 画出原始样本值
    for i in range(X.shape[1]):
        if Y[0, i] == 0:
            plt.scatter(X[0, i], X[1, i], marker="o", c='b', s=64)
        else:
            plt.scatter(X[0, i], X[1, i], marker="^", c='r', s=64)
    plt.axis([-0.1, 1.1, -0.1, 1.1])
    plt.title(title)
    plt.show()


# 测试
def Test(W, B, logic):
    n1 = input("input number one:")
    x1 = float(n1)
    n2 = input("input number two:")
    x2 = float(n2)
    a = ForwardCalculationBatch(W, B, np.array([x1, x2]).reshape(2, 1))
    print(a)
    if (logic == "OR"):
        y = x1 or x2
    if (logic == "AND"):
        y = x1 and x2
    if np.abs(a - y) < 1e-2:
        print("True")
    else:
        print("False")


# 主程序
if __name__ == '__main__':

    logic = "OR"
    X, Y = ReadData(logic)
    w, b = train(X, Y)

    print("w=", w)
    print("b=", b)
    ShowResult(w, b, X, Y, logic)
    while True:
        Test(w, b, logic)

五、运行结果

A4sEGj.png

epoch=4520
w= [[13.13318012 13.13406338]]
b= [[-6.10742937]]
Qt: Untested Windows version 10.0 detected!
input number one:0
input number two:1
[[0.99911287]]
True
input number one:1
input number two:0
[[0.99911209]]
True
input number one:0
input number two:0
[[0.00222132]]
True
input number one:1
input number two:1
[[1.]]
True

A4sMZT.png

epoch=8504
w= [[13.15361034 13.15287496]]
b= [[-19.89698044]]
Qt: Untested Windows version 10.0 detected!
input number one:0
input number two:1
[[0.00117642]]
True
input number one:1
input number two:0
[[0.00117728]]
True
input number one:0
input number two:0
[[2.28481577e-09]]
True
input number one:1
input number two:1
[[0.99835687]]
True

猜你喜欢

转载自www.cnblogs.com/wangshihong/p/10663942.html