PyTorch ToTensor解读

参考:https://www.cnblogs.com/ocean1100/p/9494640.html

PyTorch载入图片后ToTensor解读(含PIL和OpenCV读取图片对比)

 

概述

PyTorch在做一般的深度学习图像处理任务时,先使用dataset类和dataloader类读入图片,在读入的时候需要做transform变换,其中transform一般都需要ToTensor()操作,将dataset类中__getitem__()方法内读入的PIL或CV的图像数据转换为torch.FloatTensor。详细过程如下:

PIL与CV数据格式

  1. PIL(RGB)
    PIL(Python Imaging Library)是Python中最基础的图像处理库,一般操作如下:
<span style="color:#000000"><span style="color:#333333"><code><span style="color:#0000ff">from</span> PIL <span style="color:#0000ff">import</span> Image
<span style="color:#0000ff">import</span> numpy <span style="color:#0000ff">as</span> np
image = Image.open(<span style="color:#a31515">'test.jpg'</span>) <span style="color:#008000"># 图片是400x300 宽x高</span>
<span style="color:#0000ff">print</span> type(image) <span style="color:#008000"># out: PIL.JpegImagePlugin.JpegImageFile</span>
<span style="color:#0000ff">print</span> image.size  <span style="color:#008000"># out: (400,300)</span>
<span style="color:#0000ff">print</span> image.mode <span style="color:#008000"># out: 'RGB'</span>
<span style="color:#0000ff">print</span> image.getpixel((0,0)) <span style="color:#008000"># out: (143, 198, 201)</span>
<span style="color:#008000"># resize w*h</span>
image = image.resize((200,100),Image.NEAREST)
<span style="color:#0000ff">print</span> image.size <span style="color:#008000"># out: (200,100)</span>
<span style="color:#a31515">'''
代码解释
**注意image是 class:`~PIL.Image.Image` object**,它有很多属性,比如它的size是(w,h),通道是RGB,,他也有很多方法,比如获取getpixel((x,y))某个位置的像素,得到三个通道的值,x最大可取w-1,y最大可取h-1
比如resize方法,可以实现图片的放缩,具体参数如下
resize(self, size, resample=0) method of PIL.Image.Image instance
    Returns a resized copy of this image.

    :param size: The requested size in pixels, as a 2-tuple:
       (width, height). 
    注意size是 (w,h),和原本的(w,h)保持一致
    :param resample: An optional resampling filter.  This can be
       one of :py:attr:`PIL.Image.NEAREST`, :py:attr:`PIL.Image.BOX`,
       :py:attr:`PIL.Image.BILINEAR`, :py:attr:`PIL.Image.HAMMING`,
       :py:attr:`PIL.Image.BICUBIC` or :py:attr:`PIL.Image.LANCZOS`.
       If omitted, or if the image has mode "1" or "P", it is
       set :py:attr:`PIL.Image.NEAREST`.
       See: :ref:`concept-filters`.
    注意这几种插值方法,默认NEAREST最近邻(分割常用),分类常用BILINEAR双线性,BICUBIC立方
    :returns: An :py:class:`~PIL.Image.Image` object.

'''</span>
image = np.array(image,dtype=np.float32) <span style="color:#008000"># image = np.array(image)默认是uint8</span>
<span style="color:#0000ff">print</span> image.shape <span style="color:#008000"># out: (100, 200, 3)</span>
<span style="color:#008000"># 神奇的事情发生了,w和h换了,变成(h,w,c)了</span>
<span style="color:#008000"># 注意ndarray中是 行row x 列col x 维度dim 所以行数是高,列数是宽</span></code></span></span>
  1. OpenCV(python版)(BGR)
    OpenCV是一个很强大的图像处理库,适用面更广,可以在各种场合看到,性能也较好,相关代码也较多。常用操作如下:
<span style="color:#000000"><span style="color:#333333"><code><span style="color:#0000ff">import</span> cv2
<span style="color:#0000ff">import</span> numpy <span style="color:#0000ff">as</span> np
image = cv2.imread(<span style="color:#a31515">'test.jpg'</span>)
<span style="color:#0000ff">print</span> type(image) <span style="color:#008000"># out: numpy.ndarray</span>
<span style="color:#0000ff">print</span> image.dtype <span style="color:#008000"># out: dtype('uint8')</span>
<span style="color:#0000ff">print</span> image.shape <span style="color:#008000"># out: (300, 400, 3) (h,w,c) 和skimage类似</span>
<span style="color:#0000ff">print</span> image <span style="color:#008000"># BGR</span>
<span style="color:#a31515">'''
array([
        [ [143, 198, 201 (dim=3)],[143, 198, 201],... (w=200)],
        [ [143, 198, 201],[143, 198, 201],... ],
        ...(h=100)
      ], dtype=uint8)

'''</span>
<span style="color:#008000"># w*h</span>
image = cv2.resize(image,(100,200),interpolation=cv2.INTER_LINEAR)
<span style="color:#0000ff">print</span> image.dtype <span style="color:#008000"># out: dtype('uint8')</span>
<span style="color:#0000ff">print</span> image.shape <span style="color:#008000"># out: (200, 100, 3) </span>
<span style="color:#a31515">'''
注意注意注意 和skimage不同 
resize(src, dsize[, dst[, fx[, fy[, interpolation]]]]) 
关键字参数为dst,fx,fy,interpolation
dst为缩放后的图像
dsize为(w,h),但是image是(h,w,c)
fx,fy为图像x,y方向的缩放比例,
interplolation为缩放时的插值方式,有三种插值方式:
cv2.INTER_AREA:使用象素关系重采样。当图像缩小时候,该方法可以避免波纹出现。当图像放大时,类似于 CV_INTER_NN方法    
cv2.INTER_CUBIC: 立方插值
cv2.INTER_LINEAR: 双线形插值 
cv2.INTER_NN: 最近邻插值
[详细可查看该博客](http://www.tuicool.com/articles/rq6fIn)
'''</span>
<span style="color:#a31515">'''
cv2.imread(filename, flags=None):
flag:
cv2.IMREAD_COLOR 1: Loads a color image. Any transparency of image will be neglected. It is the default flag. 正常的3通道图
cv2.IMREAD_GRAYSCALE 0: Loads image in grayscale mode 单通道灰度图
cv2.IMREAD_UNCHANGED -1: Loads image as such including alpha channel 4通道图
注意: 默认应该是cv2.IMREAD_COLOR,如果你cv2.imread('gray.png'),虽然图片是灰度图,但是读入后会是3个通道值一样的3通道图片

'''</span></code></span></span>

另外,PIL图像在转换为numpy.ndarray后,格式为(h,w,c),像素顺序为RGB
OpenCV在cv2.imread()后数据类型为numpy.ndarray,格式为(h,w,c),像素顺序为BGR

torchvision.transforms.ToTensor()

torchvision.transforms.transforms.py:61

<span style="color:#000000"><span style="color:#333333"><code><span style="color:#0000ff">class</span> <span style="color:#a31515">ToTensor</span>(object):
    <span style="color:#a31515">"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
    """</span>

    <span style="color:#0000ff">def</span> <span style="color:#a31515">__call__</span>(self, pic):
        <span style="color:#a31515">"""
        Args:
            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

        Returns:
            Tensor: Converted image.
        """</span>
        <span style="color:#0000ff">return</span> F.to_tensor(pic)

    <span style="color:#0000ff">def</span> <span style="color:#a31515">__repr__</span>(self):
        <span style="color:#0000ff">return</span> self.__class__.__name__ + <span style="color:#a31515">'()'</span></code></span></span>

torchvision.transforms.functional.py:32

<span style="color:#000000"><span style="color:#333333"><code><span style="color:#0000ff">def</span> <span style="color:#a31515">to_tensor</span>(pic):
    <span style="color:#a31515">"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.

    See ``ToTensor`` for more details.

    Args:
        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.

    Returns:
        Tensor: Converted image.
    """</span>
    <span style="color:#0000ff">if</span> <span style="color:#0000ff">not</span>(_is_pil_image(pic) <span style="color:#0000ff">or</span> _is_numpy_image(pic)):
        <span style="color:#0000ff">raise</span> TypeError(<span style="color:#a31515">'pic should be PIL Image or ndarray. Got {}'</span>.format(type(pic)))

    <span style="color:#0000ff">if</span> isinstance(pic, np.ndarray):
        <span style="color:#008000"># handle numpy array</span>
        img = torch.from_numpy(pic.transpose((2, 0, 1)))
        <span style="color:#008000"># backward compatibility</span>
        <span style="color:#0000ff">if</span> isinstance(img, torch.ByteTensor):
            <span style="color:#0000ff">return</span> img.float().div(255)
        <span style="color:#0000ff">else</span>:
            <span style="color:#0000ff">return</span> img

    <span style="color:#0000ff">if</span> accimage <span style="color:#0000ff">is</span> <span style="color:#0000ff">not</span> <span style="color:#0000ff">None</span> <span style="color:#0000ff">and</span> isinstance(pic, accimage.Image):
        nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32)
        pic.copyto(nppic)
        <span style="color:#0000ff">return</span> torch.from_numpy(nppic)

    <span style="color:#008000"># handle PIL Image</span>
    <span style="color:#0000ff">if</span> pic.mode == <span style="color:#a31515">'I'</span>:
        img = torch.from_numpy(np.array(pic, np.int32, copy=<span style="color:#0000ff">False</span>))
    <span style="color:#0000ff">elif</span> pic.mode == <span style="color:#a31515">'I;16'</span>:
        img = torch.from_numpy(np.array(pic, np.int16, copy=<span style="color:#0000ff">False</span>))
    <span style="color:#0000ff">elif</span> pic.mode == <span style="color:#a31515">'F'</span>:
        img = torch.from_numpy(np.array(pic, np.float32, copy=<span style="color:#0000ff">False</span>))
    <span style="color:#0000ff">elif</span> pic.mode == <span style="color:#a31515">'1'</span>:
        img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=<span style="color:#0000ff">False</span>))
    <span style="color:#0000ff">else</span>:
        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
    <span style="color:#008000"># PIL image mode: L, P, I, F, RGB, YCbCr, RGBA, CMYK</span>
    <span style="color:#0000ff">if</span> pic.mode == <span style="color:#a31515">'YCbCr'</span>:
        nchannel = 3
    <span style="color:#0000ff">elif</span> pic.mode == <span style="color:#a31515">'I;16'</span>:
        nchannel = 1
    <span style="color:#0000ff">else</span>:
        nchannel = len(pic.mode)
    img = img.view(pic.size[1], pic.size[0], nchannel)
    <span style="color:#008000"># put it from HWC to CHW format</span>
    <span style="color:#008000"># yikes, this transpose takes 80% of the loading time/CPU</span>
    img = img.transpose(0, 1).transpose(0, 2).contiguous()
    <span style="color:#0000ff">if</span> isinstance(img, torch.ByteTensor):
        <span style="color:#0000ff">return</span> img.float().div(255)
    <span style="color:#0000ff">else</span>:
        <span style="color:#0000ff">return</span> img</code></span></span>

可以从to_tensor()函数看到,函数接受PIL Image或numpy.ndarray,将其先由HWC转置为CHW格式,再转为float后每个像素除以255.

发布了2672 篇原创文章 · 获赞 973 · 访问量 527万+

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/104385331