Pytorch学习(二) --- 模型定义之torchvivsion.models快速构建预训练模型

torchvision.model是torchvision一个很重要的包,里面包含了以下模型结构:

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet

并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。
在进行深度学习的图像分类任务时,我们可以利用torchvision.model这个包快速构建模型,做适当调整即可运用于分类训练。

使用例子1

import torchvision.models as models
# pretrained=True就可以使用预训练的模型
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)

上述就创建了resnet18和alexnet两个模型,并且用预训练模型的参数来初始化。
如果不需要预训练模型,那么pretrained=False即可。如下:

import torchvision.models as models
# pretrained=True就可以使用预训练的模型
resnet18 = models.resnet18(pretrained=False)
alexnet = models.alexnet(pretrained=False)

不过,构建的这些模型,都是在imagenet上训练得到的,他们的默认输出类别数是1000,那如果我们需要训练自己的数据,并且数据类别数目不是1000时,我们需要在最后一层微调。

使用例子2

class ResNet_101(nn.Module):
    def __init__(self, num_classes):
        super(ResNet_101, self).__init__()
        model = models.resnet101(pretrained=True)
        model.fc = nn.Sequential(
                nn.Linear(2048, num_classes, bias=True),
        )
        self.net = model
    
    def forward(self, img):
        output = self.net(img)
        return output

下面代码就构建了用预训练模型进行参数初始化的输出类别数目为3的resnet101模型。

model = ResNet_101(num_classes=3)

从torchvision 0.3.0开始,torchvision.models中就集成了目标检测、分割、关键点检测的models。

Semantic Segmentation:

  • FCN ResNet101
  • DeepLabV3 ResNet101

Object Detection:

  • Faster R-CNN ResNet-50 FPN

Instance Segmentation:

  • Mask R-CNN ResNet-50 FPN

Person Keypoint Detection:

  • keypointrcnn_resnet50_fpn

https://pytorch.org/docs/stable/torchvision/models.html#

原创文章 96 获赞 24 访问量 3万+

猜你喜欢

转载自blog.csdn.net/c2250645962/article/details/105211429