图像处理(1):PyTorch垃圾分类 数据预处理

基于深度学习框架PyTorch transforms 方法进行数据的预处理

作者:沈福利 北京工业大学硕士学位,高级算法专家。产品和技术负责人,专注于NLP、图像、推荐系统

整个过程主要包括:缩放、裁剪、归一化、标准化几个基本步骤。

图像归一化是计算机视觉、模式识别等领域广泛使用的一种技术。所谓图像归一化, 就是通过一系列变换, 将待处理的原始图像转换成相应的唯一标准形式(该标准形式图像对平移、旋转、缩放等仿射变换具有不变特性)

基于矩的图像归一化过程包括 4 个步骤 即坐标中心化、x-shearing 归一化、缩放归一化和旋转归一化。

Pytorch:transforms方法

  1. 裁剪——Crop

中心裁剪:transforms.CenterCrop
随机裁剪:transforms.RandomCrop
随机长宽比裁剪:transforms.RandomResizedCrop
上下左右中心裁剪:transforms.FiveCrop
上下左右中心裁剪后翻转,transforms.TenCrop

  1. 翻转和旋转——Flip and Rotation

依概率p水平翻转:transforms.RandomHorizontalFlip(p=0.5)
依概率p垂直翻转:transforms.RandomVerticalFlip(p=0.5)
随机旋转:transforms.RandomRotation

  1. 图像变换

resize:transforms.Resize
标准化:transforms.Normalize
转为tensor,并归一化至[0-1]:transforms.ToTensor
填充:transforms.Pad
修改亮度、对比度和饱和度:transforms.ColorJitter
转灰度图:transforms.Grayscale
线性变换:transforms.LinearTransformation()
仿射变换:transforms.RandomAffine
依概率p转为灰度图:transforms.RandomGrayscale
将数据转换为PILImage:transforms.ToPILImage
transforms.Lambda:Apply a user-defined lambda as a transform.

  1. 对transforms操作,使数据增强更灵活

transforms.RandomChoice(transforms), 从给定的一系列transforms中选一个进行操作
transforms.RandomApply(transforms, p=0.5),给一个transform加上概率,依概率进行操作
transforms.RandomOrder,将transforms中的操作随机打乱

导入库

# 导入torch 包
import torch
import torch.nn as nn
from torchvision import transforms

加载图片数据

file_name ='../images/dog.jpg'
from PIL import Image
input_image = Image.open(file_name)
print(input_image)
print(input_image.size) # 尺寸大小:长=1546,宽1213

# 数据处理后,我们看看处理后图片
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示符号
plt.imshow(input_image)

在这里插入图片描述

图片数据预处理

preprocess = transforms.Compose([
    # 1. 图像变换:重置图像分辨率,图片缩放256 * 256 
    transforms.Resize(256),
    # 2. 裁剪: 中心裁剪 ,依据给定的size从中心裁剪
    transforms.CenterCrop(224),
    # 3. 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1].注意事项:归一化至[0-1]是直接除以255
    transforms.ToTensor(),
    # 4. 对数据按通道进行标准化,即先减均值,再除以标准差
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),#图片归一化
])

# 原始数据预处理
input_tensor = preprocess(input_image)
print('input_tensor.shape = ',input_tensor.shape)
print('input_tensor = ',input_tensor)
input_tensor.shape =  torch.Size([3, 224, 224])
input_tensor =  tensor([[[-1.9295, -1.9295, -1.9124,  ..., -2.0323, -1.9467, -1.9295],
         [-1.9980, -1.8953, -1.9124,  ..., -1.9638, -1.9295, -1.7754],
         [-1.9980, -1.9467, -1.9124,  ..., -2.0494, -1.9638, -1.8953],
         ...,
         [-1.4843, -1.6042, -1.6213,  ..., -0.8678, -1.1075, -1.0733],
         [-1.5357, -1.6042, -1.6213,  ..., -1.0390, -1.6213, -1.4500],
         [-1.5528, -1.4843, -1.2445,  ..., -0.9192, -1.2788, -1.2617]],

        [[-1.8256, -1.8256, -1.8081,  ..., -1.9832, -1.9132, -1.9132],
         [-1.8256, -1.8431, -1.8431,  ..., -1.9657, -1.9307, -1.8782],
         [-1.8256, -1.8431, -1.8606,  ..., -1.9657, -1.9482, -1.9132],
         ...,
         [-0.9853, -0.9678, -0.9853,  ..., -0.4601, -0.6352, -0.6702],
         [-0.9853, -0.9853, -1.0028,  ..., -0.5651, -1.1253, -0.8978],
         [-0.9503, -0.9853, -0.8102,  ..., -0.3901, -0.8102, -0.7402]],

        [[-1.6127, -1.5953, -1.5604,  ..., -1.7173, -1.6999, -1.7173],
         [-1.6650, -1.6302, -1.6302,  ..., -1.7173, -1.7173, -1.6476],
         [-1.6476, -1.6650, -1.6476,  ..., -1.7347, -1.6999, -1.6824],
         ...,
         [-1.3164, -1.5081, -1.5953,  ..., -0.8981, -1.1596, -1.0898],
         [-1.2990, -1.5256, -1.6127,  ..., -0.9504, -1.4907, -1.3861],
         [-1.2816, -1.4210, -1.2293,  ..., -0.7413, -1.1247, -1.2816]]])

数据预处理后的效果

import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示符号

input_tensor = input_tensor.permute(1,2,0) #Changing from 3x224x224 to 224x224x3
print('input_tensor.matplotlib.shape = ',input_tensor.shape)
#clamp表示夹紧,夹住的意思,torch.clamp(input,min,max,out=None)-> Tensor
input_tensor = torch.clamp(input_tensor,0,1)
print('input_tensor.matplotlib.clamp.shape = ',input_tensor.shape)
plt.imshow(input_tensor)

在这里插入图片描述
更多精彩内容请参考: https://github.com/shenfuli/ai

发布了267 篇原创文章 · 获赞 66 · 访问量 43万+

猜你喜欢

转载自blog.csdn.net/shenfuli/article/details/102822918