《动手学深度学习》-64注意力机制

沐神版《动手学深度学习》学习笔记,记录学习过程,详细的内容请大家购买书籍查阅。
b站视频链接
开源教程链接

注意力机制

生物学中的注意力提示

在这里插入图片描述

灵长类动物的视觉系统接受了大量的感官输入,这些感官输入远远超出了大脑所能够完全处理的能力。然而并非所有刺激的影响都是等同的。意识的汇聚和专注使灵长类动物能够在复杂的视觉环境中将注意力引向感兴趣的物体,如猎物和天敌。

举例注意力在视觉世界中的应用,引出非自主性提示自主性提示,非自主性提示是基于环境中物体的突出性和易见性,如左图中,咖啡杯在当前视觉环境中是最突出和显眼的,不由自主地引起人们地注意力。当我们喝完咖啡想要读书时,重新聚集,这时因为受到认知和意识地控制,注意力在基于自主性提示来进行选择时会更为谨慎。

查询、键和值

在这里插入图片描述

自主性与非自主性地注意力提示解释了人类的注意力地方式。
考虑一个相对简单的情况,即只是用非自主性提示。要想将选择偏向感官输入,则可以简单地使用参数化的全连接层,甚至是非参数化的最大汇聚层或平均汇聚层。

因此“是否包含自主性提示”将注意力机制与全连接层或汇聚层区分开来。在注意力机制的背景下,自主性提示被称为查询(query)。给定任何查询,注意力机制通过注意力汇聚(attention pooling)将选择引导至感官输入(sensory input),例如中间特征表示。在注意力机制中,这些感官输入被称为(value)。更通俗的讲,每个值都与一个(key)匹配,这可以想象为感官输入的非自主性提示。

概括就是:通过设计注意力汇聚的方式,便于给定的查询(自主性提示)与键(非自主性提示)进行匹配,这将引导得出最匹配的值(感官输入)。

在这里插入图片描述

Nadaraya-Watson核回归是一个非参数的注意力汇聚模型,非参数的Nadaraya-Watson核回归具有一致性的优点:如果有足够的数据,此模型会收敛到最佳结果。

在这里插入图片描述

将可学习的参数集成到注意力汇聚中:

在这里插入图片描述

总结

在这里插入图片描述

动手学

注意力的可视化

import torch
from d2l import torch as d2l

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);

attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')

在这里插入图片描述

注意力汇聚: Nadaraya-Watson核回归

from torch import nn

# 生成数据集
n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本 0-5之间随机生成50个数

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数

# 绘制所有的训练样本(圆圈)、不带噪声项的真实数据生成函数、学习得到的预测函数
def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);
# 平均汇聚
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

在这里插入图片描述

# 非参数注意力汇聚

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

在这里插入图片描述
观察注意力的权重。 这里测试数据的输入相当于查询,而训练数据的输入相当于键。 因为两个输入都是经过排序的,因此由观察可知“查询-键”对越接近, 注意力汇聚的注意力权重就越高。

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

在这里插入图片描述
将训练数据集变换为键和值用于训练注意力模型。 在带参数的注意力汇聚模型中, 任何一个训练样本的输入都会和除自己以外的所有训练样本的“键-值”对进行计算, 从而得到其对应的预测输出。

# 带参数注意力汇聚
class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(50):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {
      
      epoch + 1}, loss {
      
      float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))

在这里插入图片描述
如下所示,训练完带参数的注意力汇聚模型后可以发现: 在尝试拟合带噪声的训练数据时, 预测结果绘制的线不如之前非参数模型的平滑。

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

在这里插入图片描述

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/cjw838982809/article/details/132080982