Resumen del tutorial oficial de PyTorch C ++ (1) Uso de la interfaz de PyTorch C ++

0. Prefacio

1. Por qué usar C ++

  • De hecho, es la ventaja de C ++ en comparación con Python.
    • El objetivo del front-end de C ++ no es reemplazar el front-end de Python, sino complementarlo.
  • Sistemas de baja latencia
    • La baja latencia, para decirlo sin rodeos, es la aceleración de la velocidad de inferencia del modelo y otras velocidades de procesamiento de datos.
  • Entornos altamente multiproceso
    • Python tiene restricciones GIL y no puede ejecutar varios subprocesos al mismo tiempo.
    • Aunque existe una alternativa, es decir, multiproceso, la escalabilidad es general y tiene muchas limitaciones.
  • Bases de código C ++ existentes
    • C ++ tiene muchos proyectos maduros y es más conveniente usar C ++ para implementar.

2. Ejemplo de DCGAN PyTorch C ++

2.1. Utilice el proceso básico

  • Descargue el paquete libtorch de la versión cuda especificada y la versión pytorch especificada del sitio web oficial.
    • Contiene archivos de encabezado, bibliotecas de enlaces dinámicos, etc.
  • Estructura básica del programa PyTorch C ++
    • Al compilar cmake general, CMAKE_PREFIX_PATHparámetros generalmente designados , apunte la carpeta libtorch.
    • El otro es el aprendizaje de la API de C ++.
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(dcgan)

find_package(Torch REQUIRED)

add_executable(dcgan dcgan.cpp)
target_link_libraries(dcgan "${TORCH_LIBRARIES}")
set_property(TARGET dcgan PROPERTY CXX_STANDARD 14)
#include <torch/torch.h>
#include <iostream>

int main() {
    
    
  torch::Tensor tensor = torch::eye(3);
  std::cout << tensor << std::endl;
}
// 输入如下:
// 1  0  0
// 0  1  0
// 0  0  1
// [ CPUFloatType{3,3} ]

2.2. Definición de la estructura de la red

  • La estructura de red del modelo se puede definir directamente a través de la API de PyTorch C ++
    • El uso es similar al de la API de Python.
struct DCGANGeneratorImpl : nn::Module {
    
    
  DCGANGeneratorImpl(int kNoiseSize)
      : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
                  .bias(false)),
        batch_norm1(256),
        conv2(nn::ConvTranspose2dOptions(256, 128, 3)
                  .stride(2)
                  .padding(1)
                  .bias(false)),
        batch_norm2(128),
        conv3(nn::ConvTranspose2dOptions(128, 64, 4)
                  .stride(2)
                  .padding(1)
                  .bias(false)),
        batch_norm3(64),
        conv4(nn::ConvTranspose2dOptions(64, 1, 4)
                  .stride(2)
                  .padding(1)
                  .bias(false))
 {
    
    
   // register_module() is needed if we want to use the parameters() method later on
   register_module("conv1", conv1);
   register_module("conv2", conv2);
   register_module("conv3", conv3);
   register_module("conv4", conv4);
   register_module("batch_norm1", batch_norm1);
   register_module("batch_norm2", batch_norm2);
   register_module("batch_norm3", batch_norm3);
 }

 torch::Tensor forward(torch::Tensor x) {
    
    
   x = torch::relu(batch_norm1(conv1(x)));
   x = torch::relu(batch_norm2(conv2(x)));
   x = torch::relu(batch_norm3(conv3(x)));
   x = torch::tanh(conv4(x));
   return x;
 }

 nn::ConvTranspose2d conv1, conv2, conv3, conv4;
 nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);

DCGANGenerator generator(kNoiseSize);
  • Código relacionado con la importación de datos
    • De hecho, es la API de C ++ del conjunto de datos / cargador de datos
auto dataset = torch::data::datasets::MNIST("./mnist")
    .map(torch::data::transforms::Normalize<>(0.5, 0.5))
    .map(torch::data::transforms::Stack<>());
auto data_loader = torch::data::make_data_loader(std::move(dataset));
auto data_loader = torch::data::make_data_loader(
    std::move(dataset),
    torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
for (torch::data::Example<>& batch : *data_loader) {
    
    
  std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
  for (int64_t i = 0; i < batch.data.size(0); ++i) {
    
    
    std::cout << batch.target[i].item<int64_t>() << " ";
  }
  std::cout << std::endl;
}
  • Código relacionado con la formación
    • optimizador 等 API
torch::optim::Adam generator_optimizer(
    generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
torch::optim::Adam discriminator_optimizer(
    discriminator->parameters(), torch::optim::AdamOptions(5e-4).beta1(0.5));
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
    
    
  int64_t batch_index = 0;
  for (torch::data::Example<>& batch : *data_loader) {
    
    
    // Train discriminator with real images.
    discriminator->zero_grad();
    torch::Tensor real_images = batch.data;
    torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
    torch::Tensor real_output = discriminator->forward(real_images);
    torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
    d_loss_real.backward();

    // Train discriminator with fake images.
    torch::Tensor noise = torch::randn({
    
    batch.data.size(0), kNoiseSize, 1, 1});
    torch::Tensor fake_images = generator->forward(noise);
    torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
    torch::Tensor fake_output = discriminator->forward(fake_images.detach());
    torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
    d_loss_fake.backward();

    torch::Tensor d_loss = d_loss_real + d_loss_fake;
    discriminator_optimizer.step();

    // Train generator.
    generator->zero_grad();
    fake_labels.fill_(1);
    fake_output = discriminator->forward(fake_images);
    torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
    g_loss.backward();
    generator_optimizer.step();

    std::printf(
        "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
        epoch,
        kNumberOfEpochs,
        ++batch_index,
        batches_per_epoch,
        d_loss.item<float>(),
        g_loss.item<float>());
  }
}
  • También hay algunas API, como el almacenamiento de modelos, la GPU, la inferencia de modelos, etc., que son fáciles de escribir.

Supongo que te gusta

Origin blog.csdn.net/irving512/article/details/114261561
Recomendado
Clasificación