Build Vision Transformers (tensorflow version) from scratch

preface

After Vision Transformers (ViT) was proposed by Dosovitskiy et. al. in 2020, it has gradually occupied a dominant position in the field of computer vision, and has achieved good performance in downstream tasks such as image classification, target detection, and semantic segmentation, setting off the transformer series in CV field waves. Here I will introduce how to implement the ViT model step by step based on the tensorflow framework from scratch.

foreword

In the previous article, we wrote about implementing the ViT model based on the pytorch version. If you are interested, click here . At your request, today I will actually implement how to implement my first ViT (using tensorflow version) from scratch. If you are not familiar with natural language processing ( The Transformer model used in NLP) may be a little confused about the application of Transformer in the CV field, and the use of the ViT model on images is unclear. So, don't worry, start here!

define tasks

Because the previous article has some explanations on the ViT model, here we will directly build the framework model. At the same time, the previous article used the MNIST dataset. Here is a bit richer, using the cifar100 dataset. Although the goal is simple, we can classify based on this image The task is to clarify the whole context of the ViT model.
First of all, the environment configured here is:

python==3.7 
tensorflow==2.7.0 
tensorflow_addons==0.16.1

First import some modules that need to be used:

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa  # Community-made Tensorflow code (AdamW optimizer)

The tensorflow_addons module is imported to use the AdamW optimizer [1]. Of course, other optimizers such as Adam can be used here.

Next, let's create the main function, because tensorflow2 has a built-in keras framework, here first load the dataset directly through the keras.datasets module and divide it into train and test datasets, which are used to preprocess the cifar100 dataset, define learning rate, batch_size, epochs, etc. super parameter;

Define the AdamW optimizer through model.compile , define loss, and evaluation index metrics; save the weight file in the training process through keras.callbacks.ModelCheckpoint , execute the training process through model.fit , train 100 epochs, and then calculate accurately on the test set Rate.

def main():
    # Downloading dataset
    num_classes = 100
    input_shape = (32, 32, 3)
    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

    print(f"x_train shape: {
      
      x_train.shape} - y_train shape: {
      
      y_train.shape}")
    print(f"x_test shape: {
      
      x_test.shape} - y_test shape: {
      
      y_test.shape}")

    # Hyper-parameters
    learning_rate = 0.001
    weight_decay = 0.0001
    batch_size = 256
    num_epochs = 100

    def create_vit_classifier():
       #TODO 
       pass

    def run_experiment(model):
        #define optimizer
        optimizer = tfa.optimizers.AdamW(
            learning_rate=learning_rate, weight_decay=weight_decay
        )

        # define optimizer, loss and metrics
        model.compile(
            optimizer=optimizer,
            loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            metrics=[
                keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
                keras.metrics.SparseTopKCategoricalAccuracy(5, name='top-5 accuracy'),
            ],
        )
        # saved model path
        checkpoint_filepath = './tmp/checkpoint'
        # save model when training,and only save the best model by monitor val_accuracy.
        checkpoint_callback = keras.callbacks.ModelCheckpoint(
            checkpoint_filepath,
            monitor="val_accuracy",
            save_best_only=True,
            save_weights_only=True
        )
        # train process
        history = model.fit(
            x=x_train,
            y=y_train,
            batch_size=batch_size,
            epochs=num_epochs,
            validation_split=0.1,
            callbacks=[checkpoint_callback]
        )
        # test process
        model.load_weights(checkpoint_filepath)
        _, accuracy, top5_accuracy = model.evaluate(x_test, y_test)
        print(f"Test Accuracy:       {
      
      accuracy}")
        print(f"Test Top-5 Accuracy: {
      
      top5_accuracy}")
        return history

    model = create_vit_classifier()
    run_experiment(model)

After building the entire training and testing framework, we now focus on the creation of the create_vit_classifier function, that is, the construction of the ViT model. The task of the model is to classify the images of cifar100.

ViT architecture

Since tensorflow and most DL frameworks provide autograd calculations, we only need to inherit the network layer in the ViT network from the layers class of keras and define the optimizer in the training framework. The tensorflow framework will be responsible for backpropagating the gradient and training the model parameters, so that we only need to care about the forward pass process to realize the ViT model. The ViT model has been introduced in the previous article, and the main network diagram is pasted here.
insert image description here

data augmentation

On the CIfar100 dataset, we use the layers module of keras for data enhancement:

  image_size = 72
     # Data augmentation
    data_augmentation = keras.Sequential(
        [
            layers.experimental.preprocessing.Normalization(),
            layers.experimental.preprocessing.Resizing(image_size, image_size),
            layers.experimental.preprocessing.RandomFlip("horizontal"),
            layers.experimental.preprocessing.RandomRotation(factor=0.02),
            layers.experimental.preprocessing.RandomZoom(height_factor=0.2, width_factor=0.2)
        ],
        name='data_augmentation'
    )

    # Normalizing based on training data
    data_augmentation.layers[0].adapt(x_train)

    def create_vit_classifier():
        # Creating classifier
        inputs = layers.Input(shape=input_shape)
        augmented = data_augmentation(inputs)

Patchifying

In ViT, an image is first decomposed into multiple sub-images, and each sub-image is mapped into a vector . After resizing the image to 72x72 size, divide each image into 12x12 blocks, each block size is 6x6 (if the block cannot be completely divisible, the image padding needs to be filled), so that we can get 144 sub-images from a single image. Reshape the original image into:
(N, PxP, H/P x W/P x C) = (N, 12x12, 6x6) = (N, 144, 108)
Note that although each subimage is 3x6x6 in size, the We flatten it to a 108-dimensional vector.
We implement the above functions for the code:

class Patches(layers.Layer):
    """Gets some images and returns the patches for each image"""

    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

def create_vit_classifier():
    # Creating classifier
    inputs = layers.Input(shape=input_shape)
    augmented = data_augmentation(inputs)
    patches = Patches(patch_size)(augmented)

Linear Mapping and Adding Positional Encoding

After getting the flattened patches or vectors, change the dimension through layers.Dense, the linear mapping can be mapped to any vector size, we set it to 64 here, of course you can set it to any dimension. The model dimension becomes (N, 36, 64) by layers.Embeddingadding a learnable positional encoding. Note: There are slight changes here, no classification marks are added, and the final output of the corresponding MLP has also changed. The main reason is to train faster. . If you want to add it, you can add it to the function in the PatchEncoderclass .call

projection_dim = 64
class PatchEncoder(layers.Layer):
    """Adding (learnable) positional encoding to input patches"""

    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

def create_vit_classifier():
    # Creating classifier
    inputs = layers.Input(shape=input_shape)
    augmented = data_augmentation(inputs)
    patches = Patches(patch_size)(augmented)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

Add LN, transformer

After getting the encoding result, first normalize the token , then apply the multi-head attention mechanism , and finally add a residual connection (connect the input before LN and the output after multi-head attention).


 def create_vit_classifier():
        # Creating classifier
        inputs = layers.Input(shape=input_shape)
        augmented = data_augmentation(inputs)
        patches = Patches(patch_size)(augmented)
        encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

        for _ in range(transformer_layers):
            # Layer normalization and self-attention
            x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
            attention_output = layers.MultiHeadAttention(
                num_heads, key_dim=projection_dim, dropout=0.1
            )(x1, x1)

            # Residual conenction
            x2 = layers.Add()([attention_output, encoded_patches])

LN, MLP and Residual Connections

Continue to the following network, pass the current tensor through another LN and MLP, and connect it through the residual, and build it up like this. Compared with the pytorch version, the geluactivation function is used here, and a multi-layer transformer is used.

 def mlp(x, hidden_units, dropout_rate):
    """MLP with dropout and skip-connections"""
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

  def create_vit_classifier():
        # Creating classifier
        inputs = layers.Input(shape=input_shape)
        augmented = data_augmentation(inputs)
        patches = Patches(patch_size)(augmented)
        encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

        for _ in range(transformer_layers):
            # Layer normalization and self-attention
            x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
            attention_output = layers.MultiHeadAttention(
                num_heads, key_dim=projection_dim, dropout=0.1
            )(x1, x1)

            # Residual conenction
            x2 = layers.Add()([attention_output, encoded_patches])

            # Normalization and MLP
            x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
            x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)

            # Residual connection
            encoded_patches = layers.Add()([x3, x2])

LN, MLP, Dense classification

First LN normalizes the output results, and then flattens them. In order to prevent overfitting, after adding dropout, add MLP, because pytorch uses classification marks. After MLP output, the first value is taken as the classification result. Here After improvement, it will be output through Dense later.

  def create_vit_classifier():
        # Creating classifier
        inputs = layers.Input(shape=input_shape)
        augmented = data_augmentation(inputs)
        patches = Patches(patch_size)(augmented)
        encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

        for _ in range(transformer_layers):
            # Layer normalization and self-attention
            x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
            attention_output = layers.MultiHeadAttention(
                num_heads, key_dim=projection_dim, dropout=0.1
            )(x1, x1)

            # Residual conenction
            x2 = layers.Add()([attention_output, encoded_patches])

            # Normalization and MLP
            x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
            x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)

            # Residual connection
            encoded_patches = layers.Add()([x3, x2])

        # Create a [batch_size, projection_dim] tensor
        representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        representation = layers.Flatten()(representation)
        representation = layers.Dropout(0.5)(representation)

        # Add MLP
        features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)

        # Classify output
        logits = layers.Dense(num_classes)(features)

        # create whole net 
        model = keras.Model(inputs=inputs, outputs=logits)
        return model

The output of our model is now a (N, 100) tensor. ok, you're done!

Now let's see how our model behaves, running on a cpu:
insert image description here
well, the trends are all right. Just a little slow. Finished flowering.
insert image description here

epilogue

Compared with the previous version of pytorch, the tensorflow version uses the GeLU activation function to stack multiple Transformer encoder blocks together. Subsequent ViT also has improvements in various versions, you can add based on this, pay attention to the official account background and reply "vit_tf" to get the complete code.


论文:https://arxiv.org/abs/2010.11929
参考:
[1]https://blog.csdn.net/u012744245/article/details/112671504?utm_medium=distribute.pc_relevant.none-task-blog-2defaultbaidujs_utm_term~default-9.pc_relevant_aa&spm=1001.2101.3001.4242.6&utm_relevant_index=12
[2] https://zhuanlan.zhihu.com/p/340149804

For more content, please pay attention to the invincible Zhang Dadao

Guess you like

Origin blog.csdn.net/zqwwwm/article/details/124267946