La escritura a mano implementa el proceso de avance de BN

import torch
import torch.nn as nn

x = torch.randn(batch_size, channel, height, width)

# 初始化缩放参数 gamma 和平移参数 beta
gamma = torch.ones(channel, dtype=torch.float32)
beta = torch.zeros(channel, dtype=torch.float32)

def BN(x, gamma, beta, epsilon = 1e-8):
    mean = torch.mean(x, dim = (0, 2, 3), keepdim = True)
    var = torch.var(x, dim = (0, 2, 3), keepdim = True)
    x = (x - mean) / (torch.sqrt(var + epsilon))
    out = x * gamma + beta
    return out

Supongo que te gusta

Origin blog.csdn.net/slamer111/article/details/132799892
Recomendado
Clasificación