MENU

Gumbel-Softmax 完全解析

November 13, 2021 • Read: 12566 • 数学 ,Deep Learning阅读设置

写在前面

本文对大部分人来说可能仅仅起到科普的作用,因为 Gumbel-Max 仅在部分领域会用到,例如 GAN、VAE 等。笔者是在研究 EMNLP 上的一篇论文时,看到其中有用 Gumbel-Softmax 公式解决对一个概率分布进行采样无法求导的问题,故想到对 Gumbel-Softmax 做一个总结,由此写下本文

为什么我们需要 Gumbel-Softmax ?

假设现在我们有一个离散随机变量 $Z$ 的分布

$$ p_1 = p(Z=1)=\pi_1\\ p_2 = p(Z=2) = \pi_2\\ p_3 = p(Z=3) = \pi_3\\ ...\\ p_x = p(Z=x) = \pi_x\\ $$

其中,$\sum_i \pi_i=1$。我们想根据 $p_1,p_2,...,p_x$ 的概率采样得到一系列离散 $z$ 的值。但是这么做有一个问题,我们采样出来的 $z$ 只有值,没有生成 $z$ 的式子。例如我们要求 $Z$ 的期望,那么就有公式

$$ \mathbb{E}(Z) = p_1 + 2p_2 + \cdots +xp_x $$

$Z$ 对 $p_1,p_2,...,p_x$ 的导数都很清楚。但是现在我们的需求是采样一些具体的 $z$ 值,采样这个操作没有任何公式,因此也就无法求导。于是一个很自然的想法就产生了,我们能不能给一个以 $p_1,p_2,...,p_z$ 为参数的公式,让这个公式返回的结果是 $z$ 采样的结果呢?

Gumbel-Softmax

一般来说 $\pi_i$ 是通过神经网络预测对于类别 $i$ 的概率,这在分类问题中非常常见,假设我们将一个样本送入模型,最后输出的概率分布为 $[0.2, 0.4,0.1,0.2,0.1]$,表明这是一个 5 分类问题,其中概率最大的是第 2 类,到这一步,我们直接通过 argmax 就能获得结果了,但现在我们不是预测问题,而是一个采样问题。对于模型来说,直接取出概率最大的就可以了,但对我们来说,每个类别都是有一定概率的,我们想根据这个概率来进行采样,而不是直接简单无脑的输出概率最大的值

最常见的采样 $\mathbf {z}$ 的 onehot 公式为

$$ \mathbf{z} = \text{onehot}(\max \{i\mid \pi_1 + \pi_2+\cdots +\pi_{i-1} \leq u\})\tag{1} $$

其中 $i=1,2,..,x$ 是类别的下标,随机变量 $u$ 服从均匀分布 $U (0,1)$

上面这个过程实际上是很巧妙的,我们将概率分布从前往后不断加起来,当加到 $\pi_i$ 时超过了某个随机值 $ 0\leq u \leq 1$,那么这一次随机采样过程,$z$ 就被随机采样为第 $i$ 类,最后通过一个 onehot 变换

但是上述公式存在一个致命的问题:max 函数是不可导的

Gumbel-Max Trick

Gumbel-Max 技巧就是解决 max 函数不可导问题的,我们可以用 argmax 替换 max,即

$$ \mathbf{z} = \text{onehot}(\mathop{\text{argmax}}\limits_{i} \{g_i + \log \pi_i\})\tag{2} $$

其中,$g_i=-\log (-\log (u_i)), u_i \sim U (0,1)$,这一项名为 Gumbel 噪声,或者叫 Gumbel 分布,目的是使得 $\mathbf {z}$ 的返回结果不固定

可以看到式 $(2)$ 的整个过程中,不可导的部分只有 argmax,实际上我们可以用可导的 softmax 函数,在参数 $\tau$ 的控制下逼近 argmax,最终 $z_i$ 的公式为

$$ z_i = \frac{\exp(\frac{g_i + \log \pi_i}{\tau})}{\sum_{j}^x\exp(\frac{g_j + \log \pi_j}{\tau})}\tag{3} $$

其中,$\tau$ 越小 $(\tau \to 0)$,整个 softmax 越光滑逼近 argmax,并且 $\mathbf {z} = \{z_i\mid i=1,2,...,x\}$ 也越接近 onehot 向量;$\tau$ 越大 $(\tau \to \infty)$,$\mathbf {z}$ 向量越接近于均匀分布

总结

整个过程相当于我们把不可导的取样过程,从 $\mathbf {z}$ 本身转移到了求 $\mathbf {z}$ 的公式中的一项 $g_i$ 中,而 $g_i$ 本身不依赖 $p_1,..,p_x$,所以 $z$ 对 $p_1,...,p_x$ 就可以到了,而且我们得到的 $\mathbf {z}$ 仍然是离散概率分布的采样。这种采样过程转嫁的技巧有一个专有名词,叫重参数化技巧(Reparameterization Trick)

References

Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

2 Comments
  1. 鱼

    大佬,公式 1 是不是错了 pi1+pi2+...+pi-1 <= u.

    1. mathor mathor

      @鱼是的,感谢提醒,已修改