PyTorch in Batch Normalization

Pytorch in BatchNorm the API are:

1 torch.nn.BatchNorm1d(num_features,
2 
3 eps=1e-05,
4 
5 momentum=0.1,
6 
7 affine=True,
8 
9 track_running_stats=True)

The models are generally pytorch inheritance nn.Moduleclass has a property that trainningspecifies whether the training state, training state or not some of the layers will affect whether the parameters are fixed, such as BN layer or a layer Dropout. Usually model.train()specify the current model modelfor the training of state, model.eval()specify the current model for testing state.

Meanwhile, BN's API has several parameters need to be more concerned about, one is affinedesignated whether affine, there is a track_running_statsSpecifies whether to trace the statistical properties of the current batch. Prone to problems also happens to these three parameters: trainning, affine, track_running_stats.

  • Which affinespecifies whether affine, i.e. whether the fourth formula above, if affine=Falsethe gamma] = . 1 , beta] = 0 gamma] =. 1, 0 = beta] \ Gamma =. 1, \ = 0 Beta gamma] = . 1 , beta] = 0, and learning can not be updated. Usually set to affine=True[10]
  • trainningAnd track_running_stats, track_running_stats=Trueit represents the statistical properties of batch tracking of the entire training process, get the variance and mean, and not just rely solely on statistical properties of the current batch input. Conversely, if track_running_stats=Falseyou just calculate the statistical properties of the current batch input in the mean and variance. When reasoning stage, if track_running_stats=Falseat this time if batch_sizerelatively small, its statistical properties and global statistical properties will have a large deviation, it may lead to poor results.

In general, trainningand track_running_statsthere are four combinations [7]

  1. trainning=True, track_running_stats=True. This is the training phase the desired settings, then BN will track the statistical properties of the entire training process batch.
  2. trainning=True, track_running_stats=False. At this time, BN will calculate the statistical properties of the current batch of training input, possibly not well describe the statistical properties of the whole situation.
  3. trainning=False, track_running_stats=True. This is a test phase expectation is set, this time before the BN will use the trained model (under the assumption that has been saved) running_meanand running_varand they will not be updated . In general, only we need to set model.eval()which modelcontains BN layer, to achieve this function. [6,8]
  4. trainning=False, track_running_stats=FalseThe same effect (2), but is located in a test state, this is generally not used, this is just a statistical properties of the test batch input, likely to cause offset the statistical properties, leading to poor results.

At the same time, we have to note , BN layer running_meanand running_varupdates are forward()performed operation, rather than optimizer.step()carried out, so if you are in training status, even if you do not manually step(), the statistical properties of BN will change. Such as

. 1 model.train () # is in the training state 
2  
. 3  
. 4  for Data, label in self.dataloader:
 . 5  
. 6 Pred = model (Data)
 . 7  
. 8  # where BN is updated model of the statistical parameters, running_mean, running_var 
. 9  
10 Loss = self.loss (Pred, label)
 . 11  
12 is  # even do the following three lines of code, the statistical properties of the parameter will change BN 
13 is  
14  opt.zero_grad ()
 15  
16  loss.backward ()
 . 17  
18 is opt.step ( )

This time I want to model.eval()go to the testing phase, in order to secure the running_meanand running_var. Sometimes if it is the first pre-training model and load model, re-run the test when the results are different, there is little loss of performance, this time in all likelihood trainningand track_running_statsdoes not set here need more attention. [8]

Reference

[1] with the stepped pit pytorch

[2] the Ioffe S, Szegedy C. Batch Normalization:. Accelerating Deep Internal Network Training by a covariate Reducing Shift [C] // Conference International's International's Conference ON ON Machine Learning. JMLR.org , 2015: 448-456

. [3] <deep learning optimization strategy -1> Batch Normalization (BN)

[. 4]. Detailed depth learning Normalization, BN / LN / WN

[. 5]. https://github.com /pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24

[. 6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of- -IF-Gradients-batchnorm are-Accumulated / 18870

[7]. BatchNorm2d increase parameters track_running_stats how to understand?

[. 8]. Why Not track_running_stats IS SET to False During the eval

[. 9].How to train with frozen BatchNorm?

[10]. Proper way of fixing batchnorm layers during training

[11]. 大白话《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》

Guess you like

Origin www.cnblogs.com/hizhaolei/p/11303506.html