Verificación experimental de BatchNorm2d en pytorch

BatchNorm2d

Normalice por lotes la matriz bidimensional, meanque es la batchmedia actual y la desviación estándar stdactual batch. El uso de la normalización por lotes puede asignar datos con diferentes rangos de valores al intervalo de la distribución normal estándar, reduciendo la brecha entre los datos, para facilitar la convergencia rápida del modelo. La normalización por lotes esencialmente reduce el error absoluto entre muestras, pero no cambia el error relativo. Por ejemplo, para [1,2,3,4]la normalización, aunque el tamaño del número cambia, la relación de tamaño entre los números no cambiará. En general, se recomienda seguir una normalización por lotes después del kernel de convolución.

oficial

  • fórmula de normalización
    inserte la descripción de la imagen aquí

  • 全局均值估计:running_mean全局方差估计:running_var
    xnuevo = ( 1 − impulso ) × xanterior + impulso × xt x_{nuevo}=(1-impulso) \times x_{antiguo}+impulso \times x_{t}Xnuevo _ _=( 1impulso ) _ _ _ _ _ _ _×Xviejo _ _+impulso _ _ _ _ _ _ _×Xt
    xnuevo x_{nuevo}Xnuevo _ _para el actualizado running_mean/running_var, xold x_{old}Xviejo _ _Antes de la actualización running_mean/running_var, xt x_{t}XtPara el lote actual mean和var, momentumes el factor de peso, que generalmente se toma como0.1

  • BatchNorm2d
    batchnorm=torch.nn.BatchNorm2d(num_features=通道的数量)
    No se recomienda cambiar otros parámetros en pytorch

Verificación experimental en BatchNorm2d

  • Validación de la fórmula de normalización
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as opti
from torchvision.transforms import RandomRotation
import torchsummary
import time
import datetime
import numpy as np
import copy
import torch.nn as nn
data=torch.tensor(
   [[[[1,2],
    [3,4]]]],dtype=torch.float32
)
batchnorm=nn.BatchNorm2d(num_features=1,momentum=0.1)
print('------------1--------------')
print("初始状态下的running_mean,running_var")
print(batchnorm.running_mean)
print(batchnorm.running_var)
print('------------2--------------')
print("输入data后状态下的running_mean,running_var")
test=batchnorm(data)
print(batchnorm.running_mean)
print(batchnorm.running_var)
print('训练状态下对data进行batchNorm')
print(test)
print('手动计算的batchNorm')
mean=torch.mean(data)
std=torch.var(data,False)
print((data[0][0]-mean)/torch.sqrt(std+1e-5))

Conclusión, la media y estándar normalizados son la media y estándar del lote actual

  • running_meanrunning_varValidación de fórmula para suma
print('------------3--------------')
print("人工计算的running_mean,running_var")
running_mean=torch.tensor(0)
running_var=torch.tensor(1)
running_mean=0.9*running_mean+0.1*mean
running_var=0.9*running_var+0.1*std
print(running_mean)
print(running_var)

print('测试状态下对data进行batchNorm')
batchnorm.training=False
test=batchnorm(data)
print(test)
#得出如下结论:
#running_mean=(1-momentum)*running_mean+momentum*batch_mean
#running_var=(1-momentum)*running_var+momentum*batch_var

running_mean y running_var solo tienen un impacto en la prueba y no tienen ningún impacto en el entrenamiento. Los datos de la prueba se usan running_meany running_varnormalizan.

  • 当track_running_stats=Falseimpacto del tiempo
print('------------4--------------')
print('track_running_stats设置为False时,输入data前得running_mean,running_var')
batchnorm=nn.BatchNorm2d(num_features=1,momentum=0.1,track_running_stats=False)
print(batchnorm.running_mean)
print(batchnorm.running_var)
print('------------5--------------')
print('track_running_stats设置为False时,输入data后得running_mean,running_var')
test=batchnorm(data)
print(batchnorm.running_mean)
print(batchnorm.running_var)
print('------------6--------------')
print('track_running_stats设置为False时,训练状态下对data进行batchnorm')
print(test)
print('------------7--------------')
print('track_running_stats设置为False时,测试状态下对data进行batchnorm')
batchnorm.training=False
test=batchnorm(data)
print(test)
#得出如下结论
#running_mean和running_var是用于对测试集进行归一化,如果track_running_stats设置为False,则测试集进行归一化时不会使用running_mean和running_var
#而是直接用自身得mean和std

no track_running_statsestablecer enFalse

Supongo que te gusta

Origin blog.csdn.net/qq_33880925/article/details/130244586
Recomendado
Clasificación