Some summary of Pytorch

Previously, keras was used to build a deep learning network, and the results were not bad, but some parameters could not be changed, and some parameters could not be seen in real time. Under the guidance of my brother, I decided to use pytorch to build a deep learning model. At the beginning, it was not very smooth, mainly because it was difficult to independently explore new fields.

Two aspects:

  • Combined use of pytorch and tensorboard in windows environment
  • Modify the pretrained model

Combined use of pytorch and tensorboard in windows environment

Reference link

Process:

  • Run the main function, then generate the logs file
  • Switch to the board directory and run:tensorboard --logdir=./logs --port=6006
  • Type in the browser address barhttp://localhost:6006/

Modify the pretrained model

For example, Faster-RCNN extracts features based on vgg19, but only uses a part of the model to extract features, so you need to know how to modify the pre-training model, refer to the link

step:

  • Downloading the pth file of vgg19, setting the download directly in anacondapretrained=True is generally slow. I use the browser or Thunder to download it directly model_zoo. There are download links for various pre-training models in it.
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
  • After downloading the model, you can use the following code to try it out and change the model. Create a directory at the same level as vgg19.pthtest.py
import torch
import torch.nn as nn
import torchvision.models as models

vgg16 = models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
print('vgg16:\n', vgg16)

modified_features = nn.Sequential(*list(vgg16.features.children())[:-1])
# to relu5_3
print('modified_features:\n', modified_features )
  • After modification, the features can be used for Faster-RCNN to extract features.

Note that under Linux, it runs twice as fast as under Windows.

Reprinted from: pytorch tips

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325191828&siteId=291194637