Loading [MathJax]/jax/output/SVG/jax.js
MENU

我们真的需要把训练集的损失降到零吗?

May 12, 2021 • Read: 6354 • Deep Learning阅读设置

在训练模型的时候,我们需要将损失函数一直训练到 0 吗?显然不用。一般来说,我们是用训练集来训练模型,但希望的是验证机的损失越小越好,而正常来说训练集的损失降到一定值后,验证集的损失就会开始上升,因此没必要把训练集的损失降低到 0

既然如此,在已经达到了某个阈值之后,我们可不可以做点别的事情来提升模型性能呢?ICML2020 的论文《Do We Need Zero Training Loss After Achieving Zero Training Error?》回答了这个问题,不过实际上它并没有很好的描述 "为什么",而只是提出了 "怎么做"

思路描述

论文提供的解决方案非常简单,假设原来的损失函数是 L(θ),现在改为 ˜L(θ)

˜L(θ)=|L(θ)b|+b

其中 b 是预先设定的阈值。当 L(θ)>b˜L(θ)=L(θ),这时就是执行普通的梯度下降;而 L(θ)<b˜L(θ)=2bL(θ),注意到损失函数变号了,所以这时候是梯度上升。因此,总的来说就是以 b 为阈值,低于阈值时反而希望损失函数变大。论文把这个改动称为 "Flooding"

这样做有什么效果呢?论文显示,在某些任务中,训练集的损失函数经过这样处理后,验证集的损失能出现 "二次下降(Double Descent)",如下图

左图:不加 Flooding 的训练示意图;右图:加了 Flooding 的训练示意图

简单来说,就是最终的验证集效果可能更好一些,原论文的实验结果如下:

Flooding 的实验结果:第一行 W 表示是否使用 weight decay,第二行 E 表示是否使用 early stop,第三行的 F 表示是否使用 Flooding

个人分析

如何解释这个方法呢?可以想像,当损失函数达到 b 之后,训练流程大概就是在交替执行梯度下降和梯度上升。直观想的话,感觉一步上升一步下降,似乎刚好抵消了。事实真的如此吗?我们来算一下看看。假设先下降一步后上升一步,学习率为 ε,那么:

θn=θn1εg(θn1)θn+1=θn+εg(θn)

其中 g(θ)=θL(θ),现在我们有

θn+1=θn1εg(θn1)+εg(θn1εg(θn1))θn1εg(θn1)+ε(g(θn1)εθg(θn1)g(θn1))=θn1ε22θg(θn1)2

近似那一步实际上是使用了泰勒展开,我们将 θn1 看作 xεg(θn1) 看作 Δx,由于

g(xΔx)g(x)Δx=xg(x)

所以

g(xΔx)=g(x)Δxxg(x)

最终的结果就是相当于学习率为 ε22、损失函数为梯度惩罚 g(θ)2=θL(θ)2 的梯度下降。更妙的是,改为 "先上升再下降",其表达式依然是一样的(这不禁让我想起 "先涨价 10% 再降价 10%" 和 "先降价 10% 再涨价 10% 的故事")。因此,平均而言,Flooding 对损失函数的改动,相当于在保证了损失函数足够小之后去最小化 xL(θ)2,也就是推动参数往更平稳的区域走,这通常能提高泛化性(更好地抵抗扰动),因此一定程度上就能解释 Flooding 有作用的原因了

本质上来讲,这跟往参数里边加入随机扰动、对抗训练等也没什么差别,只不过这里是保证了损失足够小后再加扰动

继续脑洞

想要使用 Flooding 非常简单,只需要在原有代码基础上增加一行即可

  • logits = model(x)
  • loss = criterion(logits, y)
  • loss = (loss - b).abs() + b # This is it!
  • optimizer.zero_grad()
  • loss.backward()
  • optimizer.step()

有心是用这个方法的读者可能会纠结于 b 的选择,原论文说 b 的选择是一个暴力迭代的过程,需要多次尝试

The flood level is chosen from b{0,0.01,0.02,...,0.50}

不过笔者倒是有另外一个脑洞:b 无非就是决定什么时候开始交替训练罢了,那如果我们从一开始就用不同的学习率进行交替训练呢?也就是自始自终都执行

θn=θn1ε1g(θn1)θn+1=θn+ε2g(θn)

其中 ε1>ε2,这样我们就把 b 去掉了(引入了 ε1,ε2 的选择,天下没有免费的午餐)。重复上述近似展开,我们就得到

θn+1=θn1ε1g(θn1)+ε2g(θn1ε1g(θn1))θn1ε1g(θn1)+ε2(g(θn1)ε1θg(θn1)g(θn1))=θn1(ε1ε2)g(θn1)ε1ε22θg(θn1)2=θn1(ε1ε2)θ[L(θn1)+ε1ε22(ε1ε2)θL(θn1)2]

这就相当于自始自终都在用学习率 ε1ε2 来优化损失函数 L(θ)+ε1ε22(ε1ε2)θL(θ)2 了,也就是说一开始就把梯度惩罚给加了进去,这样能提升模型的泛化性能吗?《Backstitch: Counteracting Finite-sample Bias via Negative Steps》里边指出这种做法在语音识别上是有效的,请读者自行测试甄别

效果检验

我随便在网上找了个竞赛,然后利用别人提供的以 BERT 为 baseline 的代码,对 Flooding 的效果进行了测试,下图分别是没有做 Flooding 和参数 b=0.7 的 Flooding 损失值变化图,值得一提的是,没有做 Flooding 的验证集最低损失值为 0.814198,而做了 Flooding 的验证集最低损失值为 0.809810

根据知乎文章一行代码发一篇 ICML?底下用户 Curry 评论所言:"通常来说 b 值需要设置成比 'Validation Error 开始上升 ' 的值更小,1/2 处甚至更小,结果更优",所以我仔细观察了下没有加 Flooding 模型损失值变化图,大概在 loss 为 0.75 到 1.0 左右的时候开始出现过拟合现象,因此我又分别设置了 b=0.4b=0.5,做了两次 Flooding 实验,结果如下图

值得一提的是,b=0.4b=0.5 时,验证集上的损失值最低仅为 0.809958 和 0.796819,而且很明显验证集损失的整体上升趋势更加缓慢。接下来我做了一个实验,主要是验证 "继续脑洞" 部分以不同的学习率一开始就交替着做梯度下降和梯度上升的效果,其中,梯度下降的学习率我设为 1e5,梯度上升的学习率为 1e6,结果如下图,验证集的损失最低仅有 0.783370

References

Last Modified: May 29, 2021
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

  • OωO
  • |´・ω・)ノ
  • ヾ(≧∇≦*)ゝ
  • (☆ω☆)
  • (╯‵□′)╯︵┴─┴
  •  ̄﹃ ̄
  • (/ω\)
  • ∠( ᐛ 」∠)_
  • (๑•̀ㅁ•́ฅ)
  • →_→
  • ୧(๑•̀⌄•́๑)૭
  • ٩(ˊᗜˋ*)و
  • (ノ°ο°)ノ
  • (´இ皿இ`)
  • ⌇●﹏●⌇
  • (ฅ´ω`ฅ)
  • (╯°A°)╯︵○○○
  • φ( ̄∇ ̄o)
  • ヾ(´・ ・`。)ノ"
  • ( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
  • (ó﹏ò。)
  • Σ(っ °Д °;)っ
  • ( ,,´・ω・)ノ"(´っω・`。)
  • ╮(╯▽╰)╭
  • o(*////▽////*)q
  • >﹏<
  • ( ๑´•ω•) "(ㆆᴗㆆ)
  • (。•ˇ‸ˇ•。)
  • 泡泡
  • 阿鲁
  • 颜文字