The role of super(__class__, self).__init__() when Pytorch builds a network model

Table of contents

0 Preface

1 Description of the super() method

2 What does it inherit from torch.nn.Module?


0 Preface

In accordance with international practice, I first declare: This article is just my own understanding. Although I have referred to other people's valuable insights, there are many inaccuracies in the content. I hope to criticize and correct and make progress together.

When using the Pytorch framework to define the class of the neural network model, first add a line of super(__class__, self).__init__() under the class __init__() method of the model. For example:

class ClassName(torch.nn.Module):
    def __init__(self):
        super(ClassName, self).__init__()

For all tutorials, this line of code has almost become a "hidden rule". Although I don't quite understand its function, it is defaulted that this line must be added over time.

Therefore, write a separate article to explain its role and deepen your own understanding.

1 Description of the super() method

All Python elementary tutorials will mention the super() method when introducing object-oriented programming-classes, indicating that its function is for class inheritance, but lacks a deeper explanation & understanding. To gain a deeper understanding of how the super() method works, first look at the following code:

class A():

    def __init__(self):
        self.ten = 10

    def hello(self):
        return 'hello world'


class B(A):

    def __init__(self,x):
        # super(B, self).__init__()
        self.x = x

    def multi_ten(self):
        return self.x * self.ten

b = B(8)

print(b.hello())
print(b.multi_ten())
-------------------------------------------------
C:\Users\Lenovo\Desktop\DL\Pytest\Scripts\python.exe C:/Users/Lenovo/Desktop/DL/Pytest/test_main.py
hello world
Traceback (most recent call last):
  File "C:\Users\Lenovo\Desktop\DL\Pytest\test_main.py", line 23, in <module>
    print(b.multi_ten())
  File "C:\Users\Lenovo\Desktop\DL\Pytest\test_main.py", line 18, in multi_ten
    return self.x * self.ten
AttributeError: 'B' object has no attribute 'ten'

Process finished with exit code 1

If you remove super(B, self).__init__(), you can find that the hello() method can still run, that is to say, the super() method is not necessary for class inheritance .

So when must the super() method be used? When it comes to magic methods that run automatically . For example, the multi_ten() method above wants to refer to self.ten in the __init__() method of the parent class A. At this time, the super() method must be used in the B class, indicating that the B class should inherit the A class. __init__() method. Otherwise, it will report an error like the previous code and prompt: There is no ten attribute in class B! (not inherited)

Magic method: A method defined inside Python that runs automatically when the class is instantiated. The naming rules of these methods are __xxxx__(), for example: __init__().

In addition, there is another detail that in the super() method, the content in the brackets can be omitted. For this, you can use F4 to view the definition of the super() method. There is a comment in it:

"super() -> same as super(__class__, <first argument>)"

__class__ is the current class name, <first argument> is self.

The Python interpreter I personally use is Python 3.9. Perhaps in earlier versions of Python, parameters must be filled in the super() method, so early tutorials will be written as super(__class__, self).__init__(), but in the future We don't need any more.

2 What does it inherit from torch.nn.Module?

Let’s start with the simplest linear neural network model code:

import torch

a = torch.tensor([1,2,3,4,5], dtype = torch.float32)

class test(torch.nn.Module):
    def __init__(self):
        # super().__init__()
        self.lin = torch.nn.Linear(5,2)

    def forward(self,x):
        return self.lin(x)

TEST = test()

print(TEST(a))

If the super() method is still removed here, an error will be reported:

AttributeError: cannot assign module before Module.__init__() call

As expected, the magic method __init__() in the parent class torch.nn.Module is not inherited (called).

So what exactly does it define?

You can also find the source code of torch.nn.Module.__init__() through F4:

class Module:

...

    def __init__(self) -> None:
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        """
        Calls super().__setattr__('a', a) instead of the typical self.a = a
        to avoid Module.__setattr__ overhead. Module's __setattr__ has special
        handling for parameters, submodules, and buffers but simply calls into
        super().__setattr__ for all other attributes.
        """
        super().__setattr__('training', True)
        super().__setattr__('_parameters', OrderedDict())
        super().__setattr__('_buffers', OrderedDict())
        super().__setattr__('_non_persistent_buffers_set', set())
        super().__setattr__('_backward_hooks', OrderedDict())
        super().__setattr__('_is_full_backward_hook', None)
        super().__setattr__('_forward_hooks', OrderedDict())
        super().__setattr__('_forward_pre_hooks', OrderedDict())
        super().__setattr__('_state_dict_hooks', OrderedDict())
        super().__setattr__('_load_state_dict_pre_hooks', OrderedDict())
        super().__setattr__('_load_state_dict_post_hooks', OrderedDict())
        super().__setattr__('_modules', OrderedDict())

    forward: Callable[..., Any] = _forward_unimplemented

It has been explained here that the function of torch.nn.Module.__init__() is to Initializes internal Module state (initializes the internal model state). Specifically, it is to initialize training, parameters..._modules, the attributes used internally in Pytorch.

Among them, super().__setattr__() is to call the __setattr__() method of the parent class Object of torch.nn.Module, and its function is similar to "assignment", for example: super().__setattr__('_parameters', OrderedDict( )) acts like self._parameters = OrderedDict(). Then why not just use assignment? Also explained here: Calls super().__setattr__('a', a) instead of the typical self.a = a to avoid Module.__setattr__ overhead. Module's __setattr__ has special handling for parameters, submodules, and buffers but simply calls into super().__setattr__ for all other attributes. It can be understood that __setattr__ has more functions than simple assignment.

Therefore, under the Pytorch framework, all neuron network model subclasses must inherit the initialization process of these internal properties.

Guess you like

Origin blog.csdn.net/m0_49963403/article/details/129573033