今天的这篇文章源自于EMNLP 2021 Findings,论文标题为《AEDA: An Easier Data Augmentation Technique for Text Classification》。实际上用一句话即可总结全文:对于文本分类任务来说,在句子中插入一些标点符号是最强的数据扩增方法
AEDA Augmentation
读者看到这肯定会想问:添加哪些标点符号?加多少?对于这些问题,原论文给出了非常详细的解答,同时这部分也是论文唯一有价值的地方,其他部分的文字叙述基本是在说一些基础概念,或者前人工作等等
首先,可选的标点符号有六个:{".", ";", "?", ":", "!", ","}。其次,设添加句子标点的个数为$n$,则
$$ n\in [1, \frac{1}{3}l] $$
其中,$l$为句子长度。下面给出几个扩增例子
$$ \begin{array}{cc} \hline \textbf{Original} & \text{a sad , superior human comedy played out on the back roads of life .} \\ \hline \textbf{Aug 1} & \text{a sad , superior human comedy played out on the back roads ; of life ; .}\\ \hline \textbf{Aug 2} & \text{a , sad . , superior human ; comedy . played . out on the back roads of life .}\\ \hline \textbf{Aug 3} & \text{: a sad ; , superior ! human : comedy , played out ? on the back roads of life .}\\ \hline \end{array} $$
光说不练假把式,效果究竟几何呢?原论文做了大量文本分类任务的实验,并且与EDA方法进行了比较,而且有意思的是,AEDA在github上的repo是fork自EDA论文的repo,怎么有种杀鸡取卵的感觉
首先看下面一组图,作者在5个数据集上进行了对比(模型为RNN)
在BERT上的效果如下表所示,为什么上面都测了5个数据集,而论文中对BERT只展示了2个数据集的结果呢?我大胆猜测是因为在其他数据集上的效果不太好,否则没有理由不把其余数据集的结果贴出来
$$ \begin{array}{c|cc} \text{Model} & \text{SST2} & \text{TREC} \\ \hline \text{BERT} & 91.10 & 97.00\\ \hline \text{+EDA} & 90.99 & 96.00\\ \hline \text{+AEDA} & \pmb{91.76} & \pmb{97.20}\\ \end{array} $$
官方开源的代码可以在这里找到,下面我给出可以直接运行的核心代码,数据集是文本对,并且我将其标点都改为中文格式使其更适合中文数据
import random
random.seed(0)
PUNCTUATIONS = ['。', ',', '!', ';', ':']
PUNC_RATIO = 0.3
# Insert punction words into a given sentence with the given ratio "punc_ratio"
def insert_punctuation_marks(sentence, punc_ratio=PUNC_RATIO):
# words = sentence.split(',')
words = sentence
new_line = []
q = random.randint(1, int(punc_ratio * len(words) + 1))
qs = random.sample(range(0, len(words)), q)
for j, word in enumerate(words):
if j in qs:
new_line.append(PUNCTUATIONS[random.randint(0, len(PUNCTUATIONS)-1)])
new_line.append(word)
else:
new_line.append(word)
new_line = ''.join(new_line)
return new_line
data_aug = []
with open('train.txt') as f:
for line in f.readlines():
s1, s2, label = line.split('\t')
s1, s2 = insert_punctuation_marks(s1.strip()), insert_punctuation_marks(s2.strip())
line_aug = '\t'.join([s1, s2, label])
data_aug.append(line_aug)
data_aug.append(line)
with open('train_aug_marks.txt', 'w') as train_orig_plus_augs:
train_orig_plus_augs.writelines(data_aug)
我记得好像有篇文章说,BERT 这类预训练语言模型对句子是否通顺,还有一些标点符号不是那么敏感。