tensorflow2.0训练模型进行预测

#数据准备
import tensorflow as tf
import random
import pathlib
import numpy as np
data_path = pathlib.Path('c:/users/hb/.keras/datasets/G')
all_image_paths = list(data_path.glob('*/*'))  
all_image_paths = [str(path) for path in all_image_paths]  # 所有图片路径的列表
random.shuffle(all_image_paths)  # 打散

image_count = len(all_image_paths)

label_names = sorted(item.name for item in data_path.glob('*/') if item.is_dir())
label_to_index = dict((name, index) for index, name in enumerate(label_names))
all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]
ds = tf.data.Dataset.from_tensor_slices((all_image_paths, all_image_labels))

def load_and_preprocess_from_path_label(path, label):
    image = tf.io.read_file(path)  # 读取图片
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, [60, 60])  # 原始图片大小为(266, 320, 3),重设为(192, 192)
    image /= 255.0  # 归一化到[0,1]范围
    return image, label

image_label_ds  = ds.map(load_and_preprocess_from_path_label)

train_image = []
train_label = []
for image, label in zip(all_image_paths, all_image_labels):
    r_image,r_label= load_and_preprocess_from_path_label(image, label)
    train_image.append(r_image)
    train_label.append(r_label)  
train_images = np.array(train_image)
train_labels = np.array(train_label)

from tensorflow import keras
from tensorflow.keras import layers

#模型设置与训练
model = keras.Sequential(
[
    layers.Flatten(input_shape=[60, 60, 3]),
    layers.Dense(128, activation='relu'),
    layers.Dense(10, activation='softmax')#这个地方有篇博客中说是二分类的话,用一个神经元!,而且要用sigmoid方式,不能用softmax方式
])

model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])
model.fit(train_images, train_labels, epochs=500)

#模型保存
model.save('circle_model',save_format='tf')

#使用模型进行预测
a1,b1 = load_and_preprocess_from_path_label("c:/users/hb/desktop/test1/nbbbb/nbya/49.png",1)
it = np.array(list(a1))
itt = it.reshape(1,60,60,3)
model.predict(itt)
发布了40 篇原创文章 · 获赞 16 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/Black_Friend/article/details/104926619