PyTorch学习笔记(3)

1.讲解一段代码:

# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion()   # interactive mode

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

其中: Resize:把图片大小调整到指定大小;Normalize:归一化具有均值和标准差的张量图像。

        RandomHorizontalFlip:以0.5的概率水平翻转给定的PIL图像;RandomVerticalFlip:以0.5的概率竖直翻转给定的PIL图像

        RandomResizedCrop:将PIL图像裁剪成任意大小和纵横比;CenterCrop:在图片的中间区域进行裁剪

        ToTensor:转化PIL 张量图像 ,其中(H*W*C) 的值的范围为[0,255] 到 torch.Tensor(C*H*W),其中值的范围为 [0.0,1.0]

        附加: RandomGrayscale:将图像以一定的概率转换为灰度图像;FiceCrop:把图像裁剪为四个角和一个中心

        Pad:填充;ColorJitter:随机改变图像的亮度对比度和饱和度;Grayscale:将图像转换为灰度图像;RandomCrop:在一个随机的位置进行裁剪

2.no_grad函数

        作用是在上下文环境中切断梯度计算

猜你喜欢

转载自blog.csdn.net/qisheng_com/article/details/82151306