"Hands-on Deep Learning"-64 Attention Mechanism

Mushen's version of "Learning Deep Learning by Hands" study notes, recording the learning process, please buy books for detailed content.
B station video link
open source tutorial link

attention mechanism

Attention cues in biology

insert image description here

The primate visual system receives an enormous amount of sensory input far beyond what the brain can fully process. Not all stimuli are created equal, however. Convergence and focus of consciousness allow primates to direct attention to objects of interest, such as prey and predators, in complex visual environments.

For example, the application of attention in the visual world leads to involuntary prompts and autonomous prompts . The involuntary prompts are based on the prominence and visibility of objects in the environment. For example, in the left picture, the coffee cup is in the current visual environment. The most prominent and conspicuous ones involuntarily attract people's attention. When we finish our coffee and want to read a book, we regroup. At this time, because of cognitive and conscious control, attention will be more cautious in making choices based on voluntary cues.

query, key and value

insert image description here

Voluntary and involuntary attentional cues explain how humans pay attention.
Consider a relatively simple case where only involuntary cues are used. To bias the selection toward sensory input, one can simply use a parametric fully connected layer, or even a non-parametric max pooling or average pooling layer.

So "whether autonomy hints are included" distinguishes the attention mechanism from the fully connected layer or the pooling layer. In the context of attention mechanisms, autonomous cues are called queries ( query). Given any query, the attention mechanism guides selections to sensory inputs ( ), such as intermediate feature representations, via attention pooling ( ). In the attention mechanism, these sensory inputs are called values ​​( ). More colloquially, each value is matched with a key ( ), which can be imagined as an involuntary cue for sensory input.attention poolingsensory inputvaluekey

The generalization is: by designing attention pooling, it is easy to match a given query (autonomous cue) with a key (involuntary cue), which will lead to the best matching value (sensory input).

insert image description here

Nadaraya-Watson kernel regression is a non-parametric attention pooling model. Non-parametric Nadaraya-Watson kernel regression has the advantage of consistency: if there is enough data, this model will converge to the best result.

insert image description here

Integrate learnable parameters into attention pooling:

insert image description here

Summarize

insert image description here

hands-on learning

Visualization of attention

import torch
from d2l import torch as d2l

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);

attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')

insert image description here

Attention Pooling: Nadaraya-Watson Kernel Regression

from torch import nn

# 生成数据集
n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本 0-5之间随机生成50个数

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数

# 绘制所有的训练样本(圆圈)、不带噪声项的真实数据生成函数、学习得到的预测函数
def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);
# 平均汇聚
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

insert image description here

# 非参数注意力汇聚

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

insert image description here
Observe the weight of attention. Here the input of the test data is equivalent to the query, and the input of the training data is equivalent to the key. Since both inputs are sorted, it can be observed that the closer the query-key pair is, the higher the attention weight of the attention pool will be.

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

insert image description here
Transform the training dataset into keys and values ​​for training the attention model. In the attention pooling model with parameters, the input of any training sample will be calculated with the "key-value" pairs of all training samples except itself, so as to obtain its corresponding predicted output.

# 带参数注意力汇聚
class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(50):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {
      
      epoch + 1}, loss {
      
      float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))

insert image description here
As shown below, after training the attention pooling model with parameters, it can be found that: When trying to fit the noisy training data, the prediction results draw a line that is not as smooth as the previous non-parametric model.

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

insert image description here

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

insert image description here

Guess you like

Origin blog.csdn.net/cjw838982809/article/details/132080982