TensorFlow函數(shù)教程:tf.nn.nce_loss

2020-10-20 15:17 更新

tf.nn.nce_loss函數(shù)

tf.nn.nce_loss(
    weights,
    biases,
    labels,
    inputs,
    num_sampled,
    num_classes,
    num_true=1,
    sampled_values=None,
    remove_accidental_hits=False,
    partition_strategy='mod',
    name='nce_loss'
)

定義在:tensorflow/python/ops/nn_impl.py.

請(qǐng)參閱指南:神經(jīng)網(wǎng)絡(luò)>候選采樣

計(jì)算并返回噪聲對(duì)比估計(jì)(NCE, Noise Contrastive Estimation)訓(xùn)練損失.

一個(gè)常見的用例是使用此方法進(jìn)行訓(xùn)練,并計(jì)算完整的S形模型損失以進(jìn)行評(píng)估或推斷.在這種情況下,您必須將partition_strategy="div",使兩個(gè)損失保持一致,如下例所示:

if mode == "train":
  loss = tf.nn.nce_loss(
      weights=weights,
      biases=biases,
      labels=labels,
      inputs=inputs,
      ...,
      partition_strategy="div")
elif mode == "eval":
  logits = tf.matmul(inputs, tf.transpose(weights))
  logits = tf.nn.bias_add(logits, biases)
  labels_one_hot = tf.one_hot(labels, n_classes)
  loss = tf.nn.sigmoid_cross_entropy_with_logits(
      labels=labels_one_hot,
      logits=logits)
  loss = tf.reduce_sum(loss, axis=1)
注意:默認(rèn)情況下,它使用對(duì)數(shù)均勻(Zipfian)分布進(jìn)行采樣,因此必須按照頻率遞減的順序?qū)?biāo)簽進(jìn)行排序,以獲得良好的結(jié)果.有關(guān)詳細(xì)信息,請(qǐng)參閱tf.nn.log_uniform_candidate_sampler.
注意:在num_true> 1 的情況下,我們?yōu)槊總€(gè)目標(biāo)類分配目標(biāo)概率1/num_true,以便目標(biāo)概率總和為每個(gè)示例1.
注意:每個(gè)示例允許目標(biāo)類的變量數(shù)量是有用的.我們希望在將來的版本中提供此功能.現(xiàn)在,如果你有一個(gè)目標(biāo)類的變量數(shù)量,你可以通過重復(fù)它們或通過填充其他未使用的類來將它們填充到一個(gè)常數(shù). 

參數(shù):

  • weights:一個(gè)Tensor,shape[num_classes, dim],或者是Tensor對(duì)象列表,其沿著維度0的連接具有shape [num_classes,dim].(可能是分區(qū)的)類嵌入.
  • biases:一個(gè)Tensor,shape[num_classes].類偏差.
  • labels:一個(gè)Tensor,類型為int64shape [batch_size, num_true].目標(biāo)類.
  • inputs:一個(gè)Tensor,shape [batch_size, dim].輸入網(wǎng)絡(luò)的正向激活.
  • num_sampled:int,每批隨機(jī)抽樣的類數(shù).
  • num_classes:int,可能的類數(shù).
  • num_true:int,每個(gè)訓(xùn)練示例的目標(biāo)類數(shù).
  • sampled_values:由* _candidate_sampler函數(shù)返回的元組(sampled_candidates,true_expected_count,sampled_expected_count).(如果是None,我們默認(rèn)為log_uniform_candidate_sampler)
  • remove_accidental_hits:bool.是否刪除“意外命中”,其中采樣類等于其中一個(gè)目標(biāo)類.如果設(shè)置為True,則這是“Sampled Logistic”損失而不是NCE,我們正在學(xué)習(xí)生成對(duì)數(shù)賠率而不是對(duì)數(shù)概率.請(qǐng)參閱[候選采樣算法參考](https://www.tensorflow.org/extras/candidate_sampling.pdf).默認(rèn)值為False.
  • partition_strategy:指定分區(qū)策略的字符串,如果len(weights) > 1.目前"div""mod"受到支持.默認(rèn)是"mod".更多詳細(xì)信息,請(qǐng)參閱tf.nn.embedding_lookup
  • name:操作的名稱(可選).

返回:

每個(gè)示例NCE損失的batch_size 1-D張量.


以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)