W3Cschool
恭喜您成為首批注冊(cè)用戶(hù)
獲得88經(jīng)驗(yàn)值獎(jiǎng)勵(lì)
用于神經(jīng)網(wǎng)絡(luò)的損失操作.
注意:默認(rèn)情況下,所有損失都將收集到 GraphKeys.LOSSES 集中。
所有的損失函數(shù)都采取一對(duì)預(yù)測(cè)和基準(zhǔn)真實(shí)標(biāo)簽,從中計(jì)算損失。假設(shè)這兩個(gè)張量的形狀是 [batch_size,d1,... dN] 的形式,其中 batch_size 是批次中的樣品數(shù)量,而 d1... dN 是其余尺寸。
在多次損失功能訓(xùn)練時(shí),通常會(huì)調(diào)整個(gè)人損失的相對(duì)優(yōu)勢(shì)。這是通過(guò)權(quán)重傳遞給損失函數(shù)的參數(shù)重新調(diào)整損失來(lái)執(zhí)行的。例如,如果我們訓(xùn)練 log_loss 和 sum_of_squares_loss,并且我們希望 log_loss 懲罰是 sum_of_squares_loss 的兩倍,我們將實(shí)現(xiàn):
# Explicitly set the weight.
tf.contrib.losses.log(predictions, labels, weight=2.0)
# Uses default weight of 1.0
tf.contrib.losses.mean_squared_error(predictions, labels)
# All the losses are collected into the `GraphKeys.LOSSES` collection.
losses = tf.get_collection(tf.GraphKeys.LOSSES)
在指定標(biāo)量損失的同時(shí),整個(gè)批處理中的損失可能會(huì)重新計(jì)算,我們有時(shí)希望重新調(diào)整每個(gè)批處理樣本的損失。例如,如果我們有一些比較重要的例子可以讓我們得到正確的結(jié)果,那么我們可能想要損失更多的其他錯(cuò)誤的樣本。在這種情況下,我們可以提供長(zhǎng)度的權(quán)重向量 batch_size,導(dǎo)致批處理中每個(gè)樣本的損失由相應(yīng)的權(quán)重元素縮放。例如,考慮分類(lèi)問(wèn)題的情況,我們希望最大化我們的準(zhǔn)確性,但我們特別有興趣獲得高精度的特定類(lèi):
inputs, labels = LoadData(batch_size=3)
logits = MyModelPredictions(inputs)
# Ensures that the loss for examples whose ground truth class is `3` is 5x
# higher than the loss for all other examples.
weight = tf.multiply(4, tf.cast(tf.equal(labels, 3), tf.float32)) + 1
onehot_labels = tf.one_hot(labels, num_classes=5)
tf.contrib.losses.softmax_cross_entropy(logits, onehot_labels, weight=weight)
最后,在某些情況下,我們可能希望為每個(gè)可衡量的值指定不同的損失。例如,如果我們執(zhí)行每像素深度預(yù)測(cè)或每像素去噪,則單個(gè)批次樣本具有P值,其中P是圖像中的像素?cái)?shù)。對(duì)于許多損失,可測(cè)量值的數(shù)量與預(yù)測(cè)和標(biāo)簽張量中的元素?cái)?shù)量相匹配。對(duì)于其他的,例如 softmax_cross_entropy 和 cosine_distance,損失函數(shù)減小輸入的維數(shù),以產(chǎn)生每個(gè)可測(cè)量值的張量。例如,softmax_cross_entropy 作為維度 [batch_size,num_classes] 的輸入預(yù)測(cè)和標(biāo)簽,但可測(cè)量值的數(shù)量為 [batch_size]。因此,當(dāng)通過(guò)權(quán)重張量以指定每個(gè)可測(cè)量值的不同損失時(shí),
對(duì)于具體的例子,考慮某些地面真值深度值缺失(由于捕獲過(guò)程中的傳感器噪聲)的每像素深度預(yù)測(cè)的情況.在這種情況下,我們要為這些預(yù)測(cè)分配零權(quán)重給損失。
# 'depths' that are missing have a value of 0:
images, depths = LoadData(...)
predictions = MyModelPredictions(images)
weight = tf.cast(tf.greater(depths, 0), tf.float32)
loss = tf.contrib.losses.mean_squared_error(predictions, depths, weight)
注意,當(dāng)使用權(quán)重作為損失時(shí),最終的平均值是通過(guò)將權(quán)重重新分配權(quán)重來(lái)計(jì)算的,然后除以非零樣本的總數(shù)。對(duì)于任意一組權(quán)重,這可能不一定產(chǎn)生加權(quán)平均值。相反,在平均觀測(cè)數(shù)量之前,它簡(jiǎn)單而透明地調(diào)整了每個(gè)元素的損失。例如,如果由損失函數(shù)計(jì)算的損失是數(shù)組[4,1,2,3],權(quán)重是數(shù)組[1,0.5,3,9],那么平均損失是:
(4*1 + 1*0.5 + 2*3 + 3*9) / 4
然而,利用單個(gè)損失函數(shù)和任意權(quán)重集合,仍然可以容易地創(chuàng)建損失函數(shù),使得所得到的損失是相對(duì)于各個(gè)預(yù)測(cè)誤差的加權(quán)平均值:
images, labels = LoadData(...)
predictions = MyModelPredictions(images)
weight = MyComplicatedWeightingFunction(labels)
weight = tf.div(weight, tf.size(weight))
loss = tf.contrib.losses.mean_squared_error(predictions, depths, weight)
以下是不推薦使用的:mean_pairwise_squared_error 和 mean_squared_error.
Copyright©2021 w3cschool編程獅|閩ICP備15016281號(hào)-3|閩公網(wǎng)安備35020302033924號(hào)
違法和不良信息舉報(bào)電話(huà):173-0602-2364|舉報(bào)郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號(hào)
聯(lián)系方式:
更多建議: