Pytorch 시작부터 숙달까지: 2. 데이터 세트 및 데이터로다

데이터는 딥 러닝의 기초입니다 일반적으로 데이터 양이 많을수록 학습된 모델의 성능이 높아집니다. 지금 데이터가 있는 경우 이 데이터를 모델에 어떻게 추가합니까? Pytorch는 데이터 세트와 데이터 로더를 제공합니다 . 함께 배우겠습니다. 데이터 세트와 데이터 로더 블로거가 몇 가지 예를 사용하여 설명하겠습니다. 지원해 주셔서 감사합니다!
여기에 이미지 설명 삽입

1. 데이터세트

데이터 및 레이블을 가져오는 방법 제공
● 각 데이터 및 레이블을 가져오는 방법 ● pytorch를 사용할 수 있는지 확인하기 위해
얼마나 많은 데이터가 있는지 알려주십시오.

print(torch.cuda.is_available()) # 查看当前cuda是否可用
True

둘째, 데이터 세트 보기

from torch.utils.data import Dataset
help(Dataset) # 用帮助文档查看Dataset

모듈 torch.utils.data.dataset의 class Dataset에 대한 도움말:
class Dataset(typing.Generic)
| 데이터세트(*args, **kwds)
|
| :class: 를 나타내는 추상 클래스 Dataset.
|
| 키에서 데이터 샘플로의 맵을 나타내는 모든 데이터 세트는
| 그것. 모든 서브클래스는 :meth: 를 덮어써야 __getitem__합니다
. 주어진 키에 대한 데이터 샘플. 하위 클래스는 선택적으로 덮어쓸 수도
| :meth: __len__, 많은 사람들이 데이터 세트의 크기를 반환할 것으로 예상됩니다
| :class: ~torch.utils.data.Sampler구현 및 기본 옵션
| :클래스: ~torch.utils.data.DataLoader.
|
| … 참고::
| :class: ~torch.utils.data.DataLoader기본적으로 인덱스를 구성합니다.
| 적분 지수를 생성하는 샘플러. 지도 스타일로 작동하게 하려면
| 필수가 아닌 인덱스/키가 있는 데이터 세트의 경우 사용자 지정 샘플러를 제공해야 합니다.
|
| 방법 해결 순서:
| 데이터 세트
| 타이핑.일반
| builtins.object
|
| 여기에 정의된 방법:
|
| add (self, other: 'Dataset[T_co]') -> 'ConcatDataset[T_co]'
|
| getattr (자신, 속성 이름)
|
| getitem (자체, 인덱스) -> +T_co
|
| -----------------------------------------------------------------------
| 여기에 정의된 클래스 메서드:
|
| builtins.type에서 register_datapipe_as_function(function_name, cls_to_register, enable_df_api_tracing=False)
|
| builtins.type의 register_function(function_name, function)
|
| -----------------------------------------------------------------------
| 여기에 정의된 데이터 설명자:
|
| 딕셔너리
| 인스턴스 변수 사전(정의된 경우)
|
| 약한 참조
| 개체에 대한 약한 참조 목록(정의된 경우)
|
| -----------------------------------------------------------------------
| 여기에 정의된 데이터 및 기타 속성:
|
| 주석= {'functions': typing.Dict[str, typing.Callable]}
|
| orig_bases = (입력.일반[+T_co],)
|
| 매개변수 = (+T_co,)
|
| functions = {'concat': functools.partial(<function Dataset.register_da…
|
| -----------------------------------------------------------------------
Typeing.Generic에서 상속된 클래스 메서드:
|
| builtins.type의 class_getitem (params)
|
| builtins.type의 init_subclass (*args, **kwargs)
| 이 메서드는 클래스가 서브클래싱될 때 호출됩니다.
|
| 기본 구현은 아무 작업도 수행하지 않습니다.
| 하위 클래스를 확장하도록 재정의되었습니다.
|
| -----------------------------------------------------------------------
| typing.Generic에서 상속된 정적 메서드:
|
| 새로운 (cls, *args, **kwds)
| 새 객체를 생성하고 반환합니다. 정확한 서명은 help(type)를 참조하십시오.

3. os 작업은 폴더 아래의 개체를 읽습니다.

import os
dir_path = "hymenoptera_data\\hymenoptera_data\\train\\ants"  # 文件夹目录
data_dir = os.listdir(dir_path)  # 获取文件夹目录中的对象
data_dir

[‘0013035.jpg’,
‘1030023514_aad5c608f9.jpg’,
‘1095476100_3906d8afde.jpg’,
‘1099452230_d1949d3250.jpg’,
‘116570827_e9c126745d.jpg’,
‘1225872729_6f0856588f.jpg’,
‘1262877379_64fcada201.jpg’,
‘1269756697_0bce92cdab.jpg’,
‘1286984635_5119e80de1.jpg’,
‘132478121_2a430adea2.jpg’,
‘1360291657_dc248c5eea.jpg’,
‘1368913450_e146e2fb6d.jpg’,
‘1473187633_63ccaacea6.jpg’,
‘148715752_302c84f5a4.jpg’,
‘1489674356_09d48dde0a.jpg’,
‘149244013_c529578289.jpg’,
‘150801003_3390b73135.jpg’,
‘150801171_cd86f17ed8.jpg’,
‘154124431_65460430f2.jpg’,
‘162603798_40b51f1654.jpg’,
‘1660097129_384bf54490.jpg’,
‘167890289_dd5ba923f3.jpg’,
‘1693954099_46d4c20605.jpg’,
‘175998972.jpg’,
‘178538489_bec7649292.jpg’,
‘1804095607_0341701e1c.jpg’,
‘1808777855_2a895621d7.jpg’,
‘188552436_605cc9b36b.jpg’,
‘1917341202_d00a7f9af5.jpg’,
‘1924473702_daa9aacdbe.jpg’,
‘196057951_63bf063b92.jpg’,
‘196757565_326437f5fe.jpg’,
‘201558278_fe4caecc76.jpg’,
‘201790779_527f4c0168.jpg’,
‘2019439677_2db655d361.jpg’,
‘207947948_3ab29d7207.jpg’,
‘20935278_9190345f6b.jpg’,
‘224655713_3956f7d39a.jpg’,
‘2265824718_2c96f485da.jpg’,
‘2265825502_fff99cfd2d.jpg’,
‘226951206_d6bf946504.jpg’,
‘2278278459_6b99605e50.jpg’,
‘2288450226_a6e96e8fdf.jpg’,
‘2288481644_83ff7e4572.jpg’,
‘2292213964_ca51ce4bef.jpg’,
‘24335309_c5ea483bb8.jpg’,
‘245647475_9523dfd13e.jpg’,
‘255434217_1b2b3fe0a4.jpg’,
‘258217966_d9d90d18d3.jpg’,
‘275429470_b2d7d9290b.jpg’,
‘28847243_e79fe052cd.jpg’,
‘318052216_84dff3f98a.jpg’,
‘334167043_cbd1adaeb9.jpg’,
‘339670531_94b75ae47a.jpg’,
‘342438950_a3da61deab.jpg’,
‘36439863_0bec9f554f.jpg’,
‘374435068_7eee412ec4.jpg’,
‘382971067_0bfd33afe0.jpg’,
‘384191229_5779cf591b.jpg’,
‘386190770_672743c9a7.jpg’,
‘392382602_1b7bed32fa.jpg’,
‘403746349_71384f5b58.jpg’,
‘408393566_b5b694119b.jpg’,
‘424119020_6d57481dab.jpg’,
‘424873399_47658a91fb.jpg’,
‘450057712_771b3bfc91.jpg’,
‘45472593_bfd624f8dc.jpg’,
‘459694881_ac657d3187.jpg’,
‘460372577_f2f6a8c9fc.jpg’,
‘460874319_0a45ab4d05.jpg’,
‘466430434_4000737de9.jpg’,
‘470127037_513711fd21.jpg’,
‘474806473_ca6caab245.jpg’,
‘475961153_b8c13fd405.jpg’,
‘484293231_e53cfc0c89.jpg’,
‘49375974_e28ba6f17e.jpg’,
‘506249802_207cd979b4.jpg’,
‘506249836_717b73f540.jpg’,
‘512164029_c0a66b8498.jpg’,
‘512863248_43c8ce579b.jpg’,
‘518773929_734dbc5ff4.jpg’,
‘522163566_fec115ca66.jpg’,
‘522415432_2218f34bf8.jpg’,
‘531979952_bde12b3bc0.jpg’,
‘533848102_70a85ad6dd.jpg’,
‘535522953_308353a07c.jpg’,
‘540889389_48bb588b21.jpg’,
‘541630764_dbd285d63c.jpg’,
‘543417860_b14237f569.jpg’,
‘560966032_988f4d7bc4.jpg’,
‘5650366_e22b7e1065.jpg’,
‘6240329_72c01e663e.jpg’,
‘6240338_93729615ec.jpg’,
‘649026570_e58656104b.jpg’,
‘662541407_ff8db781e7.jpg’,
‘67270775_e9fdf77e9d.jpg’,
‘6743948_2b8c096dda.jpg’,
‘684133190_35b62c0c1d.jpg’,
‘69639610_95e0de17aa.jpg’,
‘707895295_009cf23188.jpg’,
‘7759525_1363d24e88.jpg’,
‘795000156_a9900a4a71.jpg’,
‘822537660_caf4ba5514.jpg’,
‘82852639_52b7f7f5e3.jpg’,
‘841049277_b28e58ad05.jpg’,
‘886401651_f878e888cd.jpg’,
‘892108839_f1aad4ca46.jpg’,
‘938946700_ca1c669085.jpg’,
‘957233405_25c1d1187b.jpg’,
‘9715481_b3cb4114ff.jpg’,
‘998118368_6ac1d91f81.jpg’,
‘ant photos.jpg’,
‘Ant_1.jpg’,
‘army-ants-red-picture.jpg’,
‘formica.jpeg’,
‘hormiga_co_por.jpg’,
‘imageNotFound.gif’,
‘kurokusa.jpg’,
‘MehdiabadiAnt2_600.jpg’,
‘Nepenthes_rafflesiana_ant.jpg’,
‘swiss-army-ant.jpg’,
‘termite-vs-ant.jpg’,
‘trap-jaw-ant-insect-bg.jpg’,
‘VietnameseAntMimicSpider.jpg’]
注意在windows下,路径使用双斜线\

4. 데이터세트

데이터 세트 연습 1

from torch.utils.data import Dataset
import os
from PIL import Image


class Mydata(Dataset):
    def __init__(self,root_path,label_path):
        self.root_path = root_path  # hymenoptera_data/hymenoptera_data/train
        self.label_path = label_path  # /ants
        self.path = os.path.join(self.root_path,self.label_path)  # 从根目录开始的绝对路径
        self.image_path = os.listdir(self.path) # 从根目录开始绝对路径文件夹下的对象 hymenoptera_data/hymenoptera_data/train/ants下的图片 type--> list
    def __getitem__(self, idx):
        image_name = self.image_path[idx] # 单一的图片名称
        image_item_path = os.path.join(self.root_path,self.label_path,image_name)
        img = Image.open(image_item_path)
        label = self.label_path
        return img,label
    def __len__(self):
        return len(self.image_path)

ants_root_path = "hymenoptera_data\\hymenoptera_data\\train"
ants_label_path = "ants"
Ants = Mydata(ants_root_path,ants_label_path)
Ants[0][0].show() # 第一个0是索引,拿到第一个图像和标签,第二个0是拿到第一个图像,并显示出来

D:\anaconda\envs\Gpu-Pytorch\lib\site-packages\tqdm\auto.py:21: TqdmWarning: IProgress를 찾을 수 없습니다. jupyter 및 ipywidgets를 업데이트하십시오. https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm 참조
여기에 이미지 설명 삽입

bee_label_path = "bees"
Bees = Mydata(bee_root_path,bee_label_path)
Bees[0][0].show()

여기에 이미지 설명 삽입

# 创建训练集

train = Ants + Bees   # 直接将数据集加起来
print("the length of Ants is ",Ants.__len__())
print("the length of Bees is ",Bees.__len__())
print("the length of train is ",train.__len__())
the length of Ants is  124
the length of Bees is  121
the length of train is  245
# 查看是否正确
train[123][0].show() # 应该为蚂蚁
train[124][0].show() # 应该为蜜蜂

여기에 이미지 설명 삽입
여기에 이미지 설명 삽입

데이터셋 실사용 2

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch学习 
@File    :task_3.py
@IDE     :PyCharm 
@Author  :咋
@Date    :2023/6/29 14:29 
"""
from torch.utils.data import Dataset
import os
from PIL import Image

class Mydata(Dataset):
    def __init__(self,root_path,image_path,label_path):
        self.root_path = root_path
        self.image_path = image_path
        self.label_path = label_path
        self.A_image_path = os.path.join(self.root_path,self.image_path)
        self.A_label_path = os.path.join(self.root_path,self.label_path)
        self.img_item = os.listdir(self.A_image_path)
        self.label_item = os.listdir(self.A_label_path)

    def __getitem__(self, idx):
        img_name = self.img_item[idx]
        img_path = os.path.join(self.A_image_path, img_name)
        label_list = [i.split(".")[0] for i in self.label_item if i.count(".") == 1]
        # print(label_list)
        if img_name.split(".")[0] in label_list:
            img = Image.open(img_path)
            label_path = os.path.join(self.A_label_path,img_name.split(".")[0])
            label_path += ".txt"
            file = open(label_path, 'r')
            label = file.read()
            file.close()
            return img,label
        else:
            print("{0}没有对应的标签".format(img_name))
            return 0

    def __len__(self):
        return len(self.img_item)





train_ants_root_path = "练手数据集\\train"
train_ants_image_path = "ants_image"
train_ants_label_path = "ants_label"
Ants = Mydata(train_ants_root_path,train_ants_image_path,train_ants_label_path)
for i in range(Ants.__len__()):
    try:
        print(Ants[i][1])
    except TypeError:
        print("跳过此张图片!")
# Ants[122][0].show()
# print(Ants[122][1])

개미
개미
개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미
개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미











































개미
개미
개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미
개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미 개미











































ants
ants
ants
ants
ants
ants
ants
ants
ants
ants ants ants
ants
ants
ants
ants
ants
ants
ants ants
ants ants formica.jpeg 해당 태그 없음 이 이미지 건너뛰기
! 개미 imageNotFound.gif에는 해당 태그가 없습니다. 이 사진을 건너뛰세요! ants ants ants ants ants ants 예외 캡처를 추가하여 사진에 해당 레이블이 없는 문제를 해결합니다!












데이터 세트 실제 작업 3

torchvision의 데이터 세트를 사용하여 데이터 세트 생성

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project :Pytorch_learn 
@File    :dataset_3.py
@IDE     :PyCharm 
@Author  :咋
@Date    :2023/7/2 14:58 
"""
import torchvision
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from torchvision import transforms
dataset = torchvision.datasets.MNIST("./Mnist",train=True,download=True,transform=transforms.ToTensor())
dataloader = DataLoader(dataset,batch_size=64,shuffle=False,num_workers=0)
# 使用tensorboard将dataloader展示出来
'''方式一
# write = SummaryWriter("log_2")
# count = 0
# for data in dataloader:
#     image,label = data
#     # print(data[1])
#     # print(image.shape)
#     write.add_images("dataloader",image,count)
#     count += 1
'''

# 方式二
write = SummaryWriter("log_3")
for i,data in enumerate(dataloader):
    image,label = data
    write.add_images("dataloader",image,i)

write.close()

여기에 이미지 설명 삽입
enumerate는 인덱스와 함께 반복 가능한 객체의 내용을 반환합니다.

例如对于一个seq,得到:
(0, seq[0]), (1, seq[1]), (2, seq[2])

다섯, 데이터 로더

후속 네트워크에 다른 데이터 유형 제공

데이터 세트를 사용자 지정하고 datalodar로 로드

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from net import Net
import softmax
from torch.utils.data import Dataset
import os
from PIL import Image
import numpy as np


transform_tool = transforms.ToTensor()  # 创建一个transform工具
# # image_tensor = transform_tool(image)
with open("mnist-label.txt", 'r') as f:
    label_str = f.read().strip()   # 打开文件读入缓存
class Mydata(Dataset):
    def __init__(self,image_path):
        self.image_path = image_path
        # self.label_path = label_path  # /ants
        self.image = os.listdir(self.image_path) # 从根目录开始绝对路径文件夹下的对象 hymenoptera_data/hymenoptera_data/train/ants下的图片 type--> list
    def __getitem__(self, idx):
        image_name = self.image[idx] # 单一的图片名称
        image_item_path = os.path.join(self.image_path,image_name)
        img = Image.open(image_item_path)
        # transform_tool = transforms.ToTensor()  # 创建一个transform工具
        img = transform_tool(img)
        labels_list = [int(label) for label in label_str.split(',')]  # 读取标签,不用每次都打开
        labels = np.array(labels_list)
        label = labels[idx]
        return img,label
    def __len__(self):
        return len(self.image)
# trainset = Mydata("mnist-dataset")

# 设置训练参数
batch_size = 32
epochs = 5
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 数据集
# transform = transforms.Compose([transforms.ToTensor(),
#                                 transforms.Normalize((0.5,), (0.5,))])
# trainset =
# trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainset = Mydata("mnist-dataset")

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False,num_workers=0)
print(len(trainloader))
# 输出提示信息
print("batch_size:", batch_size)
print("data_batches:", len(trainloader))
print("epochs:", epochs)

# 神经网络
net = Net().to(device)
# net.load_state_dict(torch.load('./model/model.pth'))

# 损失函数和优化器
# 负对数似然损失
criterion = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=0.0005, momentum=0.9)
total_correct = 0
total_samples = 0
# 训练网络
```python
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader):
        inputs, labels = data
        inputs, labels = Variable(inputs).to(device), Variable(labels).to(device)

        # 反向传播优化参数
        optimizer.zero_grad()
        outputs = net(inputs)
        # outputs = int(net(inputs))
        # print(outputs)
        labels = labels.long()
        # print(labels)
        # print(type(labels))
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        # 计算每个batch的准确率
        _, predicted = torch.max(outputs.data, 1)
        total_samples += labels.size(0)
        total_correct += (predicted == labels).sum().item()

        if i % 5 == 0:    # 每轮输出损失值
            accuracy = 100.0 * total_correct / total_samples
            print('[epoch: %d, batches: %d] loss: %.5f accuracy: %.2f%%' %
                  (epoch + 1, i + 1, running_loss / 2000, accuracy))
            total_correct = 0
            total_samples = 0
            running_loss = 0.0
torch.save(net.state_dict(), 'model.pth')  # 每轮保存模型参数

print('Finished Training')

__getitem__을 실행할 때마다 파일을 열지 않고도 클래스를 정의하기 전에 파일을 열고, 파일 정보를 캐시로 읽고, __getitem__의 각 레이블을 읽을 수 있습니다.

여섯, OS의 일부 작업

windows使用两个\\表示路径
import os
dir_path = "/home/aistudio"  # 文件夹目录
data_dir = os.listdir(dir_path)  # 获取文件夹目录中的对象
label_path = "label"
all_path = os.path.join(dir_path,label_path)

Supongo que te gusta

Origin blog.csdn.net/weixin_63866037/article/details/131839109
Recomendado
Clasificación