Red de atención gráfica GAT

Derivación de la red convolucional gráfica

Este enlace es el proceso de derivación.

REVÓLVER

En resumen, la red es la matriz de adyacencia del gráfico de la FIG de enfoque borrosa , borrosa adaptativamente.
Inserte la descripción de la imagen aquí

En términos simples, el proceso consiste en incorporar nuevas muestras obtenidas a través de la transformación de la red, unirlas en pares e ingresarlas en la red para calcular el coeficiente de atención.
Inserte la descripción de la imagen aquí

  • Los datos de entrada son X ∈ RN ∗ FX \ in R ^ {N * F}XRN FFFF es la dimensión de la característica,NNN es el número de muestras.

  • La característica de salida es H ∈ RN ∗ f H \ in R ^ {N * f}HRN f , la salida final esY ∈ RN ∗ CY \ in R ^ {N * C}YRN C ,CCC es el número total de categorías.

  • W ∈ RF ∗ f W \ en R ^ {F * f}WRF f es la matriz de ponderaciones de la red gráfica, que se puede generalizar como una transformación de mapeo.

  • a ∈ R 2 f ∗ 1 a \ in R ^ {2f * 1}unR2 f 1 es la matriz de ponderaciones para calcular el coeficiente de atención entre muestras El mecanismo de atención a es una red neuronal feedforward de una sola capa, que puede modificarse y promoverse.

El proceso detallado es el siguiente:

  1. Transforme los datos de entrada para obtener una nueva incorporación, h = XW, h ∈ RN ∗ fh = XW, h \ in R ^ {N * f}h=X W ,hRN f
  2. Combine las muestras en pares para obtener una entrada _ ∈ RN ∗ N ∗ 2 fa \ _input \ in R ^ {N * N * 2f}a _ i n p u tRN N 2 f
  3. a _ inputa \ _inputa _ i n p u t se ingresa en la red, y la matriz de coeficientese = a _ ingresa ∗ a, e ∈ RN ∗ N e = a \ _input * a, e \ in R ^ {N * N}mi=a _ i n p u tuna ,miRN N
  4. De acuerdo con la matriz de adyacencia, mantenga los coeficientes de la posición distinta de cero de la matriz de adyacencia y reemplace el resto con una constante mínima.
  5. Después de softmax softmaxS O F T m A Matriz normalizada de X filas para prestar atenciónatención atencióna t t e n t i o n
  6. 计算H = δ (atención ∗ h), H ∈ RN ∗ f H = \ delta (atención * h), H \ in R ^ {N * f}H=δ ( a t t e n t i o nh ) ,HRN f es la suma ponderada de la nueva incorporación de la muestra obtenida después del cambio y otras muestras.
  7. Si hay m atenciones de múltiples cabezales, vuelva a 1, intercambie m veces y empalme para obtener H ^ ∈ RN ∗ mf \ hat {H} \ in R ^ {N * mf}H^RN m f
  8. Finalmente, la salida de múltiples cabezales empalmados se ingresa nuevamente en la red de atención, pero los pesos W, a W, aW , es necesario cambiar el tamaño deW ∈ R mf ∗ CW \ in R ^ {mf * C}WRm f Ca ∈ R 2 C ∗ 1 a \ in R ^ {2C * 1}unR2 C 1 , el proceso es el mismo que 1-6, y la salida finalY ∈ RN ∗ CY \ en R ^ {N * C}YRN C

La siguiente fórmula expresa el significado del proceso.
Inserte la descripción de la imagen aquí

código de versión de pytorch

enlace

class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Parameter(torch.empty(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, h, adj):
        Wh = torch.mm(h, self.W) # h.shape: (N, in_features), Wh.shape: (N, out_features)
        a_input = self._prepare_attentional_mechanism_input(Wh)
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))

        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)
        h_prime = torch.matmul(attention, Wh)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

Supongo que te gusta

Origin blog.csdn.net/weixin_42764932/article/details/112124896
Recomendado
Clasificación