图像的读取与加载-PyTorch

图像的读取与加载-PyTorch

学习网站

https://www.jianshu.com/p/cfca9c4338e7

1. 从文件中读取图像数据

import matplotlib.pyplot as plt
import skimage.io as io
import cv2
from PIL import Image
import numpy as np
import torch

# 128035.jpg width = 481(Column), height = 321 (Row), channel = 3

# 1. 使用skimage读取图像
# skimage.io imread()-----np.array,  (H x W x C), [0, 255],RGB
img_skimage = io.imread('128035.jpg')

# 2. 使用opencv读取图像
# cv2.imread()------np.ndarray, (H x W xC), [0, 255], BGR
img_opencv = cv2.imread('128035.jpg')
# BGR to RGB
img_opencv_to_RGB = cv2.cvtColor(img_opencv, cv2.COLOR_BGR2RGB)
# 3. 使用PIL读取图像
# img_PIL : IpegImageFile Image.Image 对象
# img_PIL_array : np.ndarray  (H x W xC) RGB
img_PIL = Image.open('128035.jpg')
img_PIL_array = np.array(img_PIL)

# 4. 分别显示3种方式读取的图像
img_set = [img_skimage, img_opencv, img_opencv_to_RGB, img_PIL_array]
img_name = ['img_skimage', 'img_opencv', 'img_opencv_to_RGB', 'img_PIL_array']
plt.figure()

for i, img in enumerate(img_set):
    ax = plt.subplot(1, 4, i + 1)
    ax.imshow(img)
    # plt.pause(2)
plt.show()

2. torch.tensor 对象与numpy.ndarray之间的转化

# 5. 图像转换为torch对象
# 在深度学习中,原始图像需要转换为深度学习框架自定义的数据格式,在pytorch中,需要转为torch.Tensor。
# pytorch提供了torch.Tensor 与numpy.ndarray转换为接口:
# torch.from_numpy  (nSample)x C x H x W
# tensor.numpy()     H x W x C
# 所以转换的时候需要使用 numpy.transpose()

# 5.1 numpy.ndarray to torch.tnsor
tensor_img_skimage = torch.from_numpy(np.transpose(img_skimage, (2, 0, 1)))
tensor_img_opencv = torch.from_numpy(np.transpose(img_opencv, (2, 0, 1)))
tensor_img_PIL_array = torch.from_numpy(np.transpose(img_PIL_array, (2, 0, 1)))

# 5.2 torch.tensor to numpy.ndarray
tensor_img_skimage_to_array = np.transpose(tensor_img_skimage.numpy(), (1, 2, 0))
tensor_img_opencv_to_array = np.transpose(tensor_img_opencv.numpy(), (1, 2, 0))
tensor_img_PIL_array_to_array = np.transpose(tensor_img_PIL_array.numpy(), (1, 2, 0))

img_set = [tensor_img_skimage_to_array, tensor_img_opencv_to_array, tensor_img_PIL_array_to_array]
plt.figure()
for i, img in enumerate(img_set):
    ax = plt.subplot(1, 3, i + 1)
    ax.imshow(img)
plt.show()
发布了38 篇原创文章 · 获赞 29 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/ruotianxia/article/details/104474493