ResNet网络结构如下:
采用模型和数据分离的代码方式,模型如下:
1 # encoding: utf-8 2 import tensorflow as tf 3 from tensorflow.keras import layers, Sequential, Model 4 5 6 class BasicBlock(layers.Layer): 7 # 残差模块 8 def __init__(self, filter_num, stride=1): 9 super(BasicBlock, self).__init__() 10 # 第一个卷积 11 self.conv1 = layers.Conv2D(filter_num, (3, 3), strides=stride, padding='same') 12 self.bn1 = layers.BatchNormalization() 13 self.relu = layers.Activation('relu') 14 # 第二个卷积 15 self.conv2 = layers.Conv2D(filter_num, (3, 3), strides=1, padding='same') 16 self.bn2 = layers.BatchNormalization() 17 18 if stride != 1: # 通过1x1卷积完成shape匹配 19 self.downsample = Sequential() 20 self.downsample.add(layers.Conv2D(filter_num, (1, 1), strides=stride)) 21 else: # 匹配直接连接 22 self.downsample = lambda x: x 23 24 # 前向运算 25 def call(self, inputs, training=None): 26 # 第一个卷积 27 out = self.conv1(inputs) 28 out = self.bn1(out,training=training) 29 out = self.relu(out) 30 # 第二个卷积 31 out = self.conv2(out) 32 out = self.bn2(out,training=training) 33 identity = self.downsample(inputs) 34 # 两条输出路径相加 35 output = layers.add([out, identity]) 36 output = tf.nn.relu(output) 37 38 return output 39 40 41 class ResNet(Model): 42 # 通用的 ResNet 实现类 43 def __init__(self, layer_dims, num_classes=100): # layer_dims = [2, 2, 2, 2] 44 super(ResNet, self).__init__() # 父类初始化 45 # 根网络, 预处理 46 self.stem = Sequential([layers.Conv2D(64, (3, 3), strides=(1, 1)), 47 layers.BatchNormalization(), 48 layers.Activation('relu'), 49 layers.MaxPool2D(pool_size=(2, 2), strides=(1, 1), padding='same')]) 50 # 堆叠4 个Block,每个block 包含了多个BasicBlock,设置步长不一样 51 self.layer1 = self.build_resblock(64, layer_dims[0]) 52 self.layer2 = self.build_resblock(128, layer_dims[1], stride=2) 53 self.layer3 = self.build_resblock(256, layer_dims[2], stride=2) 54 self.layer4 = self.build_resblock(512, layer_dims[3], stride=2) 55 # 通过Pooling 层将高宽降低为1x1 56 self.avgpool = layers.GlobalAveragePooling2D() 57 # 最后连接一个全连接层分类 58 self.fc = layers.Dense(num_classes) 59 60 def call(self, inputs, training=None): # inputs = Tensor("Placeholder:0", shape=(None, 32, 32, 3), dtype=float32) 61 # 前向计算函数:通过根网络 62 x = self.stem(inputs, training=training) # Tensor("sequential/max_pooling2d/MaxPool:0", shape=(None, 30, 30, 64), dtype=float32) 63 # 一次通过4 个模块 64 x = self.layer1(x, training=training) # Tensor("sequential_1/basic_block_1/Relu:0", shape=(None, 30, 30, 64), dtype=float32) 65 x = self.layer2(x, training=training) # Tensor("sequential_2/basic_block_3/Relu:0", shape=(None, 15, 15, 128), dtype=float32) 66 x = self.layer3(x, training=training) # Tensor("sequential_4/basic_block_5/Relu:0", shape=(None, 8, 8, 256), dtype=float32) 67 x = self.layer4(x, training=training) # Tensor("sequential_6/basic_block_7/Relu:0", shape=(None, 4, 4, 512), dtype=float32) 68 # 通过池化层 69 x = self.avgpool(x) # Tensor("global_average_pooling2d/Mean:0", shape=(None, 512), dtype=float32) 70 # 通过全连接层 71 x = self.fc(x) # Tensor("dense/BiasAdd:0", shape=(None, 100), dtype=float32) 72 73 return x 74 75 # 实现高层特征的提取 一次完成多个残差模块的新建 76 def build_resblock(self, filter_num, blocks, stride=1): # filter_num = 64, blocks = 2 77 # 辅助函数,堆叠filter_num个BasicBlock 78 res_blocks = Sequential() 79 res_blocks.add(BasicBlock(filter_num, stride)) 80 81 for _ in range(1, blocks): 82 res_blocks.add(BasicBlock(filter_num, stride=1)) 83 84 return res_blocks 85 86 87 def resnet18(): 88 # 通过调整模块内部BasicBlock 的数量和配置实现不同的ResNet 89 return ResNet([2, 2, 2, 2]) 90 91 92 def resnet34(): 93 return ResNet([3, 4, 6, 3])
训练代码:
1 # encoding: utf-8 2 import tensorflow as tf 3 from tensorflow.keras import optimizers, datasets 4 import matplotlib.pyplot as plt 5 from exam_resnet import resnet18 6 7 (x, y), (x_test, y_test) = datasets.cifar100.load_data() 8 y = tf.squeeze(y, axis=1) 9 y_test = tf.squeeze(y_test, axis=1) 10 print(x.shape, y.shape, x_test.shape, y_test.shape) 11 # (50000, 32, 32, 3) (50000,) (10000, 32, 32, 3) (10000,) 12 13 14 def preprocess(x, y): 15 x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1 16 y = tf.cast(y, dtype=tf.int32) 17 return x, y 18 19 20 train_db = tf.data.Dataset.from_tensor_slices((x, y)) 21 train_db = train_db.shuffle(1000).map(preprocess).batch(512) 22 test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test)) 23 test_db = test_db.shuffle(1000).map(preprocess).batch(512) 24 sample = next(iter(train_db)) 25 print('sample:', sample[0].shape, sample[1].shape, 26 tf.reduce_min(sample[0]), tf.reduce_max(sample[0])) 27 # (512, 32, 32, 3) (512,) tf.Tensor(-1.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32) 28 29 # 保存训练和测试过程中的误差情况 30 train_tot_loss = [] 31 test_tot_loss = [] 32 Epoch = 50 33 34 # [b, 32, 32, 3] => [b, 1, 1, 512] 35 model = resnet18() 36 model.build(input_shape=(None, 32, 32, 3)) 37 model.summary() 38 optimizer = optimizers.Adam(lr=1e-3) 39 40 41 def main(): 42 for epoch in range(Epoch): 43 cor, tot = 0, 0 44 for step, (x, y) in enumerate(train_db): 45 with tf.GradientTape() as tape: 46 logits = model(x, training=True) 47 y_onehot = tf.one_hot(y, depth=100) 48 loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True) 49 loss = tf.reduce_mean(loss) 50 51 cor += float(tf.reduce_sum(tf.cast(tf.equal(y_onehot, logits), dtype=tf.float32))) 52 tot += x.shape[0] 53 54 grads = tape.gradient(loss, model.trainable_variables) 55 optimizer.apply_gradients(zip(grads, model.trainable_variables)) 56 57 print('After %d Epoch' % epoch) 58 print('training acc is ', cor / tot) 59 train_tot_loss.append(cor / tot) 60 61 correct, total = 0, 0 62 for x, y in test_db: 63 pred = model(x, training=False) 64 pred = tf.nn.softmax(pred, axis=1) 65 pred = tf.argmax(pred, axis=1) 66 pred = tf.cast(pred, dtype=tf.int32) 67 68 correct += float(tf.reduce_sum(tf.cast(tf.equal(y, pred), dtype=tf.int32))) 69 total += x.shape[0] 70 print('testing acc is : ', correct / total) 71 test_tot_loss.append(correct / total) 72 73 74 plt.figure() 75 plt.plot(train_tot_loss, 'b', label='train') 76 plt.plot(test_tot_loss, 'r', label='test') 77 plt.xlabel('Epoch') 78 plt.ylabel('ACC') 79 plt.legend() 80 plt.savefig('exam8.4_train_test_ResNet.png') 81 plt.show() 82 83 if __name__ == '__main__': 84 main()
程序调试成功,没有训练,测试数据,
数据量太大,目前的机器不行,待有合适的时机再做预测。
下次更新:RNN网络实战IMDB数据集