[Papers understanding] CapsuleNet

CapsuleNet

Foreword

I find a lot of information, and finally they get to know the whole process, in fact, this operation is not difficult to understand, difficult for me is how to achieve with the code, find some code on github point of view, both for me a little lengthy, variable distribution too far led to my head explode, so I find the video in B station to see if there is no code to explain, be that hard, right, and finally to achieve a partial solution.

Do not write papers interpret, because the original is too difficult to read, the foreigner's English I basically have to take a look at every check translation, very hard to accept, and online tutorials, parsing very, very much, so I leave the code later look what can be remembered.

Capsule is doing

capsule is a way to change the expression of neurons, each neuron we had a scalar is expressed, we are now in the capsule in the vector to represent a neuron. The advantage of this is multi-dimensional description of a neuron, while in capsue, we mold long vector to express the probability of each other dimensions may characterize the properties of neurons. Such a dimension characterizing feature orientation, wherein when the orientation changes, neurons die length did not change, but the value of the dimension change, which is a well understood.

This part of the online information is simply too much, says only my personal opinion, you can look at someone else's version.

Capsule how to write code

Network configuration diagram of a still have posted

Overall network divided into three layers, a first layer convolution layer maps (3,28,28) to the input (256,20,20), the second layer is referred primary_caps, get 32 ​​points and 8 convolution filter, get the output (32,6,6,8), and then reshape into (1152,1,8) here to back vector in vector out to get ready.

Here is the meaning of the expression has 1152 capsule, each capsule there are an 8-dimensional vector, old interesting.

Digit_caps layer is then back up, we should be the target vector (10,1,16), the input is (1152,1,8), so here we are thinking of is how to get such mapping relationships.

Dynamic routing algorithm, we managed to get v.

Well, the end. Reconstruction of the code I do not write.

Attach total code:

import torch
import torch.nn as nn

from torchsummary import summary

from torch.autograd import Variable
class CapsuleLayer(nn.Module):
    def __init__(self,routing = False):
        super(CapsuleLayer,self).__init__()
        self.routing = routing
        def create_conv(unit_idx):
            conv_unit = nn.Conv2d(256,32,kernel_size = 9,stride = 2)
            self.add_module("conv_unit_{}".format(unit_idx),conv_unit)
            return conv_unit
        self.conv_units = [create_conv(i) for i in range(8)]
        self.w = Variable(torch.randn(1,1152,10,16))
        self.fc = nn.Linear(8,16)
    def forward(self,x):
        if self.routing:
            return self.use_routing(x)
        else:
            return self.no_routing(x)
    @staticmethod
    def squash(x):
        f = torch.sum(x**2,dim =2,keepdim = True) 
        return f / (1 + f) / (x / torch.sqrt(f))
    def use_routing(self,x):# (-1,8,32*6*6)
        x = x.transpose(1,2).view(-1,32*6*6,1,8)
        x = self.fc(x)
        w = torch.cat([self.w] * x.size(0), dim = 0)
        u = w * x # (b,1152,10,8)
        b = Variable(torch.zeros(x.size(0),x.size(1),10,1,1))

        for iter in range(3):
            c = torch.softmax(u,dim = -1)
            s = torch.sum(c,dim = 1,keepdim = True)
            v = self.squash(s).view(-1,1,10,16,1)
            b = b + u.view(x.size(0),1152,10,1,16) @ v.view(x.size(0),1,10,16,1)
        
        return v.view(x.size(0),10,16)
        
    def no_routing(self,x):
        u = [self.conv_units[i](x) for i in range(8)]    
        # every u (-1,32,6,6)
        
        # (-1,8,32,6,6)
        u = torch.stack(u,dim =1)
        u = u.view(-1,8,32*6*6)
        return self.squash(u)
class CapsuleNet(nn.Module):
    def __init__(self):
        super(CapsuleNet,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1,256,kernel_size = 9,stride = 1),
            nn.ReLU()
        )
        self.pri_caps = CapsuleLayer()
        self.digit_caps = CapsuleLayer(routing = True) 
    def forward(self,x):
        x = self.conv(x) # (-1,256,20,20)
        x = self.pri_caps(x)
        x = self.digit_caps(x)
        return x
if __name__ == "__main__":
    x = torch.randn(2,1,28,28)
    net = CapsuleNet()
    y = net(x)
    print(y.size())

Guess you like

Origin www.cnblogs.com/aoru45/p/11669355.html
Recommended