TensorFlow變量函數(shù):tf.scatter_nd_add

2018-01-05 11:11 更新

tf.scatter_nd_add 函數(shù)

scatter_nd_add(
    ref,
    indices,
    updates,
    use_locking=False,
    name=None
)

請(qǐng)參閱指南:變量>稀疏變量更新

在updates和單個(gè)值或切片之間應(yīng)用稀疏加法,根據(jù)indices在給定的變量?jī)?nèi).

ref是一個(gè)秩為P的Tensor,indices是一個(gè)秩為Q的Tensor.

indices必須是整數(shù)張量,包含索引到ref.它一定有形狀:[d_0, ..., d_{Q-2}, K],并且是:0<K<=P.

indices(具有長(zhǎng)度K)的最內(nèi)部維度對(duì)應(yīng)于沿著ref的K維度的元素(if K = P)或切片(if K < P)的索引.

updates是具有形狀的秩為Q-1+P-K的Tensor:

[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].

例如, 假設(shè)我們要將4分散的元素添加到 rank-1 張量到8元素.在 Python 中, 添加的內(nèi)容如下所示:

ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
add = tf.scatter_nd_add(ref, indices, updates)
with tf.Session() as sess:
  print sess.run(add)

對(duì)ref的結(jié)果更新如下所示:

[1, 13, 3, 14, 14, 6, 7, 20]

請(qǐng)參閱tf.scatter_nd有關(guān)如何更新切片的更多詳細(xì)信息.

函數(shù)參數(shù)

  • ref:一個(gè)可變的Tensor;必須是下列類型之一:float32,float64,int64,int32,uint8,uint16,int16,int8,complex64,complex128,qint8,quint8,qint32,half;一個(gè)可變的張量;應(yīng)該來(lái)自一個(gè)變量節(jié)點(diǎn).
  • indices:一個(gè)Tensor.必須是以下類型之一:int32,int64.索引到ref的一個(gè)張量.
  • updates:一個(gè)Tensor.必須與ref具有相同的類型.添加到ref的更新值的張量.
  • use_locking:可選的bool;如果為True,則賦值將受鎖定的保護(hù);否則行為是不確定的,但可能表現(xiàn)出較少的爭(zhēng)用.
  • name:操作的名稱(可選).

函數(shù)返回值

tf.scatter_nd_add函數(shù)返回一個(gè)可變的Tensor;與ref有相同的類型;與 ref 一樣,返回為希望在更新完成后使用更新的值的操作的方便性.

以上內(nèi)容是否對(duì)您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號(hào)
微信公眾號(hào)

編程獅公眾號(hào)