TensorFlow函數(shù):tf.unsorted_segment_sum

2020-07-27 17:07 更新

tf.unsorted_segment_sum函數(shù)

tf.unsorted_segment_sum(
    data,
    segment_ids,
    num_segments,
    name=None
)

請參閱指南:數(shù)學(xué)函數(shù)>分段

計算張量片段的和.

計算一個張量,使得 (output[i] = sum_{j...} data[j...] 總和超過元組 j...,例如,segment_ids[j...] == i.與 SegmentSum 不同,segment_ids 不需要排序,不需要覆蓋整個有效值范圍內(nèi)的所有值.

如果給定段 ID i 的和為空,則 output[i] = 0.如果給定的分段 ID i 為負值,則該值將被刪除并且不會被添加到該段的總和中.

num_segments 應(yīng)等于不同的段 ID 的數(shù)量.

TensorFlow函數(shù)

函數(shù)參數(shù):

  • data:一個 Tensor,必須是下列類型之一:float32,float64,int32,uint8,int16,int8,complex64,int64,qint8,quint8,qint32,bfloat16,uint16,complex128,half,uint32,uint64.
  • segment_ids:一個 Tensor,必須是以下類型之一:int32,int64,張量的形狀是一個 data.shape 的前綴.
  • num_segments:一個 Tensor,必須是以下類型之一:int32,int64.
  • name:操作的名稱(可選).

函數(shù)返回值:

tf.unsorted_segment_sum函數(shù)返回一個 Tensor,它與 data 的類型相同.


例子:

a = np.arange(1,10).reshape(3,3)

print(a)

print('----------')

print((sess.run(tf.unsorted_segment_sum(data=a,segment_ids=[0,1,0],num_segments=2))))


輸出:

[[1 2 3]

 [4 5 6]

 [7 8 9]]

----------

[[ 8 10 12]

 [ 4  5  6]



以上內(nèi)容是否對您有幫助:
在線筆記
App下載
App下載

掃描二維碼

下載編程獅App

公眾號
微信公眾號

編程獅公眾號