pytorch apply function

 The apply function is implemented in nn.Module, which recursively calls self.children() to process itself and submodules .

This method applies fn recursively to each of the module's children (results of .children() ) and itself. Typical usage is to initialize the parameters of a model .

from torch import nn
import torch
@torch.no_grad()  ##装饰器
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.data.fill_(1.0)
        m.bias.data.fill_(0)


model = nn.Sequential(
    nn.Linear(2, 2),
)
model.apply(init_weights)
print(list(model.parameters()))

Guess you like

Origin blog.csdn.net/qq_40107571/article/details/130467147