TensorFlow損失(contrib)

2019-01-31 18:11 更新

用于神經(jīng)網(wǎng)絡(luò)的損失操作.

注意:默認(rèn)情況下,所有損失都將收集到 GraphKeys.LOSSES 集中。

所有的損失函數(shù)都采取一對(duì)預(yù)測(cè)和基準(zhǔn)真實(shí)標(biāo)簽,從中計(jì)算損失。假設(shè)這兩個(gè)張量的形狀是 [batch_size,d1,... dN] 的形式,其中 batch_size 是批次中的樣品數(shù)量,而 d1... dN 是其余尺寸。

在多次損失功能訓(xùn)練時(shí),通常會(huì)調(diào)整個(gè)人損失的相對(duì)優(yōu)勢(shì)。這是通過權(quán)重傳遞給損失函數(shù)的參數(shù)重新調(diào)整損失來執(zhí)行的。例如,如果我們訓(xùn)練 log_loss 和 sum_of_squares_loss,并且我們希望 log_loss 懲罰是 sum_of_squares_loss 的兩倍,我們將實(shí)現(xiàn):

# Explicitly set the weight.
tf.contrib.losses.log(predictions, labels, weight=2.0)

# Uses default weight of 1.0
tf.contrib.losses.mean_squared_error(predictions, labels)

# All the losses are collected into the `GraphKeys.LOSSES` collection.
losses = tf.get_collection(tf.GraphKeys.LOSSES)

在指定標(biāo)量損失的同時(shí),整個(gè)批處理中的損失可能會(huì)重新計(jì)算,我們有時(shí)希望重新調(diào)整每個(gè)批處理樣本的損失。例如,如果我們有一些比較重要的例子可以讓我們得到正確的結(jié)果,那么我們可能想要損失更多的其他錯(cuò)誤的樣本。在這種情況下,我們可以提供長(zhǎng)度的權(quán)重向量 batch_size,導(dǎo)致批處理中每個(gè)樣本的損失由相應(yīng)的權(quán)重元素縮放。例如,考慮分類問題的情況,我們希望最大化我們的準(zhǔn)確性,但我們特別有興趣獲得高精度的特定類:

inputs, labels = LoadData(batch_size=3)
logits = MyModelPredictions(inputs)

# Ensures that the loss for examples whose ground truth class is `3` is 5x
# higher than the loss for all other examples.
weight = tf.multiply(4, tf.cast(tf.equal(labels, 3), tf.float32)) + 1

onehot_labels = tf.one_hot(labels, num_classes=5)
tf.contrib.losses.softmax_cross_entropy(logits, onehot_labels, weight=weight)

最后,在某些情況下,我們可能希望為每個(gè)可衡量的值指定不同的損失。例如,如果我們執(zhí)行每像素深度預(yù)測(cè)或每像素去噪,則單個(gè)批次樣本具有P值,其中P是圖像中的像素?cái)?shù)。對(duì)于許多損失,可測(cè)量值的數(shù)量與預(yù)測(cè)和標(biāo)簽張量中的元素?cái)?shù)量相匹配。對(duì)于其他的,例如 softmax_cross_entropy 和 cosine_distance,損失函數(shù)減小輸入的維數(shù),以產(chǎn)生每個(gè)可測(cè)量值的張量。例如,softmax_cross_entropy 作為維度 [batch_size,num_classes] 的輸入預(yù)測(cè)和標(biāo)簽,但可測(cè)量值的數(shù)量為 [batch_size]。因此,當(dāng)通過權(quán)重張量以指定每個(gè)可測(cè)量值的不同損失時(shí),

對(duì)于具體的例子,考慮某些地面真值深度值缺失(由于捕獲過程中的傳感器噪聲)的每像素深度預(yù)測(cè)的情況.在這種情況下,我們要為這些預(yù)測(cè)分配零權(quán)重給損失。

# 'depths' that are missing have a value of 0:
images, depths = LoadData(...)
predictions = MyModelPredictions(images)

weight = tf.cast(tf.greater(depths, 0), tf.float32)
loss  = tf.contrib.losses.mean_squared_error(predictions, depths, weight)

注意,當(dāng)使用權(quán)重作為損失時(shí),最終的平均值是通過將權(quán)重重新分配權(quán)重來計(jì)算的,然后除以非零樣本的總數(shù)。對(duì)于任意一組權(quán)重,這可能不一定產(chǎn)生加權(quán)平均值。相反,在平均觀測(cè)數(shù)量之前,它簡(jiǎn)單而透明地調(diào)整了每個(gè)元素的損失。例如,如果由損失函數(shù)計(jì)算的損失是數(shù)組[4,1,2,3],權(quán)重是數(shù)組[1,0.5,3,9],那么平均損失是:

(4*1 + 1*0.5 + 2*3 + 3*9) / 4

然而,利用單個(gè)損失函數(shù)和任意權(quán)重集合,仍然可以容易地創(chuàng)建損失函數(shù),使得所得到的損失是相對(duì)于各個(gè)預(yù)測(cè)誤差的加權(quán)平均值:

images, labels = LoadData(...)
predictions = MyModelPredictions(images)

weight = MyComplicatedWeightingFunction(labels)
weight = tf.div(weight, tf.size(weight))
loss = tf.contrib.losses.mean_squared_error(predictions, depths, weight)

  • tf.contrib.losses.absolute_difference 
  • tf.contrib.losses.add_loss
  • tf.contrib.losses.hinge_loss
  • tf.contrib.losses.compute_weighted_loss
  • tf.contrib.losses.cosine_distance
  • tf.contrib.losses.get_losses 
  • tf.contrib.losses.get_regularization_losses
  • tf.contrib.losses.get_total_loss 
  • tf.contrib.losses.log_loss 
  • tf.contrib.losses.mean_pairwise_squared_error 
  • tf.contrib.losses.mean_squared_error
  • tf.contrib.losses.sigmoid_cross_entropy 
  • tf.contrib.losses.softmax_cross_entropy 
  • tf.contrib.losses.sparse_softmax_cross_entropy

以下是不推薦使用的:mean_pairwise_squared_error 和 mean_squared_error.

  • tf.contrib.losses.sum_of_pairwise_squares
  • tf.contrib.losses.sum_of_squares
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)