神经网络学习中的损失函数及mini-batch学习

# 损失函数(loss function)。这个损失函数可以使用任意函数,
# 但一般用均方误差(mean squared error)和交叉熵误差(cross entropy error)等
一切都在代码时有注释哈。
import numpy as np
from minst import load_mnist


# 损失函数(loss function)。这个损失函数可以使用任意函数,
# 但一般用均方误差(mean squared error)和交叉熵误差(cross entropy error)等

# 均方误差会计算神经网络的输出和正确解监督数据的各个元素之差的平方,再求总和
def mean_quared_error(y, t):
    return 0.5 * np.sum((y-t)**2)


# 设“2”为正确解
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
# “2”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
print(mean_quared_error(np.array(y), np.array(t)))
# “7”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
print(mean_quared_error(np.array(y), np.array(t)))


def cross_entropy_error(y, t):
    # 保护性对策,添加一个微小值delta可以防止负无限大的发生
    delta = 1e-7
    if y.ndim == 1:
        t = t.reshape(1, t.size)
        y = y.reshape(1, y.size)
    batch_size = y.shape[0]
    # t 为 one-hot 表示
    return -np.sum(t * np.log(y+delta)) / batch_size
    #  t 为标签形式时
    # return -np.sum(np.log(y[np.arange(batch_size), t] + delta)) / batch_size


# 设“2”为正确解
t = [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]
# “2”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.6, 0.0, 0.05, 0.1, 0.0, 0.1, 0.0, 0.0]
print(cross_entropy_error(np.array(y), np.array(t)))
# “7”的概率最高的情况(0.6)
y = [0.1, 0.05, 0.1, 0.0, 0.05, 0.1, 0.0, 0.6, 0.0, 0.0]
print(cross_entropy_error(np.array(y), np.array(t)))

# 当数据集的训练数据有很大时,如果以全部数据为对象求损失函数的和,则计算过程需要花费较长的时间。
# 再者,如果遇到大数据,数据量会有几百万、几千万之多,这种情况下以全部数据为对象计算损失函数是不现实的。
# 因此,我们从全部数据中选出一部分,作为全部数据的“近似”。
# 神经网络的学习也是从训练数据中选出一批数据(称为mini-batch,小批量),然后对每个mini-batch进行学习。
# 比如,从60000个训练数据中随机选择100笔,再用这100笔数据进行学习。
# 这种学习方式称为mini-batch学习。

(x_train, t_train), (x_test, t_test) = load_mnist(normalize=True, one_hot_label=True)
print(x_train.shape)
print(t_train.shape)
train_size = x_train.shape[0]
batch_size = 10
batch_mask = np.random.choice(train_size, batch_size)
x_batch = x_train[batch_mask]
t_batch = t_train[batch_mask]
print(x_batch)
print(t_batch)
C:\Python36\python.exe C:/Users/Sahara/PycharmProjects/test1/test.py
C:\Users\Sahara\PycharmProjects\test1
0.09750000000000003
0.5975
0.510825457099338
2.302584092994546
(60000, 784)
(60000, 10)
[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[[0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 1. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 0. 0. 1. 0. 0. 0. 0. 0. 0.]]

Process finished with exit code 0

  

猜你喜欢

转载自www.cnblogs.com/aguncn/p/10859128.html