MENU

稀疏 Softmax(Sparse Softmax)

July 16, 2021 • Read: 5281 • Deep Learning阅读设置

本文源自于 SPACES:“抽取 - 生成” 式长文本摘要(法研杯总结),原文其实是对一个比赛的总结,里面提到了很多 Trick,其中有一个叫做稀疏 Softmax(Sparse Softmax)的东西吸引了我的注意,查阅了很多资料以后,汇总在此

Sparse Softmax 的思想源于《From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification》《Sparse Sequence-to-Sequence Models》等文章。里边作者提出了将 Softmax 稀疏化的做法来增强其解释性乃至提升效果

不够稀疏的 Softmax

前面提到 Sparse Softmax 本质上是将 Softmax 的结果稀疏化,那么为什么稀疏化之后会有效呢?我们认稀疏化可以避免 Softmax 过度学习的问题。假设已经成功分类,那么我们有 $s_{\text {max}}=s_t$(目标类别的分数最大),此时我们可以推导原始交叉熵的一个不等式:

$$ \begin{aligned} -\log \frac{e^{s_{\text{max}}}}{\sum\limits_{i=1}^n e^{s_i}}&=\log (\sum_{i=1}^n e^{s_i})-s_{\text{max}} \\&= \log (e^{s_t}+\sum_{i\neq t}e^{s_i})-s_{\text{max}}\\ &= \log (e^{s_{\text{max}}} + \sum_{i\neq t}e^{s_i})-\log (e^{s_{\text{max}}})\\ &= \log (\frac{e^{s_{\text{max}}} + \sum_{i\neq t}e^{s_i}}{e^{s_{\text{max}}}})\\ &= \log (1+ \sum_{i \neq t}e^{s_i - s_{\text{max}}})\\ & \ge \log (1+ (n - 1)e^{s_{\text{min}}-s_{\text{max}}}) \end{aligned}\tag{1} $$

假设当前交叉熵值为 $\varepsilon$,那么有

$$ \varepsilon \ge \log (1+ (n - 1)e^{s_{\text{min}}-s_{\text{max}}})\tag{2} $$

解得

$$ s_{\text{max}} - s_{\text{min}} \ge \log (n - 1) - \log (e^{\varepsilon} - 1)\tag{3} $$

我们以 $\varepsilon = \ln2 = 0.69...$ 为例,这时候 $\log (e^{\varepsilon} - 1)=0$,那么 $s_{\text {max}} - s_{\text {min}}\ge \log (n-1)$。也就是说,为了要 loss 降到 0.69,那么最大的 logit 与最小的 logit 的差就必须大于 $\log (n-1)$,当 $n$ 比较大时,对于分类问题来说这是一个没有必要的过大的间隔,因为我们只希望目标类的 logit 比所有非目标类都要大一点就行,但是并不一定需要大 $\log (n-1)$ 那么多,因此常规的交叉熵容易过度学习从而导致过拟合

稀疏的 Sparsemax

前面说了这么多关于 Softmax 的内容,那么 Sparse Softmax 或者说 Sparsemax 是如何做到稀疏化分布的呢?原文内容大家可以直接去看论文,写的非常复杂,这里我给出苏剑林大佬设计的一个更简单的版本

$$ \begin{array}{c|c|c} \hline & \text{Origin} & \text{Sparse} \\ \hline \text{Softmax} & p_i = \frac{e^{s_i}}{\sum\limits_{j=1}^{n} e^{s_j}} & p_i=\left\{\begin{aligned}&\frac{e^{s_i}}{\sum\limits_{j\in\Omega_k} e^{s_j}},\,i\in\Omega_k\\ &\quad 0,\,i\not\in\Omega_k\end{aligned}\right.\\ \hline \text{CrossEntropy} & \log\left(\sum\limits_{i=1}^n e^{s_i}\right) - s_t & \log\left(\sum\limits_{i\in\Omega_k} e^{s_i}\right) - s_t\\ \hline \end{array} $$

其中 $\Omega_k$ 是将 $s_1,s_2,...,s_n$ 从大到小排列后前 $k$ 个元素的下标集合。说白了,苏剑林大佬提出的 Sparse Softmax 就是在计算概率的时候,只保留前 $k$ 个,后面的直接置零,$k$ 是人为选择的超参数

代码

首先我根据苏剑林大佬的思路,给出一个简单版本的 PyTorch 代码

  • import torch
  • import torch.nn as nn
  • class Sparsemax(nn.Module):
  • """Sparsemax loss"""
  • def __init__(self, k_sparse=1):
  • super(Sparsemax, self).__init__()
  • self.k_sparse = k_sparse
  • def forward(self, preds, labels):
  • """
  • Args:
  • preds (torch.Tensor): [batch_size, number_of_logits]
  • labels (torch.Tensor): [batch_size] index, not ont-hot
  • Returns:
  • torch.Tensor
  • """
  • preds = preds.reshape(preds.size(0), -1) # [batch_size, -1]
  • topk = preds.topk(self.k_sparse, dim=1)[0] # [batch_size, k_sparse]
  • # log(sum(exp(topk)))
  • pos_loss = torch.logsumexp(topk, dim=1)
  • # s_t
  • neg_loss = torch.gather(preds, 1, labels[:, None].expand(-1, preds.size(1)))[:, 0]
  • return (pos_loss - neg_loss).sum()

再给出一个 Github 上找到的一个 PyTorch 原版代码

  • """Sparsemax activation function.
  • Pytorch implementation of Sparsemax function from:
  • -- "From Softmax to Sparsemax: A Sparse Model of Attention and Multi-Label Classification"
  • -- André F. T. Martins, Ramón Fernandez Astudillo (http://arxiv.org/abs/1602.02068)
  • """
  • import torch
  • import torch.nn as nn
  • device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  • class Sparsemax(nn.Module):
  • """Sparsemax function."""
  • def __init__(self, dim=None):
  • """Initialize sparsemax activation
  • Args:
  • dim (int, optional): The dimension over which to apply the sparsemax function.
  • """
  • super(Sparsemax, self).__init__()
  • self.dim = -1 if dim is None else dim
  • def forward(self, input):
  • """Forward function.
  • Args:
  • input (torch.Tensor): Input tensor. First dimension should be the batch size
  • Returns:
  • torch.Tensor: [batch_size x number_of_logits] Output tensor
  • """
  • # Sparsemax currently only handles 2-dim tensors,
  • # so we reshape to a convenient shape and reshape back after sparsemax
  • input = input.transpose(0, self.dim)
  • original_size = input.size()
  • input = input.reshape(input.size(0), -1)
  • input = input.transpose(0, 1)
  • dim = 1
  • number_of_logits = input.size(dim)
  • # Translate input by max for numerical stability
  • input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
  • # Sort input in descending order.
  • # (NOTE: Can be replaced with linear time selection method described here:
  • # http://stanford.edu/~jduchi/projects/DuchiShSiCh08.html)
  • zs = torch.sort(input=input, dim=dim, descending=True)[0]
  • range = torch.arange(start=1, end=number_of_logits + 1, step=1, device=device, dtype=input.dtype).view(1, -1)
  • range = range.expand_as(zs)
  • # Determine sparsity of projection
  • bound = 1 + range * zs
  • cumulative_sum_zs = torch.cumsum(zs, dim)
  • is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
  • k = torch.max(is_gt * range, dim, keepdim=True)[0]
  • # Compute threshold function
  • zs_sparse = is_gt * zs
  • # Compute taus
  • taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
  • taus = taus.expand_as(input)
  • # Sparsemax
  • self.output = torch.max(torch.zeros_like(input), input - taus)
  • # Reshape back to original shape
  • output = self.output
  • output = output.transpose(0, 1)
  • output = output.reshape(original_size)
  • output = output.transpose(0, self.dim)
  • return output
  • def backward(self, grad_output):
  • """Backward function."""
  • dim = 1
  • nonzeros = torch.ne(self.output, 0)
  • sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
  • self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
  • return self.grad_input

* 补充

经过苏剑林大佬的许多实验发现,Sparse Softmax 只适用于有预训练的场景,因为预训练模型已经训练得很充分了,因此 finetune 阶段要防止过拟合;但是如果从零训练一个模型,那么 Sparse Softmax 会造成性能下降,因为每次只有 $k$ 个类别被学习到,反而会存在学习不充分的情况(欠拟合)

References

Last Modified: September 4, 2021
Archives Tip
QR Code for this page
Tipping QR Code