App下載

pytorch中常用的損失函數(shù)用法說明

幼兒園搶飯第一名 2021-08-17 15:09:43 瀏覽數(shù) (8759)
反饋

在機器學(xué)習(xí)中,損失函數(shù)通常作為學(xué)習(xí)準則與優(yōu)化問題相聯(lián)系,不同的場景運用的損失函數(shù)也不相同,今天小編就帶來了pytorch中常用的一些損失函數(shù)及其用法說明,趕快收藏起來吧。

1. pytorch中常用的損失函數(shù)列舉

pytorch中的nn模塊提供了很多可以直接使用的loss函數(shù), 比如MSELoss(), CrossEntropyLoss(), NLLLoss() 等

官方鏈接: https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html

pytorch中常用的損失函數(shù)
損失函數(shù) 名稱 適用場景
torch.nn.MSELoss() 均方誤差損失 回歸
torch.nn.L1Loss() 平均絕對值誤差損失 回歸
torch.nn.CrossEntropyLoss() 交叉熵損失 多分類
torch.nn.NLLLoss() 負對數(shù)似然函數(shù)損失 多分類
torch.nn.NLLLoss2d() 圖片負對數(shù)似然函數(shù)損失 圖像分割
torch.nn.KLDivLoss() KL散度損失 回歸
torch.nn.BCELoss() 二分類交叉熵損失 二分類
torch.nn.MarginRankingLoss() 評價相似度的損失
torch.nn.MultiLabelMarginLoss() 多標簽分類的損失 多標簽分類
torch.nn.SmoothL1Loss() 平滑的L1損失 回歸
torch.nn.SoftMarginLoss() 多標簽二分類問題的損失

多標簽二分類

2. 比較CrossEntropyLoss() 和NLLLoss()

(1). CrossEntropyLoss():

torch.nn.CrossEntropyLoss(weight=None,   # 1D張量,含n個元素,分別代表n類的權(quán)重,樣本不均衡時常用
                          size_average=None, 
                          ignore_index=-100, 
                          reduce=None, 
                          reduction='mean' )

參數(shù):

weight: 1D張量,含n個元素,分別代表n類的權(quán)重,樣本不均衡時常用, 默認為None.

計算公式:

weight = None時:

weight ≠ None時:

輸入:

output: 網(wǎng)絡(luò)未加softmax的輸出

target: label值(0,1,2 不是one-hot)

代碼:

loss_func = CrossEntropyLoss(weight=torch.from_numpy(np.array([0.03,0.05,0.19,0.26,0.47])).float().to(device) ,size_average=True)
loss = loss_func(output, target)

(2). NLLLoss():

torch.nn.NLLLoss(weight=None, 
                size_average=None, 
                ignore_index=-100,
                reduce=None, 
                reduction='mean')

輸入:

output: 網(wǎng)絡(luò)在logsoftmax后的輸出

target: label值(0,1,2 不是one-hot)

代碼:

loss_func = NLLLoss(weight=torch.from_numpy(np.array([0.03,0.05,0.19,0.26,0.47])).float().to(device) ,size_average=True)
loss = loss_func(output, target)


(3). 二者總結(jié)比較:

總之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具體等價應(yīng)用如下:

####################---CrossEntropyLoss()---#######################
 
loss_func = CrossEntropyLoss()
loss = loss_func(output, target)
 
####################---Softmax+log+NLLLoss()---####################
 
self.softmax = nn.Softmax(dim = -1)
 
x = self.softmax(x)
output = torch.log(x)
 
loss_func = NLLLoss()
loss = loss_func(output, target)
 
####################---LogSoftmax+NLLLoss()---######################
 
self.log_softmax = nn.LogSoftmax(dim = -1)
 
output = self.log_softmax(x)
 
loss_func = NLLLoss()
loss = loss_func(output, target)

補充:常用損失函數(shù)用法小結(jié)之Pytorch框架

在用深度學(xué)習(xí)做圖像處理的時候,常用到的損失函數(shù)無非有四五種,為了方便Pytorch使用者,所以簡要做以下總結(jié)

1)L1損失函數(shù)

預(yù)測值與標簽值進行相差,然后取絕對值,根據(jù)實際應(yīng)用場所,可以設(shè)置是否求和,求平均,公式可見下,Pytorch調(diào)用函數(shù):nn.L1Loss

2)L2損失函數(shù)

預(yù)測值與標簽值進行相差,然后取平方,根據(jù)實際應(yīng)用場所,可以設(shè)置是否求和,求平均,公式可見下,Pytorch調(diào)用函數(shù):nn.MSELoss

3)Huber Loss損失函數(shù)

簡單來說就是L1和L2損失函數(shù)的綜合版本,結(jié)合了兩者的優(yōu)點,公式可見下,Pytorch調(diào)用函數(shù):nn.SmoothL1Loss

4)二分類交叉熵損失函數(shù)

簡單來說,就是度量兩個概率分布間的差異性信息,在某一程度上也可以防止梯度學(xué)習(xí)過慢,公式可見下,Pytorch調(diào)用函數(shù)有兩個,一個是nn.BCELoss函數(shù),用的時候要結(jié)合Sigmoid函數(shù),另外一個是nn.BCEWithLogitsLoss()

5)多分類交叉熵損失函數(shù)

也是度量兩個概率分布間的差異性信息,Pytorch調(diào)用函數(shù)也有兩個,一個是nn.NLLLoss,用的時候要結(jié)合log softmax處理,另外一個是nn.CrossEntropyLoss

以上就是pytorch中常用的一些損失函數(shù)及其用法說明,希望能給大家一個參考,也希望大家多多支持W3Cschool。



0 人點贊