如何使用 libtorch 实现 LeNet 网络?

如何使用 libtorch 实现 LeNet 网络?

LeNet 网络论文地址:
http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf

LeNet

C1 卷积层

{1,1,28,28} 是什么?

1 输入的批次
1 图像的通道大小
28 图像的高
28 图像的宽

输入:{1,1,28,28}

通过填充一个边界 2 ,使得输入变成 {1,1,32,32}

滑动窗口大小:{5,5}

输出:{1,6,32,32}

S2 降采样

输入:{1,6,32,32}

滑动窗口大小:{2,2,}
滑动步长:{2,2}

输出:{1,6,14,14}

C3 卷积层

输入:{1,16,14,14}

扫描二维码关注公众号,回复: 5897736 查看本文章

滑动窗口大小:{5,5}

输出:{1,16,10,10}

S4 降采样

输入:{1,16,10,10}

滑动窗口大小:{2,2,}
滑动步长:{2,2}

输出:{1,16,5,5}

C5 卷积层

输入:{1,16,5,5}

滑动窗口大小:{5,5}

输出:{1,120,1,1}

F6 全连接层

这里要把网络形状从 {1,120,1,1} 改变改变成 {1,120}

第一个全连接
输入:{1,120}
输出:{1,84}

第二个全连接
输入:{1,84}
输出:{84,10}

0~9 总共是 10 个类别嘛,这里就输出 10个就行了。

全连接就是线性层,网络形状不一样不能全连接的,所以这里要把形状改变成一样的。
基本按照那图写一遍就明白了。

关于输入和输出的网络推断公式可以去参考 pytorch 里面的函数说明,上面都有写推断公式滴。

// Define a new Module.
struct Net : torch::nn::Module {
    Net() {
        conv1 = register_module("conv1", torch::nn::Conv2d(torch::nn::Conv2dOptions(1, 6, /*kernel_size*/{ 5,5 }).padding(/*28->32*/{2,2})));
        conv2 = register_module("conv2", torch::nn::Conv2d(torch::nn::Conv2dOptions(6, 16, /*kernel_size*/{5,5})));
        conv3 = register_module("conv3", torch::nn::Conv2d(torch::nn::Conv2dOptions(16, 120, /*kernel_size*/{5,5})));
        fc1 = register_module("fc1", torch::nn::Linear(torch::nn::LinearOptions(120, 84)));
        fc2 = register_module("fc2", torch::nn::Linear(torch::nn::LinearOptions(84, 10)));
    }

    // Implement the Net's algorithm.
    torch::Tensor forward(torch::Tensor x) {
        x = conv1->forward(x);//6@28x28
        x = torch::max_pool2d(x, { 2,2 }, { 2,2 });//6@14x14
        x = conv2->forward(x);//16@10x10
        x = torch::max_pool2d(x, { 2,2 }, { 2,2 });//16@10x10
        
        x = conv3->forward(x);//120@1x1
        x = x.view({ x.size(0),-1 });//-1 表示自动推理计算出该值
        x = fc1->forward(x);//120->84
        x = fc2->forward(x);//84->10
        x = torch::log_softmax(x,/*dim=*/1);
        return x;
    }

    // Use one of many "standard library" modules.
    torch::nn::Conv2d conv1 { nullptr };
    torch::nn::Conv2d conv2 { nullptr };
    torch::nn::Conv2d conv3 { nullptr };
    torch::nn::Linear fc1{ nullptr };
    torch::nn::Linear fc2{ nullptr };
};

猜你喜欢

转载自www.cnblogs.com/cheungxiongwei/p/10710968.html