Keras自定义网络

数据集:CIFAR10这是一个加拿大组织制作的

代码过程:

import tensorflow as tf
from tensorflow.keras import datasets,layers,optimizers,Sequential,metrics

from tensorflow import keras
import os

os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

def preprocess(x,y):
    #  #这里我们把它放在0-1之间其实不是最好的,[0~255] => [-1~1],这个范围可能是最适合的范围。
    x = tf.cast(x,dtype=tf.float32)/255.-1
    y = tf.cast(y,dtype=tf.int32)
    return x,y

#一次并行计算128 个样本的数据
batchsz = 128
#[32,32,3]
(x,y),(x_val,y_val) = datasets.cifar10.load_data() #下载数据集

y = tf.squeeze(y)   #tensor中删除所有1维数据,(50000, 1) 变成(50000)
y_val = tf.squeeze(y_val)
y = tf.one_hot(y,depth=10)#tf.one_hot()函数是将y转化为one-hot类型数据输出,维度是10
y_val = tf.one_hot(y_val,depth=10)
print('datasets:',x.shape,y.shape,x.min(),x.max())

#将python列表和numpy数组转换成tensorflow的dataset 只有dataset才能被model.fit函数训练
train_db =tf.data.Dataset.from_tensor_slices((x,y))
#预处理 map(preprocess)都经过标准化处理到[0-1]之间
train_db = train_db.map(preprocess).shuffle(10000).batch(batchsz)
test_db = tf.data.Dataset.from_tensor_slices((x_val,y_val))
test_db = test_db.map(preprocess).batch(batchsz)

sample = next(iter(train_db))
print("batch",sample[0].shape,sample[1].shape)

#自定义的连接层
class MyDense(layers.Layer):
    def __init__(self,inp_dim,outp_dim):
        super(MyDense,self).__init__()
        self.kernel = self.add_variable('w',[inp_dim,outp_dim])
        # self.bias = self.add_variable('b',[outp_dim])

    def call(self,inputs,training=None):
        x = inputs @ self.kernel
        return x

#自定义连接网络
class MyNetWork(keras.Model):
    def __init__(self):
        super(MyNetWork,self).__init__()
        # #5层网络
        # self.fac1=MyDense(32*32*3,256)
        # self.fac2=MyDense(256,128)
        # self.fac3=MyDense(128,64)
        # self.fac4=MyDense(64,32)
        # self.fac5=MyDense(32,10)
        #5层网络 优化网络参数量放大一点
        self.fac1=MyDense(32*32*3,256)
        self.fac2=MyDense(256,256)
        self.fac3=MyDense(256,256)
        self.fac4=MyDense(256,256)
        self.fac5=MyDense(256,10)

    def call(self,inputs,training=None):
        #前向传播逻辑
        x = tf.reshape(inputs,[-1,32*32*3])
        #[b,32*32*3] => [b,256]
        x = self.fac1(x)
        x = tf.nn.relu(x) #添加激活函数relu
        #[b,256] => [b,128]
        x = self.fac2(x)
        x=tf.nn.relu(x)
        # [b,128] => [b,64]
        x = self.fac3(x)
        x = tf.nn.relu(x)
        # [b,64] => [b,32]
        x = self.fac4(x)
        x = tf.nn.relu(x)
        # [b,32] => [b,10]
        x = self.fac5(x)

        return x

#keras进网络装配起来
network =MyNetWork()
#开始编译optimizers优化,loss损失函数,metrics模型评价函数,accuracy准确率
network.compile(optimizer=optimizers.Adam(lr=1e-3),
                loss =tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])

#开始训练模型,训练数据,5轮,验证数据,验证频率
network.fit(train_db,epochs=5,validation_data=test_db,validation_freq=1)

#模型的保存方式
network.evaluate(test_db) #验证测试数据
network.save_weights('ckpt/weights.ckpt') #模型的保存路径
del network #以后删除网络对象
print('saved to ckpt/weightts.ckpt') #保存了单纯的权值

#网络更新一下权值

#keras进网络装配起来
network =MyNetWork()
#开始编译optimizers优化,loss损失函数,metrics模型评价函数,accuracy准确率
network.compile(optimizer=optimizers.Adam(lr=1e-3),
                loss =tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
network.load_weights('ckpt/weights.ckpt')
print('saved to ckpt/weightts.ckpt') #保存了单纯的权值
network.evaluate(test_db) #验证测试数据

 测试结果:

准确率不高,0.45 

猜你喜欢

转载自blog.csdn.net/chehec2010/article/details/127029927