Pytorch个人学习记录总结 01

 目录

函数-dir()、help()

Dataset类


函数-dir()、help()

dir() 函数,打开工具箱(例如PyTorch,进一步打开某一些分隔区)

help() 函数,查看工具包中某一个工具函数的用法

(1) 查看torch工具包有哪些分割区

dir(torch)
# ['AVG', 'AggregationType', 'AnyType', 'Argument', 'ArgumentSpec', 'BFloat16Storage', 'BFloat16Tensor',...]

(2) 查看torch.cuda有哪些分隔区

dir(torch.cuda)
# ['Any', 'BFloat16Storage', 'BFloat16Tensor', 'BoolStorage', 'BoolTensor', 'ByteStorage', ...]

(3) 查看torch.cuda.is_available()有哪些分隔区

dir(torch.cuda.is_available())	# 函数后面的()去掉,效果一样
# ['__abs__', '__add__', '__and__', '__bool__', '__ceil__', '__class__', ...]

此时发现前后都是带有两个下划线的:__这说明是规定好不可更改的,也就说明是torch.cuda.is_available不再是一个分隔区而是一个函数,因此可调用help()来查看该函数的基本作用。

扫描二维码关注公众号,回复: 15866876 查看本文章
help(torch.cuda.is_available) 	# 注意这后面不能跟有()

# 打印结果,该函数会返回一个bool值
# Help on function is_available in module torch.cuda:
# is_available() -> bool
#     Returns a bool indicating if CUDA is currently available.

Dataset类

手写加载数据集的类:MyData。主要是要重写__init__()、__getitem__()、__len()__这3个类


get到一个小技巧,可以直接用+对两个Data类进行拼接(可用于数据集不足时,直接将两个数据集这样加起来一起使用)


new_path = os.path.join(path1,path2,...)将所有路径联合起来,返回一个整合路径(str)


file_name_list = os.listdir(path)读取path路径中的所有文件名称,返回一个名称列表(list)

from torch.utils.data import Dataset
from PIL import Image
import os

# 构造一个子文件夹数据集类MyData
class MyData(Dataset):
    def __init__(self, root_dir, label_dir):    # root_dir是指整个数据集的根目录,label_dir是指具体某一个类的子目录
        # 在init初始化函数中,定义一些类中的全局变量,即跟在self.后的变量们
        self.root_dir = root_dir
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_list = os.listdir(self.path)

    def __getitem__(self, index):   # 传入下标获取元素
        img_name = self.img_list[index]
        img_item_path = os.path.join(self.path, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label[:-6]	# 返回的是一个元组
        # 这里进行了截取,因为我不想要label_dir最后面的'_image'这6个元素

    def __len__(self):
        return len(self.img_list)

# --------------实例化ants_data和bees_data------------- #
root_dir = 'hymenoptera_data/train'
ants_dir = 'ants'
bees_dir = 'bees'
ants_data = Mydata(root_dir, ants_dir)
bees_data = Mydata(root_dir, bees_dir)
# ---------------------------------------------------- #

# -------------返回一个元组,分别赋值给img和label------- #
img, label = ants_data[0]
img.show()
# ----------------------------------------------------- #

# ---因为是元组,所以可用[0]、[1]直接提取出img、label---- #
print(label == ants_data[0][1])		# true
# ----------------------------------------------------- #

# ----------将ants_data和bees_data相加起来使用---------- #
sum = ants_data + bees_data
len_ants = len(ants_data)	# 124
len_bees = len(bees_data)	# 121
len_sum = len(sum)				# 245
print(len_sum == len_ants+len_bees)	# True
print(sum[123][1])			# ants
print(sum[124][1])			# bees


 

猜你喜欢

转载自blog.csdn.net/timberman666/article/details/131841895
今日推荐