torchvision.model是torchvision一个很重要的包,里面包含了以下模型结构:
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
- …
并且提供了预训练模型,可以通过简单调用来读取网络结构和预训练模型。
在进行深度学习的图像分类任务时,我们可以利用torchvision.model这个包快速构建模型,做适当调整即可运用于分类训练。
使用例子1:
import torchvision.models as models
# pretrained=True就可以使用预训练的模型
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
上述就创建了resnet18和alexnet两个模型,并且用预训练模型的参数来初始化。
如果不需要预训练模型,那么pretrained=False
即可。如下:
import torchvision.models as models
# pretrained=True就可以使用预训练的模型
resnet18 = models.resnet18(pretrained=False)
alexnet = models.alexnet(pretrained=False)
不过,构建的这些模型,都是在imagenet上训练得到的,他们的默认输出类别数是1000,那如果我们需要训练自己的数据,并且数据类别数目不是1000时,我们需要在最后一层微调。
使用例子2:
class ResNet_101(nn.Module):
def __init__(self, num_classes):
super(ResNet_101, self).__init__()
model = models.resnet101(pretrained=True)
model.fc = nn.Sequential(
nn.Linear(2048, num_classes, bias=True),
)
self.net = model
def forward(self, img):
output = self.net(img)
return output
下面代码就构建了用预训练模型进行参数初始化的输出类别数目为3的resnet101模型。
model = ResNet_101(num_classes=3)
从torchvision 0.3.0开始,torchvision.models中就集成了目标检测、分割、关键点检测的models。
Semantic Segmentation:
- FCN ResNet101
- DeepLabV3 ResNet101
Object Detection:
- Faster R-CNN ResNet-50 FPN
Instance Segmentation:
- Mask R-CNN ResNet-50 FPN
Person Keypoint Detection:
- keypointrcnn_resnet50_fpn