pytorch en resnet cómo carga rápida modelo de pre-formación proporcionada por el funcionario

En el proceso de hacer las estructuras de redes neuronales, resnet pytorch utiliza a menudo como la columna vertebral, en particular resnet50, la configuración de red como la siguiente

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision import models

class base_resnet(nn.Module):
    def __init__(self):
        super(base_resnet, self).__init__()
        self.model = models.resnet50(pretrained=True)
        #self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
        self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)

        # x = x.view(x.size(0), x.size(1))
        return x

Los corresponde estructura de la red resnet50 hereda todos los parámetros, pero es en el futuro, cambiar el proceso de transmisión de datos, no después de que las características de despliegue finales y clasificación lineal. Al seguir esta línea de código es equivalente a llamar a la red pytoch resnet50 se define en, y automáticamente descargar y cargar los parámetros de red capacitados, si se ajusta a pretrained = False, el parámetro no está cargado entrenado, pero parámetros de asignación al azar. Pero me encontré con este tipo de código en el servidor cuando lo encuentra, cuando vuelva a ejecutar un programa, si está configurado como TRUE volver a descargar los parámetros resnet50 capacitado, sino porque a veces la red particularmente malo, porque yo descargo un fundamento resnet50 me costaría mucho tiempo, así que quería ser capaz de descargar los parámetros de este avance resnet50 buen uso del tiempo para cargar directamente. Por supuesto, se habilita.

self.model = models.resnet50(pretrained=True)

Podemos utilizar nuestra estructura, para descargar el modelo correspondiente a la dirección correspondiente a la resnet local común siguiente dirección:

 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',

Para descargarlo, y se coloca en los archivos de modelo y de sus net.py en la misma carpeta directorio de abajo, y luego usar el siguiente código puede evitar la re-descarga cada vez que el tema del modelo.

self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))

 

Publicado 36 artículos originales · ganado elogios 11 · vistas 6539

Supongo que te gusta

Origin blog.csdn.net/t20134297/article/details/103885879
Recomendado
Clasificación