MENU

基于去噪Transformer的无监督句子编码

November 17, 2021 • Read: 4657 • Deep Learning阅读设置

EMNLP2021 Findings上有一篇名为TSDAE: Using Transformer-based Sequential Denoising Auto-Encoder for Unsupervised Sentence Embedding Learning的论文,利用Transformer结构无监督训练句子编码,网络架构如下所示

具体来说,输入的文本添加了一些确定的噪声,例如删除、交换、添加、Mask一些词等方法。Encoder需要将含有噪声的句子编码为一个固定大小的向量,然后利用Decoder将原本的不带噪声的句子还原。说是这么说,但是其中有非常多细节,首先是训练目标

$$ \begin{aligned} J_{\text{SDAE}}(\theta) &= \mathbb{E}_{x\sim D}[\log P_{\theta}(x\mid \tilde{x})]\\ &=\mathbb{E}_{x\sim D}[\sum_{t=1}^l \log P_{\theta}(x_t\mid \tilde{x})]\\ &=\mathbb{E}_{x\sim D}[\sum_{t=1}^l \log \frac{\exp(h_t^T e_t)}{\sum_{i=1}^N \exp(h_t^T e_i)}] \end{aligned} $$

其中,$D$是训练集;$x = x_1x_2\cdots x_l$是长度为$l$的输入句子;$\tilde{x}$是$x$添加噪声之后的句子;$e_t$是词$x_t$的word embedding;$N$为Vocabulary size;$h_t$是Decoder第$t$步输出的hidden state

不同于原始的Transformer,作者提出的方法,Decoder只利用Encoder输出的固定大小的向量进行解码,具体来说,Encoder-Decoder之间的cross-attention形式化地表示如下:

$$ \begin{aligned} &H^{(k)}=\text{Attention}(H^{(k-1)}, [s^T], [s^T])\\ &\text{Attention}(Q,K,V) = \text{Softmax}(\frac{QK^T}{\sqrt{d}})V \end{aligned} $$

其中,$H^{(k)}\in \mathbb{R}^{t\times d}$是Decoder第$k$层$t$个解码步骤内的hidden state;$d$是句向量的维度(Encoder输出向量的维度);$[s^T]\in \mathbb{R}^{1\times d}$是Encoder输出的句子(行)向量。从上面的公式我们可以看出,不论哪一层的cross-attention,$K$和$V$永远都是$s^T$,作者这样设计的目的是为了人为给模型添加一个瓶颈,如果Encoder编码的句向量$s^T$不够准确,Decoder就很难解码成功,换句话说,这样设计是为了使得Encoder编码的更加准确。训练结束后如果需要提取句向量只需要用Encoder即可

作者通过在STS数据集上调参,发现最好的组合方法如下:

  1. 采用删除单词这种添加噪声的方法,并且比例设置为60%
  2. 使用[CLS]位置的输出作为句向量

Results

从TSDAE的结果来看,基本上是拳打SimCSE,脚踢BERT-flow

个人总结

如果我是reviewer,我特别想问的一个问题是:"你们这种方法,与BART有什么区别?"

论文源码在UKPLab/sentence-transformers/,其实sentence-transformers已经把TSDAE封装成pip包,完整的训练流程可以参考Sentence-Transformer的使用及fine-tune教程,在此基础上只需要修改dataset和loss就可以轻松的训练TSDAE

# 创建可即时添加噪声的特殊去噪数据集
train_dataset = datasets.DenoisingAutoEncoderDataset(train_sentences)

# DataLoader 批量处理数据
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)

# 使用去噪自动编码器损失
train_loss = losses.DenoisingAutoEncoderLoss(model, decoder_name_or_path=model_name, tie_encoder_decoder=True)

# 模型训练
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=1,
    weight_decay=0,
    scheduler='constantlr',
    optimizer_params={'lr': 3e-5},
    show_progress_bar=True
)
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

5 Comments
  1. maple maple

    up你好,不小心右键点到把你博客的公式的渲染关掉了,现在的公式部分只能看到不知道是katex还是什么格式的代码了,右键又点不出来那个界面了,请问下怎么弄成原来的渲染好的显示格式啊@(汗)

    1. mathor mathor

      @mapleI have no idea,我尝试把你说的公式渲染什么的关掉,我都不知道该点哪里

    2. maple maple

      @mathor老哥我邮箱发给你截图了,真不好意思用这种问题打扰你@(笑尿)

    3. mathor mathor

      @maple我对比了一下我和你打勾的那些项,没有区别,不知道为什么你的就不能显示

    4. 张

      @maple清除浏览器的缓存就可以啦