MENU

Sentence-Transformer的使用及fine-tune教程

October 14, 2020 • Read: 192 • Deep Learning阅读设置

简述

Sentence-Transformer官方文档写的很详细,里面有各种你可能会用到的示例代码,并且都有比较详细的说明,如果有什么问题,应该先去看官方文档

本文主要从两种情况来介绍如何使用Sentence-Transformer,一种是直接使用,另一种是在自己的数据集上fine-tune

首先,无论何种场景,您都应该先安装以下两个库

pip install -U sentence-transformers
pip install -U transformers

直接使用

Sentence-Transformer提供了非常多的预训练模型供我们使用,对于STS(Semantic Textual Similarity)任务来说,比较好的模型有以下几个

  • roberta-large-nli-stsb-mean-tokens - STSb performance: 86.39
  • roberta-base-nli-stsb-mean-tokens - STSb performance: 85.44
  • bert-large-nli-stsb-mean-tokens - STSb performance: 85.29
  • distilbert-base-nli-stsb-mean-tokens - STSb performance: 85.16

全部STS模型列表(需要科学的力量)

这里我就选择最好的模型做一下语义文本相似度任务

from sentence_transformers import SentenceTransformer
model = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')

语义文本相似度任务指的是给定一个句子(query),在整个语料库中寻找和该句子语义上最相近的几个句子

用一个list来代表整个语料库,list中存的是str类型的句子

sentences = ['Lack of saneness',
        'Absence of sanity',
        'A man is eating food.',
        'A man is eating a piece of bread.',
        'The girl is carrying a baby.',
        'A man is riding a horse.',
        'A woman is playing violin.',
        'Two men pushed carts through the woods.',
        'A man is riding a white horse on an enclosed ground.',
        'A monkey is playing drums.',
        'A cheetah is running behind its prey.']
sentence_embeddings = model.encode(sentences)

for sentence, embedding in zip(sentences, sentence_embeddings):
    print("Sentence:", sentence)
    print("Embedding:", embedding)
    print("")

非常简单几行代码就获得了一个句子的向量表示

下面定义query句子,并获得它的向量表示

query = 'Nobody has sane thoughts'  #  A query sentence uses for searching semantic similarity score.
queries = [query]
query_embeddings = model.encode(queries)

scipy库计算两个向量的余弦距离,找出与query句子余弦距离最小的前三个句子

import scipy

print("Semantic Search Results")
number_top_matches = 3
for query, query_embedding in zip(queries, query_embeddings):
    distances = scipy.spatial.distance.cdist([query_embedding], sentence_embeddings, "cosine")[0]
    results = zip(range(len(distances)), distances)
    results = sorted(results, key=lambda x: x[1])
    print("Query:", query)
    print("\nTop {} most similar sentences in corpus:".format(number_top_matches))

    for idx, distance in results[0:number_top_matches]:
        print(sentences[idx].strip(), "(Cosine Score: %.4f)" % (1-distance))

distance表示两个句子的余弦距离,1-distance可以理解为两个句子的余弦分数,分数越大表示两个句子的语义越相近

Fine-Tune

Fine-Tune仍然是STS任务,我使用的数据集是中-韩句子对,如下所示

每一行第一列是韩语句子,第二列是对应的中文,所有这些行就构成了正样本,负样本只需要将左右两列的行顺序打乱即可

这里我自己实现了一个shuffle函数,因为random.shuffle()会改变原有数据,而我不希望改变原有数据

from copy import deepcopy
from random import randint

def shuffle(lst):
  temp_lst = deepcopy(lst)
  m = len(temp_lst)
  while (m):
    m -= 1
    i = randint(0, m)
    temp_lst[m], temp_lst[i] = temp_lst[i], temp_lst[m]
  return temp_lst

首先读取数据

import xlrd
f = xlrd.open_workbook('Ko2Cn.xlsx').sheet_by_name('Xbench QA')
Ko_list = f.col_values(0) # 所有的中文句子
Cn_list = f.col_values(1) # 所有的韩语句子

shuffle_Cn_list = shuffle(Cn_list) # 所有的中文句子打乱排序
shuffle_Ko_list = shuffle(Ko_list) # 所有的韩语句子打乱排序

Sentence-Transformer在fine-tune的时候,数据必须保存到list中,list里是Sentence-Transformer库的作者自己定义的InputExample()对象

InputExample()对象需要传两个参数textslabel,其中,texts也是个list类型,里面保存了一个句子对,label必须为float类型,表示这个句子对的相似程度

比方说下面的示例代码

train_examples = [InputExample(texts=['My first sentence', 'My second sentence'], label=0.8),
   InputExample(texts=['Another pair', 'Unrelated sentence'], label=0.3)]

以下是构建数据集的代码

from sentence_transformers import SentenceTransformer, SentencesDataset, InputExample, evaluation, losses
from torch.utils.data import DataLoader

train_size = int(len(Ko_list) * 0.8)
eval_size = len(Ko_list) - train_size

# Define your train examples.
train_data = []
for idx in range(train_size):
  train_data.append(InputExample(texts=[Ko_list[idx], Cn_list[idx]], label=1.0))
  train_data.append(InputExample(texts=[shuffle_Ko_list[idx], shuffle_Cn_list[idx]], label=0.0))

# Define your evaluation examples
sentences1 = Ko_list[train_size:]
sentences2 = Cn_list[train_size:]
sentences1.extend(list(shuffle_Ko_list[train_size:]))
sentences2.extend(list(shuffle_Cn_list[train_size:]))
scores = [1.0] * eval_size + [0.0] * eval_size

evaluator = evaluation.EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)
# Define your train dataset, the dataloader and the train loss
train_dataset = SentencesDataset(train_data, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=32)
train_loss = losses.CosineSimilarityLoss(model)

然后定义模型开始训练,这里我是用的是multilingual预训练模型,因为这个数据集既包含中文,也有韩语。Sentence-Transformer提供了三个可使用的多语言预训练模型

  • distiluse-base-multilingual-cased: Supported languages: Arabic, Chinese, Dutch, English, French, German, Italian, Korean, Polish, Portuguese, Russian, Spanish, Turkish. Model is based on DistilBERT-multi-lingual.
  • xlm-r-base-en-ko-nli-ststb: Supported languages: English, Korean. Performance on Korean STSbenchmark: 81.47
  • xlm-r-large-en-ko-nli-ststb: Supported languages: English, Korean. Performance on Korean STSbenchmark: 84.05
#Define the model. Either from scratch of by loading a pre-trained model
model = SentenceTransformer('distiluse-base-multilingual-cased')

# Tune the model
model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100, evaluator=evaluator, evaluation_steps=100, output_path='./Ko2CnModel')

每隔100次训练集的迭代,进行一次验证,并且它会自动将在验证集上表现最好的模型保存到output_path

如果要加载模型做测试,使用如下代码即可

from sentence_transformers import SentenceTransformer, util
model = SentenceTransformer('./Ko2CnModel')

# Sentences are encoded by calling model.encode()
emb1 = model.encode("터너를 이긴 푸들.")
emb2 = model.encode("战胜特纳的泰迪。")

cos_sim = util.pytorch_cos_sim(emb1, emb2)
print("Cosine-Similarity:", cos_sim)
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

13 Comments
  1. leo leo

    Up主,你好,很喜欢你的博客和视频,看了很多,三连有了!By the way,想问问你的视频中的护眼软件是啥,可以分享一波吗?@(呵呵)

    1. mathor mathor

      @leo额,我好像没用过任何护眼软件,应该是我博客自带的配色吧

    2. leo leo

      @mathor就是视频里电脑桌面渐变出来的,是定时提醒?

    3. mathor mathor

      @leo您具体发个视频链接,说一下几分几秒,我去看看,我记得我真的没用护眼软件

    4. leo leo

      @mathor尴尬啊,一下还真没找到,我模糊的记得中间是很大的时钟?也有可能看串了...

    5. mathor mathor

      @leo@(不高兴)你肯定是看的别人的视频

    6. leo leo

      @mathor奥利给

  2. mrchen mrchen

    up主是哪个学校的呀?

    1. mathor mathor

      @mrchenb区一所211

    2. mrchen mrchen

      @mathor我浏览一下了up主的博客,感觉太棒了,能这么详细的记录自己的学习经历。我也是研一狗,方向也是自然语言处理(但不是关系抽取),不知能否一起交流学习?

    3. mathor mathor

      @mrchen可以的,欢迎加入计算机养生群203827195@(呵呵)

    4. mrchen mrchen

      @mathor已加。不知道up组平时有参加一些比赛吗?不知有没有兴趣一起组队(抱大腿嘻嘻)

    5. mathor mathor

      @mrchen我都是自己小打小闹,不成气候的,这队,不组也罢√