TensorFlow函數(shù)教程:tf.nn.embedding_lookup_sparse

2019-01-31 13:47 更新

tf.nn.embedding_lookup_sparse函數(shù)

tf.nn.embedding_lookup_sparse(
    params,
    sp_ids,
    sp_weights,
    partition_strategy='mod',
    name=None,
    combiner=None,
    max_norm=None
)

定義在:tensorflow/python/ops/embedding_ops.py.

請參閱指南:神經(jīng)網(wǎng)絡(luò)>Embeddings(嵌套)

計(jì)算給定 id 和 weight 的 embedding.

此操作假定由 sp_ids 表示的密集張量中的每一行至少有一個(gè) id(即,沒有具有空 feature 的行),并且 sp_ids 的所有 indice 都是規(guī)范的 row-major 順序.

該函數(shù)還假設(shè)所有 id 值都在[0,p0]范圍內(nèi),其中 p0 是沿著維度0的參數(shù)大小的總和.

參數(shù):

  • params:表示完整的 embedding 張量的單張量,或除了第一維之外全部具有相同 shape 的 P 張量列表,表示切分的 embedding 張量.或者,一個(gè) PartitionedVariable,通過沿維度0進(jìn)行分區(qū)創(chuàng)建.對于給定的 partition_strategy,每個(gè)元素的大小必須適當(dāng).
  • sp_ids:int64 類型的 id 的 N x M SparseTensor(通常來自FeatureValueToId),其中 N 通常是批次大小,M 是任意的.
  • sp_weights:可以是具有 float/double weight 的 SparseTensor,或者是 None 以表示所有 weight 應(yīng)為1.如果指定,則 sp_weights 必須具有與 sp_ids 完全相同的 shape 和 indice.
  • partition_strategy:指定切分策略的字符串,在 len(params) > 1 的情況下使用.目前支持兩種切分方式:"div"和"mod",默認(rèn)是"mod".查看tf.nn.embedding_lookup獲取更多信息.
  • name:操作的可選名稱.
  • combiner:指定 reduction 操作的字符串.目前支持“mean”,“sqrtn”和“sum”.“sum”計(jì)算每行的 embedding 結(jié)果的加權(quán)和.“mean”是加權(quán)和除以總 weight.“sqrtn”是加權(quán)和除以 weight 平方和的平方根.
  • max_norm:如果提供,則在組合之前將每個(gè) embedding 規(guī)范化為具有等于 max_norm 的 l2 范數(shù). 

返回:

表示稀疏 id 的組合 embedding 的密集張量.對于由 sp_ids 表示的密集張量中的每一行,操作查找該行中所有 id 的 embedding,將它們乘以相應(yīng)的 weight,并按指定的方式組合這些 embedding.

換句話說,如果

shape(combined params) = [p0, p1, ..., pm]

并且:

shape(sp_ids) = shape(sp_weights) = [d0, d1, ..., dn]

然后:

shape(output) = [d0, d1, ..., dn-1, p1, ..., pm]

例如,如果 params 是一個(gè) 10x20 矩陣,則 sp_ids / sp_weights 是

[0, 0]: id 1, weight 2.0 [0, 1]: id 3, weight 0.5 [1, 0]: id 0, weight 1.0 [2, 3]: id 1, weight 3.0

如果 combiner=“mean”,那么輸出將是3x20矩陣,其中:

output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5) output[1, :] = (params[0, :] * 1.0) / 1.0 output[2, :] = (params[1, :] * 3.0) / 3.0

可能引發(fā)的異常:

  • TypeError:如果 sp_ids 不是 SparseTensor,或者 sp_weights 既不是 None 也不是 SparseTensor.
  • ValueError:如果 combiner 不是 {“mean”,“sqrtn”,“sum”} 中的一個(gè).
以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號