"Aprendizaje profundo práctico"-64 Mecanismo de atención

Notas de estudio de la versión de Mushen de "Learning Deep Learning by Hands", registrando el proceso de aprendizaje, compre libros para obtener contenido detallado.
enlace de video de la estación B
enlace de tutorial de código abierto

mecanismo de atención

Señales de atención en biología.

inserte la descripción de la imagen aquí

El sistema visual de los primates recibe una enorme cantidad de información sensorial mucho más allá de lo que el cerebro puede procesar por completo. Sin embargo, no todos los estímulos son iguales. La convergencia y el enfoque de la conciencia permiten a los primates dirigir la atención a objetos de interés, como presas y depredadores, en entornos visuales complejos.

Por ejemplo, la aplicación de la atención en el mundo visual conduce a avisos involuntarios y avisos autónomos . Los avisos involuntarios se basan en la prominencia y visibilidad de los objetos en el entorno. Por ejemplo, en la imagen de la izquierda, la taza de café está en el actual entorno visual Los más destacados y conspicuos atraen involuntariamente la atención de las personas. Cuando terminamos nuestro café y queremos leer un libro, nos reagrupamos.En este momento, debido al control cognitivo y consciente, la atención será más cautelosa al tomar decisiones basadas en señales voluntarias.

consulta, clave y valor

inserte la descripción de la imagen aquí

Las señales de atención voluntarias e involuntarias explican cómo los humanos prestan atención.
Considere un caso relativamente simple en el que solo se utilizan señales involuntarias. Para sesgar la selección hacia la entrada sensorial, uno puede simplemente usar una capa paramétrica totalmente conectada, o incluso una capa de agrupación máxima no paramétrica o una capa de agrupación promedio.

Entonces, "si se incluyen sugerencias de autonomía" distingue el mecanismo de atención de la capa totalmente conectada o la capa de agrupación. En el contexto de los mecanismos de atención, las señales autónomas se denominan consultas ( query). Ante cualquier consulta, el mecanismo de atención guía las selecciones a entradas sensoriales ( ), como representaciones de características intermedias, a través de la agrupación de atención ( ). En el mecanismo de atención, estas entradas sensoriales se denominan valores ( ). Más coloquialmente, cada valor se empareja con una clave ( ), que se puede imaginar como una señal involuntaria para la entrada sensorial.attention poolingsensory inputvaluekey

La generalización es: mediante el diseño de la agrupación de atención, es fácil hacer coincidir una consulta dada (señal autónoma) con una clave (señal involuntaria), lo que conducirá al mejor valor de coincidencia (entrada sensorial).

inserte la descripción de la imagen aquí

La regresión del kernel de Nadaraya-Watson es un modelo de agrupación de atención no paramétrica.La regresión del kernel de Nadaraya-Watson no paramétrica tiene la ventaja de la consistencia: si hay suficientes datos, este modelo convergerá al mejor resultado.

inserte la descripción de la imagen aquí

Integre parámetros aprendibles en la agrupación de atención:

inserte la descripción de la imagen aquí

Resumir

inserte la descripción de la imagen aquí

aprendizaje práctico

visualización de la atención

import torch
from d2l import torch as d2l

#@save
def show_heatmaps(matrices, xlabel, ylabel, titles=None, figsize=(2.5, 2.5),
                  cmap='Reds'):
    """显示矩阵热图"""
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize,
                                 sharex=True, sharey=True, squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6);

attention_weights = torch.eye(10).reshape((1, 1, 10, 10))
show_heatmaps(attention_weights, xlabel='Keys', ylabel='Queries')

inserte la descripción de la imagen aquí

Agrupación de atención: regresión del kernel de Nadaraya-Watson

from torch import nn

# 生成数据集
n_train = 50  # 训练样本数
x_train, _ = torch.sort(torch.rand(n_train) * 5)   # 排序后的训练样本 0-5之间随机生成50个数

def f(x):
    return 2 * torch.sin(x) + x**0.8

y_train = f(x_train) + torch.normal(0.0, 0.5, (n_train,))  # 训练样本的输出
x_test = torch.arange(0, 5, 0.1)  # 测试样本
y_truth = f(x_test)  # 测试样本的真实输出
n_test = len(x_test)  # 测试样本数

# 绘制所有的训练样本(圆圈)、不带噪声项的真实数据生成函数、学习得到的预测函数
def plot_kernel_reg(y_hat):
    d2l.plot(x_test, [y_truth, y_hat], 'x', 'y', legend=['Truth', 'Pred'],
             xlim=[0, 5], ylim=[-1, 5])
    d2l.plt.plot(x_train, y_train, 'o', alpha=0.5);
# 平均汇聚
y_hat = torch.repeat_interleave(y_train.mean(), n_test)
plot_kernel_reg(y_hat)

inserte la descripción de la imagen aquí

# 非参数注意力汇聚

# X_repeat的形状:(n_test,n_train),
# 每一行都包含着相同的测试输入(例如:同样的查询)
X_repeat = x_test.repeat_interleave(n_train).reshape((-1, n_train))
# x_train包含着键。attention_weights的形状:(n_test,n_train),
# 每一行都包含着要在给定的每个查询的值(y_train)之间分配的注意力权重
attention_weights = nn.functional.softmax(-(X_repeat - x_train)**2 / 2, dim=1)
# y_hat的每个元素都是值的加权平均值,其中的权重是注意力权重
y_hat = torch.matmul(attention_weights, y_train)
plot_kernel_reg(y_hat)

inserte la descripción de la imagen aquí
Observa el peso de la atención. Aquí la entrada de los datos de prueba es equivalente a la consulta, y la entrada de los datos de entrenamiento es equivalente a la clave. Dado que ambas entradas están ordenadas, se puede observar que cuanto más cerca esté el par de claves de consulta, mayor será el peso de atención del grupo de atención.

d2l.show_heatmaps(attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

inserte la descripción de la imagen aquí
Transforme el conjunto de datos de entrenamiento en claves y valores para entrenar el modelo de atención. En el modelo de agrupación de atención con parámetros, la entrada de cualquier muestra de entrenamiento se calculará con los pares "clave-valor" de todas las muestras de entrenamiento excepto ella misma, para obtener su salida prevista correspondiente.

# 带参数注意力汇聚
class NWKernelRegression(nn.Module):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.w = nn.Parameter(torch.rand((1,), requires_grad=True))

    def forward(self, queries, keys, values):
        # queries和attention_weights的形状为(查询个数,“键-值”对个数)
        queries = queries.repeat_interleave(keys.shape[1]).reshape((-1, keys.shape[1]))
        self.attention_weights = nn.functional.softmax(
            -((queries - keys) * self.w)**2 / 2, dim=1)
        # values的形状为(查询个数,“键-值”对个数)
        return torch.bmm(self.attention_weights.unsqueeze(1),
                         values.unsqueeze(-1)).reshape(-1)

# X_tile的形状:(n_train,n_train),每一行都包含着相同的训练输入
X_tile = x_train.repeat((n_train, 1))
# Y_tile的形状:(n_train,n_train),每一行都包含着相同的训练输出
Y_tile = y_train.repeat((n_train, 1))
# keys的形状:('n_train','n_train'-1)
keys = X_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
# values的形状:('n_train','n_train'-1)
values = Y_tile[(1 - torch.eye(n_train)).type(torch.bool)].reshape((n_train, -1))
net = NWKernelRegression()
loss = nn.MSELoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.5)
animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[1, 5])

for epoch in range(50):
    trainer.zero_grad()
    l = loss(net(x_train, keys, values), y_train)
    l.sum().backward()
    trainer.step()
    print(f'epoch {
      
      epoch + 1}, loss {
      
      float(l.sum()):.6f}')
    animator.add(epoch + 1, float(l.sum()))

inserte la descripción de la imagen aquí
Como se muestra a continuación, después de entrenar el modelo de agrupación de atención con parámetros, se puede encontrar que: Al intentar ajustar los datos de entrenamiento ruidosos, los resultados de la predicción dibujan una línea que no es tan suave como el modelo no paramétrico anterior.

# keys的形状:(n_test,n_train),每一行包含着相同的训练输入(例如,相同的键)
keys = x_train.repeat((n_test, 1))
# value的形状:(n_test,n_train)
values = y_train.repeat((n_test, 1))
y_hat = net(x_test, keys, values).unsqueeze(1).detach()
plot_kernel_reg(y_hat)

inserte la descripción de la imagen aquí

d2l.show_heatmaps(net.attention_weights.unsqueeze(0).unsqueeze(0),
                  xlabel='Sorted training inputs',
                  ylabel='Sorted testing inputs')

inserte la descripción de la imagen aquí

Supongo que te gusta

Origin blog.csdn.net/cjw838982809/article/details/132080982
Recomendado
Clasificación