在pytorch中,如何对标准的预训练模型进行修改以适应三通道以上的输入

net = resnet50(pretrained=pretrained)
with torch.no_grad():
    pretrained_conv1 = net.conv1.weight.clone()
    # Assign new conv layer with 4 input channels
    net.conv1 = torch.nn.Conv2d(4, 64, 7, 2, 3, bias=False)
    # Use same initialization as vanilla ResNet (Don't know if good idea)
    torch.nn.init.kaiming_normal_(
        net.conv1.weight, mode='fan_out', nonlinearity='relu')
    # Re-assign pretraiend weights to first 3 channels
    # (assuming alpha channel is last in your input data)
    net.conv1.weight[:, :3] = pretrained_conv1

此代码的作用是修改标准预训练模型来适应四通道的输入,前三个通道保持原来的参数,最后一个通道kaiming初始化

发布了33 篇原创文章 · 获赞 3 · 访问量 5556

猜你喜欢

转载自blog.csdn.net/weixin_42990464/article/details/98079235