MENU

Sentence-Transformer 的使用及 fine-tune 教程

October 14, 2020 • Read: 39029 • Deep Learning阅读设置

简述

Sentence-Transformer 官方文档写的很详细,里面有各种你可能会用到的示例代码,并且都有比较详细的说明,如果有什么问题,应该先去看官方文档

本文主要从两种情况来介绍如何使用 Sentence-Transformer,一种是直接使用,另一种是在自己的数据集上 fine-tune

首先,无论何种场景,您都应该先安装以下两个库

  • pip install -U sentence-transformers
  • pip install -U transformers

直接使用

Sentence-Transformer 提供了非常多的预训练模型供我们使用,对于 STS(Semantic Textual Similarity)任务来说,比较好的模型有以下几个

  • roberta-large-nli-stsb-mean-tokens - STSb performance: 86.39
  • roberta-base-nli-stsb-mean-tokens - STSb performance: 85.44
  • bert-large-nli-stsb-mean-tokens - STSb performance: 85.29
  • distilbert-base-nli-stsb-mean-tokens - STSb performance: 85.16

全部 STS 模型列表(需要科学的力量)

这里我就选择最好的模型做一下语义文本相似度任务

  • from sentence_transformers import SentenceTransformer
  • model = SentenceTransformer('roberta-large-nli-stsb-mean-tokens')

语义文本相似度任务指的是给定一个句子(query),在整个语料库中寻找和该句子语义上最相近的几个句子

用一个 list 来代表整个语料库,list 中存的是 str 类型的句子

  • sentences = ['Lack of saneness',
  • 'Absence of sanity',
  • 'A man is eating food.',
  • 'A man is eating a piece of bread.',
  • 'The girl is carrying a baby.',
  • 'A man is riding a horse.',
  • 'A woman is playing violin.',
  • 'Two men pushed carts through the woods.',
  • 'A man is riding a white horse on an enclosed ground.',
  • 'A monkey is playing drums.',
  • 'A cheetah is running behind its prey.']
  • sentence_embeddings = model.encode(sentences)
  • for sentence, embedding in zip(sentences, sentence_embeddings):
  • print("Sentence:", sentence)
  • print("Embedding:", embedding)
  • print("")

非常简单几行代码就获得了一个句子的向量表示

下面定义 query 句子,并获得它的向量表示

  • query = 'Nobody has sane thoughts' # A query sentence uses for searching semantic similarity score.
  • queries = [query]
  • query_embeddings = model.encode(queries)

scipy 库计算两个向量的余弦距离,找出与 query 句子余弦距离最小的前三个句子

  • import scipy
  • print("Semantic Search Results")
  • number_top_matches = 3
  • for query, query_embedding in zip(queries, query_embeddings):
  • distances = scipy.spatial.distance.cdist([query_embedding], sentence_embeddings, "cosine")[0]
  • results = zip(range(len(distances)), distances)
  • results = sorted(results, key=lambda x: x[1])
  • print("Query:", query)
  • print("\nTop {} most similar sentences in corpus:".format(number_top_matches))
  • for idx, distance in results[0:number_top_matches]:
  • print(sentences[idx].strip(), "(Cosine Score: %.4f)" % (1-distance))

distance 表示两个句子的余弦距离,1-distance 可以理解为两个句子的余弦分数,分数越大表示两个句子的语义越相近

Fine-Tune

Fine-Tune 仍然是 STS 任务,我使用的数据集是中 - 韩句子对,如下所示

每一行第一列是韩语句子,第二列是对应的中文,所有这些行就构成了正样本,负样本只需要将左右两列的行顺序打乱即可

这里我自己实现了一个 shuffle 函数,因为 random.shuffle() 会改变原有数据,而我不希望改变原有数据

  • from copy import deepcopy
  • from random import randint
  • def shuffle(lst):
  • temp_lst = deepcopy(lst)
  • m = len(temp_lst)
  • while (m):
  • m -= 1
  • i = randint(0, m)
  • temp_lst[m], temp_lst[i] = temp_lst[i], temp_lst[m]
  • return temp_lst

首先读取数据

  • import xlrd
  • f = xlrd.open_workbook('Ko2Cn.xlsx').sheet_by_name('Xbench QA')
  • Ko_list = f.col_values(0) # 所有的中文句子
  • Cn_list = f.col_values(1) # 所有的韩语句子
  • shuffle_Cn_list = shuffle(Cn_list) # 所有的中文句子打乱排序
  • shuffle_Ko_list = shuffle(Ko_list) # 所有的韩语句子打乱排序

Sentence-Transformer 在 fine-tune 的时候,数据必须保存到 list 中,list 里是 Sentence-Transformer 库的作者自己定义的 InputExample() 对象

InputExample() 对象需要传两个参数 textslabel,其中,texts 也是个 list 类型,里面保存了一个句子对,label 必须为 float 类型,表示这个句子对的相似程度

比方说下面的示例代码

  • train_examples = [InputExample(texts=['My first sentence', 'My second sentence'], label=0.8),
  • InputExample(texts=['Another pair', 'Unrelated sentence'], label=0.3)]

以下是构建数据集的代码

  • from sentence_transformers import SentenceTransformer, SentencesDataset, InputExample, evaluation, losses
  • from torch.utils.data import DataLoader
  • train_size = int(len(Ko_list) * 0.8)
  • eval_size = len(Ko_list) - train_size
  • # Define your train examples.
  • train_data = []
  • for idx in range(train_size):
  • train_data.append(InputExample(texts=[Ko_list[idx], Cn_list[idx]], label=1.0))
  • train_data.append(InputExample(texts=[shuffle_Ko_list[idx], shuffle_Cn_list[idx]], label=0.0))
  • # Define your evaluation examples
  • sentences1 = Ko_list[train_size:]
  • sentences2 = Cn_list[train_size:]
  • sentences1.extend(list(shuffle_Ko_list[train_size:]))
  • sentences2.extend(list(shuffle_Cn_list[train_size:]))
  • scores = [1.0] * eval_size + [0.0] * eval_size
  • evaluator = evaluation.EmbeddingSimilarityEvaluator(sentences1, sentences2, scores)
  • # Define your train dataset, the dataloader and the train loss
  • train_dataset = SentencesDataset(train_data, model)
  • train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=32)
  • train_loss = losses.CosineSimilarityLoss(model)

然后定义模型开始训练,这里我是用的是 multilingual 预训练模型,因为这个数据集既包含中文,也有韩语。Sentence-Transformer 提供了三个可使用的多语言预训练模型

  • distiluse-base-multilingual-cased: Supported languages: Arabic, Chinese, Dutch, English, French, German, Italian, Korean, Polish, Portuguese, Russian, Spanish, Turkish. Model is based on DistilBERT-multi-lingual.
  • xlm-r-base-en-ko-nli-ststb: Supported languages: English, Korean. Performance on Korean STSbenchmark: 81.47
  • xlm-r-large-en-ko-nli-ststb: Supported languages: English, Korean. Performance on Korean STSbenchmark: 84.05
  • #Define the model. Either from scratch of by loading a pre-trained model
  • model = SentenceTransformer('distiluse-base-multilingual-cased')
  • # Tune the model
  • model.fit(train_objectives=[(train_dataloader, train_loss)], epochs=1, warmup_steps=100, evaluator=evaluator, evaluation_steps=100, output_path='./Ko2CnModel')

每隔 100 次训练集的迭代,进行一次验证,并且它会自动将在验证集上表现最好的模型保存到 output_path

如果要加载模型做测试,使用如下代码即可

  • from sentence_transformers import SentenceTransformer, util
  • model = SentenceTransformer('./Ko2CnModel')
  • # Sentences are encoded by calling model.encode()
  • emb1 = model.encode("터너를 이긴 푸들.")
  • emb2 = model.encode("战胜特纳的泰迪。")
  • cos_sim = util.pytorch_cos_sim(emb1, emb2)
  • print("Cosine-Similarity:", cos_sim)
Last Modified: April 20, 2021
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

37 Comments
  1. leo leo

    Up 主,你好,很喜欢你的博客和视频,看了很多,三连有了!By the way,想问问你的视频中的护眼软件是啥,可以分享一波吗?@(呵呵)

    1. mathor mathor

      @leo 额,我好像没用过任何护眼软件,应该是我博客自带的配色吧

    2. leo leo

      @mathor 就是视频里电脑桌面渐变出来的,是定时提醒?

    3. mathor mathor

      @leo 您具体发个视频链接,说一下几分几秒,我去看看,我记得我真的没用护眼软件

    4. leo leo

      @mathor 尴尬啊,一下还真没找到,我模糊的记得中间是很大的时钟?也有可能看串了...

    5. mathor mathor

      @leo@(不高兴) 你肯定是看的别人的视频

    6. leo leo

      @mathor 奥利给

  2. mrchen mrchen

    up 主是哪个学校的呀?

    1. mathor mathor

      @mrchenb 区一所 211

    2. mrchen mrchen

      @mathor 我浏览一下了 up 主的博客,感觉太棒了,能这么详细的记录自己的学习经历。我也是研一狗,方向也是自然语言处理(但不是关系抽取),不知能否一起交流学习?

    3. mathor mathor

      @mrchen 可以的,欢迎加入计算机养生群 203827195@(呵呵)

    4. mrchen mrchen

      @mathor 已加。不知道 up 组平时有参加一些比赛吗?不知有没有兴趣一起组队(抱大腿嘻嘻)

    5. mathor mathor

      @mrchen 我都是自己小打小闹,不成气候的,这队,不组也罢√

  3. 实时 实时

    你好,distiluse-base-multilingual-cased 这个模型是如何训练得到的?这个模型是知识蒸馏之后的模型吗?

    1. mathor mathor

      @实时是蒸馏得到的,官方预训练好的模型

    2. 实时 实时

      @mathor 那 distiluse-base-multilingual-cased 训练使用的 teacher 模型是哪个?一直查不到

    3. mathor mathor

      @实时这我没注意过,您可以去他们 github 页面提 issue 问一下

    4. 实时 实时

      @mathor 好吧

  4. 谈情说哎 谈情说哎

    老哥 finetune 时候如果两个句子都是中文句子的话,需要进行分词吗

    1. mathor mathor

      @谈情说哎分不分词好像都可以

  5. yan yan

    Up cos 计算怎么没用官方给的 demo,cos_sim = util.pytorch_cos_sim (embeddings [0], embeddings [1]),跟 scipy 里面有什么区别么?谢谢!

    1. mathor mathor

      @yan 都是计算 cos,应该没什么区别,因为当时我没看到这个官方的用法,所以就用的 scipy

  6. 圈

    help,想问下你的数据集的规模有多大 ,参数想参考下,自己用的数据集不知道设置多少参数~~

  7. johnbager johnbager

    合成数据集的代码里 scores 因该写错了把

    1. sun sun

      @johnbager 我感觉也是

    2. sun sun

      @johnbager 我感觉也是

  8. upqing upqing

    请问,这个数据集在哪里可以获得呢,非常感谢

  9. 发蓝 发蓝

    楼主,请教一下,Sentence-Transformer 不能直接用于中文么?还是自己训练?

  10. 卡卡罗特 卡卡罗特

    微调之后训练好的模型不能保存,报错:no entry found for key。请问如何解决?

  11. 博主高人咦 博主高人咦

    博主能发一下 Ko2Cn.xlsx 文件嘛

    1. mathor mathor

      @博主高人咦不能,因为这个文件是我自己的数据集

  12. liusssyang liusssyang

    先用 model = SentenceTransformer ('roberta-large-nli-stsb-mean-tokens') 这个模型做出词嵌入计算损失然后对 model = SentenceTransformer ('distiluse-base-multilingual-cased') 这个模型做微调?我没咋看懂这个操作

    1. mathor mathor

      @liusssyang 睁大你的眼睛

    2. liusssyang liusssyang

      @mathor 没懂 train_loss = losses.CosineSimilarityLoss (model) ,求指教

    3. liusssyang liusssyang

      @liusssyang 你这里的 model 不是前面定义的吗,下面又 model = SentenceTransformer ('distiluse-base-multilingual-cased'),对 这个模型做微调?

  13. bonbonx bonbonx

    请问 eval_size 这个参数根据什么设置的呀?

  14. qihang qihang

    我能用他训练匹配相似信息吗?