PyTorch学习笔记(一) ---------数据集的简单创建

一、图像基本处理以及数据集的简单创建

初次接触pytorch,配置环境还是比较麻烦的,主要是用到anaconda

下面是对图像处理的基本操作

from PIL import Image  # 图像处理的库
img_path = r'D://情绪图片测试/path1.jpg'  # 图片路径
img = Image.open(img_path)  # 调用方法,打开该图像
print(img.size)  # 输出该图像的尺寸
img.show()  # 显示数据集

这是将目的地址的图片形成列表

import os
dir_path = r'D://情绪图片测试5张'  # 目的地址 
img_path_list = os.listdir(dir_path)  # 运用os库将目的地址里面的图片放在一个列表里面
print(img_path_list[0])  # 测试 输出list【0】

下面来创建一个类,主要包含三种函数

class MyData(Dataset):
    def __init__(self, root_dir, label_dir):  # 初始化类,为class提供全局变量
    def __getitem__(self, idx):  # 获取列表中每一个图片
    def __len__(self):  # 返回长度

一个一个看,来看第一个初始化值

    def __init__(self, root_dir, label_dir):  # 初始化类,为class提供全局变量
        self.root_dir = root_dir  # 根文件位置
        self.label_dir = label_dir  # 子文件名
        self.path = os.path.join(self.root_dir, self.label_dir)  # 合并,即具体位置
        self.img_path = os.listdir(self.path)  # 转换成列表的形式

这边的self其实就类似于C++的指针,是为了将它变成全局变量

来看第二个获取图片

    def __getitem__(self, idx):  # 获取列表中每一个图片
        img_name = self.img_path[idx]  # idx表示下标,即对应位置
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 每一个图片的位置
        img = Image.open(img_item_path)  # 调用方法,拿到该图像
        label = self.label_dir  # 标签
        return img, label  # 返回img 图片 label 标签

这边返回的值有两个,一个是img图片的信息,一个是label标签,在下面调用的时候会用到

下面是第三部分获取长度,这部分比较简单,就是一个len()

    def __len__(self):  # 返回长度
        return len(self.img_path)

下面看完整的类以及调用测试

# -*- coding = utf-8 -*-
from torch.utils.data import Dataset
import cv2
from PIL import Image  # 图像处理的库
import os
'''
img_path = r'D://情绪图片测试/path1.jpg'  # 图片路径
img = Image.open(img_path)  # 调用方法,打开该图像
print(img.size)  # 输出该图像的尺寸
img.show()  # 显示数据集
'''
'''
dir_path = r'D://情绪图片测试5张'  # 目的地址
img_path_list = os.listdir(dir_path)  # 运用os库将目的地址里面的图片放在一个列表里面
print(img_path_list)  # 测试 输出list【0】
'''
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):  # 初始化类,为class提供全局变量
        self.root_dir = root_dir  # 根文件位置
        self.label_dir = label_dir  # 子文件名
        self.path = os.path.join(self.root_dir, self.label_dir)  # 合并,即具体位置
        self.img_path = os.listdir(self.path)  # 转换成列表的形式

    def __getitem__(self, idx):  # 获取列表中每一个图片
        img_name = self.img_path[idx]  # idx表示下标,即对应位置
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)  # 每一个图片的位置
        img = Image.open(img_item_path)  # 调用方法,拿到该图像
        label = self.label_dir  # 标签
        return img, label  # 返回img 图片 label 标签

    def __len__(self):  # 返回长度
        return len(self.img_path)

root_dir = 'D://情绪图片测试5张'  # 根目录
happy_label_dir = '开心'  # 子目录
happy_dataset = MyData(root_dir, happy_label_dir)  # 开心数据集创建完成
img, label = happy_dataset[2]  # 由上面可知返回的是两个值
print(label)  # 分别调用
img.show()

各行代码的意思解释以及给大家注释了,还有不明白的地方欢迎留言私信哦

这里的D://情绪图片5张是我自己创建的文件夹,开心是根文件下的子文件,数据集(就是里面的图片)是通过爬虫获取的,关于爬虫的相关知识可以看我的第一篇文章哦~

来看结果

 二、总结

这次主要是学习了一些图像操作的库以及数据集的简单创建,为日后训练做准备,初次学习,欢迎大家批评指正哦~~

猜你喜欢

转载自blog.csdn.net/m0_60964321/article/details/122288600