How to store the gradients of Alexnet as numpy array (in each iteration) in Python?

God_Help :

I want to store the final gradient vector of a model as a numpy array. Is there an easy and intuitive way to do that using Tensorflow?

I want to store the gradient vectors of Alexnet (in a numpy array) for each iteration,, until convergence.

Rohit :

Below is the model that resembles Alexnet architecture and capturing gradient for every epoch.

# (1) Importing dependency
import keras
from keras import backend as K
from keras.models import Sequential
from keras.layers import Dense, Activation, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.layers.normalization import BatchNormalization
import numpy as np
np.random.seed(1000)

# (2) Get Data
import tflearn.datasets.oxflower17 as oxflower17
x, y = oxflower17.load_data(one_hot=True)

# (3) Create a sequential model
model = Sequential()

# 1st Convolutional Layer
model.add(Conv2D(filters=96, input_shape=(224,224,3), kernel_size=(11,11), strides=(4,4), padding='valid'))
model.add(Activation('relu'))
# Pooling 
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
# Batch Normalisation before passing it to the next layer
model.add(BatchNormalization())

# 2nd Convolutional Layer
model.add(Conv2D(filters=256, kernel_size=(11,11), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Pooling
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
# Batch Normalisation
model.add(BatchNormalization())

# 3rd Convolutional Layer
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Batch Normalisation
model.add(BatchNormalization())

# 4th Convolutional Layer
model.add(Conv2D(filters=384, kernel_size=(3,3), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Batch Normalisation
model.add(BatchNormalization())

# 5th Convolutional Layer
model.add(Conv2D(filters=256, kernel_size=(3,3), strides=(1,1), padding='valid'))
model.add(Activation('relu'))
# Pooling
model.add(MaxPooling2D(pool_size=(2,2), strides=(2,2), padding='valid'))
# Batch Normalisation
model.add(BatchNormalization())

# Passing it to a dense layer
model.add(Flatten())
# 1st Dense Layer
model.add(Dense(4096, input_shape=(224*224*3,)))
model.add(Activation('relu'))
# Add Dropout to prevent overfitting
model.add(Dropout(0.4))
# Batch Normalisation
model.add(BatchNormalization())

# 2nd Dense Layer
model.add(Dense(4096))
model.add(Activation('relu'))
# Add Dropout
model.add(Dropout(0.4))
# Batch Normalisation
model.add(BatchNormalization())

# 3rd Dense Layer
model.add(Dense(1000))
model.add(Activation('relu'))
# Add Dropout
model.add(Dropout(0.4))
# Batch Normalisation
model.add(BatchNormalization())

# Output Layer
model.add(Dense(17))
model.add(Activation('softmax'))

model.summary()

# (4) Compile 
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# (5) Define Gradient Function
def get_gradient_func(model):
    grads = K.gradients(model.total_loss, model.trainable_weights)
    inputs = model.model._feed_inputs + model.model._feed_targets + model.model._feed_sample_weights
    func = K.function(inputs, grads)
    return func

# (6) Train the model such that gradients are captured for every epoch
epoch_gradient = []
for epoch in range(1,5):
    model.fit(x, y, batch_size=64, epochs= epoch, initial_epoch = (epoch-1), verbose=1, validation_split=0.2, shuffle=True)
    get_gradient = get_gradient_func(model)
    grads = get_gradient([x, y, np.ones(len(y))])
    epoch_gradient.append(grads)

# (7) Convert to a 2 dimensiaonal array of (epoch, gradients) type
gradient = np.asarray(epoch_gradient)
print("Total number of epochs run:", epoch)
print("Gradient Array has the shape:",gradient.shape)

Output: gradient is the 2 dimensional array that has gradient captured for every epoch that retains the structure of gradient as per the network layers.

Model: "sequential_34"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_115 (Conv2D)          (None, 54, 54, 96)        34944     
_________________________________________________________________
activation_213 (Activation)  (None, 54, 54, 96)        0         
_________________________________________________________________
max_pooling2d_83 (MaxPooling (None, 27, 27, 96)        0         
_________________________________________________________________
batch_normalization_180 (Bat (None, 27, 27, 96)        384       
_________________________________________________________________
conv2d_116 (Conv2D)          (None, 17, 17, 256)       2973952   
_________________________________________________________________
activation_214 (Activation)  (None, 17, 17, 256)       0         
_________________________________________________________________
max_pooling2d_84 (MaxPooling (None, 8, 8, 256)         0         
_________________________________________________________________
batch_normalization_181 (Bat (None, 8, 8, 256)         1024      
_________________________________________________________________
conv2d_117 (Conv2D)          (None, 6, 6, 384)         885120    
_________________________________________________________________
activation_215 (Activation)  (None, 6, 6, 384)         0         
_________________________________________________________________
batch_normalization_182 (Bat (None, 6, 6, 384)         1536      
_________________________________________________________________
conv2d_118 (Conv2D)          (None, 4, 4, 384)         1327488   
_________________________________________________________________
activation_216 (Activation)  (None, 4, 4, 384)         0         
_________________________________________________________________
batch_normalization_183 (Bat (None, 4, 4, 384)         1536      
_________________________________________________________________
conv2d_119 (Conv2D)          (None, 2, 2, 256)         884992    
_________________________________________________________________
activation_217 (Activation)  (None, 2, 2, 256)         0         
_________________________________________________________________
max_pooling2d_85 (MaxPooling (None, 1, 1, 256)         0         
_________________________________________________________________
batch_normalization_184 (Bat (None, 1, 1, 256)         1024      
_________________________________________________________________
flatten_34 (Flatten)         (None, 256)               0         
_________________________________________________________________
dense_99 (Dense)             (None, 4096)              1052672   
_________________________________________________________________
activation_218 (Activation)  (None, 4096)              0         
_________________________________________________________________
dropout_66 (Dropout)         (None, 4096)              0         
_________________________________________________________________
batch_normalization_185 (Bat (None, 4096)              16384     
_________________________________________________________________
dense_100 (Dense)            (None, 4096)              16781312  
_________________________________________________________________
activation_219 (Activation)  (None, 4096)              0         
_________________________________________________________________
dropout_67 (Dropout)         (None, 4096)              0         
_________________________________________________________________
batch_normalization_186 (Bat (None, 4096)              16384     
_________________________________________________________________
dense_101 (Dense)            (None, 1000)              4097000   
_________________________________________________________________
activation_220 (Activation)  (None, 1000)              0         
_________________________________________________________________
dropout_68 (Dropout)         (None, 1000)              0         
_________________________________________________________________
batch_normalization_187 (Bat (None, 1000)              4000      
_________________________________________________________________
dense_102 (Dense)            (None, 17)                17017     
_________________________________________________________________
activation_221 (Activation)  (None, 17)                0         
=================================================================
Total params: 28,096,769
Trainable params: 28,075,633
Non-trainable params: 21,136
_________________________________________________________________
Train on 1088 samples, validate on 272 samples
Epoch 1/1
1088/1088 [==============================] - 22s 20ms/step - loss: 3.1251 - acc: 0.2178 - val_loss: 13.0005 - val_acc: 0.1140
Train on 1088 samples, validate on 272 samples
Epoch 2/2
 128/1088 [==>...........................] - ETA: 1s - loss: 2.3913 - acc: 0.2656/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py:111: UserWarning: `Sequential.model` is deprecated. `Sequential` is a subclass of `Model`, you can just use your `Sequential` instance directly.
  warnings.warn('`Sequential.model` is deprecated. '
1088/1088 [==============================] - 2s 2ms/step - loss: 2.2318 - acc: 0.3465 - val_loss: 9.6171 - val_acc: 0.1912
Train on 1088 samples, validate on 272 samples
Epoch 3/3
  64/1088 [>.............................] - ETA: 1s - loss: 1.5143 - acc: 0.5000/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py:111: UserWarning: `Sequential.model` is deprecated. `Sequential` is a subclass of `Model`, you can just use your `Sequential` instance directly.
  warnings.warn('`Sequential.model` is deprecated. '
1088/1088 [==============================] - 2s 2ms/step - loss: 1.8109 - acc: 0.4320 - val_loss: 4.3375 - val_acc: 0.3162
Train on 1088 samples, validate on 272 samples
Epoch 4/4
  64/1088 [>.............................] - ETA: 1s - loss: 1.7827 - acc: 0.4688/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py:111: UserWarning: `Sequential.model` is deprecated. `Sequential` is a subclass of `Model`, you can just use your `Sequential` instance directly.
  warnings.warn('`Sequential.model` is deprecated. '
1088/1088 [==============================] - 2s 2ms/step - loss: 1.5861 - acc: 0.4871 - val_loss: 3.4091 - val_acc: 0.3787
Total number of epochs run: 4
Gradient Array has the shape: (4, 34)
/usr/local/lib/python3.6/dist-packages/keras/engine/sequential.py:111: UserWarning: `Sequential.model` is deprecated. `Sequential` is a subclass of `Model`, you can just use your `Sequential` instance directly.
  warnings.warn('`Sequential.model` is deprecated. '

Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=319212&siteId=1