pytorch:nn.Softmax()

本篇主要分析softmax函数中的dim参数
首先介绍一下softmax函数:
设 x = [1,2,3]
则softmax(x)= [ e 1 e 1 + e 2 + e 3 \frac{e^1}{e^1+e^2+e^3} , e 2 e 1 + e 2 + e 3 \frac{e^2}{e^1+e^2+e^3} , e 3 e 1 + e 2 + e 3 \frac{e^3}{e^1+e^2+e^3} ]

接下来分析torch.nn里面的softmax函数

y = torch.tensor([[[1.,2.,3.],[4.,5.,6.]],[[7.,8.,9.],[10.,11.,12.]]])
#y的size是2,2,4。可以看成有两张表,每张表2行3列
net_1 = nn.Softmax(dim=0)
net_2 = nn.Softmax(dim=1)
net_3 = nn.Softmax(dim=2)
print('dim=0的结果是:\n',net_1(y),"\n")
print('dim=1的结果是:\n',net_2(y),"\n")
print('dim=2的结果是:\n',net_3(y),"\n")

在这里插入图片描述

dim = 0:

dim = 0指第一个维度,在本例中第一个维度的size是2,如前文所说,我们把“2”看成是两张表,那么0.0025和0.9975是怎么来的呢?
第一张表中6个数的平均值是:(1+2+3+4+5+6)/6 = 3.5
第二张表中6个数的平均值是:(7+8+9+10+11+12)/6 = 9.5

0.0025≈ e ( 3.5 ) e ( 3.5 ) + e ( 6.5 ) \frac{e^(3.5)}{e^(3.5)+e^(6.5)}

0.9975≈ e ( 6.5 ) e ( 3.5 ) + e ( 6.5 ) \frac{e^(6.5)}{e^(3.5)+e^(6.5)}

dim = 1:

dim = 1指第二个维度,在本例中第二个维度的size是2,我们可以看成是2行。
我们把所有表中的第一行的数据拿出来:1,2,3;7,8,9 求平均:5
我们把所有表中的第一行的数据拿出来:4,5,6;10,11,12 求平均:8

0.0474≈ e 5 e 5 + e 8 \frac{e^5}{e^5+e^8}

0.9526≈ e 8 e 5 + e 8 \frac{e^8}{e^5+e^8}

dim = 2:

dim = 2指第三个维度,在本例中第三个维度的size是3,我们可以看成是3列。
我们把所有表中的第1列的数据拿出来:1,4;7,10 求平均:5.5
我们把所有表中的第2列的数据拿出来:2,5;8,11 求平均:6.5
我们把所有表中的第3列的数据拿出来:3,6;9,12 求平均:7.5

0.09≈ e ( 5.5 ) e ( 5.5 ) + e ( 6.5 ) + e ( 7.5 ) \frac{e^(5.5)}{e^(5.5)+e^(6.5)+e^(7.5)}

0.2447≈ e ( 6.5 ) e ( 5.5 ) + e ( 6.5 ) + e ( 7.5 ) \frac{e^(6.5)}{e^(5.5)+e^(6.5)+e^(7.5)}

0.6652≈ e ( 7.5 ) e ( 5.5 ) + e ( 6.5 ) + e ( 7.5 ) \frac{e^(7.5)}{e^(5.5)+e^(6.5)+e^(7.5)}


在这里插入图片描述
在这里插入图片描述

发布了43 篇原创文章 · 获赞 1 · 访问量 746

猜你喜欢

转载自blog.csdn.net/weixin_41391619/article/details/104823086