inferno Pytorch: inferno.io.transform 介绍及使用

inferno简介

Inferno是一个库,提供了围绕PyTorch的实用程序和方便的函数/类,为深度学习和实现神经网络提供便利。关于inferno的其他模块介绍:
inferno Pytorch: inferno.extensions.layers.convolutional 介绍及使用
inferno Pytorch: inferno.io.box.cifar下载cifar10 cifar100数据集 介绍及使用
inferno Pytorch: inferno.io.transform 介绍及使用

inferno安装

pip install inferno-pytorch

inferno.io.transform

源码介绍

inferno.io.transform下面有多个模块实现对数据的transformation。

inferno.io.transform.base

Transform

是下面介绍的其他变换的基类

class Transform(object):
    """
    Base class for a Transform. The argument `apply_to` (list) specifies the indices of
    the tensors this transform will be applied to.

    The following methods are recognized (in order of descending priority):
        - `batch_function`: Applies to all tensors in a batch simultaneously
        - `tensor_function`: Applies to just __one__ tensor at a time.
        - `volume_function`: For 3D volumes, applies to just __one__ volume at a time.
        - `image_function`: For 2D or 3D volumes, applies to just __one__ image at a time.

    For example, if both `volume_function` and `image_function` are defined, this means that
    only the former will be called. If the inputs are therefore not 5D batch-tensors of 3D
    volumes, a `NotImplementedError` is raised.
    """
    def __init__(self, apply_to=None):
        """
        Parameters
        ----------
        apply_to : list or tuple
            Indices of tensors to apply this transform to. The indices are with respect
            to the list of arguments this object is called with.
        """
        self._random_variables = {
    
    }
        self._apply_to = list(apply_to) if apply_to is not None else None

    def build_random_variables(self, **kwargs):
        pass

    def clear_random_variables(self):
        self._random_variables = {
    
    }

    def get_random_variable(self, key, default=None, build=True,
                            **random_variable_building_kwargs):
        if key in self._random_variables:
            return self._random_variables.get(key, default)
        else:
            if not build:
                return default
            else:
                self.build_random_variables(**random_variable_building_kwargs)
                return self.get_random_variable(key, default, build=False)

    def set_random_variable(self, key, value):
        self._random_variables.update({
    
    key: value})

    def __call__(self, *tensors, **transform_function_kwargs):
        tensors = pyu.to_iterable(tensors)
        # Get the list of the indices of the tensors to which we're going to apply the transform
        apply_to = list(range(len(tensors))) if self._apply_to is None else self._apply_to
        # Flush random variables and assume they're built by image_function
        self.clear_random_variables()
        if hasattr(self, 'batch_function'):
            transformed = self.batch_function(tensors, **transform_function_kwargs)
            return pyu.from_iterable(transformed)
        elif hasattr(self, 'tensor_function'):
            transformed = [self.tensor_function(tensor, **transform_function_kwargs)
                           if tensor_index in apply_to else tensor
                           for tensor_index, tensor in enumerate(tensors)]
            return pyu.from_iterable(transformed)
        elif hasattr(self, 'volume_function'):
            # Loop over all tensors
            transformed = [self._apply_volume_function(tensor, **transform_function_kwargs)
                           if tensor_index in apply_to else tensor
                           for tensor_index, tensor in enumerate(tensors)]
            return pyu.from_iterable(transformed)
        elif hasattr(self, 'image_function'):
            # Loop over all tensors
            transformed = [self._apply_image_function(tensor, **transform_function_kwargs)
                           if tensor_index in apply_to else tensor
                           for tensor_index, tensor in enumerate(tensors)]
            return pyu.from_iterable(transformed)
        else:
            raise NotImplementedError

    # noinspection PyUnresolvedReferences
    def _apply_image_function(self, tensor, **transform_function_kwargs):
        assert pyu.has_callable_attr(self, 'image_function')
        # 2D case
        if tensor.ndim == 4:
            return np.array([np.array([self.image_function(image, **transform_function_kwargs)
                                       for image in channel_image])
                             for channel_image in tensor])
        # 3D case
        elif tensor.ndim == 5:
            return np.array([np.array([np.array([self.image_function(image,
                                                                     **transform_function_kwargs)
                                                 for image in volume])
                                       for volume in channel_volume])
                             for channel_volume in tensor])
        elif tensor.ndim == 3:
            # Assume we have a 3D volume (signature zyx) and apply the image function
            # on all yx slices.
            return np.array([self.image_function(image, **transform_function_kwargs)
                             for image in tensor])
        elif tensor.ndim == 2:
            # Assume we really do have an image.
            return self.image_function(tensor, **transform_function_kwargs)
        else:
            msg = f"{type(tensor)} "
            try:
                msg += str(tensor.ndim)
            except Exception:
                pass

            raise NotImplementedError(msg)

    # noinspection PyUnresolvedReferences
    def _apply_volume_function(self, tensor, **transform_function_kwargs):
        assert pyu.has_callable_attr(self, 'volume_function')
        # 3D case
        if tensor.ndim == 5:
            # tensor is bczyx
            # volume function is applied to zyx, i.e. loop over b and c
            # FIXME This loops one time too many
            return np.array([np.array([np.array([self.volume_function(volume,
                                                                      **transform_function_kwargs)
                                                 for volume in channel_volume])
                                       for channel_volume in batch])
                             for batch in tensor])
        elif tensor.ndim == 4:
            # We're applying the volume function on a czyx tensor, i.e. we loop over c and apply
            # volume function to (zyx)
            return np.array([self.volume_function(volume, **transform_function_kwargs)
                             for volume in tensor])
        elif tensor.ndim == 3:
            # We're applying the volume function on the volume itself
            return self.volume_function(tensor, **transform_function_kwargs)
        else:
            raise NotImplementedError
Compose

将多个变换组合起来,类似torchvision的Compose

class Compose(object):
    """Composes multiple callables (including but not limited to `Transform` objects)."""
    def __init__(self, *transforms: Callable):
        """
        Parameters
        ----------
        transforms : list of callable or tuple of callable
            Transforms to compose.
        """
        assert all([callable(transform) for transform in transforms])
        self.transforms = list(transforms)

    def add(self, transform):
        assert callable(transform)
        self.transforms.append(transform)
        return self

    def remove(self, name):
        transform_idx = None
        for idx, transform in enumerate(self.transforms):
            if type(transform).__name__ == name:
                transform_idx = idx
                break
        if transform_idx is not None:
            self.transforms.pop(transform_idx)
        return self

    def __call__(self, *tensors):
        intermediate = tensors
        for transform in self.transforms:
            intermediate = pyu.to_iterable(transform(*intermediate))
        return pyu.from_iterable(intermediate)

inferno.io.transform.image 用于图像处理

PILImage2NumPyArray

将一个PIL Image对象转换为一个numpy数组

class PILImage2NumPyArray(Transform):
    """Convert a PIL Image object to a numpy array.

    For images with multiple channels (say RGB), the channel axis is moved to front. Therefore,
    a (100, 100, 3) RGB image becomes an array of shape (3, 100, 100).
    """
    def tensor_function(self, tensor):
        tensor = np.asarray(tensor)
        if tensor.ndim == 3:
            # There's a channel axis - we move it to front
            tensor = np.moveaxis(tensor, source=-1, destination=0)
        elif tensor.ndim == 2:
            pass
        else:
            raise NotImplementedError("Expected tensor to be a 2D or 3D "
                                      "numpy array, got a {}D array instead."
                                      .format(tensor.ndim))
        return tensor

Scale

实现图像放缩变换,用要求顺序的样条插值将图像缩放到给定的大小

class Scale(Transform):
    """Scales an image to a given size with spline interpolation of requested order.

    Unlike torchvision.transforms.Scale, this does not depend on PIL and therefore works
    with numpy arrays. If you do have a PIL image and wish to use this transform, consider
    applying `PILImage2NumPyArray` first.

    Warnings
    --------
    This transform uses `scipy.ndimage.zoom` and requires scipy >= 0.13.0 to work correctly.
    """
    def __init__(self, output_image_shape, interpolation_order=3, zoom_kwargs=None, **super_kwargs):
        """
        Parameters
        ----------
        output_image_shape : list or tuple or int or None
            Target size of the output image. Aspect ratio may not be preserved.
            If output_image_shape is None, image input size will be preserved
        interpolation_order : int
            Interpolation order for the spline interpolation.
        zoom_kwargs : dict
            Keyword arguments for `scipy.ndimage.zoom`.
        super_kwargs : dict
            Keyword arguments for the superclass.
        """
        super(Scale, self).__init__(**super_kwargs)
        if output_image_shape is not None:
            output_image_shape = (output_image_shape, output_image_shape) \
                if isinstance(output_image_shape, int) else tuple(output_image_shape)
            assert_(len(output_image_shape) == 2,
                    "`output_image_shape` must be an integer or a tuple of length 2.",
                    ValueError)
        self.output_image_shape = output_image_shape
        self.interpolation_order = interpolation_order
        self.zoom_kwargs = {
    
    } if zoom_kwargs is None else dict(zoom_kwargs)

    def image_function(self, image):
        source_height, source_width = image.shape
        target_height, target_width = self.output_image_shape
        # We're on Python 3 - take a deep breath and relax.
        zoom_height, zoom_width = (target_height / source_height), (target_width / source_width)
        with catch_warnings():
            # Ignore warning that scipy should be > 0.13 (it's 0.19 these days)
            simplefilter('ignore')
            rescaled_image = zoom(image, (zoom_height, zoom_width),
                                  order=self.interpolation_order, **self.zoom_kwargs)
        # This should never happen
        assert_(rescaled_image.shape == (target_height, target_width),
                "Shape mismatch that shouldn't have happened if you were on scipy > 0.13.0. "
                "Are you on scipy > 0.13.0?",
                ShapeError)
        return rescaled_image
RandomCrop

类似于torchvision.transforms对图像进行随机裁剪,不过这里处理的是PIL图像。

class RandomCrop(Transform):
    """Crop input to a given size.

    This is similar to torchvision.transforms.RandomCrop, except that it operates on
    numpy arrays instead of PIL images. If you do have a PIL image and wish to use this
    transform, consider applying `PILImage2NumPyArray` first.

    Warnings
    --------
    If `output_image_shape` is larger than the image itself, the image is not cropped
    (along the relevant dimensions).
    """
    def __init__(self, output_image_shape, **super_kwargs):
        """
        Parameters
        ----------
        output_image_shape : tuple or list or int
            Expected shape of the output image. Could be an integer, (say) 100, in
            which case it's interpreted as `(100, 100)`. Note that if the image shape
            along some (or all) dimension is smaller, say `(50, 200)`, the resulting
            output images will have the shape `(50, 100)`.
        super_kwargs : dict
            Keywords to the super class.
        """
        super(RandomCrop, self).__init__(**super_kwargs)
        # Privates
        self._image_shape_cache = None
        # Publics
        output_image_shape = (output_image_shape, output_image_shape) \
            if isinstance(output_image_shape, int) else tuple(output_image_shape)
        assert_(len(output_image_shape) == 2,
                "`output_image_shape` must be an integer or a tuple of length 2.",
                ValueError)
        self.output_image_shape = output_image_shape

    def clear_random_variables(self):
        self._image_shape_cache = None
        super(RandomCrop, self).clear_random_variables()

    def build_random_variables(self, height_leeway, width_leeway):
        if height_leeway > 0:
            self.set_random_variable('height_location',
                                     np.random.randint(low=0, high=height_leeway + 1))
        if width_leeway > 0:
            self.set_random_variable('width_location',
                                     np.random.randint(low=0, high=width_leeway + 1))

    def image_function(self, image):
        # Validate image shape
        if self._image_shape_cache is not None:
            assert_(self._image_shape_cache == image.shape,
                    "RandomCrop works on multiple images simultaneously only "
                    "if they have the same shape. Was expecting an image of "
                    "shape {}, got one of shape {} instead."
                    .format(self._image_shape_cache, image.shape),
                    ShapeError)
        else:
            self._image_shape_cache = image.shape
        source_height, source_width = image.shape
        crop_height, crop_width = self.output_image_shape
        height_leeway = source_height - crop_height
        width_leeway = source_width - crop_width
        if height_leeway > 0:
            # Crop height
            height_location = self.get_random_variable('height_location',
                                                       height_leeway=height_leeway,
                                                       width_leeway=width_leeway)
            cropped = image[height_location:(height_location + crop_height), :]
            assert cropped.shape[0] == self.output_image_shape[0], "Well, shit."
        else:
            cropped = image
        if width_leeway > 0:
            # Crop width
            width_location = self.get_random_variable('width_location',
                                                      height_leeway=height_leeway,
                                                      width_leeway=width_leeway)
            cropped = cropped[:, width_location:(width_location + crop_width)]
            assert cropped.shape[1] == self.output_image_shape[1], "Well, shit."
        return cropped
RandomSizedCrop

从图像中裁剪一个随机大小的图像。

class RandomSizedCrop(Transform):
    """Extract a randomly sized crop from the image.

    The ratio of the sizes of the cropped and the original image can be limited within
    specified bounds along both axes. To resize back to a constant sized image, compose
    with `Scale`.
    """
    def __init__(self, ratio_between=None, height_ratio_between=None, width_ratio_between=None,
                 preserve_aspect_ratio=False, relative_target_aspect_ratio=None, **super_kwargs):
        """
        Parameters
        ----------
        ratio_between : tuple
            Specify the bounds between which to sample the crop ratio. This applies to
            both height and width if not overriden. Can be None if both height and width
            ratios are specified individually.
        height_ratio_between : tuple
            Specify the bounds between which to sample the vertical crop ratio.
            Can be None if `ratio_between` is not None.
        width_ratio_between : tuple
            Specify the bounds between which to sample the horizontal crop ratio.
            Can be None if `ratio_between` is not None.
        preserve_aspect_ratio : bool
            Whether to preserve aspect ratio. If both `height_ratio_between`
            and `width_ratio_between` are specified, the former is used if this
            is set to True.
        relative_target_aspect_ratio : float
            Specify the target aspect ratio (W x H) relative to the input image
            (i.e. by mapping the input image ratio to 1:1). For instance, if an image
            has the size 1024 (H) x 2048 (W), a relative target aspect ratio of 0.5
            might yield images of size 1024 x 1024. Note that this only applies if
            `preserve_aspect_ratio` is set to False.
        super_kwargs : dict
            Keyword arguments for the super class.
        """
AdditiveGaussianNoise

添加高斯噪声

class AdditiveGaussianNoise(Transform):
    """Add gaussian noise to the input."""
    def __init__(self, sigma, **super_kwargs):
        super(AdditiveGaussianNoise, self).__init__(**super_kwargs)
        self.sigma = sigma

    def build_random_variables(self, **kwargs):
        np.random.seed()
        self.set_random_variable('noise', np.random.normal(loc=0, scale=self.sigma,
                                                           size=kwargs.get('imshape')))

    def image_function(self, image):
        image = image + self.get_random_variable('noise', imshape=image.shape)
        return image
RandomRotate

随机90度旋转

class RandomRotate(Transform):
    """Random 90-degree rotations."""
    def __init__(self, **super_kwargs):
        super(RandomRotate, self).__init__(**super_kwargs)

    def build_random_variables(self, **kwargs):
        np.random.seed()
        self.set_random_variable('k', np.random.randint(0, 4))

    def image_function(self, image):
        return np.rot90(image, k=self.get_random_variable('k'))
RandomTranspose

随机2D transpose, 就是维数翻转过来,不懂可以看下面示例,二维的就相当于把 H × W H \times W H×W变成 W × H W \times H W×H

>>> a = np.random.random((2,1,3,4))
>>> b = np.transpose(a)
>>> b.shape
(4, 3, 1, 2)
class RandomTranspose(Transform):
    """Random 2d transpose."""
    def __init__(self, **super_kwargs):
        super(RandomTranspose, self).__init__(**super_kwargs)

    def build_random_variables(self, **kwargs):
        np.random.seed()
        self.set_random_variable('do_transpose', np.random.uniform() > 0.5)

    def image_function(self, image):
        if self.get_random_variable('do_transpose'):
            image = np.transpose(image)
        return image
RandomFlip

随机的左右翻转或上下翻转

class RandomFlip(Transform):
    """Random left-right or up-down flips."""
    def __init__(self, allow_lr_flips=True, allow_ud_flips=True, **super_kwargs):
        super(RandomFlip, self).__init__(**super_kwargs)
        self.allow_lr_flips = allow_lr_flips
        self.allow_ud_flips = allow_ud_flips

    def build_random_variables(self, **kwargs):
        np.random.seed()
        self.set_random_variable('flip_lr', np.random.uniform() > 0.5)
        self.set_random_variable('flip_ud', np.random.uniform() > 0.5)

    def image_function(self, image):
        if self.allow_lr_flips and self.get_random_variable('flip_lr'):
            image = np.fliplr(image)
        if self.allow_ud_flips and self.get_random_variable('flip_ud'):
            image = np.flipud(image)
        return image
CenterCrop

就是中心裁剪

class CenterCrop(Transform):
    """ Crop patch of size `size` from the center of the image """
    def __init__(self, size, **super_kwargs):
        super(CenterCrop, self).__init__(**super_kwargs)
        assert isinstance(size, (int, tuple))
        self.size = (size, size) if isinstance(size, int) else size

    def image_function(self, image):
        h, w = image.shape
        th, tw = self.size
        if h > th:
            y1 = int(round((h - th) / 2.))
            image = image[y1:y1 + th, :]
        if w > tw:
            x1 = int(round((w - tw) / 2.))
            image = image[:, x1:x1 + tw]
        return image
BinaryMorphology

对图像应用二进制形态学操作。支持的操作是膨胀和侵蚀

class BinaryMorphology(Transform):
    """
    Apply a binary morphology operation on an image. Supported operations are dilation
    and erosion.
    """
    def __init__(self, mode, num_iterations=1, morphology_kwargs=None, **super_kwargs):
        """
        Parameters
        ----------
        mode : {'dilate', 'erode'}
            Whether to dilate or erode.
        num_iterations : int
            Number of iterations to apply the operation for.
        morphology_kwargs: dict
            Keyword arguments to the morphology function
            (i.e. `scipy.ndimage.morphology.binary_erosion` or
            `scipy.ndimage.morphology.binary_erosion`)
        super_kwargs : dict
            Keyword arguments to the superclass.
        """
        super(BinaryMorphology, self).__init__(**super_kwargs)
        # Validate and assign mode
        assert_(mode in ['dilate', 'erode'],
                "Mode must be one of ['dilate', 'erode']. Got {} instead.".format(mode),
                ValueError)
        self.mode = mode
        self.num_iterations = num_iterations
        self.morphology_kwargs = {
    
    } if morphology_kwargs is None else dict(morphology_kwargs)

    def image_function(self, image):
        if self.mode == 'dilate':
            transformed_image = binary_dilation(image, iterations=self.num_iterations,
                                                **self.morphology_kwargs)
        elif self.mode == 'erode':
            transformed_image = binary_erosion(image, iterations=self.num_iterations,
                                               **self.morphology_kwargs)
        else:
            raise ValueError
        # Cast transformed image to the right dtype and return
        return transformed_image.astype(image.dtype)
class BinaryDilation(BinaryMorphology):
    """Apply a binary dilation operation on an image."""
    def __init__(self, num_iterations=1, morphology_kwargs=None, **super_kwargs):
        super(BinaryDilation, self).__init__(mode='dilate', num_iterations=num_iterations,
                                             morphology_kwargs=morphology_kwargs,
                                             **super_kwargs)


class BinaryErosion(BinaryMorphology):
    """Apply a binary erosion operation on an image."""
    def __init__(self, num_iterations=1, morphology_kwargs=None, **super_kwargs):
        super(BinaryErosion, self).__init__(mode='erode', num_iterations=num_iterations,
                                            morphology_kwargs=morphology_kwargs,
                                            **super_kwargs)
FineRandomRotations

对输入图像和标签图像进行旋转

class FineRandomRotations(Transform):
    """ Random Rotation with random uniform angle distribution
        batch_function applies to rotation of input and label image

        Parameters
        ----------
        angle_range : int
                      maximum angle of rotation
        axes        : tuple, default (1,2) assuming that channel axis is 0
                      pair of axis that define the 2d-plane of rotation
        mask_label  : constant value that is used to pad the label images
    """
    def __init__(self, angle_range, axes=(1,2), mask_label=0, **super_kwargs):
        super(FineRandomRotations, self).__init__(**super_kwargs)
        self.angle_range = angle_range
        self.axes = axes
        self.ml = mask_label

    def build_random_variables(self):
        np.random.seed()
        self.set_random_variable('angle',
                 np.random.uniform(low=-self.angle_range,
                                   high=self.angle_range))

    def batch_function(self, image):
        angle = self.get_random_variable('angle')
        return rotate(image[0], angle, axes=self.axes, reshape=False), \
               rotate(image[1], angle, axes=self.axes, order=0, cval=self.ml, reshape=False)
RandomScaleSegmentation

随机比例输入图像和标签图像

class RandomScaleSegmentation(Transform):
    """ Random Scale input and label image

        Parameters
        ----------
        scale_range : tuple of floats defining (min, max) scales
                      maximum angle of rotation
        resize  : if True, image is cropped or padded to the original size
        pad_const: value used for constant padding
    """
    def __init__(self, scale_range, resize=True, pad_const=0, **super_kwargs):
        super(RandomScaleSegmentation, self).__init__(**super_kwargs)
        self.scale_range = scale_range
        self.resize = resize
        self.pad_const = pad_const

    def build_random_variables(self):
        np.random.seed()
        self.set_random_variable('seg_scale',
                 np.random.uniform(low=self.scale_range[0],
                                   high=self.scale_range[1]))

    def batch_function(self, image):
        scale = self.get_random_variable('seg_scale')
        input_image, segmentation = image
        image_shape = np.array(input_image.shape[1:])
        if input_image.ndim == segmentation.ndim + 1:
            segmentation = segmentation[None]
        with catch_warnings():
            simplefilter('ignore')
            img = np.stack([zoom(x, scale, order=3) for x in input_image])
            seg = np.stack([zoom(x, scale, order=0) for x in segmentation])
        new_shape = np.array(img.shape[1:])
        if self.resize:
            if scale > 1.:
                # pad image to original size
                crop_l = (new_shape - image_shape) // 2
                crop_r = new_shape - image_shape - crop_l
                cropping = [slice(None)] + [slice(c[0] if c[0] > 0 else None,
                                                 -c[1] if c[1] > 0 else None) for c in zip(crop_l, crop_r)]
                img = img[cropping]
                seg = seg[cropping]
            else:
                # crop image to original size
                pad_l = (image_shape - new_shape) // 2
                pad_r = image_shape - new_shape - pad_l
                padding = [(0,0)] + list(zip(pad_l, pad_r))
                img = np.pad(img, padding, 'constant', constant_values=self.pad_const)
                seg = np.pad(seg, padding, 'constant', constant_values=self.pad_const)     
        return img, seg
RandomGammaCorrection

伽玛校正 https://en.wikipedia.org/wiki/Gamma_correction

class RandomGammaCorrection(Transform):
    """Applies gamma correction [1] with a random gamma.

    This transform uses `skimage.exposure.adjust_gamma`, which requires the input be positive.

    References
    ----------
    [1] https://en.wikipedia.org/wiki/Gamma_correction
    """
    def __init__(self, gamma_between=(0.5, 2.), gain=1, **super_kwargs):
        """
        Parameters
        ----------
        gamma_between : tuple or list
            Specifies the range within which to sample gamma (uniformly).
        gain : int or float
            The resulting gamma corrected image is multiplied by this `gain`.
        super_kwargs : dict
            Keyword arguments for the superclass.
        """
        super(RandomGammaCorrection, self).__init__(**super_kwargs)
        self.gamma_between = list(gamma_between)
        self.gain = gain

    def build_random_variables(self):
        np.random.seed()
        self.set_random_variable('gamma',
                                 np.random.uniform(low=self.gamma_between[0],
                                                   high=self.gamma_between[1]))

    def image_function(self, image):
        gamma_adjusted = adjust_gamma(image,
                                      gamma=self.get_random_variable('gamma'),
                                      gain=self.gain)
        return gamma_adjusted
ElasticTransform

随机弹性变换

class ElasticTransform(Transform):
    """Random Elastic Transformation."""
    NATIVE_DTYPES = {
    
    'float32', 'float64'}
    PREFERRED_DTYPE = 'float32'

    def __init__(self, alpha, sigma, order=1, invert=False, **super_kwargs):
        self._initial_dtype = None
        super(ElasticTransform, self).__init__(**super_kwargs)
        self.alpha = alpha
        self.sigma = sigma
        self.order = order
        self.invert = invert

    def build_random_variables(self, **kwargs):
        # All this is done just once per batch (i.e. until `clear_random_variables` is called)
        np.random.seed()
        imshape = kwargs.get('imshape')
        # Build and scale random fields
        random_field_x = np.random.uniform(-1, 1, imshape) * self.alpha
        random_field_y = np.random.uniform(-1, 1, imshape) * self.alpha
        # Smooth random field (this has to be done just once per reset)
        sdx = gaussian_filter(random_field_x, self.sigma, mode='reflect')
        sdy = gaussian_filter(random_field_y, self.sigma, mode='reflect')
        # Make meshgrid
        x, y = np.meshgrid(np.arange(imshape[1]), np.arange(imshape[0]))
        # Make inversion coefficient
        _inverter = 1. if not self.invert else -1.
        # Distort meshgrid indices (invert if required)
        flow_y, flow_x = (y + _inverter * sdy).reshape(-1, 1), (x + _inverter * sdx).reshape(-1, 1)
        # Set random states
        self.set_random_variable('flow_x', flow_x)
        self.set_random_variable('flow_y', flow_y)

    def cast(self, image):
        if image.dtype not in self.NATIVE_DTYPES:
            self._initial_dtype = image.dtype
            image = image.astype(self.PREFERRED_DTYPE)
        return image

    def uncast(self, image):
        if self._initial_dtype is not None:
            image = image.astype(self._initial_dtype)
        self._initial_dtype = None
        return image

    def image_function(self, image):
        # Cast image to one of the native dtypes (one which that is supported by scipy)
        image = self.cast(image)
        # Take measurements
        imshape = image.shape
        # Obtain flows
        flows = self.get_random_variable('flow_y', imshape=imshape), \
                self.get_random_variable('flow_x', imshape=imshape)
        # Map cooordinates from image to distorted index set
        transformed_image = map_coordinates(image, flows,
                                            mode='reflect', order=self.order).reshape(imshape)
        # Uncast image to the original dtype
        transformed_image = self.uncast(transformed_image)
        return transformed_image

inferno.io.transform.volume 处理三维体积数据

功能和inferno.io.transform.image的函数大多类似,这里只给出函数入口

RandomFlip3D
class RandomFlip3D(Transform):
    def __init__(self, **super_kwargs):
        super(RandomFlip3D, self).__init__(**super_kwargs)
RandomRot3D
class RandomRot3D(Transform):
    def __init__(self, rot_range, p=0.125,  only_one=True, **super_kwargs):
AdditiveRandomNoise3D
class AdditiveRandomNoise3D(Transform):
    """ Add gaussian noise to 3d volume

    Need to know input shape before application, but can be
    synchronized between different inputs (cf. `AdditiveNoise`)
    Arguments:
        shape: shape of input volumes
        std: standard deviation of gaussian
        super_kwargs: keyword arguments for `Transform` base class
    """
    def __init__(self, shape, std, **super_kwargs):
AdditiveNoise
class AdditiveNoise(Transform):
    """ Add noise to 3d volume

    Do NOT need to know input shape before application, but CANNOT be
    synchronized between different inputs (cf. `AdditiveRandomNoise`)
    Arguments:
        sigma: sigma for noise
        mode: mode of distribution (only gaussian supported for now)
        super_kwargs: keyword arguments for `Transform` base class
    """
    def __init__(self, sigma, mode='gaussian', **super_kwargs):
VolumeCenterCrop
class VolumeCenterCrop(Transform):
    """ Crop patch of size `size` from the center of the volume """
    def __init__(self, size, **super_kwargs):
class VolumeAsymmetricCrop(Transform):
    """ Crop `crop_left` from the left borders and `crop_right` from the right borders """
    def __init__(self, crop_left, crop_right, **super_kwargs):
Slices2Channels

输入数据将一维(x、y或z)转换为通道
对于目标数据,只需接受中央切片并丢弃所有其余部分

class Slices2Channels(Transform):
    """ Needed for training 2D network with slices above/below as additional channels
        For the input data transforms one dimension (x, y or z) into channels
        For the target data just takes the central slice and discards all the rest"""
    def __init__(self, num_channels, downsampling = 1, **super_kwargs):

inferno.io.transform.generic 通用的处理方法

Normalize

归一化为均值为0方差为1, 就是服从正态分布

class Normalize(Transform):
    """Normalizes input to zero mean unit variance."""
    def __init__(self, eps=1e-4, mean=None, std=None, **super_kwargs):
        """
        Parameters
        ----------
        eps : float
            A small epsilon for numerical stability.
        mean : list or float or numpy.ndarray
            Global dataset mean for all channels.
        std : list or float or numpy.ndarray
            Global dataset std for all channels.
        super_kwargs : dict
            Kwargs to the superclass `inferno.io.transform.base.Transform`.
        """
        super(Normalize, self).__init__(**super_kwargs)
        self.eps = eps
        self.mean = np.asarray(mean) if mean is not None else None
        self.std = np.asarray(std) if std is not None else None

    def tensor_function(self, tensor):
        mean = np.asarray(tensor.mean()) if self.mean is None else self.mean
        std = np.asarray(tensor.std()) if self.std is None else self.std
        # Figure out how to reshape mean and std
        reshape_as = [-1] + [1] * (tensor.ndim - 1)
        # Normalize
        tensor = (tensor - mean.reshape(*reshape_as))/(std.reshape(*reshape_as) + self.eps)
        return tensor
NormalizeRange

除以某个值进行归一化

class NormalizeRange(Transform):
    """Normalizes input by a constant."""
    def __init__(self, normalize_by=255., **super_kwargs):
        """
        Parameters
        ----------
        normalize_by : float or int
            Scalar to normalize by.
        super_kwargs : dict
            Kwargs to the superclass `inferno.io.transform.base.Transform`.
        """
        super(NormalizeRange, self).__init__(**super_kwargs)
        self.normalize_by = float(normalize_by)

    def tensor_function(self, tensor):
        return tensor / self.normalize_by
Project

给定一个投影映射(即一个dict)和一个输入张量,这个变换将该张量中等于一个键的所有值替换为该键对应的值。

class Project(Transform):
    """
    Given a projection mapping (i.e. a dict) and an input tensor, this transform replaces
    all values in the tensor that equal a key in the mapping with the value corresponding to
    the key.
    """
    def __init__(self, projection, **super_kwargs):
        """
        Parameters
        ----------
        projection : dict
            The projection mapping.
        super_kwargs : dict
            Keywords to the super class.
        """
        super(Project, self).__init__(**super_kwargs)
        self.projection = dict(projection)

    def tensor_function(self, tensor):
        output = np.zeros_like(tensor)
        for source, target in self.projection.items():
            output[tensor == source] = target
        return output
Label2OneHot

对标签进行one-hot编码

class Label2OneHot(Transform, DTypeMapping):
    """Convert integer labels to one-hot vectors for arbitrary dimensional data."""
    def __init__(self, num_classes, dtype='float', **super_kwargs):
        """
        Parameters
        ----------
        num_classes : int
            Number of classes.
        dtype : str
            Datatype of the output.
        super_kwargs : dict
            Keyword arguments to the superclass.
        """
        super(Label2OneHot, self).__init__(**super_kwargs)
        self.num_classes = num_classes
        self.dtype = self.DTYPE_MAPPING.get(dtype)

    def tensor_function(self, tensor):
        reshaped_arange = np.arange(self.num_classes).reshape(-1, *(1,)*tensor.ndim)
        output = np.equal(reshaped_arange, tensor).astype(self.dtype)
        # output = np.zeros(shape=(self.num_classes,) + tensor.shape, dtype=self.dtype)
        # # Optimizing for simplicity and memory efficiency, because one would usually
        # # spawn multiple workers
        # for class_num in range(self.num_classes):
        #     output[class_num] = tensor == class_num
        return output
Cast

将输入强制转换为指定的数据类型

class Cast(Transform, DTypeMapping):
    """Casts inputs to a specified datatype."""
    def __init__(self, dtype='float', **super_kwargs):
        """
        Parameters
        ----------
        dtype : {'float16', 'float32', 'float64', 'half', 'float', 'double'}
            Datatype to cast to.
        super_kwargs : dict
            Kwargs to the superclass `inferno.io.transform.base.Transform`.
        """
        super(Cast, self).__init__(**super_kwargs)
        assert dtype in self.DTYPE_MAPPING.keys()
        self.dtype = self.DTYPE_MAPPING.get(dtype)

    def tensor_function(self, tensor):
        return getattr(np, self.dtype)(tensor)
AsTorchBatch

将给定的numpy数组转换为torch批张量
维度:{1,2,3}

数据的维数:1 向量,2 图像,3 体积。

class AsTorchBatch(Transform):
    """Converts a given numpy array to a torch batch tensor.

    The result is a torch tensor __without__ the leading batch axis. For example,
    if the input is an image of shape `(100, 100)`, the output is a batch of shape
    `(1, 100, 100)`. The collate function will add the leading batch axis to obtain
    a tensor of shape `(N, 1, 100, 100)`, where `N` is the batch-size.
    """
    def __init__(self, dimensionality, add_channel_axis_if_necessary=True, **super_kwargs):
        """
        Parameters
        ----------
        dimensionality : {1, 2, 3}
            Dimensionality of the data: 1 if vector, 2 if image, 3 if volume.
        add_channel_axis_if_necessary : bool
            Whether to add a channel axis where necessary. For example, if `dimensionality = 2`
            and the input temperature has 2 dimensions (i.e. an image), setting
            `add_channel_axis_if_necessary` to True results in the output being a 3 dimensional
            tensor, where the leading dimension is a singleton and corresponds to `channel`.
        super_kwargs : dict
            Kwargs to the superclass `inferno.io.transform.base.Transform`.
        """
        super(AsTorchBatch, self).__init__(**super_kwargs)
        assert dimensionality in [1, 2, 3]
        self.dimensionality = dimensionality
        self.add_channel_axis_if_necessary = bool(add_channel_axis_if_necessary)

    def _to_batch(self, tensor):
        assert_(isinstance(tensor, np.ndarray),
                "Expected numpy array, got %s" % type(tensor),
                DTypeError)
        if self.dimensionality == 3:
            # We're dealing with a volume. tensor can either be 3D or 4D
            assert tensor.ndim in [3, 4]
            if tensor.ndim == 3 and self.add_channel_axis_if_necessary:
                # Add channel axis
                return torch.from_numpy(tensor[None, ...])
            else:
                # Channel axis is in already
                return torch.from_numpy(tensor)
        elif self.dimensionality == 2:
            # We're dealing with an image. tensor can either be 2D or 3D
            assert tensor.ndim in [2, 3]
            if tensor.ndim == 2 and self.add_channel_axis_if_necessary:
                # Add channel axis
                return torch.from_numpy(tensor[None, ...])
            else:
                # Channel axis is in already
                return torch.from_numpy(tensor)
        elif self.dimensionality == 1:
            # We're dealing with a vector - it has to be 1D
            assert tensor.ndim == 1
            return torch.from_numpy(tensor)
        else:
            raise NotImplementedError

    def tensor_function(self, tensor):
        assert_(isinstance(tensor, (list, np.ndarray)),
                "Expected numpy array or list, got %s" % type(tensor),
                DTypeError)
        if isinstance(tensor, np.ndarray):
            return self._to_batch(tensor)
        else:
            return [self._to_batch(elem) for elem in tensor]

使用示例

使用上面提到的一些模块调用了一下,发现和torchvision.transforms的使用其实没有什么太大差别,就是Compose不需要加[]了,如果你之前习惯用torchvision.transforms,那么这个inferno.io.transform也没什么难的。

from inferno.io.transform.base import Transform, Compose
from inferno.io.transform.generic import Normalize
from inferno.io.transform.image import RandomCrop, RandomRotate, RandomFlip, PILImage2NumPyArray
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder

transform = Compose(PILImage2NumPyArray(), RandomCrop(64), RandomRotate(), RandomFlip())
transform.add(Normalize()) ## 增加一个变换


datasetData = ImageFolder('../video2image/Data/Data_train',transform=transform)       
dataloaderData = DataLoader(datasetData, batch_size=4, shuffle=False, num_workers=1, drop_last=True)  
datasetLabel = ImageFolder('../video2image/Data/Label_train',transform=transform)        
dataloaderLabel = DataLoader(datasetLabel, batch_size=4, shuffle=False, num_workers=1, drop_last=True)  
dataTrain = zip(dataloaderData, dataloaderLabel)   

for batch,loadTrain in enumerate(dataTrain):                  
    data0, label0 = loadTrain[0][0], loadTrain[1][0]   
    print(data0.shape, label0.shape)

结果
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_36937684/article/details/110143080