以下是简单的 C++ 代码实现 ResNet34 网络的示例:
#include <torch/torch.h>
// 定义 ResNet34 模型
torch::nn::Sequential resnet34()
{
// 定义基础模块
struct BasicBlock : torch::nn::Module
{
BasicBlock(int64_t inplanes, int64_t planes, int64_t stride = 1,
torch::nn::Sequential downsample = nullptr)
: conv1(torch::nn::Conv2dOptions(inplanes, planes, 3)
.stride(stride)
.padding(1)
.bias(false)),
bn1(planes),
conv2(torch::nn::Conv2dOptions(planes, planes, 3)
.stride(1)
.padding(1)
.bias(false)),
bn2(planes),
downsample(downsample),
stride(stride)
{
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("conv2", conv2);
register_module("bn2", bn2);
if (downsample)
register_module("downsample", downsample);
}
torch::Tensor forward(torch::Tensor x)
{
torch::Tensor identity = x;
x = conv1(x);
x = bn1(x);
x = torch::relu_(x);
x = conv2(x);
x = bn2(x);
if (downsample)
identity = downsample(identity);
x += identity;
x = torch::relu_(x);
return x;
}
torch::nn::Conv2d conv1{nullptr};
torch::nn::BatchNorm2d bn1{nullptr};
torch::nn::Conv2d conv2{nullptr};
torch::nn::BatchNorm2d bn2{nullptr};
torch::nn::Sequential downsample{nullptr};
int64_t stride{1};
};
// 定义网络结构
struct Net : torch::nn::Module
{
Net()
: conv1(torch::nn::Conv2dOptions(3, 64, 7)
.stride(2)
.padding(3)
.bias(false)),
bn1(64),
layer1(make_layer(64, 64, 3)),
layer2(make_layer(64, 128, 4, 2)),
layer3(make_layer(128, 256, 6, 2)),
layer4(make_layer(256, 512, 3, 2)),
avgpool(torch::nn::AdaptiveAvgPool2dOptions({1, 1})),
fc(512, 1000)
{
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("layer1", layer1);
register_module("layer2", layer2);
register_module("layer3", layer3);
register_module("layer4", layer4);
register_module("avgpool", avgpool);
register_module("fc", fc);
}
torch::nn::Sequential make_layer(int64_t inplanes, int64_t planes, int64_t blocks,
int64_t stride = 1)
{
torch::nn::Sequential downsample{nullptr};
if (stride != 1 || inplanes != planes)
{
downsample = torch::nn::Sequential{
torch::nn::Conv2d(torch::nn::Conv2dOptions(inplanes, planes, 1)
.stride(stride)
.bias(false)),
torch::nn::BatchNorm2d(planes)};
}
torch::nn::Sequential layers;
layers->push_back(BasicBlock(inplanes, planes, stride, downsample));
for (int64_t i = 1; i < blocks; ++i)
{
layers->push_back(BasicBlock(planes, planes));
}
return layers;
}
torch::Tensor forward(torch::Tensor x)
{
x = conv1(x);
x = bn1(x);
x = torch::relu_(x);
x = torch::max_pool2d(x, 3, 2, 1);
x = layer1->forward(x);
x = layer2->forward(x);
x = layer3->forward(x);
x = layer4->forward(x);
x = avgpool(x);
x = x.view({x.size(0), -1});
x = fc->forward(x);
return x;
}
torch::nn::Conv2d conv1{nullptr};
torch::nn::BatchNorm2d bn1{nullptr};
torch::nn::Sequential layer1{nullptr};
torch::nn::Sequential layer2{nullptr};
torch::nn::Sequential layer3{nullptr};
torch::nn::Sequential layer4{nullptr};
torch::nn::AdaptiveAvgPool2d avgpool{nullptr};
torch::nn::Linear fc{nullptr};
};
return Net();
}
int main()
{
// 创建一个 224 x 224 的随机输入
torch::Tensor input = torch::rand({1, 3, 224, 224});
// 构建模型
torch::nn::Sequential model = resnet34();
std::cout << model << std::endl;
// 前向传递
torch::Tensor output = model->forward(input);
std::cout << output.sizes() << std::endl;
return 0;
}
这是一个基于 PyTorch C++ API 实现的简单示例,欢迎参考。