preface
After Vision Transformers (ViT) was proposed by Dosovitskiy et. al. in 2020, it has gradually occupied a dominant position in the field of computer vision, and has achieved good performance in downstream tasks such as image classification, target detection, and semantic segmentation, setting off the transformer series in CV field waves. Here I will introduce how to implement the ViT model step by step based on the tensorflow framework from scratch.
foreword
In the previous article, we wrote about implementing the ViT model based on the pytorch version. If you are interested, click here . At your request, today I will actually implement how to implement my first ViT (using tensorflow version) from scratch. If you are not familiar with natural language processing ( The Transformer model used in NLP) may be a little confused about the application of Transformer in the CV field, and the use of the ViT model on images is unclear. So, don't worry, start here!
define tasks
Because the previous article has some explanations on the ViT model, here we will directly build the framework model. At the same time, the previous article used the MNIST dataset. Here is a bit richer, using the cifar100 dataset. Although the goal is simple, we can classify based on this image The task is to clarify the whole context of the ViT model.
First of all, the environment configured here is:
python==3.7
tensorflow==2.7.0
tensorflow_addons==0.16.1
First import some modules that need to be used:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa # Community-made Tensorflow code (AdamW optimizer)
The tensorflow_addons module is imported to use the AdamW optimizer [1]. Of course, other optimizers such as Adam can be used here.
Next, let's create the main function, because tensorflow2 has a built-in keras framework, here first load the dataset directly through the keras.datasets module and divide it into train and test datasets, which are used to preprocess the cifar100 dataset, define learning rate, batch_size, epochs, etc. super parameter;
Define the AdamW optimizer through model.compile , define loss, and evaluation index metrics; save the weight file in the training process through keras.callbacks.ModelCheckpoint , execute the training process through model.fit , train 100 epochs, and then calculate accurately on the test set Rate.
def main():
# Downloading dataset
num_classes = 100
input_shape = (32, 32, 3)
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()
print(f"x_train shape: {
x_train.shape} - y_train shape: {
y_train.shape}")
print(f"x_test shape: {
x_test.shape} - y_test shape: {
y_test.shape}")
# Hyper-parameters
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100
def create_vit_classifier():
#TODO
pass
def run_experiment(model):
#define optimizer
optimizer = tfa.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
)
# define optimizer, loss and metrics
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
keras.metrics.SparseTopKCategoricalAccuracy(5, name='top-5 accuracy'),
],
)
# saved model path
checkpoint_filepath = './tmp/checkpoint'
# save model when training,and only save the best model by monitor val_accuracy.
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True
)
# train process
history = model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=0.1,
callbacks=[checkpoint_callback]
)
# test process
model.load_weights(checkpoint_filepath)
_, accuracy, top5_accuracy = model.evaluate(x_test, y_test)
print(f"Test Accuracy: {
accuracy}")
print(f"Test Top-5 Accuracy: {
top5_accuracy}")
return history
model = create_vit_classifier()
run_experiment(model)
After building the entire training and testing framework, we now focus on the creation of the create_vit_classifier function, that is, the construction of the ViT model. The task of the model is to classify the images of cifar100.
ViT architecture
Since tensorflow and most DL frameworks provide autograd calculations, we only need to inherit the network layer in the ViT network from the layers class of keras and define the optimizer in the training framework. The tensorflow framework will be responsible for backpropagating the gradient and training the model parameters, so that we only need to care about the forward pass process to realize the ViT model. The ViT model has been introduced in the previous article, and the main network diagram is pasted here.
data augmentation
On the CIfar100 dataset, we use the layers module of keras for data enhancement:
image_size = 72
# Data augmentation
data_augmentation = keras.Sequential(
[
layers.experimental.preprocessing.Normalization(),
layers.experimental.preprocessing.Resizing(image_size, image_size),
layers.experimental.preprocessing.RandomFlip("horizontal"),
layers.experimental.preprocessing.RandomRotation(factor=0.02),
layers.experimental.preprocessing.RandomZoom(height_factor=0.2, width_factor=0.2)
],
name='data_augmentation'
)
# Normalizing based on training data
data_augmentation.layers[0].adapt(x_train)
def create_vit_classifier():
# Creating classifier
inputs = layers.Input(shape=input_shape)
augmented = data_augmentation(inputs)
Patchifying
In ViT, an image is first decomposed into multiple sub-images, and each sub-image is mapped into a vector . After resizing the image to 72x72 size, divide each image into 12x12 blocks, each block size is 6x6 (if the block cannot be completely divisible, the image padding needs to be filled), so that we can get 144 sub-images from a single image. Reshape the original image into:
(N, PxP, H/P x W/P x C) = (N, 12x12, 6x6) = (N, 144, 108)
Note that although each subimage is 3x6x6 in size, the We flatten it to a 108-dimensional vector.
We implement the above functions for the code:
class Patches(layers.Layer):
"""Gets some images and returns the patches for each image"""
def __init__(self, patch_size):
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
batch_size = tf.shape(images)[0]
patches = tf.image.extract_patches(
images=images,
sizes=[1, self.patch_size, self.patch_size, 1],
strides=[1, self.patch_size, self.patch_size, 1],
rates=[1, 1, 1, 1],
padding="VALID",
)
patch_dims = patches.shape[-1]
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
def create_vit_classifier():
# Creating classifier
inputs = layers.Input(shape=input_shape)
augmented = data_augmentation(inputs)
patches = Patches(patch_size)(augmented)
Linear Mapping and Adding Positional Encoding
After getting the flattened patches or vectors, change the dimension through layers.Dense, the linear mapping can be mapped to any vector size, we set it to 64 here, of course you can set it to any dimension. The model dimension becomes (N, 36, 64) by layers.Embedding
adding a learnable positional encoding. Note: There are slight changes here, no classification marks are added, and the final output of the corresponding MLP has also changed. The main reason is to train faster. . If you want to add it, you can add it to the function in the PatchEncoder
class .call
projection_dim = 64
class PatchEncoder(layers.Layer):
"""Adding (learnable) positional encoding to input patches"""
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
self.projection = layers.Dense(units=projection_dim)
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
positions = tf.range(start=0, limit=self.num_patches, delta=1)
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
def create_vit_classifier():
# Creating classifier
inputs = layers.Input(shape=input_shape)
augmented = data_augmentation(inputs)
patches = Patches(patch_size)(augmented)
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
Add LN, transformer
After getting the encoding result, first normalize the token , then apply the multi-head attention mechanism , and finally add a residual connection (connect the input before LN and the output after multi-head attention).
def create_vit_classifier():
# Creating classifier
inputs = layers.Input(shape=input_shape)
augmented = data_augmentation(inputs)
patches = Patches(patch_size)(augmented)
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization and self-attention
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
attention_output = layers.MultiHeadAttention(
num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Residual conenction
x2 = layers.Add()([attention_output, encoded_patches])
LN, MLP and Residual Connections
Continue to the following network, pass the current tensor through another LN and MLP, and connect it through the residual, and build it up like this. Compared with the pytorch version, the gelu
activation function is used here, and a multi-layer transformer is used.
def mlp(x, hidden_units, dropout_rate):
"""MLP with dropout and skip-connections"""
for units in hidden_units:
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
def create_vit_classifier():
# Creating classifier
inputs = layers.Input(shape=input_shape)
augmented = data_augmentation(inputs)
patches = Patches(patch_size)(augmented)
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization and self-attention
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
attention_output = layers.MultiHeadAttention(
num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Residual conenction
x2 = layers.Add()([attention_output, encoded_patches])
# Normalization and MLP
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Residual connection
encoded_patches = layers.Add()([x3, x2])
LN, MLP, Dense classification
First LN normalizes the output results, and then flattens them. In order to prevent overfitting, after adding dropout, add MLP, because pytorch uses classification marks. After MLP output, the first value is taken as the classification result. Here After improvement, it will be output through Dense later.
def create_vit_classifier():
# Creating classifier
inputs = layers.Input(shape=input_shape)
augmented = data_augmentation(inputs)
patches = Patches(patch_size)(augmented)
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
for _ in range(transformer_layers):
# Layer normalization and self-attention
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
attention_output = layers.MultiHeadAttention(
num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Residual conenction
x2 = layers.Add()([attention_output, encoded_patches])
# Normalization and MLP
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Residual connection
encoded_patches = layers.Add()([x3, x2])
# Create a [batch_size, projection_dim] tensor
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
# Add MLP
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
# Classify output
logits = layers.Dense(num_classes)(features)
# create whole net
model = keras.Model(inputs=inputs, outputs=logits)
return model
The output of our model is now a (N, 100) tensor. ok, you're done!
Now let's see how our model behaves, running on a cpu:
well, the trends are all right. Just a little slow. Finished flowering.
epilogue
Compared with the previous version of pytorch, the tensorflow version uses the GeLU activation function to stack multiple Transformer encoder blocks together. Subsequent ViT also has improvements in various versions, you can add based on this, pay attention to the official account background and reply "vit_tf" to get the complete code.
论文:https://arxiv.org/abs/2010.11929
参考:
[1]https://blog.csdn.net/u012744245/article/details/112671504?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_utm_term~default-9.pc_relevant_aa&spm=1001.2101.3001.4242.6&utm_relevant_index=12
[2] https://zhuanlan.zhihu.com/p/340149804