convert vgg-face model weight from caffe to pytorch

#coding=UTF-8
import torch 
import torch.nn as nn
import math
import numpy as np
from PIL import Image,ImageDraw
import matplotlib.pyplot as plt
import collections
import matplotlib.cm as cm

from torch.autograd import Variable
from torchvision import models
import caffe


def vis_square(data):
    """Take an array of shape (n, height, width) or (n, height, width, 3)
       and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""

    # normalize data for display
    data = (data - data.min()) / (data.max() - data.min())

    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = (((0, n ** 2 - data.shape[0]),
               (0, 1), (0, 1))                 # add some space between filters
               + ((0, 0),) * (data.ndim - 3))  # don't pad the last dimension (if there is one)
    data = np.pad(data, padding, mode='constant', constant_values=1)  # pad with ones (white)

    # tile the filters into an image
    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])

    plt.imshow(data); plt.axis('off')

def vis_activation(activations, img_arr):
    assert(len(activations.shape) == 3),'make sure : nchannels * height * width '
    n_nodes = activations.shape[0] # number of nodels

    print('n_nodes: ',n_nodes)
    vis_size = activations.shape[1:] #visualization shape
    print('visual_size: ',vis_size)
    vis = np.zeros((vis_size[0], vis_size[1]), dtype=np.float32)

    #generating saliency image
    for i in range(n_nodes):
        activation = activations[i, :, :]
        weight = 1
        weighted_activation = activation*weight
        vis += weighted_activation
    vis = np.maximum(vis, 0) # relu
    vis_img = Image.fromarray(vis, None)
    vis_img = vis_img.resize((224,224),Image.BICUBIC)
    vis_img = vis_img / np.max(vis_img)


    vis_img = Image.fromarray(np.uint8(cm.jet(vis_img) * 255)) # winter ,jet
    vis_img = vis_img.convert('RGB') # dropping alpha channel


    input_image = Image.fromarray(img_arr)
    input_image = input_image.resize((224,224))
    input_image = input_image.convert('RGB')
    plt.imshow(vis_img)
    #plt.show()
    #print vis_img.size, input_i
    heat_map = Image.blend(input_image, vis_img, 0.5)
    plt.imshow(heat_map)
    plt.axis('off')

class vgg16(nn.Module):
    def __init__(self,num_classes=1000):
        super(vgg16,self).__init__() # call  parents' init method'
        inplace = True
        self.features = nn.Sequential(
            nn.Conv2d(3,64,kernel_size=(3,3),stride=(1,1),padding=(1,1)), # input_channel, output_channel
            nn.ReLU(inplace),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
            nn. Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
            nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
            nn.ReLU(inplace),
            nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1), ceil_mode=False)  
            )
        self.classifier = nn.Sequential(
            nn.Linear(in_features=25088, out_features=4096, bias=True),
            nn.ReLU(inplace),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=4096, bias=True),
            nn.ReLU(inplace),
            nn.Dropout(p=0.5),
            nn.Linear(in_features=4096, out_features=num_classes, bias=True)
            )
        self._initialize_weights()
    def forward(self,x):
        x = self.features(x)
        x = x.view(x.size(0),-1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.zero_()

model = vgg16()
# print items 
#print(model)
params = model.state_dict()
print(type(params))
for key,item in params.items():
    print(key,': ', type(item)) # tensor

filepath = '../my_headpose_using_VGG_version2/pretrained_model/vgg_face_caffe/'
pretrained_weight = filepath + 'VGG_FACE.caffemodel'
deploy_pro = filepath + 'VGG_FACE_deploy.prototxt'
net = caffe.Net(deploy_pro,pretrained_weight,caffe.TEST)
print('load caffemodel over ....')

'''
orderdict = collections.OrderedDict()
orderdict['features.0'] = 'conv1_1'
orderdict['features.2'] = 'conv1_2'
orderdict['features.5'] = 'conv2_1'
orderdict['features.7'] = 'conv2_2'
orderdict['features.10'] = 'conv3_1'
orderdict['features.12'] = 'conv3_2'
orderdict['features.14'] = 'conv3_3'
orderdict['features.17'] = 'conv4_1'
orderdict['features.19'] = 'conv4_2'
orderdict['features.21'] = 'conv4_3'
orderdict['features.24'] = 'conv5_1'
orderdict['features.26'] = 'conv5_2'
orderdict['features.28'] = 'conv5_3'
orderdict['classifier.0'] = 'fc6'
orderdict['classifier.3'] = 'fc7'
orderdict['classifier.6'] = 'fc8'
'''

dict_feature = {0:'conv1_1',
                2:'conv1_2',
                5:'conv2_1',
                7:'conv2_2',
                10:'conv3_1',
                12:'conv3_2',
                14:'conv3_3',
                17:'conv4_1',
                19:'conv4_2',
                21:'conv4_3',
                24:'conv5_1',
                26:'conv5_2',
                28:'conv5_3'}

dict_classifier = {0:'fc6',
                   3:'fc7'}
#                   6:'fc8'}


for i in dict_feature:

    caffe_weight = net.params[dict_feature[i]][0].data
    #print(type(model.features[i].weight))
    caffe_weight = torch.from_numpy(caffe_weight).view_as(model.features[i].weight)

    model.features[i].weight.data.copy_(caffe_weight)
    model.features[i].bias.data.copy_(torch.from_numpy(np.array(net.params[dict_feature[i]][1].data)))

for i in dict_classifier:

    model.classifier[i].weight.data.copy_(torch.from_numpy(np.array(net.params[dict_classifier[i]][0].data)))

    model.classifier[i].bias.data.copy_(torch.from_numpy(np.array(net.params[dict_classifier[i]][1].data)))

print('copy weight from caffemodel to pytorch over....')

########## check #####################
imgSize = [224,224]

imgpath = '../copy_caffemodel_2_pytorch/cat.jpg'
img = Image.open(imgpath)
res_img = img.resize((imgSize[0],imgSize[1]))
img = np.double(res_img)
temp_img = np.uint8(res_img) # for vis
img = img[:,:,(2,1,0)] # rgb 2 bgr
img = np.transpose(img, (2,0,1)) # h * w *c==> c*h*w

print(img.shape)

data_arr = np.zeros(shape=(1,3,imgSize[0],imgSize[1]),dtype=np.float32)
data_arr[0,...] = img
input_data = Variable(torch.from_numpy(data_arr).type(torch.FloatTensor))
feat_result  = []
def get_features_hook(self,input,output):
    # number of input:
    print('len(input): ',len(input))
    # number of output:
    print('len(output): ',len(output))
    print('###################################')
    print(input[0].shape) # torch.Size([1, 3, 224, 224])

    print('###################################')
    print(output[0].shape) # 
    feat_result.append(output.data.cpu().numpy())

handle_feat = model.features[0].register_forward_hook(get_features_hook)  # conv1_1  
handle_heat = model.features[30].register_forward_hook(get_features_hook) # pool5
handle_fc7 = model.classifier[3].register_forward_hook(get_features_hook) # fc7

model.eval() # make dropout invalid in test stage
score  = model(input_data)

handle_feat.remove()
handle_heat.remove()

feat1 = feat_result[0]
feat1_heat = feat_result[1]


vis_square(feat1[0,...])
#plt.show()
plt.savefig('feat_visual_pytorch.png')
####################################### pytorch heatmap ####################################
vis_activation(feat1_heat[0,...],temp_img)
print('show heatmap for pytorch...')
#plt.show()
plt.savefig('heatmap_visual_pytorch.png')
############################################  for caffe ############################################
net.blobs['data'].reshape(1,3,imgSize[0],imgSize[1])
net.blobs['data'].data[...] = data_arr
output = net.forward()

feat2 = net.blobs['conv1_1'].data[0,...] 
vis_square(feat2)
#plt.show()
plt.savefig('feat_visual_caffe.png')

############################# caffe heatmap ################################
feat2_heat = net.blobs['pool5'].data[0,...] 
vis_activation(feat2_heat,temp_img)
#plt.show()
plt.savefig('heatmap_visual_caffe.png')

######################### check fc7 layer#######################
fc7_pytorch = feat_result[2]
fc7_caffe = net.blobs['fc7'].data
print(fc7_pytorch.shape)
print(fc7_caffe.shape)

err = np.max(np.abs(fc7_pytorch - fc7_caffe))
print(err) #6.914139e-06

验证结果图:

下图分别是: feat_visual_pytorch.png 和 heatmap_visual_pytorch.png
对于caffe现实的结果相同,不再附图。

这里写图片描述

这里写图片描述

参考文献:
1.https://github.com/marvis/pytorch-caffe

猜你喜欢

转载自blog.csdn.net/xuluhui123/article/details/80172346
今日推荐