The apply function is implemented in nn.Module, which recursively calls self.children() to process itself and submodules .
This method applies fn recursively to each of the module's children (results of .children() ) and itself. Typical usage is to initialize the parameters of a model .
from torch import nn
import torch
@torch.no_grad() ##装饰器
def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.data.fill_(1.0)
m.bias.data.fill_(0)
model = nn.Sequential(
nn.Linear(2, 2),
)
model.apply(init_weights)
print(list(model.parameters()))