读取数据集代码注释

文章目录


注意,看看对应行的注释就行,并没有完善

1、

import os
import random
rng_seed=620
data_dir = r"D:\deepshare\cat_dog_100"
split_n = 0.9
mode = "train"
img_names = os.listdir(data_dir) # img_names:["cat.40.jpg","cat.47.jpg","cat.93.jpg"]
img_names = list(filter(lambda x: x.endswith('.jpg'), img_names)) #

random.seed(rng_seed)
random.shuffle(img_names)

img_labels = [0 if n.startswith('cat') else 1 for n in img_names] # [0,0,0]

split_idx = int(len(img_labels) * split_n) # 25000* 0.9 = 22500
# split_idx = int(100 * split_n)
if mode == "train":
    img_set = img_names[:split_idx]     # 数据集90%训练
    # img_set = img_names[:22500]     #  字符串切片。列表也可以切片??呵呵
    label_set = img_labels[:split_idx]
elif mode == "valid":
    img_set = img_names[split_idx:]
    label_set = img_labels[split_idx:]
else:
    raise Exception("mode 无法识别,仅支持(train, valid)")

path_img_set = [os.path.join(data_dir, n) for n in img_set]  # ["D:\\..\\..\\cat.40.jpg","D:\\..\\..\\cat.41.jpg"]将图片名和其在本地的路径拼接起来作为列表中的一个元素
data_info = [(n, l) for n, l in zip(path_img_set, label_set)] # [("D:\\..\\..\\cat.40.jpg",0),("D:\\..\\..\\cat.41.jpg",0)]
# 将可迭代的对象作为参数,将对象中对应的元素打包成一个个元组,然后返回由这些元组组成的列表。

2

import os
data_dir = r"D:\dataset\myGTSRB"
tsr_label={"0": 0, "1": 1, "2": 2, "3": 3,"4": 4, "5": 5, "6": 6, "7": 7,"8": 8, "9": 9, "10": 10,"11": 11,"12": 12, "13": 13, "14": 14, "15": 15,"16": 16, "17": 17, "18": 18,
         "19": 19,"20": 20, "21": 21, "22": 22, "23": 23,"24": 24, "25": 25, "26": 26, "27": 27, "28": 28, "29": 29, "30": 30, "31": 31,"32": 32, "33": 33, "34": 34, "35": 35,
         "36": 36, "37": 37, "38": 38, "39": 39,"40": 40, "41": 41, "42": 42
         }
data_info = list()
for root, dirs, _ in os.walk(data_dir): 
# root:"D:\\dataset\\myGTSRB"
# dirs:["0","1"....]
# _:空
    # 遍历类别
    for sub_dir in dirs: # sub_dir:“0”
        img_names = os.listdir(os.path.join(root, sub_dir)) #os.listdir:用于返回指定的文件夹包含的文件或文件夹的名字的列表
        #  img_names :图片名列表
        # img_names = list(filter(lambda x: x.endswith('.jpg'), img_names))
        # 遍历图片
        for i in range(len(img_names)):
            img_name = img_names[i]
            path_img = os.path.join(root, sub_dir, img_name)
            # path_img:图片路径的列表
            label = tsr_label[sub_dir]
            data_info.append((path_img, int(label)))

照猫画虎,把每个遍历的数据类型、输出都打印出来。

猜你喜欢

转载自blog.csdn.net/weixin_42630613/article/details/107839968