TensorFlow函數(shù):tf.losses.compute_weighted_loss

2018-08-18 15:46 更新

tf.losses.compute_weighted_loss函數(shù)

tf.losses.compute_weighted_loss(
    losses,
    weights=1.0,
    scope=None,
    loss_collection=tf.GraphKeys.LOSSES,
    reduction=Reduction.SUM_BY_NONZERO_WEIGHTS
)

定義在:tensorflow/python/ops/losses/losses_impl.py.

計(jì)算加權(quán)l(xiāng)oss.

參數(shù):

  • losses:形狀為[batch_size, d1, ... dN]的Tensor.
  • weights:可選的,其秩為0或與losses具有相同等級(jí)的Tensor,并且必須可廣播到losses(即,所有維度必須為1或者與相應(yīng)的losses維度相同).
  • scope:計(jì)算loss時(shí)執(zhí)行的操作范圍.
  • loss_collection:loss將被添加到這些集合中.
  • reduction:適用于loss的減少類(lèi)型.

返回:

與losses相同類(lèi)型的加權(quán)損失Tensor,如果reduction是NONE,它的形狀與losses相同;否則,它是標(biāo)量.

可能引發(fā)的異常:

  • ValueError:如果weights是None,或者與losses的形狀不兼容,或者是否存在losses或weights缺少維度(秩)的數(shù)量.
注意:當(dāng)計(jì)算的來(lái)自?xún)蓚€(gè)加權(quán)損失捐款的梯度losses和weights被考慮.如果你weights依賴(lài)于某些模型參數(shù),但你不希望這會(huì)影響損失梯度,則需要應(yīng)用tf.stop_gradient到weights它們傳遞給前compute_weighted_loss.

注: 當(dāng)計(jì)算從lossesweights的加權(quán)l(xiāng)oss貢獻(xiàn)的梯度時(shí)要考慮.如果您的權(quán)重依賴(lài)于某些模型參數(shù),但您不希望此影響loss漸變,則在將它們傳遞給compute_weighted_loss之前,需要將tf. stop_gradient傳遞給weights.

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)