Résumé du didacticiel officiel PyTorch C ++ (1) Utilisation du frontend PyTorch C ++

0. Préface

1. Pourquoi utiliser C ++

  • En fait, c'est l'avantage du C ++ par rapport à Python.
    • Le but du front-end C ++ n'est pas de remplacer le front-end Python, mais de le compléter.
  • Systèmes à faible latence
    • Une faible latence, pour le dire franchement, est l'accélération de la vitesse d'inférence du modèle et d'autres vitesses de traitement des données.
  • Environnements hautement multithread
    • Python a des restrictions GIL et ne peut pas exécuter plusieurs threads en même temps.
    • Bien qu'il existe une alternative, c'est-à-dire multi-processus, l'évolutivité est générale et présente de nombreuses limitations.
  • Bases de code C ++ existantes
    • C ++ a de nombreux projets matures, et il est plus pratique d'utiliser C ++ pour déployer.

2. Exemple de DCGAN PyTorch C ++

2.1. Utiliser le processus de base

  • Téléchargez le package libtorch de la version cuda spécifiée et de la version pytorch spécifiée à partir du site officiel.
    • Il contient des fichiers d'en-tête, des bibliothèques de liens dynamiques, etc.
  • Structure de base du programme PyTorch C ++
    • En compilant cmake général, CMAKE_PREFIX_PATHparamètres généralement désignés , pointez le dossier libtorch.
    • L'autre est l'apprentissage de l'API C ++.
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(dcgan)

find_package(Torch REQUIRED)

add_executable(dcgan dcgan.cpp)
target_link_libraries(dcgan "${TORCH_LIBRARIES}")
set_property(TARGET dcgan PROPERTY CXX_STANDARD 14)
#include <torch/torch.h>
#include <iostream>

int main() {
    
    
  torch::Tensor tensor = torch::eye(3);
  std::cout << tensor << std::endl;
}
// 输入如下:
// 1  0  0
// 0  1  0
// 0  0  1
// [ CPUFloatType{3,3} ]

2.2. Définition de la structure du réseau

  • La structure du réseau de modèle peut être définie directement via l'API PyTorch C ++
    • L'utilisation est similaire à l'API Python.
struct DCGANGeneratorImpl : nn::Module {
    
    
  DCGANGeneratorImpl(int kNoiseSize)
      : conv1(nn::ConvTranspose2dOptions(kNoiseSize, 256, 4)
                  .bias(false)),
        batch_norm1(256),
        conv2(nn::ConvTranspose2dOptions(256, 128, 3)
                  .stride(2)
                  .padding(1)
                  .bias(false)),
        batch_norm2(128),
        conv3(nn::ConvTranspose2dOptions(128, 64, 4)
                  .stride(2)
                  .padding(1)
                  .bias(false)),
        batch_norm3(64),
        conv4(nn::ConvTranspose2dOptions(64, 1, 4)
                  .stride(2)
                  .padding(1)
                  .bias(false))
 {
    
    
   // register_module() is needed if we want to use the parameters() method later on
   register_module("conv1", conv1);
   register_module("conv2", conv2);
   register_module("conv3", conv3);
   register_module("conv4", conv4);
   register_module("batch_norm1", batch_norm1);
   register_module("batch_norm2", batch_norm2);
   register_module("batch_norm3", batch_norm3);
 }

 torch::Tensor forward(torch::Tensor x) {
    
    
   x = torch::relu(batch_norm1(conv1(x)));
   x = torch::relu(batch_norm2(conv2(x)));
   x = torch::relu(batch_norm3(conv3(x)));
   x = torch::tanh(conv4(x));
   return x;
 }

 nn::ConvTranspose2d conv1, conv2, conv3, conv4;
 nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;
};
TORCH_MODULE(DCGANGenerator);

DCGANGenerator generator(kNoiseSize);
  • Code lié à l'importation de données
    • En fait, c'est l'API C ++ de dataset / dataloader
auto dataset = torch::data::datasets::MNIST("./mnist")
    .map(torch::data::transforms::Normalize<>(0.5, 0.5))
    .map(torch::data::transforms::Stack<>());
auto data_loader = torch::data::make_data_loader(std::move(dataset));
auto data_loader = torch::data::make_data_loader(
    std::move(dataset),
    torch::data::DataLoaderOptions().batch_size(kBatchSize).workers(2));
for (torch::data::Example<>& batch : *data_loader) {
    
    
  std::cout << "Batch size: " << batch.data.size(0) << " | Labels: ";
  for (int64_t i = 0; i < batch.data.size(0); ++i) {
    
    
    std::cout << batch.target[i].item<int64_t>() << " ";
  }
  std::cout << std::endl;
}
  • Code lié à la formation
    • optimiseur 等 API
torch::optim::Adam generator_optimizer(
    generator->parameters(), torch::optim::AdamOptions(2e-4).beta1(0.5));
torch::optim::Adam discriminator_optimizer(
    discriminator->parameters(), torch::optim::AdamOptions(5e-4).beta1(0.5));
for (int64_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
    
    
  int64_t batch_index = 0;
  for (torch::data::Example<>& batch : *data_loader) {
    
    
    // Train discriminator with real images.
    discriminator->zero_grad();
    torch::Tensor real_images = batch.data;
    torch::Tensor real_labels = torch::empty(batch.data.size(0)).uniform_(0.8, 1.0);
    torch::Tensor real_output = discriminator->forward(real_images);
    torch::Tensor d_loss_real = torch::binary_cross_entropy(real_output, real_labels);
    d_loss_real.backward();

    // Train discriminator with fake images.
    torch::Tensor noise = torch::randn({
    
    batch.data.size(0), kNoiseSize, 1, 1});
    torch::Tensor fake_images = generator->forward(noise);
    torch::Tensor fake_labels = torch::zeros(batch.data.size(0));
    torch::Tensor fake_output = discriminator->forward(fake_images.detach());
    torch::Tensor d_loss_fake = torch::binary_cross_entropy(fake_output, fake_labels);
    d_loss_fake.backward();

    torch::Tensor d_loss = d_loss_real + d_loss_fake;
    discriminator_optimizer.step();

    // Train generator.
    generator->zero_grad();
    fake_labels.fill_(1);
    fake_output = discriminator->forward(fake_images);
    torch::Tensor g_loss = torch::binary_cross_entropy(fake_output, fake_labels);
    g_loss.backward();
    generator_optimizer.step();

    std::printf(
        "\r[%2ld/%2ld][%3ld/%3ld] D_loss: %.4f | G_loss: %.4f",
        epoch,
        kNumberOfEpochs,
        ++batch_index,
        batches_per_epoch,
        d_loss.item<float>(),
        g_loss.item<float>());
  }
}
  • Il existe également des API telles que la sauvegarde de modèle, le GPU, l'inférence de modèle, etc., qui sont faciles à écrire.

Je suppose que tu aimes

Origine blog.csdn.net/irving512/article/details/114261561
conseillé
Classement