W3Cschool
恭喜您成為首批注冊用戶
獲得88經(jīng)驗值獎勵
用于神經(jīng)網(wǎng)絡(luò)的損失操作.
注意:默認(rèn)情況下,所有損失都將收集到 GraphKeys.LOSSES 集中。
所有的損失函數(shù)都采取一對預(yù)測和基準(zhǔn)真實標(biāo)簽,從中計算損失。假設(shè)這兩個張量的形狀是 [batch_size,d1,... dN] 的形式,其中 batch_size 是批次中的樣品數(shù)量,而 d1... dN 是其余尺寸。
在多次損失功能訓(xùn)練時,通常會調(diào)整個人損失的相對優(yōu)勢。這是通過權(quán)重傳遞給損失函數(shù)的參數(shù)重新調(diào)整損失來執(zhí)行的。例如,如果我們訓(xùn)練 log_loss 和 sum_of_squares_loss,并且我們希望 log_loss 懲罰是 sum_of_squares_loss 的兩倍,我們將實現(xiàn):
# Explicitly set the weight.
tf.contrib.losses.log(predictions, labels, weight=2.0)
# Uses default weight of 1.0
tf.contrib.losses.mean_squared_error(predictions, labels)
# All the losses are collected into the `GraphKeys.LOSSES` collection.
losses = tf.get_collection(tf.GraphKeys.LOSSES)
在指定標(biāo)量損失的同時,整個批處理中的損失可能會重新計算,我們有時希望重新調(diào)整每個批處理樣本的損失。例如,如果我們有一些比較重要的例子可以讓我們得到正確的結(jié)果,那么我們可能想要損失更多的其他錯誤的樣本。在這種情況下,我們可以提供長度的權(quán)重向量 batch_size,導(dǎo)致批處理中每個樣本的損失由相應(yīng)的權(quán)重元素縮放。例如,考慮分類問題的情況,我們希望最大化我們的準(zhǔn)確性,但我們特別有興趣獲得高精度的特定類:
inputs, labels = LoadData(batch_size=3)
logits = MyModelPredictions(inputs)
# Ensures that the loss for examples whose ground truth class is `3` is 5x
# higher than the loss for all other examples.
weight = tf.multiply(4, tf.cast(tf.equal(labels, 3), tf.float32)) + 1
onehot_labels = tf.one_hot(labels, num_classes=5)
tf.contrib.losses.softmax_cross_entropy(logits, onehot_labels, weight=weight)
最后,在某些情況下,我們可能希望為每個可衡量的值指定不同的損失。例如,如果我們執(zhí)行每像素深度預(yù)測或每像素去噪,則單個批次樣本具有P值,其中P是圖像中的像素數(shù)。對于許多損失,可測量值的數(shù)量與預(yù)測和標(biāo)簽張量中的元素數(shù)量相匹配。對于其他的,例如 softmax_cross_entropy 和 cosine_distance,損失函數(shù)減小輸入的維數(shù),以產(chǎn)生每個可測量值的張量。例如,softmax_cross_entropy 作為維度 [batch_size,num_classes] 的輸入預(yù)測和標(biāo)簽,但可測量值的數(shù)量為 [batch_size]。因此,當(dāng)通過權(quán)重張量以指定每個可測量值的不同損失時,
對于具體的例子,考慮某些地面真值深度值缺失(由于捕獲過程中的傳感器噪聲)的每像素深度預(yù)測的情況.在這種情況下,我們要為這些預(yù)測分配零權(quán)重給損失。
# 'depths' that are missing have a value of 0:
images, depths = LoadData(...)
predictions = MyModelPredictions(images)
weight = tf.cast(tf.greater(depths, 0), tf.float32)
loss = tf.contrib.losses.mean_squared_error(predictions, depths, weight)
注意,當(dāng)使用權(quán)重作為損失時,最終的平均值是通過將權(quán)重重新分配權(quán)重來計算的,然后除以非零樣本的總數(shù)。對于任意一組權(quán)重,這可能不一定產(chǎn)生加權(quán)平均值。相反,在平均觀測數(shù)量之前,它簡單而透明地調(diào)整了每個元素的損失。例如,如果由損失函數(shù)計算的損失是數(shù)組[4,1,2,3],權(quán)重是數(shù)組[1,0.5,3,9],那么平均損失是:
(4*1 + 1*0.5 + 2*3 + 3*9) / 4
然而,利用單個損失函數(shù)和任意權(quán)重集合,仍然可以容易地創(chuàng)建損失函數(shù),使得所得到的損失是相對于各個預(yù)測誤差的加權(quán)平均值:
images, labels = LoadData(...)
predictions = MyModelPredictions(images)
weight = MyComplicatedWeightingFunction(labels)
weight = tf.div(weight, tf.size(weight))
loss = tf.contrib.losses.mean_squared_error(predictions, depths, weight)
以下是不推薦使用的:mean_pairwise_squared_error 和 mean_squared_error.
Copyright©2021 w3cschool編程獅|閩ICP備15016281號-3|閩公網(wǎng)安備35020302033924號
違法和不良信息舉報電話:173-0602-2364|舉報郵箱:jubao@eeedong.com
掃描二維碼
下載編程獅App
編程獅公眾號
聯(lián)系方式:
更多建議: