[Red neuronal gráfica] Implementación de código de red ViG (Vision GNN)

Interpretación del papel:

[Red neuronal gráfica] Red neuronal gráfica visual ViG (Vision GNN)--Lectura en papel https://blog.csdn.net/weixin_37878740/article/details/130124772?spm=1001.2014.3001.5501 Dirección del código:

ViG https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch

1. Estructura de la red

        ViG se puede apilar en una arquitectura isotrópica (similar a ViT) y una arquitectura piramidal (similar a ResNet). Este artículo analiza principalmente la estructura piramidal PyramidViG-B como ejemplo. El código involucrado son los tres archivos en las carpetas pyramid.py y gcn_lib en git.

         Como se muestra en la figura anterior, se puede construir una red piramidal con cuatro etapas apilando bloques ViG de diferentes especificaciones. Después del trasplante, puede reemplazar a Resnet50 como red troncal en Faster RCNN (pero el efecto del trasplante directo no es ideal).

        Código de definición de red:

def pvig_b_224_gelu(num_classes =1000,pretrained=False, **kwargs):
    class OptInit:
        # 参数列表
        def __init__(self, num_classes=1000, drop_path_rate=0.0, **kwargs):
            self.k = 9 # 邻居节点数,默认为9
            self.conv = 'mr' # 图卷积层类型,可选 {edge, mr}
            self.act = 'gelu' # 激活层类型,可选 {relu, prelu, leakyrelu, gelu, hswish}
            self.norm = 'batch' # 归一化方式,可选 {batch, instance}
            self.bias = True # 卷积层是否使用偏置
            self.dropout = 0.0 # dropout率
            self.use_dilation = True # 是否使用扩张knn
            self.epsilon = 0.2 # gcn的随机采样率
            self.use_stochastic = False # gcn的随机性
            self.drop_path = drop_path_rate
            self.blocks = [2,2,18,2] # 各层的block个数
            self.channels = [128, 256, 512, 1024] # 各层的通道数
            self.n_classes = num_classes # 分类器输出通道数
            self.emb_dims = 1024 # 嵌入尺寸

    opt = OptInit(**kwargs)
    model = DeepGCN(opt)    #构造gcn
    model.default_cfg = default_cfgs['vig_b_224_gelu']    #注入参数
    return model
#  网络参数计算代码
class DeepGCN(torch.nn.Module):
    def __init__(self, opt):
        super(DeepGCN, self).__init__()
        # ...
        #  参数赋值省略
        # ...

        blocks = opt.blocks            # 获取各层block个数列表[2,2,18,2]
        self.n_blocks = sum(blocks)    # 获取block层数总数
        channels = opt.channels        # 获取输出通道数(用于分类器赋值)
        reduce_ratios = [4, 2, 1, 1]   # 下采样率
        #  获取FFN的随机深度衰减规律
        dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)]
        # 获取各层knn的数量
        num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)]
        max_dilation = 49 // max(num_knn)    #最大相关数目
        HW = 224 // 4 * 224 // 4

 2. Módulo ViG

      La construcción de la red real utiliza bloques ViG para el apilamiento. Los bloques ViG están compuestos por módulos GCN y módulos FFN . La construcción utiliza bucles de código para apilar bloques ViG.

# 构造骨干网络
self.backbone = nn.ModuleList([])
        idx = 0
        for i in range(len(blocks)):
            if i > 0:
                #  如果不是第一层需要额外在层间添加下采样
                self.backbone.append(Downsample(channels[i-1], channels[i]))
                HW = HW // 4
            for j in range(blocks[i]):
                self.backbone += [
                    # 构造GCN
                    Seq(Grapher(channels[i], num_knn[idx], min(idx // 4 + 1, max_dilation), conv, act, norm,
                                    bias, stochastic, epsilon, reduce_ratios[i], n=HW, drop_path=dpr[idx],
                                    relative_pos=True),
                    # 构造FFN
                          FFN(channels[i], channels[i] * 4, act=act, drop_path=dpr[idx])
                         )]
                idx += 1
        self.backbone = Seq(*self.backbone)
        # 构造分类器
        self.prediction = Seq(nn.Conv2d(channels[-1], 1024, 1, bias=True),
                              nn.BatchNorm2d(1024),
                              act_layer(act),
                              nn.Dropout(opt.dropout),
                              nn.Conv2d(1024, opt.n_classes, 1, bias=True))
        self.model_init()

En la función de transferencia directa de la red, puede ver que la imagen está sujeta a la codificación de tallo (es decir, la operación de parche en ViT) y posición (la matriz correspondiente a la posición)         antes de ingresar a la red gráfica.

    def forward(self, inputs):
        x = self.stem(inputs) + self.pos_embed    #patch分割和位置嵌入
        B, C, H, W = x.shape
        for i in range(len(self.backbone)):
            x = self.backbone[i](x)

        x = F.adaptive_avg_pool2d(x, 1)
        return self.prediction(x).squeeze(-1).squeeze(-1)

       Las operaciones de tallo y las incrustaciones de posición son las siguientes:

self.stem = Stem(out_dim=channels[0], act=act)
#返回整数部分
self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], 224//4, 224//4))

        1. Módulo gráfico

                Primer vistazo a la función de transferencia directa de Grapher

def forward(self, x):
        _tmp = x
        x = self.fc1(x)
        B, C, H, W = x.shape
        relative_pos = self._get_relative_pos(self.relative_pos, H, W)
        x = self.graph_conv(x, relative_pos)
        x = self.fc2(x)
        x = self.drop_path(x) + _tmp
        return x

                Se puede ver que para cada módulo de Grapher, el flujo de procesamiento básico es:

                ① Capa completamente conectada fc1

# 由一个1x1Conv和一个BatchNorm组成
self.fc1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels),
        )

                ②La posición asociada se actualiza mediante la función _get_relative_pos(.)

                De hecho, desde el punto de vista del código, se usa para hacer coincidir el cambio de tamaño causado por la reducción de resolución (ajustar el tamaño)

    def _get_relative_pos(self, relative_pos, H, W):
        if relative_pos is None or H * W == self.n:
            return relative_pos
        else:
            N = H * W
            N_reduced = N // (self.r * self.r)
            return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0)

                 Cuando se inicializa el bloque, el valor inicial lo proporciona la función get_2d_relative_pos_embed(.) (si no está habilitada, se establecerá en Ninguno directamente);

# 获取位置嵌入
relative_pos_tensor = torch.from_numpy(np.float32(
        get_2d_relative_pos_embed(in_channels,int(n**0.5)))).unsqueeze(0).unsqueeze(1)
# 进行双线性插值
relative_pos_tensor = F.interpolate(relative_pos_tensor, size=(n, n//(r*r)), 
        mode='bicubic', align_corners=False)
# 转换为nn参数
self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1), requires_grad=False)

                        get_2d_relative_pos_embed(.) Función de incrustación posicional, ubicada en gcn_lib/pos_embed.py . La función es construir una cuadrícula y obtener la ubicación incrustada (incluido cls_token)

                ③Convolución de gráficos (graph_conv)

self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size,
                     dilation, conv, act, norm, bias, stochastic, epsilon, r)

                Vaya a graph_conv para ver su función de paso hacia adelante:

def forward(self, x, relative_pos=None):
    B, C, H, W = x.shape
    y = None
    if self.r > 1:    #  此参数为下采样率,金字塔池化情况下默认开启(始终大于1)
        y = F.avg_pool2d(x, self.r, self.r)
        y = y.reshape(B, C, -1, 1).contiguous()            
    x = x.reshape(B, C, -1, 1).contiguous()

    # 获取邻居节点的聚合信息(基于knn)
    edge_index = self.dilated_knn_graph(x, y, relative_pos)
    # 图卷积
    x = super(DyGraphConv2d, self).forward(x, edge_index, y)
    # 将tensor变形为四维并输出
    return x.reshape(B, -1, H, W).contiguous()

                Entre ellos, self.dilated_knn_graph es DenseDilatedKnnGraph, que proviene de gcn_lib/torch_edge.py Como la mayoría de los algoritmos de red de gráficos, torch.topk(.) se usa para la escasez de matriz de adyacencia. Al mismo tiempo, use la función part_pairwise_distance para extraer tres valores de x_square_part, x_inner y x_square de la característica.

                ④ Capa totalmente conectada fc2

        self.fc2 = nn.Sequential(
            nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels),
        )

                Esto es lo mismo que la capa totalmente conectada anterior, excepto que los canales de entrada se duplican.

                ⑤Eliminación aleatoria de DropPath

self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

                Se utiliza para evitar el sobreajuste, y la red también tiene una estructura similar a los residuales.

x = self.drop_path(x) + _tmp

        2. Módulo FNN

                El módulo FNN es un perceptrón multicapa, implementado por dos capas de conexiones completas, y también tiene una estructura residual

shortcut = x
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop_path(x) + shortcut
return x
        self.fc1 = nn.Sequential(
            nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
            nn.BatchNorm2d(hidden_features),
        )
        self.act = act_layer(act)
        self.fc2 = nn.Sequential(
            nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
            nn.BatchNorm2d(out_features),
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

                La capa de activación actúa aquí por defecto a la función de activación relu .

3. Migración de red

        Gracias a las características multiescala que ofrece la estructura piramidal, ViG se puede utilizar como una red troncal para la extracción de características como Swin Transformer Aquí, se trasplanta como una red troncal a Faster RCNN para reemplazar el ResNet50 original. Después de quitar el cabezal de predicción de predicción y la agrupación promedio adaptive_avg_pool2d , se puede obtener una característica de 7x7x1024 a partir de una entrada de 224x224x3.

    def forward(self, inputs):
        x = self.stem(inputs) + self.pos_embed
        B, C, H, W = x.shape
        for i in range(len(self.backbone)):
            x = self.backbone[i](x)

        # x = F.adaptive_avg_pool2d(x, 1)
        return x

        Después de la prueba, ViG puede obtener más del 70 % de mAP en el conjunto de datos, pero el efecto es inferior a resnet50 y mobilenetv3, y se desconoce el motivo específico.

Supongo que te gusta

Origin blog.csdn.net/weixin_37878740/article/details/130642630
Recomendado
Clasificación