PyTorch学习笔记(10)transforms(4)

自定义transforms

自定义transforms要素

  1. 仅接收一个参数,返回一个参数
    2.注意上下游的输出与输入

通过类实现多参数传入:

class YourTransforms(object):
     def __init__(self,...):
         ...
     def __call__(self, img):
         ...
         return img

椒盐噪声

椒盐噪声又称为脉冲噪声,是一种随机出现的白点或黑点,白点成为盐噪声,黑点称为椒噪声
信噪比(Signal-Noise Rate,SNR) 是衡量噪声的比例,图像中为图像像素的占比

class AddPepperNoise(object):
    def __init__(self,snr,p):
        self.snr = snr
        self.p = p
    def __call__(self, img):
        """
        添加椒盐噪声具体实现过程
        :param img:
        :return:
        """
        return img
class Compose(object):
    def __call__(self,img):
        for t in self.transforms:
            img = t(img)
        return img
# -*- coding: utf-8 -*-
import os
import numpy as np
import torch
import random
import torchvision.transforms as transforms
from PIL import Image
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from tools.my_dataset import RMBDataset
from tools.common_tools import transform_invert


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed(1)  # 设置随机种子

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 1
LR = 0.01
log_interval = 10
val_interval = 1
rmb_label = {"1": 0, "100": 1}


class AddPepperNoise(object):
    """增加椒盐噪声
    Args:
        snr (float): Signal Noise Rate
        p (float): 概率值,依概率执行该操作
    """
    # 默认信噪比90% 保存90%的像素是原始图像
    def __init__(self, snr, p=0.9):
        assert isinstance(snr, float) or (isinstance(p, float))
        self.snr = snr
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """
        if random.uniform(0, 1) < self.p:
            img_ = np.array(img).copy()
            h, w, c = img_.shape
            # 设置信号的百分比 信噪比
            signal_pct = self.snr
            # 噪声的百分比
            noise_pct = (1 - self.snr)
            # 选取mask  mask值 0,1,2 0代表原始图像 1代表盐噪声 2代表椒噪声
            mask = np.random.choice((0, 1, 2), size=(h, w, 1), p=[signal_pct, noise_pct/2., noise_pct/2.])
            mask = np.repeat(mask, c, axis=2)
            img_[mask == 1] = 255   # 盐噪声 白色
            img_[mask == 2] = 0     # 椒噪声 黑色
            return Image.fromarray(img_.astype('uint8')).convert('RGB')
        else:
            return img


# ============================ step 1/5 数据 ============================
split_dir = os.path.join( "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]


train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    AddPepperNoise(0.9, p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std)
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)


# ============================ step 5/5 训练 ============================
for epoch in range(MAX_EPOCH):
    for i, data in enumerate(train_loader):

        inputs, labels = data   # B C H W

        img_tensor = inputs[0, ...]     # C H W
        img = transform_invert(img_tensor, train_transform)
        plt.imshow(img)
        plt.show()
        plt.pause(0.5)
        plt.close()

如何制定数据增强策略

原则 让训练集合测试集更接近
空间位置:平移
色彩:灰度图,色彩抖动
形状:仿射变换
上下文场景:遮挡,填充

发布了21 篇原创文章 · 获赞 0 · 访问量 228

猜你喜欢

转载自blog.csdn.net/qq_33357094/article/details/104452291