三元組損失Triplet loss 詳解

深度神經網絡在識別模式和進行預測方面表現出色,但在涉及圖像識別任務時,它們常常難以區分相似個體的圖像。三元組損失是一種強大的訓練技術,可以解決這個問題,它通過學習相似度度量,在高維空間中將相似圖像準確地嵌入到彼此接近的位置。 在這篇文章中,我們將以簡單的技術術語解析三元組損失及其變體批量三元組損失,並提供一個相關的例子來幫助你理解這些概念。

三元組損失

三元組損失是一種用於訓練神經網絡的損失函數,可以用於執行諸如人臉識別或目標分類等任務。三元組損失的目標是在高維嵌入空間(也稱為特徵空間)中學習一種相似度度量,在這個空間中,相似對象(例如,同一個人的圖像)的表示彼此接近,而不相似對象的表示則相距較遠。

三元組損失的核心概念是使用三元組,它由一個錨點樣本、一個正樣本和一個負樣本組成。錨點樣本和正樣本是相似的實例,而負樣本則是不相似的。算法學習以這樣一種方式嵌入這些樣本:錨點樣本與正樣本之間的距離小於錨點樣本與負樣本之間的距離。

在實踐中,三元組損失通常與一種稱為孿生網絡的神經網絡架構一起使用,該架構在處理相同輸入的兩個或多個分支之間共享權重。這種共享表示允許網絡在嵌入空間中學習一個穩健的相似度度量。

當錨點樣本和正樣本在嵌入空間中不夠接近,或者錨點樣本和負樣本太接近時,三元組損失函數會對網絡進行懲罰。這鼓勵網絡學習輸入數據的有意義表示,捕捉相關樣本之間的相似性。

三元組損失的例子

假設有一組不同人的照片,我們想訓練一個人臉識別系統。目標是識別兩張圖像是否屬於同一個人。三元組損失可以用來學習一個相似度度量,使系統能夠準確識別人臉。

一個三元組由三張照片組成:一個錨點、一個正樣本和一個負樣本。錨點是特定人的照片,正樣本是同一個人的另一張照片,負樣本是不同人的照片。

在訓練過程中,網絡會呈現三元組,三元組損失函數計算錨點、正樣本和負樣本嵌入(高維特徵表示)之間的距離。如果錨點和正樣本嵌入之間的距離太大,或者錨點和負樣本嵌入之間的距離太小,三元組損失函數就會懲罰網絡。

通過基於這個損失函數迭代調整網絡的權重,網絡學會將相似的人臉(即錨點和正樣本)嵌入到嵌入空間中彼此接近的位置,而不相似的人臉(即錨點和負樣本)則被分開。

例如,如果同一個人的兩張照片(錨點和正樣本)的嵌入彼此接近,系統就能準確識別它們屬於同一個人。相反,如果不同人的照片(錨點和負樣本)的嵌入相距較遠,系統就能自信地將它們歸類為屬於不同的個體。

批量三元組損失

批量三元組損失是傳統三元組損失的一種變體,它在訓練過程中對數據批次進行操作。在標準三元組損失中,一個批次由三張圖像組成:一個錨點、一個正樣本和一個負樣本。目標是學習一個相似度度量,例如能夠準確識別人臉。

而批量三元組損失,不是一次處理一個三元組,而是在一個批次中一起處理多個三元組。這種方法在計算上可能更高效,並且可以利用現代 gpu 的能力更快地訓練深度神經網絡。

在訓練過程中,網絡會呈現一批三元組,三元組損失函數計算每個三元組內錨點、正樣本和負樣本嵌入(高維特徵表示)之間的距離。如果錨點和正樣本嵌入之間的距離太大,或者錨點和負樣本嵌入之間的距離太小,批量三元組損失函數就會懲罰網絡。

通過基於這個損失函數迭代調整網絡的權重,網絡學會將相似的特徵(即錨點和正樣本)嵌入到嵌入空間中彼此接近的位置,而不相似的特徵(即錨點和負樣本)則被分開。

例如,如果同一個人的兩張照片(錨點和正樣本)的嵌入彼此接近,系統就能準確識別它們屬於同一個人。相反,如果不同人的照片(錨點和負樣本)的嵌入相距較遠,系統就能自信地將它們歸類為屬於不同的個體。

批量三元組損失是一種有效的方法,用於訓練深度神經網絡進行人臉識別和其他需要相似度度量的應用。

批量三元組損失的例子

假設你是機場的一名安保人員,你的任務是在安檢站識別經過的個人。我們有一個手持設備,一次顯示三張照片:一個錨點、一個正樣本和一個負樣本。目標是快速確定錨點照片中的人是否與正樣本照片中的人相同,如果不同,還需要識別負樣本照片中的人。

這個場景可以被構建為一個批量三元組損失問題。手持設備本質上是在執行一個使用批量三元組損失訓練的深度神經網絡。錨點、正樣本和負樣本圖像是網絡的輸入,輸出是一組嵌入(高維特徵表示),捕捉圖像之間的相似性。網絡被訓練以最小化同一個人的嵌入之間的距離(正對),同時最大化不同人的嵌入之間的距離(負對)。

在這個安保場景中,當手持設備向你呈現一批三元組時,網絡計算每個三元組內錨點、正樣本和負樣本圖像嵌入之間的距離。如果錨點和正樣本圖像的嵌入之間的距離很小,你就可以自信地說它們屬於同一個人。如果距離很大,你就可以將負樣本圖像中的人識別為一個不同的個體。

通過使用批量三元組損失和大型圖像數據集訓練網絡,它學會將相似的圖像(即同一個人的圖像)嵌入到嵌入空間中彼此接近的位置,而不相似的圖像(即不同人的圖像)則被分開。

總結

本文介紹了三元組損失,這是一種用於訓練深度神經網絡的技術,主要應用於圖像識別任務。三元組損失通過學習高維嵌入空間中的相似度度量,使相似圖像的表示彼此接近,不相似圖像的表示相距較遠。

三元組損失的核心概念是使用由錨點、正樣本和負樣本組成的三元組進行訓練。網絡學習將錨點與正樣本的距離最小化,同時最大化與負樣本的距離。而批量三元組損失,這是一種在單個批次中處理多個三元組的變體,提高了計算效率。

作者:jyoti dabass, ph.d


more
kaggle比賽交流和組隊

喜歡就關注一下吧:

點個 在看 你最好看!