Tensorflow创建数据集

import os
import glob
import random
import csv
import tensorflow as tf
def load_pokemon(root,mode ='train'):
    name2label = {}
    for name in sorted(os.listdir(os.path.join(root))):
        if not os.path.isdir(os.path.join(root,name)):
            continue
        #因为未指定index,所以key默认为0,1,2...
        name2label[name] =len(name2label.keys())
    images,labels = load_csv(root,'images.csv',name2label)
    if mode=='train':
        images = images[:int(0.6*len(images))]
        labels = labels[:int(0.6*len(labels))]
    elif mode=='val':
        images =images[int(0.6*len(images)):int(0.8*len(images))]
        labels = labels[int(0.6*len(labels)):int(0.8*len(labels))]
    else:
        images = images[int(0.8*len(images)):]
        labels = labels[int(0.8*len(labels)):]
    return images,labels,name2label


def load_csv(root,filename,name2label):
    if not os.path.exists(os.path.join(root,filename)):
        images =[]
        for name in name2label.keys():
            images +=glob.glob(os.path.join(root,name,'*.png'))
            images +=glob.glob(os.path.join(root,name,'*.jpg'))
            images +=glob.glob(os.path.join(root,name,'*.jpeg'))
        print(len(images),images)
        random.shuffle(images)
        with open(os.path.join(root,filename),mode = 'w',newline ='') as f:
            writer =csv.writer(f)
            for img in images:
                #查看类别,bulasaur
                name = img.split(os.sep)[-2]
                label = name2label[name]
                writer.writerow([img,label])
            print('written into csv file:',filename)
    
    images,labels = [],[]
    with open(os.path.join(root,filename)) as f:
        reader = csv.reader(f)
        for row in reader:
            img,label = row
            label = int(label)
            images.append(img)
            labels.append(label)
            
    return images,labels


images,labels,table =load_pokemon('pokeman','train')
print('images:',len(images),images)
print('labels:',len(labels),labels)
print('table:',table)

#根据路径找图片
def preprocess(x,y):
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x,channels = 3)
    x = tf.image.resize(x,[244,244])
    #数据增强
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_crop(x,[244,244,3])
    #转换为张亮
    x = tf.cast(x,dtype = tf.float32)/255.
    x = normalize(x)
    y = tf.convert_to_tensor(y)
    return x,y


img_mean = tf.constant([0.485,0.456,0.406])
img_std = tf.constant([0.229,0.224,0.225])
#标准化
def normalize(x,mean = img_mean,std = img_std):
    x = (x-mean)/std
    return x
def denormalize(x,mean = img_mean,std = img_std):
    x =x*std+mean
    return x

batchsz = 128 
# 创建训练集 Dataset 对象
images, labels, table = load_pokemon('pokeman',mode='train') 
db_train = tf.data.Dataset.from_tensor_slices((images, labels))
db_train = db_train.shuffle(1000).map(preprocess).batch(batchsz) 
# 创建验证集 Dataset 对象 
images2, labels2, table = load_pokemon('pokeman',mode='val')
db_val = tf.data.Dataset.from_tensor_slices((images2, labels2)) 
db_val = db_val.map(preprocess).batch(batchsz) 
# 创建测试集 Dataset 对象 
images3, labels3, table = load_pokemon('pokeman',mode='test') 
db_test = tf.data.Dataset.from_tensor_slices((images3, labels3)) 
db_test = db_test.map(preprocess).batch(batchsz)
发布了101 篇原创文章 · 获赞 46 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/weixin_40539952/article/details/103434368