BatchNorm2d
Normalice por lotes la matriz bidimensional, mean
que es la batch
media actual y la desviación estándar std
actual batch
. El uso de la normalización por lotes puede asignar datos con diferentes rangos de valores al intervalo de la distribución normal estándar, reduciendo la brecha entre los datos, para facilitar la convergencia rápida del modelo. La normalización por lotes esencialmente reduce el error absoluto entre muestras, pero no cambia el error relativo. Por ejemplo, para [1,2,3,4]
la normalización, aunque el tamaño del número cambia, la relación de tamaño entre los números no cambiará. En general, se recomienda seguir una normalización por lotes después del kernel de convolución.
oficial
-
fórmula de normalización
-
全局均值估计:running_mean
和全局方差估计:running_var
xnuevo = ( 1 − impulso ) × xanterior + impulso × xt x_{nuevo}=(1-impulso) \times x_{antiguo}+impulso \times x_{t}Xnuevo _ _=( 1−impulso ) _ _ _ _ _ _ _×Xviejo _ _+impulso _ _ _ _ _ _ _×Xt
xnuevo x_{nuevo}Xnuevo _ _para el actualizadorunning_mean/running_var
, xold x_{old}Xviejo _ _Antes de la actualizaciónrunning_mean/running_var
, xt x_{t}XtPara el lote actualmean和var
,momentum
es el factor de peso, que generalmente se toma como0.1
BatchNorm2d
batchnorm=torch.nn.BatchNorm2d(num_features=通道的数量)
No se recomienda cambiar otros parámetros en pytorch
Verificación experimental en BatchNorm2d
- Validación de la fórmula de normalización
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opti
from torchvision.transforms import RandomRotation
import torchsummary
import time
import datetime
import numpy as np
import copy
import torch.nn as nn
data=torch.tensor(
[[[[1,2],
[3,4]]]],dtype=torch.float32
)
batchnorm=nn.BatchNorm2d(num_features=1,momentum=0.1)
print('------------1--------------')
print("初始状态下的running_mean,running_var")
print(batchnorm.running_mean)
print(batchnorm.running_var)
print('------------2--------------')
print("输入data后状态下的running_mean,running_var")
test=batchnorm(data)
print(batchnorm.running_mean)
print(batchnorm.running_var)
print('训练状态下对data进行batchNorm')
print(test)
print('手动计算的batchNorm')
mean=torch.mean(data)
std=torch.var(data,False)
print((data[0][0]-mean)/torch.sqrt(std+1e-5))
Conclusión, la media y estándar normalizados son la media y estándar del lote actual
running_mean
running_var
Validación de fórmula para suma
print('------------3--------------')
print("人工计算的running_mean,running_var")
running_mean=torch.tensor(0)
running_var=torch.tensor(1)
running_mean=0.9*running_mean+0.1*mean
running_var=0.9*running_var+0.1*std
print(running_mean)
print(running_var)
print('测试状态下对data进行batchNorm')
batchnorm.training=False
test=batchnorm(data)
print(test)
#得出如下结论:
#running_mean=(1-momentum)*running_mean+momentum*batch_mean
#running_var=(1-momentum)*running_var+momentum*batch_var
running_mean y running_var solo tienen un impacto en la prueba y no tienen ningún impacto en el entrenamiento. Los datos de la prueba se usan running_mean
y running_var
normalizan.
当track_running_stats=False
impacto del tiempo
print('------------4--------------')
print('track_running_stats设置为False时,输入data前得running_mean,running_var')
batchnorm=nn.BatchNorm2d(num_features=1,momentum=0.1,track_running_stats=False)
print(batchnorm.running_mean)
print(batchnorm.running_var)
print('------------5--------------')
print('track_running_stats设置为False时,输入data后得running_mean,running_var')
test=batchnorm(data)
print(batchnorm.running_mean)
print(batchnorm.running_var)
print('------------6--------------')
print('track_running_stats设置为False时,训练状态下对data进行batchnorm')
print(test)
print('------------7--------------')
print('track_running_stats设置为False时,测试状态下对data进行batchnorm')
batchnorm.training=False
test=batchnorm(data)
print(test)
#得出如下结论
#running_mean和running_var是用于对测试集进行归一化,如果track_running_stats设置为False,则测试集进行归一化时不会使用running_mean和running_var
#而是直接用自身得mean和std
no track_running_stats
establecer enFalse