TensorFlow函數(shù):tf.losses.mean_squared_error

2018-08-27 11:50 更新

tf.losses.mean_squared_error函數(shù)

tf.losses.mean_squared_error(
    labels,
    predictions,
    weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

定義在:tensorflow/python/ops/losses/losses_impl.py.

在訓(xùn)練過程中增加了平方和loss.

在這個函數(shù)中,weights作為loss的系數(shù).如果提供了標(biāo)量,那么loss只是按給定值縮放.如果weights是一個大小為[batch_size]的張量,那么批次的每個樣本的總損失由weights向量中的相應(yīng)元素重新調(diào)整.如果weights的形狀與predictions的形狀相匹配,則predictions中每個可測量元素的loss由相應(yīng)的weights值縮放.

參數(shù):

  • labels:真實的輸出張量,與“predictions”相同.
  • predictions:預(yù)測的輸出.
  • weights:可選的Tensor,其秩為0或與labels具有相同的秩,并且必須可廣播到labels(即,所有維度必須為1與相應(yīng)的losses具有相同的維度).
  • scope:計算loss時執(zhí)行的操作范圍.
  • loss_collection:將添加loss的集合.
  • reduction:適用于loss的減少類型.

返回:

加權(quán)損失浮動Tensor.如果reduction是NONE,則它的形狀與labels相同;否則,它是標(biāo)量.

可能引發(fā)的異常:

  • ValueError:如果predictions與labels的形狀不匹配,或者形狀weights是無效,亦或,如果labels或是predictions為None,則會引發(fā)此類異常.
以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號