La estructura de red del modelo se puede definir directamente a través de la API de PyTorch C ++
El uso es similar al de la API de Python.
struct DCGANGeneratorImpl : nn::Module {
DCGANGeneratorImpl(int kNoiseSize):conv1(nn::ConvTranspose2dOptions(kNoiseSize,256,4).bias(false)),batch_norm1(256),conv2(nn::ConvTranspose2dOptions(256,128,3).stride(2).padding(1).bias(false)),batch_norm2(128),conv3(nn::ConvTranspose2dOptions(128,64,4).stride(2).padding(1).bias(false)),batch_norm3(64),conv4(nn::ConvTranspose2dOptions(64,1,4).stride(2).padding(1).bias(false)){
// register_module() is needed if we want to use the parameters() method later onregister_module("conv1", conv1);register_module("conv2", conv2);register_module("conv3", conv3);register_module("conv4", conv4);register_module("batch_norm1", batch_norm1);register_module("batch_norm2", batch_norm2);register_module("batch_norm3", batch_norm3);}
torch::Tensor forward(torch::Tensor x){
x = torch::relu(batch_norm1(conv1(x)));
x = torch::relu(batch_norm2(conv2(x)));
x = torch::relu(batch_norm3(conv3(x)));
x = torch::tanh(conv4(x));return x;}
nn::ConvTranspose2d conv1, conv2, conv3, conv4;
nn::BatchNorm2d batch_norm1, batch_norm2, batch_norm3;};TORCH_MODULE(DCGANGenerator);
DCGANGenerator generator(kNoiseSize);
Código relacionado con la importación de datos
De hecho, es la API de C ++ del conjunto de datos / cargador de datos