Loading [MathJax]/jax/output/SVG/jax.js
MENU

Siamese Network & Triplet NetWork

October 12, 2020 • Read: 7150 • Deep Learning阅读设置

Siamese Network(孪生网络)

简单来说,孪生网络就是共享参数的两个神经网络

在孪生网络中,我们把一张图片 X1 作为输入,得到该图片的编码 GW(X1)。然后,我们在不对网络参数进行任何更新的情况下,输入另一张图片 X2,并得到改图片的编码 GW(X2)。由于相似的图片应该具有相似的特征(编码),利用这一点,我们就可以比较并判断两张图片的相似性

孪生网络的损失函数

传统的 Siamese Network 使用 Contrastive Loss(对比损失函数)

L=(1Y)12(DW)2+(Y)12{max(0,mDW)}2

其中 DW 被定义为孪生网络两个输入之间的欧氏距离,即

DW={GW(X1)GW(X2)}2

  • Y 值为 0 或 1,如果 X1,X2 这对样本属于同一类,则 Y=0,反之 Y=1
  • m 是边际价值(margin value),即当 Y=1,如果 X1X2 之间距离大于 m,则不做优化(省时省力);如果 X1X2 之间的距离小于 m,则调整参数使其距离增大到 m
Contrastive Loss 代码
  • import torch
  • import numpy as np
  • import torch.nn.functional as F
  • class ContrastiveLoss(torch.nn.Module):
  • "Contrastive loss function"
  • def __init__(self, m=2.0):
  • super(ContrastiveLoss, self).__init__()
  • self.m = m
  • def forward(self, output1, output2, label):
  • d_w = F.pairwise_distance(output1, output2)
  • contrastive_loss = torch.mean((1-label) * 0.5 * torch.pow(d_w, 2) +
  • (label) * 0.5 * torch.pow(torch.clamp(self.m - d_w, min=0.0), 2))
  • return contrastive_loss

其中,F.pairwise_distance(x1, x2, p=2) 函数公式如下

(ni=1(|x1x2|p))1px1,x2Rb×n

pairwise_distance(x1, x2, p) Computes the batchwise pairwise distance between vectors x1, x2 using the p-norm

孪生网络的用途

简单来说,孪生网络的直接用途就是衡量两个输入的差异程度(或者说相似程度)。将两个输入分别送入两个神经网络,得到其在新空间的 representation,然后通过 Loss Function 来计算它们的差异程度(或相似程度)

  • 词汇语义相似度分析,QA 中 question 和 answer 的匹配
  • 手写体识别也可以用 Siamese Network
  • Kaggle 上 Quora 的 Question Pair 比赛,即判断两个提问是否为同一个问题
Pseudo-Siamese Network(伪孪生网络)

对于伪孪生网络来说,两边可以是不同的神经网络(如一个是 lstm,一个是 cnn),并且如果是相同的神经网络,是不共享参数

孪生网络和伪孪生网络分别适用的场景
  • 孪生网络适用于处理两个输入比较类似的情况
  • 伪孪生网络适用于处理两个输入有一定差别的情况

例如,计算两个句子或者词汇的语义相似度,使用 Siamese Network 比较合适;验证标题与正文的描述是否一致(标题和正文长度差别很大),或者文字是否描述了一幅图片(一个是图片,一个是文字)就应该使用 Pseudo-Siamese Network

Triplet Network(三胞胎网络)

如果说 Siamese Network 是双胞胎,那 Triplet Network 就是三胞胎。它的输入是三个:一个正例 + 两个负例,或一个负例 + 两个正例。训练的目标仍然是让相同类别间的距离尽可能小,不同类别间的距离尽可能大。Triplet Network 在 CIFAR,MNIST 数据集上效果均超过了 Siamese Network

损失函数定义如下:

L=max(d(a,p)d(a,n)+margin,0)

  • a 表示 anchor 图像
  • p 表示 positive 图像
  • n 表示 negative 图像

我们希望 ap 的距离应该小于 an 的距离。margin 是个超参数,它表示 d(a,p)d(a,n) 之间应该相差多少,例如,假设 margin=0.2,并且 d(a,p)=0.5,那么 d(a,n) 应该大于等于 0.7

Reference

Last Modified: April 20, 2021
Archives Tip
QR Code for this page
Tipping QR Code
Leave a Comment

  • OωO
  • |´・ω・)ノ
  • ヾ(≧∇≦*)ゝ
  • (☆ω☆)
  • (╯‵□′)╯︵┴─┴
  •  ̄﹃ ̄
  • (/ω\)
  • ∠( ᐛ 」∠)_
  • (๑•̀ㅁ•́ฅ)
  • →_→
  • ୧(๑•̀⌄•́๑)૭
  • ٩(ˊᗜˋ*)و
  • (ノ°ο°)ノ
  • (´இ皿இ`)
  • ⌇●﹏●⌇
  • (ฅ´ω`ฅ)
  • (╯°A°)╯︵○○○
  • φ( ̄∇ ̄o)
  • ヾ(´・ ・`。)ノ"
  • ( ง ᵒ̌皿ᵒ̌)ง⁼³₌₃
  • (ó﹏ò。)
  • Σ(っ °Д °;)っ
  • ( ,,´・ω・)ノ"(´っω・`。)
  • ╮(╯▽╰)╭
  • o(*////▽////*)q
  • >﹏<
  • ( ๑´•ω•) "(ㆆᴗㆆ)
  • (。•ˇ‸ˇ•。)
  • 泡泡
  • 阿鲁
  • 颜文字

已有 1 条评论
  1. Sentence-BERT 详解 – 闪念基因 – 个人技术分享

    [...] 更多关于 Triplet Network 的内容可以看我的这篇 Siamese Network & Triplet NetWork [...]