TensorFlow函數(shù)教程:tf.nn.embedding_lookup

2019-01-31 13:47 更新

tf.nn.embedding_lookup函數(shù)

tf.nn.embedding_lookup(
    params,
    ids,
    partition_strategy='mod',
    name=None,
    validate_indices=True,
    max_norm=None
)

定義在:tensorflow/python/ops/embedding_ops.py.

請(qǐng)參閱指南:神經(jīng)網(wǎng)絡(luò)>Embeddings(嵌套)

在 embedding 張量列表中查找 ids.

此函數(shù)用于在 params 的張量列表中執(zhí)行并行查找.它是tf.gather的概括,其中params解釋為大型 embedding 張量的分區(qū).params 可以是使用帶分區(qū)的 tf.get_variable() 返回的 PartitionedVariable .

如果 len(params) > 1,ids 的每個(gè)元素 id 根據(jù) partition_strategy 在 params 元素之間被分區(qū).在所有策略中,如果 id 空間不均勻地劃分分區(qū)數(shù),則前(max_id + 1)% len(params)個(gè)分區(qū)中的每個(gè)分區(qū)將再分配一個(gè)id.

如果 partition_strategy 是 "mod",我們將每個(gè) id 分配給分區(qū) p = id % len(params).例如,13個(gè) id 分為5個(gè)分區(qū):[[0, 5, 10], [1, 6, 11], [2, 7, 12], [3, 8], [4, 9]]

如果 partition_strategy 是 "div",我們以連續(xù)的方式將 id 分配給分區(qū).在這種情況下,13 個(gè) id 分為5個(gè)分區(qū):[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10], [11, 12]]

查詢的結(jié)果被連接成一個(gè)密集的張量.返回的張量的 shape 為 shape(ids) + shape(params)[1:].

參數(shù):

  • params:表示完整的 embedding 張量的單張量,或除了第一維之外全部具有相同 shape 的 P 張量列表,表示切分的 embedding 張量.或者,一個(gè) PartitionedVariable,通過沿維度0進(jìn)行分區(qū)創(chuàng)建.對(duì)于給定的 partition_strategy,每個(gè)元素的大小必須適當(dāng).
  • ids:一個(gè) int32 或 int64 類型的 Tensor,包含要在 params 中查找的 id.
  • partition_strategy:指定切分策略的字符串,在 len(params) > 1 的情況下使用.目前支持兩種切分方式:"div"和"mod",默認(rèn)是"mod".
  • name:操作的名稱(可選).
  • validate_indices:已棄用.如果將此操作分配給 CPU,則 indices 中的值始終被驗(yàn)證為在范圍內(nèi).如果分配給 GPU,則超出范圍的 indices 會(huì)導(dǎo)致安全的但未指定的行為,這可能包括引發(fā)錯(cuò)誤.
  • max_norm:如果提供該參數(shù),embedding 值將被 L2-normalize 為 max_norm 的值.

返回:

該函數(shù)與 params 中的張量具有相同類型的 Tensor.

可能引發(fā)的異常:

  • ValueError:如果 params 是空的.
以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)