Pytorch in BatchNorm the API are:
1 torch.nn.BatchNorm1d(num_features, 2 3 eps=1e-05, 4 5 momentum=0.1, 6 7 affine=True, 8 9 track_running_stats=True)
The models are generally pytorch inheritance nn.Module
class has a property that trainning
specifies whether the training state, training state or not some of the layers will affect whether the parameters are fixed, such as BN layer or a layer Dropout. Usually model.train()
specify the current model model
for the training of state, model.eval()
specify the current model for testing state.
Meanwhile, BN's API has several parameters need to be more concerned about, one is affine
designated whether affine, there is a track_running_stats
Specifies whether to trace the statistical properties of the current batch. Prone to problems also happens to these three parameters: trainning
, affine
, track_running_stats
.
- Which
affine
specifies whether affine, i.e. whether the fourth formula above, ifaffine=False
the gamma] gamma] =. 1, 0 = beta] \ Gamma =. 1, \ = 0 Beta gamma] = . 1 , beta] = 0, and learning can not be updated. Usually set toaffine=True
[10] trainning
Andtrack_running_stats
,track_running_stats=True
it represents the statistical properties of batch tracking of the entire training process, get the variance and mean, and not just rely solely on statistical properties of the current batch input. Conversely, iftrack_running_stats=False
you just calculate the statistical properties of the current batch input in the mean and variance. When reasoning stage, iftrack_running_stats=False
at this time ifbatch_size
relatively small, its statistical properties and global statistical properties will have a large deviation, it may lead to poor results.
In general, trainning
and track_running_stats
there are four combinations [7]
trainning=True
,track_running_stats=True
. This is the training phase the desired settings, then BN will track the statistical properties of the entire training process batch.trainning=True
,track_running_stats=False
. At this time, BN will calculate the statistical properties of the current batch of training input, possibly not well describe the statistical properties of the whole situation.trainning=False
,track_running_stats=True
. This is a test phase expectation is set, this time before the BN will use the trained model (under the assumption that has been saved)running_mean
andrunning_var
and they will not be updated . In general, only we need to setmodel.eval()
whichmodel
contains BN layer, to achieve this function. [6,8]trainning=False
,track_running_stats=False
The same effect (2), but is located in a test state, this is generally not used, this is just a statistical properties of the test batch input, likely to cause offset the statistical properties, leading to poor results.
At the same time, we have to note , BN layer running_mean
and running_var
updates are forward()
performed operation, rather than optimizer.step()
carried out, so if you are in training status, even if you do not manually step()
, the statistical properties of BN will change. Such as
. 1 model.train () # is in the training state 2 . 3 . 4 for Data, label in self.dataloader: . 5 . 6 Pred = model (Data) . 7 . 8 # where BN is updated model of the statistical parameters, running_mean, running_var . 9 10 Loss = self.loss (Pred, label) . 11 12 is # even do the following three lines of code, the statistical properties of the parameter will change BN 13 is 14 opt.zero_grad () 15 16 loss.backward () . 17 18 is opt.step ( )
This time I want to model.eval()
go to the testing phase, in order to secure the running_mean
and running_var
. Sometimes if it is the first pre-training model and load model, re-run the test when the results are different, there is little loss of performance, this time in all likelihood trainning
and track_running_stats
does not set here need more attention. [8]
Reference
[1] with the stepped pit pytorch
[2] the Ioffe S, Szegedy C. Batch Normalization:. Accelerating Deep Internal Network Training by a covariate Reducing Shift [C] // Conference International's International's Conference ON ON Machine Learning. JMLR.org , 2015: 448-456
. [3] <deep learning optimization strategy -1> Batch Normalization (BN)
[. 4]. Detailed depth learning Normalization, BN / LN / WN
[. 5]. https://github.com /pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py#L23-L24
[. 6]. https://discuss.pytorch.org/t/what-is-the-running-mean-of- -IF-Gradients-batchnorm are-Accumulated / 18870
[7]. BatchNorm2d increase parameters track_running_stats how to understand?
[. 8]. Why Not track_running_stats IS SET to False During the eval
[. 9].How to train with frozen BatchNorm?
[10]. Proper way of fixing batchnorm layers during training
[11]. 大白话《Understanding the Disharmony between Dropout and Batch Normalization by Variance Shift》