PyTorch C++ 官方教程摘要(1) Using the PyTorch C++ Frontend

0. 前言

1. 为什么要用C++

  • 其实就是相比Python,C++的优势。
    • C++前端的目标不是替代Python前端,而是补充。
  • Low Latency Systems
    • 低延时,说白了就是模型推理速度以及其他数据处理的速度加快。
  • Highly Multithreaded Environments
    • Python存在GIL的限制,不能同时运行多个线程。
    • 虽然有替代方案,即多进程,但扩展性一般且有很多限制。
  • Existing C++ Codebases
    • C++有很多成熟的项目,使用C++部署更方便。

2. DCGAN PyTorch C++ 示例

2.1. 使用基本流程

  • 从官网下载指定cuda版本、指定pytorch版本的libtorch包。
    • 里面包含头文件、动态链接库等内容。
  • PyTorch C++程序基本结构
    • 一般通过 cmake 编译,一般会指定 CMAKE_PREFIX_PATH 参数,指向libtorch文件夹。
    • 其他就是 C++ API 学习了。
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(dcgan)

find_package(Torch REQUIRED)

add_executable(dcgan dcgan.cpp)
target_link_libraries(dcgan "${TORCH_LIBRARIES}")
set_property(TARGET dcgan PROPERTY CXX_STANDARD 14)
#include <torch/torch.h>
#include <iostream>

int main() {
    
    
  torch::Tensor tensor = torch::eye(3);
  std::cout << tensor << std::endl;
}
// 输入如下:
// 1  0  0
// 0  1  0
// 0  0  1
// [ CPUFloatType{3,3} ]

2.2. 网络结构定义

  • 可以直接通过 PyTorch C++ API 定义模型网络结构
    • 使用与Python API类似。
struct DCGANGeneratorImpl : nn::Module {
    
    
  DCGANGeneratorImpl(int kNoiseSize)
      : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
                  .bias(false)),
        batch_norm1(256),
        conv2(nn::ConvTranspose2dOptions(256, 128, 3)
                  .stride(2)
                  .padding(1)
                  .bias(false)),
        batch_norm2(128),
        conv3(nn::ConvTranspose2dOptions(128, 64, 4)
                  .stride(2)
                  .padding(1)
                  .bias(false)),
        batch_norm3(64),
        conv4(nn::ConvTranspose2dOptions(64, 1, 4)
                  .stride(2)
                  .padding(1)
                  .bias(false))
 {
    
    
   // register_module() is needed if we want to use the parameters() method later on
   register_module("conv1", conv1);
   register_module("conv2", conv2);
   register_module("conv3", conv3);
   register_module("conv4", conv4);
   register_module("batch_norm1", batch_norm1);
   register_module("batch_norm2", batch_norm2);
   register_module("batch_norm3", batch_norm3);
 }

 torch::Tensor forward(torch::Tensor x) {
    
    
   x = torch::relu(batch_norm1(conv1(x)));
   x = torch::relu(batch_norm2(conv2(x)));
   x = torch::relu(batch_norm3(conv3(x)));
   x = torch::tanh(conv4(x));
   return x;
 }

 nn::ConvTranspose2d conv1, conv2, conv3, conv4;
 nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);

DCGANGenerator generator(kNoiseSize);
  • 数据导入相关代码
    • 其实也就是dataset/dataloader的C++ API
auto dataset = torch::data::datasets::MNIST("./mnist")
    .map(torch::data::transforms::Normalize<>(0.5, 0.5))
    .map(torch::data::transforms::Stack<>());
auto data_loader = torch::data::make_data_loader(std::move(dataset));
auto data_loader = torch::data::make_data_loader(
    std::move(dataset),
    torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
for (torch::data::Example<>& batch : *data_loader) {
    
    
  std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
  for (int64_t i = 0; i < batch.data.size(0); ++i) {
    
    
    std::cout << batch.target[i].item<int64_t>() << " ";
  }
  std::cout << std::endl;
}
  • 训练相关代码
    • optimizer等API
torch::optim::Adam generator_optimizer(
    generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
torch::optim::Adam discriminator_optimizer(
    discriminator->parameters(), torch::optim::AdamOptions(5e-4).beta1(0.5));
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
    
    
  int64_t batch_index = 0;
  for (torch::data::Example<>& batch : *data_loader) {
    
    
    // Train discriminator with real images.
    discriminator->zero_grad();
    torch::Tensor real_images = batch.data;
    torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
    torch::Tensor real_output = discriminator->forward(real_images);
    torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
    d_loss_real.backward();

    // Train discriminator with fake images.
    torch::Tensor noise = torch::randn({
    
    batch.data.size(0), kNoiseSize, 1, 1});
    torch::Tensor fake_images = generator->forward(noise);
    torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
    torch::Tensor fake_output = discriminator->forward(fake_images.detach());
    torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
    d_loss_fake.backward();

    torch::Tensor d_loss = d_loss_real + d_loss_fake;
    discriminator_optimizer.step();

    // Train generator.
    generator->zero_grad();
    fake_labels.fill_(1);
    fake_output = discriminator->forward(fake_images);
    torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
    g_loss.backward();
    generator_optimizer.step();

    std::printf(
        "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
        epoch,
        kNumberOfEpochs,
        ++batch_index,
        batches_per_epoch,
        d_loss.item<float>(),
        g_loss.item<float>());
  }
}
  • 其他还有一些模型保存、to GPU、模型推理等API,写起来没花头

猜你喜欢

转载自blog.csdn.net/irving512/article/details/114261561
今日推荐