Pytorchスタディノート1:独自の画像データセットを作成してロードする

Pytorchスタディノート1:独自の画像データセットを作成してロードする



序文

まず、pytorchを使用してネットワークの既存のデータセットをロードする方法を紹介し、次に独自の画像データセットを作成してバッチで読み取り、独自のネットワークをトレーニングする方法を紹介します。


ヒント:以下はこの記事の内容です。以下のケースは参照用です。

1つは、データセットをダウンロードする

Pytorchを使用して、ローカルのMINISTデータセットを読み取り、ロードします

# 训练数据和测试数据的下载
trainDataset = torchvision.datasets.MNIST( # torchvision可以实现数据集的训练集和测试集的下载
  root="./data", # 下载数据,并且存放在data文件夹中
  train=True, # train用于指定在数据集下载完成后需要载入哪部分数据,如果设置为True,则说明载入的是该数据集的训练集部分;如果设置为False,则说明载入的是该数据集的测试集部分。
  transform=transforms.ToTensor(), # 数据的标准化等操作都在transforms中,此处是转换
  download=True 
)
 
testDataset = torchvision.datasets.MNIST(
  root="./data",
  train=False,
  transform=transforms.ToTensor(),
  download=True
)

2、独自のデータセットをロードします

1.データセットを作成します

ニューラルネットワークのトレーニングには、標準入力画像とそのグラウンドトゥルースラベルが必要です。
猫、犬、ボート、車などの分類問題では、数字を使用してさまざまな分類を表すことができます。入力画像のアドレスとそれに対応するラベル番号を保存するために、txtファイルを作成できます。
画像を入力として取得し、別の処理済み画像をその真理値として取得する必要があるタスクがあるため、txtテキストの下に書き込んだのはそれらのパスです。トレーニング画像のプロジェクトパスの下に新しいtrainフォルダが作成され、トレーニング画像とラベル画像にラベルを付けるためにtrainフォルダの下に新しいtrainingtxtが作成されます。
ここに画像の説明を挿入

2.データセットをロードします

データセットクラス

PyTorchは、主にDatasetクラスを介して画像を読み取ります。これは、Pytorchのすべてのデータセット読み込みクラスから継承する必要がある親クラスです。Datasetクラスを継承して書き換えることにより、独自の画像データセットを読み取ります。次の3つの関数を書き直す必要があります。
データファイルを読み取るための__init__メソッド

__getitem__メソッドは添え字アクセスをサポートします

__len__メソッドは、後のトラバーサルを容易にするためにカスタムデータセットのサイズを返します

class OpticalSARDataset(Data.Dataset):
    """
      定义自己的数据集、读取数据、初始化数据
    """

    def __init__(self, data_dir, part):
        # 所有图片的绝对路径
        assert part in ["train", "val"]
        self.image_dir = os.path.join(data_dir, part)
        self.img_names = []
        self.label_names = []

        with open(os.path.join(data_dir, part, "label.txt")) as f:
            while True:
                il = f.readline(1500)  # 如果样本数据名称大于1500,修改该值
                if not il:
                    break
                a = il.split(sep=' ')
                self.img_names.append(a[0])
                self.label_names.append(a[1][0:-1])  # remove '\n'
        self.samples_num = len(self.img_names)
        # print(self.samples_num)

        self.transform = torchvision.transforms.Compose([
            # 将 PIL 图片转换成位于[0.0, 1.0]的floatTensor, shape (C x H x W)
            torchvision.transforms.ToTensor()])

    def __len__(self):
        # 返回图像的数量
        return self.samples_num

    def __getitem__(self, index):
        tp_img = Image.open(os.path.join(self.image_dir,  self.img_names[index])
                                 ).convert('RGB')
        tp_label = Image.open(os.path.join(self.image_dir, self.label_names[index])
                               ).convert('RGB')
        # PIL.Image.open 读取的图片数据是RGB格式;
        tp_img = cv2.cvtColor(np.asarray(tp_img), cv2.COLOR_RGB2BGR)
        tp_label = cv2.cvtColor(np.asarray(tp_label), cv2.COLOR_RGB2BGR) # 转换为BGR便于cv2.imshow,跟下面imshow之前RGB2BGR只用一种方法,这里统一为cv2的BGR格式
        img = self.transform(tp_img)
        label = self.transform(tp_label)


        sample = {
    
    
            "label": label,  # shape
            "image": img  # shape: (3, *image_size)
        }


        return sample

データセットを定義する

# 利用之前创建好的OpticalSARDataset类去创建数据对象
train_dataset = OpticalSARDataset(data_dir, 'train')  # 训练数据集

データローダークラス

前述のDatasetクラスは、データセットを読み取り、読み取ったデータにインデックスを付けます。
しかし、この関数だけでは十分ではありません。データセットをロードする実際のプロセスでは、データ量が非常に大きいことがよくあります。これには、いくつかの関数が必要です。
バッチで読み取ることができます。バッチサイズ
でデータをランダムに読み取ることができます。シャッフル操作(シャッフル)であり、データセット内のデータ分散の順序を乱し、データ
を並列にロードできます(マルチコアプロセッサを使用してデータのロードの効率を高速化します)
Dataloaderクラスは私たちを必要としません独自のコードを設計する、DataLoaderクラスを使用するだけで、設計したクラスを読み取ることができます。

データセットをインスタンス化する

# 利用dataloader读取我们的数据对象,并设定batch-size和工作现场
train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=0)
batch = iter(train_iter).next()
print(batch["image"].shape, batch["label"].shape)
print(batch["image"][0].shape)

総括する

リファレンスブログ:
独自のデータセットを定義する
pytorch独自のデータセットをロードする独自のデータを
設計する独自のデータを
トレーニングする完全な手順
データセットクラス



おすすめ

転載: blog.csdn.net/qq_43173239/article/details/108948228