版权声明:本文为博主原创文章,未经博主允许不得转载。作为分享主义者(sharism),本人所有互联网发布的图文均采用知识共享署名 4.0 国际许可协议(https://creativecommons.org/licenses/by/4.0/)进行许可。转载请保留作者信息并注明作者Jie Qiao专栏:http://blog.csdn.net/a358463121。商业使用请联系作者。 https://blog.csdn.net/a358463121/article/details/80820878
The Gumbel soft-max
Gumbel trick有两个用途,一个用途是是用来对离散分布进行采样,这是一种重参数化(reparameterization trick)的技巧,另外一个用途是用于估计normalizing partition function,也就是分布的归一化项。本文将介绍这两种方法的原理。
下面是一个使用gumbel trick来模拟离散分布采样的例子:
如上图例子,首先有
log α 1
log
α
1
,这可以看做是一个多项式分布的概率值, 然后加上一个gumbel noise G1,最后取最大值,就是我们要的离散分布的样本,样本出现的概率跟
α
α
成正比。这个过程可以形式化为,设X是离散随机分布
P ( X = k ) ∝ α k
P
(
X
=
k
)
∝
α
k
, 设
{ G k } k ≤ K
{
G
k
}
k
≤
K
是独立同分布的Gumbel 分布的随机变量。于是:
X = arg max k ( log ( α k ) + G k )
X
=
arg
max
k
(
log
(
α
k
)
+
G
k
)
为了让这个argmax可求导,于是就把中间的argmax换成softmax。 我们从这个图底下的“+”号可以看到,这是一种重参数的方法,通过加一个随机的,固定分布的噪声,从而实现采样。这个噪声的采样方法可以通过Inverse transform sampling 方法直接从均匀分布进行采样,即
G i ∼ − l o g ( − l o g ( Uniform ( 0 , 1 ) ) )
G
i
∼
−
l
o
g
(
−
l
o
g
(
Uniform
(
0
,
1
)
)
)
目前一篇论文“Categorical Reparametrization with Gumbel-Softmax ”正是用了这个方法去对离散的隐状态进行采样,从而使得里面的参数可导。
Gumbel distribution
要想知道他为什么有这样的效果,我们需要先介绍一下gumbel distribution
这一个分布,可以把看作是一个关于“最大值”的概率的分布 ,比如你想预测明年河流最大的水位是多少,那么你就可以用gumbel分布去预测,这个分布会告诉你每一个值作为“最大值“的概率是多少。一个很简单的推广,如果你对这个分布取个负号的话,你就可以去预测最小值。
他的概率密度函数:
f ( x ) = 1 β e − ( z + e − z )
f
(
x
)
=
1
β
e
−
(
z
+
e
−
z
)
其中
z = x − μ β
z
=
x
−
μ
β
他的分布函数:
F ( x ) = e − e − ( x − μ ) / β
F
(
x
)
=
e
−
e
−
(
x
−
μ
)
/
β
均值:
E ( X ) = μ + c β
E
(
X
)
=
μ
+
c
β
,方差:
π 2 6 β 2
π
2
6
β
2
,其中
c
c
是一个常数(
Euler–Mascheroni constant )
Gumbel trick用于估计归一化项
我们先考虑一下,求解normalizing partition function. 就是分布的归一化项的问题。
定义一个非标准化的mass function
p ~ : X → [ 0 , ∞ )
p
~
:
X
→
[
0
,
∞
)
这个分布是没有标准化的,也就是他加起来不等于1.而它的标准化项normalizing partition function为
Z := ∑ x ∈ X p ~ ( x )
Z
:=
∑
x
∈
X
p
~
(
x
)
,接来下我们定义
ϕ ( x ) = ln p ~ ( x )
ϕ
(
x
)
=
ln
p
~
(
x
)
对其概率密度取对数。
于是可以证明:
max x ∈ X { ϕ ( x ) + γ ( x ) } ∼ Gumbel ( − c + ln Z )
max
x
∈
X
{
ϕ
(
x
)
+
γ
(
x
)
}
∼
Gumbel
(
−
c
+
ln
Z
)
其中
γ ∼ Gumbel ( − c )
γ
∼
Gumbel
(
−
c
)
。这就意味,只要我们从
max x ∈ X { ϕ ( x ) + γ ( x ) }
max
x
∈
X
{
ϕ
(
x
)
+
γ
(
x
)
}
中采集足够多的样本,我们就能够知道Z的取值(通过求期望得到)。
具体的推导过程如下: 令
T = max x ∈ X { ϕ ( x ) + γ ( x ) }
T
=
max
x
∈
X
{
ϕ
(
x
)
+
γ
(
x
)
}
,于是他的概率分布等于
P ( T < t ) = P ( max x ∈ X { ϕ ( x ) + γ ( x ) } < t ) = ∏ x ∈ X P ( ϕ ( x ) + γ ( x ) < t ) ( 最 大 值 小 于 t 等 价 于 每 一 项 都 小 于 t ) = ∏ x ∈ X P ( γ ( x ) < t − ϕ ( x ) ) = ∏ x ∈ X F G u m b e l ( t − ϕ ( x ) ) = exp ( − ∑ x ∈ X exp ( − ( t − ϕ ( x ) + c ) ) ) = exp ( − Z exp ( − ( t + c ) ) ) = exp ( − exp ( − ( t + c − ln Z ) ) ) ⇒ F ( t ) where t ∼ Gumbel ( − c + ln Z )
P
(
T
<
t
)
=
P
(
max
x
∈
X
{
ϕ
(
x
)
+
γ
(
x
)
}
<
t
)
=
∏
x
∈
X
P
(
ϕ
(
x
)
+
γ
(
x
)
<
t
)
(
最
大
值
小
于
t
等
价
于
每
一
项
都
小
于
t
)
=
∏
x
∈
X
P
(
γ
(
x
)
<
t
−
ϕ
(
x
)
)
=
∏
x
∈
X
F
G
u
m
b
e
l
(
t
−
ϕ
(
x
)
)
=
exp
(
−
∑
x
∈
X
exp
(
−
(
t
−
ϕ
(
x
)
+
c
)
)
)
=
exp
(
−
Z
exp
(
−
(
t
+
c
)
)
)
=
exp
(
−
exp
(
−
(
t
+
c
−
ln
Z
)
)
)
⇒
F
(
t
)
where
t
∼
Gumbel
(
−
c
+
ln
Z
)
我们发现这个max的函数,最后是服从
Gumbel ( − c + ln Z )
Gumbel
(
−
c
+
ln
Z
)
分布的,也就是说,我们只要求这个分布的期望:
E = − c + ln Z + c = ln Z
E
=
−
c
+
ln
Z
+
c
=
ln
Z
就可以把
ln Z
ln
Z
还原出来!这个例子也从侧面说明了Gumbel分布用于表示最大值的概率分布的优势所在(这里优势我觉得直观来看体现在,因为max就相当于所有小于某个数的概率连乘,而因为gumbel分布的指数项的性质,所以,连乘之后指数项没有消失,从而还是服从gumbel分布)。
为什么Gumbel trick能够模拟多项式分布采样?
如果我们的p是已经标准化的p,那么Z=1,于是,这个分布只与
γ ( x )
γ
(
x
)
有关。实际上,当
γ ∼ G u m b e l ( 0 , 1 )
γ
∼
G
u
m
b
e
l
(
0
,
1
)
,而p是多项式分布的时候就是我们模拟多项式分布进行采样时所服从的分布!那么为什么这个Gumbel 分布能够模拟多项式分布?
我们来考虑一个问题,对于公式1,多项式一共有K个类别。那么第k个类别恰好是最大的概率是多少?
令
z k = log α k + G k
z
k
=
log
α
k
+
G
k
要求解这个问题,我们要先求出
z k
z
k
是最大的概率多少?然后再对z积分,从而求出第k个是最大的概率。
Pr ( log α k + G k > max i ≠ k log α i + G i ) = Pr ( max i ≠ k log α i + G i < log α k + G k ) = ∏ i ≠ k Pr ( log α i + G i < log α k + G k ) = ∏ i ≠ k Pr ( G i < log α k + G k − log α i ) = ∏ i ≠ k F ( log α k + G k − log α i ) = ∏ i ≠ k exp { − exp { − ( log α k + G k − log α i ) } }
Pr
(
log
α
k
+
G
k
>
max
i
≠
k
log
α
i
+
G
i
)
=
Pr
(
max
i
≠
k
log
α
i
+
G
i
<
log
α
k
+
G
k
)
=
∏
i
≠
k
Pr
(
log
α
i
+
G
i
<
log
α
k
+
G
k
)
=
∏
i
≠
k
Pr
(
G
i
<
log
α
k
+
G
k
−
log
α
i
)
=
∏
i
≠
k
F
(
log
α
k
+
G
k
−
log
α
i
)
=
∏
i
≠
k
exp
{
−
exp
{
−
(
log
α
k
+
G
k
−
log
α
i
)
}
}
现在我们有了
z k
z
k
是最大的那个概率值,现在我们想知道第k个元素是最大的概率值是多少,因此,我们需要对所有z的取值进行积分,从而得到第k个位置取值最大的概率。
Pr ( k is largest | { x k ′ } ) = ∫ exp { − ( z k − log α k ) − exp { − ( z k − log α k ) } } ∏ i ≠ k exp { − exp { − ( z k − log α i ) } } d z k = ∫ exp { − z k + log α k − exp { − z k } ∑ i = 1 K exp { log α i } } d z k = exp { log α k } ∑ K i = 1 exp { log α i }
Pr
(
k is largest
|
{
x
k
′
}
)
=
∫
exp
{
−
(
z
k
−
log
α
k
)
−
exp
{
−
(
z
k
−
log
α
k
)
}
}
∏
i
≠
k
exp
{
−
exp
{
−
(
z
k
−
log
α
i
)
}
}
d
z
k
=
∫
exp
{
−
z
k
+
log
α
k
−
exp
{
−
z
k
}
∑
i
=
1
K
exp
{
log
α
i
}
}
d
z
k
=
exp
{
log
α
k
}
∑
i
=
1
K
exp
{
log
α
i
}
这时候,奇迹来了,上面这条等式恰好是一个softmax的公式,也就是说,第k个位置最大的概率,恰好就是对离散概率分布的一个近似。而且一个有趣的性质是这里的
α k
α
k
是不需要归一化的,因为经过softmax之后他就自动归一化了!
参考资料
https://en.wikipedia.org/wiki/Gumbel_distribution http://irenechen.net/blog/2017/08/17/gumbel-trick.html https://www.youtube.com/watch?v=wVkLM2KKHp8 https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/