TensorFlow函數(shù):tf.nn.compute_accidental_hits

2019-01-31 11:30 更新

tf.nn.compute_accidental_hits函數(shù)

tf.nn.compute_accidental_hits(
    true_classes,
    sampled_candidates,
    num_true,
    seed=None,
    name=None
)

定義在:tensorflow/python/ops/candidate_sampling_ops.py.

請(qǐng)參閱指南:神經(jīng)網(wǎng)絡(luò)>候選抽樣

計(jì)算與true_classes匹配的sampled_candidate中的位置id.

在Candidate Sampling中,此操作實(shí)際上有助于刪除恰好與目標(biāo)類(lèi)匹配的抽樣類(lèi).這在Sampled Softmax和Sampled Logistic中完成.

我們預(yù)先假定sampled_candidates是獨(dú)一無(wú)二的.

當(dāng)其中一個(gè)目標(biāo)類(lèi)與其中一個(gè)抽樣類(lèi)匹配時(shí),我們將其稱(chēng)為“意外命中”.此操作將意外命中報(bào)告為三元組(index, id, weight),其中index表示true_classes中的行號(hào),id表示sampled_candidates中的位置,權(quán)重為-FLOAT_MAX.

此op的結(jié)果應(yīng)該通過(guò)一個(gè)sparse_to_dense操作來(lái)傳遞,然后添加到抽樣類(lèi)的logits中.這消除了意外采樣真實(shí)目標(biāo)類(lèi)作為同一示例的噪聲類(lèi)的矛盾效果.

參數(shù):

  • true_classes:一個(gè)Tensor,器類(lèi)型為int64,并且形狀為[batch_size, num_true];是目標(biāo)類(lèi).
  • sampled_candidates:一個(gè)Tensor,類(lèi)型為int64,并且形狀為[num_sampled];CandidateSampler的sampled_candidates輸出.
  • num_true:int,每個(gè)訓(xùn)練示例的目標(biāo)類(lèi)數(shù).
  • seed:int,特定于操作的seed;默認(rèn)值為0.
  • name:操作的名稱(chēng)(可選).

返回:

  • indices:一個(gè)Tensor,其類(lèi)型為int32,并且形狀為[num_accidental_hits];值表示true_classes中的行.
  • ids:一個(gè)Tensor,類(lèi)型為int64,并且形狀為[num_accidental_hits];值表示sampled_candidates中的位置.
  • weights:一個(gè)Tensor,其類(lèi)型為float,并且形狀為[num_accidental_hits];每個(gè)值都是-FLOAT_MAX.
以上內(nèi)容是否對(duì)您有幫助:
在線(xiàn)筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)