dropout理解(三)

本节使用的pytorch版本为1.8.1,其中的torch.nn.functional函数中的dropout方法中的参数training默认为True:
在这里插入图片描述
下面进入正文,首先创建一个简单的模型:

import torch
import torch.nn as nn
import torch.nn.functional as F


class LinearFC(nn.Module):

    def __init__(self):
        super(LinearFC, self).__init__()
        self.fc = nn.Linear(3, 2)

    def forward(self, input):
        out = self.fc(input)
        out = F.dropout(out, p=0.5, training=self.training)
        # out = F.dropout(out, p=0.5)
        return out


Net = LinearFC()
x = torch.randint(10, (2, 3)).float()  # 随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据
Net.train()
# Net.eval()
output = Net(x)
print(output)

# train the Net

网络模型很简单,就是一个线性结构,然后加上一个dropout。

下面一步一步执行:
在这里插入图片描述

  1. 这里创建了模型Net,Net的对应的参数如下:
Net.fc.weight=tensor([[-0.2141, -0.0099,  0.3015],
        [-0.4614, -0.1297,  0.3278]], requires_grad=True)
Net.fc.bias=tensor([ 0.5359, -0.3696], requires_grad=True)
Net.training=True

在这里插入图片描述

  1. 随机生成不大于10的整数,转为float, 因为nn.linear需要float类型数据,x的数据内容如下:
x=tensor([[5., 1., 6.],
        [9., 5., 9.]])
  1. Net.train()的作用是将Net.training=True,如果是在test阶段,需要执行Net.eval(),这时Net.training=False
  2. 接下来要进入网络模型中的forward了:
    在这里插入图片描述
    这里的out的结果的含义是计算 x × W T + b x \times W^T+b x×WT+b,相当于执行torch.mm(x, Net.fc.weight.t()) + Net.fc.bias,从控制台中可以看出这两个的执行结果一样。
    在这里插入图片描述
  3. 接下来将执行dropout,因为当前self.training=True,所以将执行dropout,out的执行结果相当于执行(torch.mm(x, Net.fc.weight.t()) + Net.fc.bias)/(1-0.5),然后以0.5概率去除:
    在这里插入图片描述
    可以看出:
去除前:
tensor([[ 2.5288, -1.6790],
        [ 2.5452, -4.4408]], grad_fn=<DivBackward0>)
以0.5概率去除后:
tensor([[2.5288, -0.0000],
        [0.0000, -0.0000]], grad_fn=<MulBackward0>)
  1. 返回结果,执行完毕。

如果是test阶段,也就是开启Net.eval()
在执行Net.eval()之前,Net.training=Ture,当执行Net.eval()之后,Net.training=False
在这里插入图片描述
然后进入到forward中,执行 x × W T + b x \times W^T+b x×WT+b
在这里插入图片描述
接下来要准备进入dropout了,但是由于self.traing=Flase所以这一步将不会执行,也就是说,不会执行(torch.mm(x, Net.fc.weight.t()) + Net.fc.bias)/(1-0.5)以及去除操作。而是直接对torch.mm(x, Net.fc.weight.t()) + Net.fc.bias结果进行返回:
在这里插入图片描述

总结

在模型中定义好F.dropout(out, p=0.5, training=self.training),根据训练阶段和测试阶段传入的参数Model.training来判断是否执行dropout。

如果写成F.dropout(out, p=0.5),那么参数默认training=True,就无法在测试阶段正确跳过dropout了。

猜你喜欢

转载自blog.csdn.net/vincent_duan/article/details/119952072
今日推荐