03_PyTorch モデルのトレーニング [Dataset クラスがデータセットを読み取る]

PyTorch は主に Dataset クラスを介して画像を読み取る ため、最初にDatasetクラス を簡単に理解しましょう データセット
このクラスはすべてのデータセット の基本クラスとして存在し C++の仮想ベースと同様に、すべての データセットがそれを継承する必要があります。
親切。

 

ここでgetitem 関数 に注目します。getitem index を受け取り 、画像データとラベルを返します。
index は通常、 リストインデックス を指し 、このリストの各要素には画像データのパスとラベル情報が含まれます
息。
ただ、このリストの 作り方は 、写真のパスとラベル情報をtxtに保存する方法が一般的 です
、そしてその txtから 読み取ります。
次に、独自のデータを読み取る基本的なプロセスは次のとおりです。
1. 画像のパスとラベル情報を格納する txtを作成する
2. この情報を リストに変換する と、 リストの 各要素がサンプルに対応します
3. getitem関数 を使用して 、データとタグを読み取り、データとタグを返します
したがって、 PyTorch が 独自のデータセットを読み取れるようにするには、次の 2 つの手順のみが必要です。
1. 画像データのインデックスを作る
2. データセットサブクラス を構築する
1.メモ帳コードを生成する
import os
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
'''
    为数据集生成对应的txt文件
'''
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
train_txt_path = os.path.join(base_dir, "Data", "train.txt")
train_dir = os.path.join(base_dir, "Data", "train")
valid_txt_path = os.path.join(base_dir, "Data", "valid.txt")
valid_dir = os.path.join(base_dir, "Data", "valid")
print(train_txt_path)
print(train_dir)
print(valid_txt_path)
print(valid_dir)
def gen_txt(txt_path, img_dir):
    f = open(txt_path, 'w')
    
    for root, s_dirs, _ in os.walk(img_dir, topdown=True):  # 获取 train文件下各文件夹名称
        for sub_dir in s_dirs:
            i_dir = os.path.join(root, sub_dir)             # 获取各类的文件夹 绝对路径
            img_list = os.listdir(i_dir)                    # 获取类别文件夹下所有png图片的路径
            for i in range(len(img_list)):
                if not img_list[i].endswith('png'):         # 若不是png文件,跳过
                    continue
                label = img_list[i].split('_')[0]
                img_path = os.path.join(i_dir, img_list[i])
                line = img_path + ' ' + label + '\n'
                f.write(line)
    f.close()
gen_txt(train_txt_path, train_dir)
gen_txt(valid_txt_path, valid_dir)

2.効果

 

 3. データセット クラス コード

class MyDataset(Dataset):
    def __init__(self, txt_path, transform=None, target_transform=None):
        fh = open(txt_path, 'r') 
        imgs = []
        for line in fh:
            line = line.rstrip() #rstrip函数返回字符串副本,该副本是从字符串最右边删除了参数指定字符后的字符串,不带参数进去则是去除最右边的空格
            words = line.split() #默认以空格为分隔符
            imgs.append((words[0], int(words[1])))

        self.imgs = imgs        # 最主要就是要生成这个list, 然后DataLoader中给index,通过getitem读取图片数据
        # transform 是一个 Compose 类型,里边有一个 list,list 中就会定义了各种对图像进行处理的操作,
        #可以设置减均值,除标准差,随机裁剪,旋转,翻转,仿射变换等操作
        #在这里我们可以知道,一张图片读取进来之后,会经过数据处理(数据增强),
        #最终变成输入模型的数据。这里就有一点需要注意,PyTorch 的数据增强是将原始图片进行了处理
        #并不会生成新的一份图片,而是“覆盖”原图
        self.target_transform = target_transform
        self.transform = transform 

    def __getitem__(self, index):
        fn, label = self.imgs[index]
        #对图片进行读取
        img = Image.open(fn).convert('RGB')     # 像素值 0~255,在transfrom.totensor会除以255,使像素值变成 0~1

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.imgs)

4.データロード

Mydatasetがビルドさ れる 、残りの操作は DataLoderに 引き渡されます
Mydataset getiterm関数は 、画像のデータとラベルを読み取り、 バッチにまとめて 返します。
モデルの実際の入力。

 

おすすめ

転載: blog.csdn.net/zhang2362167998/article/details/128807544